This commit is contained in:
2021-03-27 16:30:53 +01:00
parent fdbd865889
commit 3615c56152
14 changed files with 280 additions and 110 deletions

View File

@@ -13,19 +13,30 @@ sockets: Dict[str, HTTPClient] = {}
def create(command: str, url: str, port): def create(command: str, url: str, port):
"""
Create a corresponding Command instance of the specified HTTP `command` with the specified `url` and `port`.
@param command: The command type to create
@param url: The url for the command
@param port: The port for the command
"""
uri = parser.get_uri(url)
if command == "GET": if command == "GET":
return GetCommand(url, port) return GetCommand(uri, port)
elif command == "HEAD": elif command == "HEAD":
return HeadCommand(url, port) return HeadCommand(uri, port)
elif command == "POST": elif command == "POST":
return PostCommand(url, port) return PostCommand(uri, port)
elif command == "PUT": elif command == "PUT":
return PutCommand(url, port) return PutCommand(uri, port)
else: else:
raise ValueError() raise ValueError()
class AbstractCommand(ABC): class AbstractCommand(ABC):
"""
A class representing the command for sending an HTTP command.
"""
uri: str uri: str
host: str host: str
path: str path: str
@@ -111,6 +122,9 @@ class AbstractCommand(ABC):
class AbstractWithBodyCommand(AbstractCommand, ABC): class AbstractWithBodyCommand(AbstractCommand, ABC):
"""
The building block for creating an HTTP message for an HTTP command with a body.
"""
def _build_message(self, message: str) -> bytes: def _build_message(self, message: str) -> bytes:
body = input(f"Enter {self.command} data: ").encode(FORMAT) body = input(f"Enter {self.command} data: ").encode(FORMAT)
@@ -127,12 +141,19 @@ class AbstractWithBodyCommand(AbstractCommand, ABC):
class HeadCommand(AbstractCommand): class HeadCommand(AbstractCommand):
"""
A Command for sending a `HEAD` message.
"""
@property @property
def command(self): def command(self):
return "HEAD" return "HEAD"
class GetCommand(AbstractCommand): class GetCommand(AbstractCommand):
"""
A Command for sending a `GET` message.
"""
def __init__(self, uri: str, port, dir=None): def __init__(self, uri: str, port, dir=None):
super().__init__(uri, port) super().__init__(uri, port)
@@ -160,12 +181,20 @@ class GetCommand(AbstractCommand):
class PostCommand(AbstractWithBodyCommand): class PostCommand(AbstractWithBodyCommand):
"""
A command for sending a `POST` command.
"""
@property @property
def command(self): def command(self):
return "POST" return "POST"
class PutCommand(AbstractWithBodyCommand): class PutCommand(AbstractWithBodyCommand):
"""
A command for sending a `PUT` command.
"""
@property @property
def command(self): def command(self):
return "PUT" return "PUT"

View File

@@ -1,6 +1,6 @@
import socket import socket
from httplib.httpsocket import HTTPSocket from httplib.httpsocket import HTTPSocket, InvalidResponse
BUFSIZE = 4096 BUFSIZE = 4096
TIMEOUT = 3 TIMEOUT = 3
@@ -13,3 +13,9 @@ class HTTPClient(HTTPSocket):
def __init__(self, host: str): def __init__(self, host: str):
super().__init__(socket.socket(socket.AF_INET, socket.SOCK_STREAM), host) super().__init__(socket.socket(socket.AF_INET, socket.SOCK_STREAM), host)
def read_line(self):
try:
return super().read_line()
except UnicodeDecodeError:
raise InvalidResponse("Unexpected decoding error")

View File

@@ -14,7 +14,7 @@ from httplib.message import ClientMessage as Message
from httplib.retriever import Retriever from httplib.retriever import Retriever
def handle(client: HTTPClient, msg: Message, command: AbstractCommand, dir=None): def handle(client: HTTPClient, msg: Message, command: AbstractCommand, directory=None):
handler = BasicResponseHandler(client, msg, command) handler = BasicResponseHandler(client, msg, command)
retriever = handler.handle() retriever = handler.handle()
@@ -23,9 +23,9 @@ def handle(client: HTTPClient, msg: Message, command: AbstractCommand, dir=None)
content_type = msg.headers.get("content-type") content_type = msg.headers.get("content-type")
if content_type and "text/html" in content_type: if content_type and "text/html" in content_type:
handler = HTMLDownloadHandler(retriever, client, msg, command, dir) handler = HTMLDownloadHandler(retriever, client, msg, command, directory)
else: else:
handler = RawDownloadHandler(retriever, client, msg, command, dir) handler = RawDownloadHandler(retriever, client, msg, command, directory)
return handler.handle() return handler.handle()
@@ -130,20 +130,20 @@ class BasicResponseHandler(ResponseHandler):
class DownloadHandler(ResponseHandler, ABC): class DownloadHandler(ResponseHandler, ABC):
def __init__(self, retriever: Retriever, client: HTTPClient, msg, cmd, dir=None): def __init__(self, retriever: Retriever, client: HTTPClient, msg, cmd, directory=None):
super().__init__(retriever, client, msg, cmd) super().__init__(retriever, client, msg, cmd)
if not dir: if not directory:
dir = self._create_directory() directory = self._create_directory()
self.path = self._get_duplicate_name(os.path.join(dir, self.get_filename())) self.path = self._get_duplicate_name(os.path.join(directory, self.get_filename()))
@staticmethod @staticmethod
def create(retriever: Retriever, client: HTTPClient, msg, cmd, dir=None): def create(retriever: Retriever, client: HTTPClient, msg, cmd, directory=None):
content_type = msg.headers.get("content-type") content_type = msg.headers.get("content-type")
if content_type and "text/html" in content_type: if content_type and "text/html" in content_type:
return HTMLDownloadHandler(retriever, client, msg, cmd, dir) return HTMLDownloadHandler(retriever, client, msg, cmd, directory)
return RawDownloadHandler(retriever, client, msg, cmd, dir) return RawDownloadHandler(retriever, client, msg, cmd, directory)
def _create_directory(self): def _create_directory(self):
path = self._get_duplicate_name(os.path.abspath(self.client.host)) path = self._get_duplicate_name(os.path.abspath(self.client.host))
@@ -194,14 +194,14 @@ class RawDownloadHandler(DownloadHandler):
class HTMLDownloadHandler(DownloadHandler): class HTMLDownloadHandler(DownloadHandler):
def __init__(self, retriever: Retriever, client: HTTPClient, msg: Message, cmd: AbstractCommand, dir=None): def __init__(self, retriever: Retriever, client: HTTPClient, msg: Message, cmd: AbstractCommand, directory=None):
super().__init__(retriever, client, msg, cmd, dir) super().__init__(retriever, client, msg, cmd, directory)
def handle(self) -> str: def handle(self) -> str:
(dir, file) = os.path.split(self.path) (directory, file) = os.path.split(self.path)
tmp_filename = f".{file}.tmp" tmp_filename = f".{file}.tmp"
tmp_path = os.path.join(dir, tmp_filename) tmp_path = os.path.join(directory, tmp_filename)
file = open(tmp_path, "wb") file = open(tmp_path, "wb")
for buffer in self.retriever.retrieve(): for buffer in self.retriever.retrieve():
@@ -217,11 +217,11 @@ class HTMLDownloadHandler(DownloadHandler):
with open(tmp_filename, "rb") as fp: with open(tmp_filename, "rb") as fp:
soup = BeautifulSoup(fp, 'lxml') soup = BeautifulSoup(fp, 'lxml')
base_url = parser.base_url(self.cmd.uri)
base_element = soup.find("base") base_element = soup.find("base")
base_url = self.cmd.uri
if base_element: if base_element:
base_url = f"http://{self.cmd.host}" + base_element["href"] base_url = parser.urljoin(self.cmd.uri, base_element["href"])
processed = {} processed = {}
tag: Tag tag: Tag
@@ -241,22 +241,18 @@ class HTMLDownloadHandler(DownloadHandler):
logging.error("Failed to download image: %s, skipping...", tag["src"], exc_info=e) logging.error("Failed to download image: %s, skipping...", tag["src"], exc_info=e)
with open(target_filename, 'w') as file: with open(target_filename, 'w') as file:
file.write(str(soup)) file.write(soup.prettify(formatter="minimal"))
def __download_image(self, img_src, base_url): def __download_image(self, img_src, base_url):
"""
Download image from the specified `img_src` and `base_url`.
If the image is available, it will be downloaded to the directory of `self.path`
"""
logging.info("Downloading image: %s", img_src) logging.info("Downloading image: %s", img_src)
parsed = urlsplit(img_src) parsed = urlsplit(img_src)
img_src = parser.urljoin(base_url, img_src)
if parsed.scheme not in ("", "http", "https"):
# Not a valid url
return None
if parsed.hostname is None:
if img_src[0] == "/":
img_src = f"http://{self.cmd.host}{img_src}"
else:
img_src = parser.absolute_url(base_url, img_src)
if parsed.hostname is None or parsed.hostname == self.cmd.host: if parsed.hostname is None or parsed.hostname == self.cmd.host:
port = self.cmd.port port = self.cmd.port

View File

@@ -17,7 +17,7 @@ class InvalidStatusLine(HTTPException):
class UnsupportedEncoding(HTTPException): class UnsupportedEncoding(HTTPException):
""" Reponse Encoding not support """ """ Encoding not supported """
def __init(self, enc_type, encoding): def __init(self, enc_type, encoding):
self.enc_type = enc_type self.enc_type = enc_type
@@ -39,12 +39,28 @@ class HTTPServerException(Exception):
self.body = body self.body = body
class BadRequest(HTTPServerException): class HTTPServerCloseException(HTTPServerException):
""" When thrown, the connection should be closed """
class BadRequest(HTTPServerCloseException):
""" Malformed HTTP request""" """ Malformed HTTP request"""
status_code = 400 status_code = 400
message = "Bad Request" message = "Bad Request"
class Forbidden(HTTPServerException):
""" Request not allowed """
status_code = 403
message = "Forbidden"
class NotFound(HTTPServerException):
""" Resource not found """
status_code = 404
message = "Not Found"
class MethodNotAllowed(HTTPServerException): class MethodNotAllowed(HTTPServerException):
""" Method is not allowed """ """ Method is not allowed """
status_code = 405 status_code = 405
@@ -54,7 +70,7 @@ class MethodNotAllowed(HTTPServerException):
self.allowed_methods = allowed_methods self.allowed_methods = allowed_methods
class InternalServerError(HTTPServerException): class InternalServerError(HTTPServerCloseException):
""" Internal Server Error """ """ Internal Server Error """
status_code = 500 status_code = 500
message = "Internal Server Error" message = "Internal Server Error"
@@ -66,16 +82,10 @@ class NotImplemented(HTTPServerException):
message = "Not Implemented" message = "Not Implemented"
class NotFound(HTTPServerException): class HTTPVersionNotSupported(HTTPServerCloseException):
""" Resource not found """ """ The server does not support the major version HTTP used in the request message """
status_code = 404 status_code = 505
message = "Not Found" message = "HTTP Version Not Supported"
class Forbidden(HTTPServerException):
""" Request not allowed """
status_code = 403
message = "Forbidden"
class Conflict(HTTPServerException): class Conflict(HTTPServerException):
@@ -84,10 +94,10 @@ class Conflict(HTTPServerException):
message = "Conflict" message = "Conflict"
class HTTPVersionNotSupported(HTTPServerException): class NotModified(HTTPServerException):
""" The server does not support the major version HTTP used in the request message """ """ Requested resource was not modified """
status_code = 505 status_code = 304
message = "HTTP Version Not Supported" message = "Not Modified"
class InvalidRequestLine(BadRequest): class InvalidRequestLine(BadRequest):

View File

@@ -26,42 +26,26 @@ class HTTPSocket:
self.file = self.conn.makefile("rb") self.file = self.conn.makefile("rb")
def close(self): def close(self):
"""
Close this socket
"""
self.file.close() self.file.close()
# self.conn.shutdown(socket.SHUT_RDWR)
self.conn.close() self.conn.close()
def is_closed(self): def is_closed(self):
return self.file is None return self.file is None
def reset_request(self): def reset_request(self):
"""
Close the file handle of this socket and create a new one.
"""
self.file.close() self.file.close()
self.file = self.conn.makefile("rb") self.file = self.conn.makefile("rb")
def __do_receive(self):
if self.conn.fileno() == -1:
raise Exception("Connection closed")
result = self.conn.recv(BUFSIZE)
return result
def receive(self):
"""Receive data from the client up to BUFSIZE
"""
count = 0
while True:
count += 1
try:
return self.__do_receive()
except socket.timeout:
logging.debug("Socket receive timed out after %s seconds", TIMEOUT)
if count == 3:
break
logging.debug("Retrying %s", count)
logging.debug("Timed out after waiting %s seconds for response", TIMEOUT * count)
raise TimeoutError("Request timed out")
def read(self, size=BUFSIZE, blocking=True) -> bytes: def read(self, size=BUFSIZE, blocking=True) -> bytes:
"""
Read bytes up to the specified buffer size. This method will block when `blocking` is set to True (Default).
"""
if blocking: if blocking:
buffer = self.file.read(size) buffer = self.file.read(size)
else: else:
@@ -72,14 +56,18 @@ class HTTPSocket:
return buffer return buffer
def read_line(self): def read_line(self):
try: """
line = str(self.read_bytes_line(), FORMAT) Read a line decoded as `httpsocket.FORMAT`.
except UnicodeDecodeError: @return: the decoded line
# Expected UTF-8 @raise: UnicodeDecodeError
raise BadRequest() """
return line return str(self.read_bytes_line(), FORMAT)
def read_bytes_line(self) -> bytes: def read_bytes_line(self) -> bytes:
"""
Read a line as bytes.
"""
line = self.file.readline(MAXLINE + 1) line = self.file.readline(MAXLINE + 1)
if len(line) > MAXLINE: if len(line) > MAXLINE:
raise InvalidResponse("Line too long") raise InvalidResponse("Line too long")

View File

@@ -23,6 +23,7 @@ class ClientMessage(Message):
def __init__(self, version: str, status: int, msg: str, headers: Dict[str, str], raw=None, body: bytes = None): def __init__(self, version: str, status: int, msg: str, headers: Dict[str, str], raw=None, body: bytes = None):
super().__init__(version, headers, raw, body) super().__init__(version, headers, raw, body)
self.status = status self.status = status
self.msg = msg
class ServerMessage(Message): class ServerMessage(Message):

View File

@@ -1,6 +1,7 @@
import logging import logging
import os.path import os.path
import re import re
import urllib
from urllib.parse import urlparse, urlsplit from urllib.parse import urlparse, urlsplit
from httplib.exceptions import InvalidStatusLine, InvalidResponse, BadRequest, InvalidRequestLine from httplib.exceptions import InvalidStatusLine, InvalidResponse, BadRequest, InvalidRequestLine
@@ -255,6 +256,19 @@ def parse_uri(uri: str):
return host, port, path return host, port, path
def get_uri(url: str):
"""
Returns a valid URI of the specified URL.
"""
parsed = urlsplit(url)
result = f"http://{parsed.netloc}{parsed.path}"
if parsed.query != '':
result = f"{result}?{parsed.query}"
return result
def base_url(uri: str): def base_url(uri: str):
parsed = urlsplit(uri) parsed = urlsplit(uri)
path = parsed.path.rsplit("/", 1)[0] path = parsed.path.rsplit("/", 1)[0]
@@ -265,3 +279,7 @@ def absolute_url(uri: str, rel_path: str):
parsed = urlsplit(uri) parsed = urlsplit(uri)
path = os.path.normpath(os.path.join(parsed.path, rel_path)) path = os.path.normpath(os.path.join(parsed.path, rel_path))
return f"{parsed.scheme}://{parsed.hostname}{path}" return f"{parsed.scheme}://{parsed.hostname}{path}"
def urljoin(base, url):
return urllib.parse.urljoin(base, url)

View File

@@ -48,6 +48,7 @@
<div> <div>
<h2>Local image</h2> <h2>Local image</h2>
<img width="200px" src="ulyssis.png"> <img width="200px" src="ulyssis.png">
</div>
</body> </body>
</html> </html>

View File

@@ -4,11 +4,10 @@ import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from datetime import datetime from datetime import datetime
from time import mktime from time import mktime
from typing import Dict
from wsgiref.handlers import format_date_time from wsgiref.handlers import format_date_time
from client.httpclient import FORMAT from client.httpclient import FORMAT
from httplib.exceptions import NotFound, Conflict, Forbidden from httplib.exceptions import NotFound, Forbidden, NotModified
from httplib.message import ServerMessage as Message from httplib.message import ServerMessage as Message
root = os.path.join(os.path.dirname(sys.argv[0]), "public") root = os.path.join(os.path.dirname(sys.argv[0]), "public")
@@ -21,7 +20,6 @@ status_message = {
400: "Bad Request", 400: "Bad Request",
404: "Not Found", 404: "Not Found",
500: "Internal Server Error", 500: "Internal Server Error",
} }
@@ -40,7 +38,6 @@ def create(message: Message):
class AbstractCommand(ABC): class AbstractCommand(ABC):
path: str path: str
headers: Dict[str, str]
msg: Message msg: Message
def __init__(self, message: Message): def __init__(self, message: Message):
@@ -52,7 +49,15 @@ class AbstractCommand(ABC):
def command(self): def command(self):
pass pass
@property
@abstractmethod
def _conditional_headers(self):
pass
def _get_date(self): def _get_date(self):
"""
Returns a string representation of the current date according to RFC 1123
"""
now = datetime.now() now = datetime.now()
stamp = mktime(now.timetuple()) stamp = mktime(now.timetuple())
return format_date_time(stamp) return format_date_time(stamp)
@@ -61,7 +66,12 @@ class AbstractCommand(ABC):
def execute(self): def execute(self):
pass pass
def _build_message(self, status: int, content_type: str, body: bytes): def _build_message(self, status: int, content_type: str, body: bytes, extra_headers=None):
if extra_headers is None:
extra_headers = {}
self._process_conditional_headers()
message = f"HTTP/1.1 {status} {status_message[status]}\r\n" message = f"HTTP/1.1 {status} {status_message[status]}\r\n"
message += self._get_date() + "\r\n" message += self._get_date() + "\r\n"
@@ -72,15 +82,17 @@ class AbstractCommand(ABC):
message += f"Content-Type: {content_type}" message += f"Content-Type: {content_type}"
if content_type.startswith("text"): if content_type.startswith("text"):
message += "; charset=UTF-8" message += "; charset=UTF-8"
message += "\r\n" message += "\r\n"
elif content_length > 0: elif content_length > 0:
message += f"Content-Type: application/octet-stream" message += f"Content-Type: application/octet-stream\r\n"
for header in extra_headers:
message += f"{header}: {extra_headers[header]}\r\n"
message += "\r\n" message += "\r\n"
message = message.encode(FORMAT) message = message.encode(FORMAT)
if content_length > 0: if content_length > 0:
message += body message += body
message += b"\r\n"
return message return message
@@ -97,6 +109,30 @@ class AbstractCommand(ABC):
return path return path
def _process_conditional_headers(self):
for header in self._conditional_headers:
tmp = self.msg.headers.get(header)
if not tmp:
continue
self._conditional_headers[header]()
def _if_modified_since(self):
date_val = self.msg.headers.get("if-modified-since")
if not date_val:
return True
modified = datetime.utcfromtimestamp(os.path.getmtime(self._get_path(False)))
try:
min_date = datetime.strptime(date_val, '%a, %d %b %Y %H:%M:%S GMT')
except ValueError:
return True
if modified <= min_date:
raise NotModified()
return True
class AbstractModifyCommand(AbstractCommand, ABC): class AbstractModifyCommand(AbstractCommand, ABC):
@@ -105,6 +141,10 @@ class AbstractModifyCommand(AbstractCommand, ABC):
def _file_mode(self): def _file_mode(self):
pass pass
@property
def _conditional_headers(self):
return {}
def execute(self): def execute(self):
path = self._get_path(False) path = self._get_path(False)
dir = os.path.dirname(path) dir = os.path.dirname(path)
@@ -114,31 +154,47 @@ class AbstractModifyCommand(AbstractCommand, ABC):
if os.path.exists(dir) and not os.path.isdir(dir): if os.path.exists(dir) and not os.path.isdir(dir):
raise Forbidden("Target directory is an existing file!") raise Forbidden("Target directory is an existing file!")
exists = os.path.exists(path)
try: try:
with open(path, mode=f"{self._file_mode}b") as file: with open(path, mode=f"{self._file_mode}b") as file:
file.write(self.msg.body) file.write(self.msg.body)
except IsADirectoryError: except IsADirectoryError:
raise Forbidden("The target resource is a directory!") raise Forbidden("The target resource is a directory!")
if exists:
status = 204
else:
status = 201
return self._build_message(status, None, )
class HeadCommand(AbstractCommand): class HeadCommand(AbstractCommand):
@property
def command(self):
return "HEAD"
@property
def _conditional_headers(self):
return {'if-modified-since': self._if_modified_since}
def execute(self): def execute(self):
path = self._get_path() path = self._get_path()
mime = mimetypes.guess_type(path)[0] mime = mimetypes.guess_type(path)[0]
return self._build_message(200, mime, b"") return self._build_message(200, mime, b"")
@property
def command(self):
return "HEAD"
class GetCommand(AbstractCommand): class GetCommand(AbstractCommand):
@property @property
def command(self): def command(self):
return "GET" return "GET"
@property
def _conditional_headers(self):
return {'if-modified-since': self._if_modified_since}
def get_mimetype(self, path): def get_mimetype(self, path):
mime = mimetypes.guess_type(path)[0] mime = mimetypes.guess_type(path)[0]

View File

@@ -10,6 +10,9 @@ from server import worker
class HTTPServer: class HTTPServer:
"""
"""
address: str address: str
port: int port: int
workers = [] workers = []
@@ -20,6 +23,13 @@ class HTTPServer:
_stop_event: Event _stop_event: Event
def __init__(self, address: str, port: int, worker_count, logging_level): def __init__(self, address: str, port: int, worker_count, logging_level):
"""
Initialize a HTTP server with the specified address, port, worker_count and logging_level
@param address: the address to listen on for connections
@param port: the port to listen on for connections
@param worker_count:
@param logging_level:
"""
self.address = address self.address = address
self.port = port self.port = port
self.worker_count = worker_count self.worker_count = worker_count
@@ -30,24 +40,39 @@ class HTTPServer:
self._stop_event = mp.Event() self._stop_event = mp.Event()
def start(self): def start(self):
"""
Start the HTTP server.
"""
try: try:
self.__do_start() self.__do_start()
except KeyboardInterrupt: except KeyboardInterrupt:
self.__shutdown() self.__shutdown()
def __do_start(self): def __do_start(self):
"""
Internal method to start the server.
@raise:
"""
# Create socket # Create socket
self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# self.server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.server.bind((self.address, self.port)) self.server.bind((self.address, self.port))
# Create workers processes to handle requests
self.__create_workers() self.__create_workers()
self.__listen() self.__listen()
def __listen(self): def __listen(self):
"""
Start listening for new connections
If a connection is received, it will be dispatched to the worker queue, and picked up by a worker process.
"""
self.server.listen() self.server.listen()
logging.debug("Listening for connections") logging.debug("Listening on %s:%d", self.address, self.port)
while True: while True:
if self._dispatch_queue.qsize() > self.worker_count: if self._dispatch_queue.qsize() > self.worker_count:
@@ -62,6 +87,11 @@ class HTTPServer:
logging.debug("Dispatched connection %s", addr) logging.debug("Dispatched connection %s", addr)
def __shutdown(self): def __shutdown(self):
"""
Cleanly shutdown the server
Notifies the worker processes to shutdown and eventually closes the server socket
"""
# Set stop event # Set stop event
self._stop_event.set() self._stop_event.set()
@@ -85,10 +115,18 @@ class HTTPServer:
self.server.close() self.server.close()
def __create_workers(self): def __create_workers(self):
"""
Create worker processes up to `self.worker_count`.
A worker process is created with start method "spawn", target `worker.worker` and the `self.logging_level`
is passed along with the `self.dispatch_queue` and `self._stop_event`
"""
for i in range(self.worker_count): for i in range(self.worker_count):
logging.debug("Creating worker: %d", i + 1) logging.debug("Creating worker: %d", i + 1)
p = mp.Process(target=worker.worker, p = mp.Process(target=worker.worker,
args=(f"{self.address}:{self.port}", i + 1, self.logging_level, self._dispatch_queue, self._stop_event)) args=(f"{self.address}:{self.port}", i + 1, self.logging_level, self._dispatch_queue,
self._stop_event))
p.start() p.start()
self.workers.append(p) self.workers.append(p)

View File

@@ -1,7 +1,6 @@
import logging import logging
import os import os
import sys import sys
import time
from datetime import datetime from datetime import datetime
from socket import socket from socket import socket
from time import mktime from time import mktime
@@ -16,6 +15,7 @@ from httplib.httpsocket import HTTPSocket, FORMAT
from httplib.message import ServerMessage as Message from httplib.message import ServerMessage as Message
from httplib.retriever import Retriever, PreambleRetriever from httplib.retriever import Retriever, PreambleRetriever
from server import command from server import command
from server.serversocket import ServerSocket
METHODS = ("GET", "HEAD", "PUT", "POST") METHODS = ("GET", "HEAD", "PUT", "POST")
@@ -25,7 +25,7 @@ class RequestHandler:
root = os.path.join(os.path.dirname(sys.argv[0]), "public") root = os.path.join(os.path.dirname(sys.argv[0]), "public")
def __init__(self, conn: socket, host): def __init__(self, conn: socket, host):
self.conn = HTTPSocket(conn, host) self.conn = ServerSocket(conn, host)
def listen(self): def listen(self):
@@ -68,6 +68,7 @@ class RequestHandler:
cmd = command.create(message) cmd = command.create(message)
msg = cmd.execute() msg = cmd.execute()
logging.debug("---response begin---\r\n%s---response end---", msg)
self.conn.conn.sendall(msg) self.conn.conn.sendall(msg)
def _check_request_line(self, method: str, target: Union[ParseResultBytes, ParseResult], version): def _check_request_line(self, method: str, target: Union[ParseResultBytes, ParseResult], version):

18
server/serversocket.py Normal file
View File

@@ -0,0 +1,18 @@
import socket
from httplib.exceptions import BadRequest
from httplib.httpsocket import HTTPSocket
BUFSIZE = 4096
TIMEOUT = 3
FORMAT = "UTF-8"
MAXLINE = 4096
class ServerSocket(HTTPSocket):
def read_line(self):
try:
return super().read_line()
except UnicodeDecodeError:
raise BadRequest()

View File

@@ -4,7 +4,7 @@ import socket
import threading import threading
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from httplib.exceptions import HTTPServerException, InternalServerError from httplib.exceptions import HTTPServerException, InternalServerError, HTTPServerCloseException
from server.requesthandler import RequestHandler from server.requesthandler import RequestHandler
THREAD_LIMIT = 128 THREAD_LIMIT = 128
@@ -61,17 +61,25 @@ class Worker:
self.shutdown() self.shutdown()
def _handle_client(self, conn: socket.socket, addr): def _handle_client(self, conn: socket.socket, addr):
try:
handler = RequestHandler(conn, self.host) while True:
handler.listen() try:
except HTTPServerException as e: handler = RequestHandler(conn, self.host)
logging.debug("HTTP Exception:", exc_info=e) handler.listen()
RequestHandler.send_error(conn, e.status_code, e.message) except HTTPServerCloseException as e:
except socket.timeout: logging.debug("HTTP Exception:", exc_info=e)
logging.debug("Socket for client %s timed out", addr) RequestHandler.send_error(conn, e.status_code, e.message)
except Exception as e: break
logging.debug("Internal error", exc_info=e) except HTTPServerException as e:
RequestHandler.send_error(conn, InternalServerError.status_code, InternalServerError.message) logging.debug("HTTP Exception:", exc_info=e)
RequestHandler.send_error(conn, e.status_code, e.message)
except socket.timeout:
logging.debug("Socket for client %s timed out", addr)
break
except Exception as e:
logging.debug("Internal error", exc_info=e)
RequestHandler.send_error(conn, InternalServerError.status_code, InternalServerError.message)
break
conn.shutdown(socket.SHUT_RDWR) conn.shutdown(socket.SHUT_RDWR)
conn.close() conn.close()