Compare commits

...

12 Commits

17 changed files with 527 additions and 215 deletions

View File

@@ -4,6 +4,7 @@ import logging
import sys import sys
from client import command as cmd from client import command as cmd
from httplib.exceptions import UnhandledHTTPCode
def main(): def main():
@@ -15,7 +16,8 @@ def main():
arguments = parser.parse_args() arguments = parser.parse_args()
logging.basicConfig(level=logging.ERROR - (10 * arguments.verbose), format="[%(levelname)s] %(message)s") # Setup logging
logging.basicConfig(level=logging.INFO - (10 * arguments.verbose), format="[%(levelname)s] %(message)s")
logging.debug("Arguments: %s", arguments) logging.debug("Arguments: %s", arguments)
command = cmd.create(arguments.command, arguments.URI, arguments.port) command = cmd.create(arguments.command, arguments.URI, arguments.port)
@@ -24,7 +26,10 @@ def main():
try: try:
main() main()
except UnhandledHTTPCode as e:
logging.info(f"[{e.status_code}] {e.cause}:\r\n{e.headers}")
sys.exit(2)
except Exception as e: except Exception as e:
print("[ABRT] Internal error: " + str(e), file=sys.stderr) logging.info("[ABRT] Internal error: %s", e)
logging.debug("Internal error", exc_info=e) logging.debug("Internal error", exc_info=e)
sys.exit(70) sys.exit(1)

View File

@@ -1,11 +1,10 @@
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, Tuple from typing import Dict
from urllib.parse import urlparse
from client.httpclient import HTTPClient from client.httpclient import HTTPClient
from httplib import parser from httplib import parser
from httplib.exceptions import InvalidResponse, InvalidStatusLine, UnsupportedEncoding from httplib.exceptions import InvalidResponse, InvalidStatusLine, UnsupportedEncoding, UnsupportedProtocol
from httplib.httpsocket import FORMAT from httplib.httpsocket import FORMAT
from httplib.message import ResponseMessage as Message from httplib.message import ResponseMessage as Message
from httplib.retriever import PreambleRetriever from httplib.retriever import PreambleRetriever
@@ -21,7 +20,7 @@ def create(method: str, url: str, port):
@param port: The port for the command @param port: The port for the command
""" """
uri = parser.get_uri(url) uri = parser.uri_from_url(url)
if method == "GET": if method == "GET":
return GetCommand(uri, port) return GetCommand(uri, port)
elif method == "HEAD": elif method == "HEAD":
@@ -42,12 +41,10 @@ class AbstractCommand(ABC):
_host: str _host: str
_path: str _path: str
_port: int _port: int
sub_request: bool
def __init__(self, uri: str, port): def __init__(self, uri: str, port):
self.uri = uri self.uri = uri
self._port = int(port) self._port = int(port)
self.sub_request = False
@property @property
def uri(self): def uri(self):
@@ -81,23 +78,24 @@ class AbstractCommand(ABC):
@param sub_request: If this execution is in function of a prior command. @param sub_request: If this execution is in function of a prior command.
""" """
self.uri = ""
self.sub_request = sub_request
(host, path) = self.parse_uri()
client = sockets.get(host) client = sockets.get(self.host)
if client and client.is_closed(): if client and client.is_closed():
sockets.pop(self.host) sockets.pop(self.host)
client = None client = None
if not client: if not client:
client = HTTPClient(host) logging.info("Connecting to %s", self.host)
client.conn.connect((host, self.port)) client = HTTPClient(self.host)
sockets[host] = client client.conn.connect((self.host, self.port))
logging.info("Connected.")
sockets[self.host] = client
else:
logging.info("Reusing socket for %s", self.host)
message = f"{self.method} {path} HTTP/1.1\r\n" message = f"{self.method} {self.path} HTTP/1.1\r\n"
message += f"Host: {host}:{self.port}\r\n" message += f"Host: {self.host}:{self.port}\r\n"
message += "Accept: */*\r\n" message += "Accept: */*\r\n"
message += "Accept-Encoding: identity\r\n" message += "Accept-Encoding: identity\r\n"
encoded_msg = self._build_message(message) encoded_msg = self._build_message(message)
@@ -111,73 +109,51 @@ class AbstractCommand(ABC):
try: try:
self._await_response(client) self._await_response(client)
except InvalidResponse as e: except InvalidResponse as e:
logging.debug("Internal error: Response could not be parsed", exc_info=e) logging.error("Response could not be parsed")
return logging.debug("", exc_info=e)
except InvalidStatusLine as e: except InvalidStatusLine as e:
logging.debug("Internal error: Invalid status-line in response", exc_info=e) logging.error("Invalid status-line in response")
return logging.debug("", exc_info=e)
except UnsupportedEncoding as e: except UnsupportedEncoding as e:
logging.debug("Internal error: Unsupported encoding in response", exc_info=e) logging.error("Unsupported encoding in response")
logging.debug("", exc_info=e)
except UnsupportedProtocol as e:
logging.error("Unsupported protocol: %s", e.protocol)
logging.debug("", exc_info=e)
finally: finally:
if not sub_request: if not sub_request:
client.close() client.close()
def _get_preamble(self, client):
"""
Returns the preamble (start-line and headers) of the response of this command.
@param client: the client object to retrieve from
@return: A Message object containing the HTTP-version, status code, status message, headers and buffer
"""
retriever = PreambleRetriever(client)
lines = retriever.retrieve()
(version, status, msg) = parser.parse_status_line(next(lines))
headers = parser.parse_headers(lines)
buffer = retriever.buffer
logging.debug("---response begin---\r\n%s---response end---", "".join(buffer))
return Message(version, status, msg, headers, buffer)
def _await_response(self, client): def _await_response(self, client):
""" """
Simple response method. Simple response method.
Receives the response and prints to stdout. Receives the response and prints to stdout.
""" """
while True:
line = client.read_line() msg = self._get_preamble(client)
print(line, end="")
if line in ("\r\n", "\n", ""): print("".join(msg.raw))
break
def _build_message(self, message: str) -> bytes: def _build_message(self, message: str) -> bytes:
return (message + "\r\n").encode(FORMAT) return (message + "\r\n").encode(FORMAT)
def parse_uri(self):
"""
Parses the URI and returns the hostname and path.
@return: A tuple of the hostname and path.
"""
parsed = urlparse(self.uri)
# If there is no netloc, the url is invalid, so prepend `//` and try again
if parsed.netloc == "":
parsed = urlparse("http://" + self.uri)
host = parsed.netloc
path = parsed.path
if len(path) == 0 or path[0] != '/':
path = "/" + path
port_pos = host.find(":")
if port_pos >= 0:
host = host[:port_pos]
return host, path
class AbstractWithBodyCommand(AbstractCommand, ABC):
"""
The building block for creating an HTTP message for an HTTP method with a body (POST and PUT).
"""
def _build_message(self, message: str) -> bytes:
body = input(f"Enter {self.method} data: ").encode(FORMAT)
print()
message += "Content-Type: text/plain\r\n"
message += f"Content-Length: {len(body)}\r\n"
message += "\r\n"
message = message.encode(FORMAT)
message += body
message += b"\r\n"
return message
class HeadCommand(AbstractCommand): class HeadCommand(AbstractCommand):
""" """
@@ -204,30 +180,35 @@ class GetCommand(AbstractCommand):
def method(self): def method(self):
return "GET" return "GET"
def _get_preamble(self, client):
"""
Returns the preamble (start-line and headers) of the response of this command.
@param client: the client object to retrieve from
@return: A Message object containing the HTTP-version, status code, status message, headers and buffer
"""
retriever = PreambleRetriever(client)
lines = retriever.retrieve()
(version, status, msg) = parser.parse_status_line(next(lines))
headers = parser.parse_headers(lines)
buffer = retriever.buffer
logging.debug("---response begin---\r\n%s---response end---", "".join(buffer))
return Message(version, status, msg, headers, buffer)
def _await_response(self, client): def _await_response(self, client):
""" """
Handles the response of this command. Handles the response of this command.
""" """
msg = self._get_preamble(client) msg = self._get_preamble(client)
from client import response_handler from client import responsehandler
self.filename = response_handler.handle(client, msg, self, self.dir) self.filename = responsehandler.handle(client, msg, self, self.dir)
class AbstractWithBodyCommand(AbstractCommand, ABC):
"""
The building block for creating an HTTP message for an HTTP method with a body (POST and PUT).
"""
def _build_message(self, message: str) -> bytes:
input_line = input(f"Enter {self.method} data: ")
input_line += "\r\n"
body = input_line.encode(FORMAT)
print()
message += "Content-Type: text/plain\r\n"
message += f"Content-Length: {len(body)}\r\n"
message += "\r\n"
message = message.encode(FORMAT)
message += body
message += b"\r\n"
return message
class PostCommand(AbstractWithBodyCommand): class PostCommand(AbstractWithBodyCommand):

View File

@@ -1,6 +0,0 @@
from bs4 import BeautifulSoup
class HTMLParser:
def __init__(self, soup: BeautifulSoup):
pass

View File

@@ -4,12 +4,23 @@ from httplib.httpsocket import HTTPSocket, InvalidResponse
class HTTPClient(HTTPSocket): class HTTPClient(HTTPSocket):
"""
Wrapper class for a socket. Represents a client which connects to a server.
"""
host: str host: str
def __init__(self, host: str): def __init__(self, host: str):
super().__init__(socket.socket(socket.AF_INET, socket.SOCK_STREAM), host) super().__init__(socket.socket(socket.AF_INET, socket.SOCK_STREAM))
self.host = host
def read_line(self): def read_line(self):
"""
Reads the next line decoded as `httpsocket.FORMAT`
@return: the decoded next line retrieved from the socket
@raise InvalidResponse: If the next line couldn't be decoded, but was expected to
"""
try: try:
return super().read_line() return super().read_line()
except UnicodeDecodeError: except UnicodeDecodeError:

View File

@@ -7,7 +7,7 @@ from urllib.parse import urlsplit, unquote
from client.command import AbstractCommand, GetCommand from client.command import AbstractCommand, GetCommand
from client.httpclient import HTTPClient from client.httpclient import HTTPClient
from httplib import parser from httplib import parser
from httplib.exceptions import InvalidResponse from httplib.exceptions import InvalidResponse, UnhandledHTTPCode, UnsupportedProtocol
from httplib.httpsocket import FORMAT from httplib.httpsocket import FORMAT
from httplib.message import ResponseMessage as Message from httplib.message import ResponseMessage as Message
from httplib.retriever import Retriever from httplib.retriever import Retriever
@@ -17,6 +17,14 @@ IMG_REGEX = re.compile(r"<\s*img[^>]*\ssrc\s*=\s*['\"]([^\"']+)['\"][^>]*>", re.
def handle(client: HTTPClient, msg: Message, command: AbstractCommand, directory=None): def handle(client: HTTPClient, msg: Message, command: AbstractCommand, directory=None):
"""
Handle the response of the request message
@param client: the client which sent the request.
@param msg: the response message
@param command: the command of the sent request-message
@param directory: the directory to download the response to (if available)
"""
handler = BasicResponseHandler(client, msg, command) handler = BasicResponseHandler(client, msg, command)
retriever = handler.handle() retriever = handler.handle()
@@ -33,6 +41,9 @@ def handle(client: HTTPClient, msg: Message, command: AbstractCommand, directory
class ResponseHandler(ABC): class ResponseHandler(ABC):
"""
Helper class for handling response messages.
"""
client: HTTPClient client: HTTPClient
retriever: Retriever retriever: Retriever
msg: Message msg: Message
@@ -46,12 +57,15 @@ class ResponseHandler(ABC):
@abstractmethod @abstractmethod
def handle(self): def handle(self):
"""
Handle the response.
"""
pass pass
class BasicResponseHandler(ResponseHandler): class BasicResponseHandler(ResponseHandler):
""" """
Response handler which throws away the body and only shows the headers. Response handler which will handle redirects and other HTTP status codes.
In case of a redirect, it will process it and pass it to the appropriate response handler. In case of a redirect, it will process it and pass it to the appropriate response handler.
""" """
@@ -67,7 +81,7 @@ class BasicResponseHandler(ResponseHandler):
for line in self.retriever.retrieve(): for line in self.retriever.retrieve():
try: try:
logging.debug("%s", line.decode(FORMAT)) logging.debug("%s", line.decode(FORMAT))
except Exception: except UnicodeDecodeError:
logging.debug("%r", line) logging.debug("%r", line)
logging.debug("] done.") logging.debug("] done.")
@@ -77,8 +91,7 @@ class BasicResponseHandler(ResponseHandler):
if self.msg.status == 101: if self.msg.status == 101:
# Switching protocols is not supported # Switching protocols is not supported
print("".join(self.msg.raw), end="") raise UnhandledHTTPCode(self.msg.status, "".join(self.msg.raw), "Switching protocols is not supported")
return None
if 200 <= self.msg.status < 300: if 200 <= self.msg.status < 300:
return self.retriever return self.retriever
@@ -91,16 +104,13 @@ class BasicResponseHandler(ResponseHandler):
if 400 <= self.msg.status < 600: if 400 <= self.msg.status < 600:
self._skip_body() self._skip_body()
# Dump headers and exit with error # Dump headers and exit with error
if not self.cmd.sub_request: raise UnhandledHTTPCode(self.msg.status, "".join(self.msg.raw), self.msg.msg)
print("".join(self.msg.raw), end="")
return None
return None return None
def _handle_redirect(self): def _handle_redirect(self):
if self.msg.status == 304: if self.msg.status == 304:
print("".join(self.msg.raw), end="") raise UnhandledHTTPCode(self.msg.status, "".join(self.msg.raw), self.msg.msg)
return None
location = self.msg.headers.get("location") location = self.msg.headers.get("location")
if not location or len(location.strip()) == 0: if not location or len(location.strip()) == 0:
@@ -112,7 +122,7 @@ class BasicResponseHandler(ResponseHandler):
raise InvalidResponse("Invalid location") raise InvalidResponse("Invalid location")
if not parsed_location.scheme == "http": if not parsed_location.scheme == "http":
raise InvalidResponse("Only http is supported") raise UnsupportedProtocol(parsed_location.scheme)
self.cmd.uri = location self.cmd.uri = location
@@ -181,7 +191,7 @@ class RawDownloadHandler(DownloadHandler):
super().__init__(retriever, client, msg, cmd, directory) super().__init__(retriever, client, msg, cmd, directory)
def handle(self) -> str: def handle(self) -> str:
logging.debug("Retrieving payload") logging.info("Saving to '%s'", parser.get_relative_save_path(self.path))
file = open(self.path, "wb") file = open(self.path, "wb")
for buffer in self.retriever.retrieve(): for buffer in self.retriever.retrieve():
@@ -211,20 +221,20 @@ class HTMLDownloadHandler(DownloadHandler):
os.remove(tmp_path) os.remove(tmp_path)
return self.path return self.path
def _download_images(self, tmp_filename, target_filename, charset=FORMAT): def _download_images(self, tmp_path, target_path, charset=FORMAT):
""" """
Downloads images referenced in the html of `tmp_filename` and replaces the references in the html Download images referenced in the html of `tmp_filename` and replaces the references in the html
and writes it to `target_filename`. and writes it to `target_filename`.
@param tmp_filename: the path to the temporary html file @param tmp_path: the path to the temporary html file
@param target_filename: the path for the final html fil @param target_path: the path for the final html file
@param charset: the charset to decode `tmp_filename` @param charset: the charset to decode `tmp_filename`
""" """
try: try:
fp = open(tmp_filename, "r", encoding=charset) fp = open(tmp_path, "r", encoding=charset)
html = fp.read() html = fp.read()
except UnicodeDecodeError: except UnicodeDecodeError or LookupError:
fp = open(tmp_filename, "r", encoding=FORMAT, errors="replace") fp = open(tmp_path, "r", encoding=FORMAT, errors="replace")
html = fp.read() html = fp.read()
fp.close() fp.close()
@@ -237,6 +247,7 @@ class HTMLDownloadHandler(DownloadHandler):
processed = {} processed = {}
to_replace = [] to_replace = []
# Find all <img> tags, and the urls from the corresponding `src` fields
for m in IMG_REGEX.finditer(html): for m in IMG_REGEX.finditer(html):
url_start = m.start(1) url_start = m.start(1)
url_end = m.end(1) url_end = m.end(1)
@@ -245,14 +256,12 @@ class HTMLDownloadHandler(DownloadHandler):
try: try:
if len(target) == 0: if len(target) == 0:
continue continue
if target in processed: if target in processed:
# url is already processed
new_url = processed.get(target) new_url = processed.get(target)
else: else:
new_url = self.__download_image(target, base_url) new_url = self.__download_image(target, base_url)
if not new_url:
# Image failed to download
continue
processed[target] = new_url processed[target] = new_url
if new_url: if new_url:
@@ -260,13 +269,18 @@ class HTMLDownloadHandler(DownloadHandler):
to_replace.append((url_start, url_end, local_path)) to_replace.append((url_start, url_end, local_path))
except Exception as e: except Exception as e:
logging.error("Failed to download image: %s, skipping...", target, exc_info=e) logging.error("Failed to download image: %s, skipping...", target)
logging.debug("", exc_info=e)
processed[target] = None
# reverse the list so urls at the bottom of the html file are processed first.
# Otherwise, our start and end positions won't be correct.
to_replace.reverse() to_replace.reverse()
for (start, end, path) in to_replace: for (start, end, path) in to_replace:
html = html[:start] + path + html[end:] html = html[:start] + path + html[end:]
with open(target_filename, 'w', encoding=FORMAT) as file: logging.info("Saving HTML to '%s'", parser.get_relative_save_path(target_path))
with open(target_path, 'w', encoding=FORMAT) as file:
file.write(html) file.write(html)
def __download_image(self, img_src, base_url): def __download_image(self, img_src, base_url):
@@ -280,6 +294,7 @@ class HTMLDownloadHandler(DownloadHandler):
parsed = urlsplit(img_src) parsed = urlsplit(img_src)
img_src = parser.urljoin(base_url, img_src) img_src = parser.urljoin(base_url, img_src)
# Check if the port of the image sh
if parsed.hostname is None or parsed.hostname == self.cmd.host: if parsed.hostname is None or parsed.hostname == self.cmd.host:
port = self.cmd.port port = self.cmd.port
elif ":" in parsed.netloc: elif ":" in parsed.netloc:

View File

@@ -1,109 +1,163 @@
class HTTPException(Exception): class HTTPException(Exception):
""" Base class for HTTP exceptions """ """
Base class for HTTP exceptions
"""
class UnhandledHTTPCode(Exception):
"""
Exception thrown if HTTP codes are not further processed.
"""
status_code: str
headers: str
cause: str
def __init__(self, status, headers, cause):
self.status_code = status
self.headers = headers
self.cause = cause
class InvalidResponse(HTTPException): class InvalidResponse(HTTPException):
""" Response message cannot be parsed """ """
Response message cannot be parsed
"""
def __init(self, message): def __init__(self, message):
self.message = message self.message = message
class InvalidStatusLine(HTTPException): class InvalidStatusLine(HTTPException):
""" Response status line is invalid """ """
Response status line is invalid
"""
def __init(self, line): def __init__(self, line):
self.line = line self.line = line
class UnsupportedEncoding(HTTPException): class UnsupportedEncoding(HTTPException):
""" Encoding not supported """ """
Encoding not supported
"""
def __init(self, enc_type, encoding): def __init__(self, enc_type, encoding):
self.enc_type = enc_type self.enc_type = enc_type
self.encoding = encoding self.encoding = encoding
class UnsupportedProtocol(HTTPException):
"""
Protocol is not supported
"""
def __init__(self, protocol):
self.protocol = protocol
class IncompleteResponse(HTTPException): class IncompleteResponse(HTTPException):
def __init(self, cause): def __init__(self, cause):
self.cause = cause self.cause = cause
class HTTPServerException(Exception): class HTTPServerException(HTTPException):
""" Base class for HTTP Server exceptions """ """
Base class for HTTP Server exceptions
"""
status_code: str status_code: str
message: str message: str
body: str
arg: str arg: str
def __init__(self, arg, body=""): def __init__(self, arg):
self.arg = arg self.arg = arg
self.body = body
class HTTPServerCloseException(HTTPServerException): class HTTPServerCloseException(HTTPServerException):
""" When thrown, the connection should be closed """ """
When raised, the connection should be closed
"""
class BadRequest(HTTPServerCloseException): class BadRequest(HTTPServerCloseException):
""" Malformed HTTP request""" """
Malformed HTTP request
"""
status_code = 400 status_code = 400
message = "Bad Request" message = "Bad Request"
class Forbidden(HTTPServerException): class Forbidden(HTTPServerException):
""" Request not allowed """ """
Request not allowed
"""
status_code = 403 status_code = 403
message = "Forbidden" message = "Forbidden"
class NotFound(HTTPServerException): class NotFound(HTTPServerException):
""" Resource not found """ """
Resource not found
"""
status_code = 404 status_code = 404
message = "Not Found" message = "Not Found"
class MethodNotAllowed(HTTPServerException): class MethodNotAllowed(HTTPServerException):
""" Method is not allowed """ """
Method is not allowed
"""
status_code = 405 status_code = 405
message = "Method Not Allowed" message = "Method Not Allowed"
def __init(self, allowed_methods): def __init__(self, allowed_methods):
self.allowed_methods = allowed_methods self.allowed_methods = allowed_methods
class InternalServerError(HTTPServerCloseException): class InternalServerError(HTTPServerCloseException):
""" Internal Server Error """ """
Internal Server Error
"""
status_code = 500 status_code = 500
message = "Internal Server Error" message = "Internal Server Error"
class NotImplemented(HTTPServerException): class NotImplemented(HTTPServerException):
""" Functionality not implemented """ """
Functionality not implemented
"""
status_code = 501 status_code = 501
message = "Not Implemented" message = "Not Implemented"
class HTTPVersionNotSupported(HTTPServerCloseException): class HTTPVersionNotSupported(HTTPServerCloseException):
""" The server does not support the major version HTTP used in the request message """ """
The server does not support the major version HTTP used in the request message
"""
status_code = 505 status_code = 505
message = "HTTP Version Not Supported" message = "HTTP Version Not Supported"
class Conflict(HTTPServerException): class Conflict(HTTPServerException):
""" Conflict in the current state of the target resource """ """
Conflict in the current state of the target resource
"""
status_code = 409 status_code = 409
message = "Conflict" message = "Conflict"
class NotModified(HTTPServerException): class NotModified(HTTPServerException):
""" Requested resource was not modified """ """
Requested resource was not modified
"""
status_code = 304 status_code = 304
message = "Not Modified" message = "Not Modified"
class InvalidRequestLine(BadRequest): class InvalidRequestLine(BadRequest):
""" Request start-line is invalid """ """
Request start-line is invalid
"""
def __init__(self, line): def __init__(self, line, arg):
super().__init__(arg)
self.request_line = line self.request_line = line

View File

@@ -1,10 +1,7 @@
import logging
import socket import socket
from io import BufferedReader from io import BufferedReader
from typing import Tuple from typing import Tuple
from httplib.exceptions import BadRequest
BUFSIZE = 4096 BUFSIZE = 4096
TIMEOUT = 3 TIMEOUT = 3
FORMAT = "UTF-8" FORMAT = "UTF-8"
@@ -12,13 +9,20 @@ MAXLINE = 4096
class HTTPSocket: class HTTPSocket:
host: str """
Wrapper class for a socket. Represents an HTTP connection.
This class adds helper methods to read the underlying socket as a file.
"""
conn: socket.socket conn: socket.socket
file: Tuple[BufferedReader, None] file: BufferedReader
def __init__(self, conn: socket.socket, host: str): def __init__(self, conn: socket.socket):
"""
Initialize an HTTPSocket with the given socket and host.
@param conn: the socket object
"""
self.host = host
self.conn = conn self.conn = conn
self.conn.settimeout(TIMEOUT) self.conn.settimeout(TIMEOUT)
self.conn.setblocking(True) self.conn.setblocking(True)
@@ -78,11 +82,15 @@ class HTTPSocket:
class HTTPException(Exception): class HTTPException(Exception):
""" Base class for HTTP exceptions """ """
Base class for HTTP exceptions
"""
class InvalidResponse(HTTPException): class InvalidResponse(HTTPException):
""" Response message cannot be parsed """ """
Response message cannot be parsed
"""
def __init(self, message): def __init(self, message):
self.message = message self.message = message

View File

@@ -6,7 +6,7 @@ from urllib.parse import SplitResult
class Message(ABC): class Message(ABC):
version: str version: str
headers: Dict[str, str] headers: Dict[str, str]
raw: str raw: [str]
body: bytes body: bytes
def __init__(self, version: str, headers: Dict[str, str], raw=None, body: bytes = None): def __init__(self, version: str, headers: Dict[str, str], raw=None, body: bytes = None):

View File

@@ -1,8 +1,13 @@
import logging import logging
import os
import pathlib
import re import re
import urllib import urllib
from datetime import datetime
from time import mktime
from typing import Dict from typing import Dict
from urllib.parse import urlparse, urlsplit from urllib.parse import urlparse, urlsplit
from wsgiref.handlers import format_date_time
from httplib.exceptions import InvalidStatusLine, InvalidResponse, BadRequest, InvalidRequestLine from httplib.exceptions import InvalidStatusLine, InvalidResponse, BadRequest, InvalidRequestLine
from httplib.httpsocket import FORMAT from httplib.httpsocket import FORMAT
@@ -56,7 +61,7 @@ def parse_status_line(line: str):
def parse_request_line(line: str): def parse_request_line(line: str):
""" """
Parses the specified line as and HTTP request-line. Parses the specified line as an HTTP request-line.
Returns the method, target as ParseResult and HTTP version from the request-line. Returns the method, target as ParseResult and HTTP version from the request-line.
@param line: the request-line to be parsed @param line: the request-line to be parsed
@@ -67,7 +72,7 @@ def parse_request_line(line: str):
split = list(filter(None, line.rstrip().split(" ", 2))) split = list(filter(None, line.rstrip().split(" ", 2)))
if len(split) < 3: if len(split) < 3:
raise InvalidRequestLine(line) raise InvalidRequestLine(line, "missing argument in request-line")
method, target, version = split method, target, version = split
if method not in ("CONNECT", "DELETE", "GET", "HEAD", "OPTIONS", "POST", "PUT", "TRACE"): if method not in ("CONNECT", "DELETE", "GET", "HEAD", "OPTIONS", "POST", "PUT", "TRACE"):
@@ -85,6 +90,12 @@ def parse_request_line(line: str):
def parse_headers(lines): def parse_headers(lines):
"""
Parses the lines from the `lines` iterator as headers.
@param lines: iterator to retrieve the lines from.
@return: A dictionary with header as key and value as value.
"""
headers = [] headers = []
try: try:
@@ -97,7 +108,7 @@ def parse_headers(lines):
break break
while True: while True:
if line in ("\r\n", "\n", ""): if line in ("\r\n", "\r", "\n", ""):
break break
if line[0].isspace(): if line[0].isspace():
@@ -127,17 +138,21 @@ def parse_headers(lines):
def check_next_header(headers, next_header: str, next_value: str): def check_next_header(headers, next_header: str, next_value: str):
if next_header == "content-length": if next_header == "content-length":
if "content-length" in headers: if "content-length" in headers:
logging.error("Multiple content-length headers specified") raise InvalidResponse("Multiple content-length headers specified")
raise InvalidResponse()
if not next_value.isnumeric() or int(next_value) <= 0: if not next_value.isnumeric() or int(next_value) <= 0:
logging.error("Invalid content-length value: %r", next_value) raise InvalidResponse(f"Invalid content-length value: {next_value}")
raise InvalidResponse()
def parse_uri(uri: str): def parse_uri(uri: str):
"""
Parse the specified URI into the host, port and path.
If the URI is invalid, this method will try to create one.
@param uri: the URI to be parsed
@return: A tuple with the host, port and path
"""
parsed = urlsplit(uri) parsed = urlsplit(uri)
# If there is no netloc, the given string is not a valid URI, so split on / # If there is no hostname, the given string is not a valid URI, so split on /
if parsed.hostname: if parsed.hostname:
host = parsed.hostname host = parsed.hostname
path = parsed.path path = parsed.path
@@ -159,13 +174,21 @@ def parse_uri(uri: str):
return host, port, path return host, port, path
def get_uri(url: str): def uri_from_url(url: str):
""" """
Returns a valid URI of the specified URL. Returns a valid URI of the specified URL.
""" """
parsed = urlsplit(url) parsed = urlsplit(url)
result = f"http://{parsed.netloc}{parsed.path}" if parsed.hostname is None:
url = f"http://{url}"
parsed = urlsplit(url)
path = parsed.path
if path == "":
path = "/"
result = f"http://{parsed.netloc}{path}"
if parsed.query != "": if parsed.query != "":
result = f"{result}?{parsed.query}" result = f"{result}?{parsed.query}"
@@ -174,12 +197,18 @@ def get_uri(url: str):
def urljoin(base, url): def urljoin(base, url):
""" """
Join a base url and a URL to form an absolute url. Join a base url, and a URL to form an absolute url.
""" """
return urllib.parse.urljoin(base, url) return urllib.parse.urljoin(base, url)
def get_charset(headers: Dict[str, str]): def get_charset(headers: Dict[str, str]):
"""
Returns the charset of the content from the headers if found. Otherwise, returns `FORMAT`
@param headers: the headers to retrieve the charset from
@return: A charset
"""
if "content-type" in headers: if "content-type" in headers:
content_type = headers["content-type"] content_type = headers["content-type"]
match = re.search(r"charset\s*=\s*([a-z\-0-9]*)", content_type, re.I) match = re.search(r"charset\s*=\s*([a-z\-0-9]*)", content_type, re.I)
@@ -187,3 +216,26 @@ def get_charset(headers: Dict[str, str]):
return match.group(1) return match.group(1)
return FORMAT return FORMAT
def get_relative_save_path(path: str):
"""
Returns the specified path relative to the working directory.
@param path: the path to compute
@return: the relative path
"""
path_obj = pathlib.PurePath(path)
root = pathlib.PurePath(os.getcwd())
rel = path_obj.relative_to(root)
return str(rel)
def get_date():
"""
Returns a string representation of the current date according to RFC 1123.
"""
now = datetime.now()
stamp = mktime(now.timetuple())
return format_date_time(stamp)

View File

@@ -62,11 +62,18 @@ class PreambleRetriever(Retriever):
""" """
Retriever instance for retrieving the start-line and headers of an HTTP message. Retriever instance for retrieving the start-line and headers of an HTTP message.
""" """
client: HTTPSocket client: HTTPSocket
_buffer: [] _buffer: []
@property @property
def buffer(self): def buffer(self):
"""
Returns a copy of the internal buffer.
Clears the internal buffer afterwards.
@return: A list of the buffered lines.
"""
tmp_buffer = self._buffer tmp_buffer = self._buffer
self._buffer = [] self._buffer = []
@@ -87,7 +94,7 @@ class PreambleRetriever(Retriever):
while True: while True:
self._buffer.append(line) self._buffer.append(line)
if line in ("\r\n", "\n", ""): if line in ("\r\n", "\r", "\n", ""):
return line return line
yield line yield line
@@ -140,8 +147,8 @@ class ContentLengthRetriever(Retriever):
class RawRetriever(Retriever): class RawRetriever(Retriever):
""" """
Retriever instance for retrieve a message body without any length specifier or encoding. Retriever instance for retrieving a message body without any length specifier or encoding.
This retriever will keep waiting until a timeout occurs or the connection is disconnected. This retriever will keep waiting until a timeout occurs, or the connection is disconnected.
""" """
def retrieve(self): def retrieve(self):
@@ -161,6 +168,7 @@ class ChunkedRetriever(Retriever):
""" """
Returns an iterator of the received message bytes. Returns an iterator of the received message bytes.
The size of each iteration is not necessarily constant. The size of each iteration is not necessarily constant.
@raise IncompleteResponse: if the connection is closed or timed out before receiving the complete payload. @raise IncompleteResponse: if the connection is closed or timed out before receiving the complete payload.
@raise InvalidResponse: if the length of a chunk could not be determined. @raise InvalidResponse: if the length of a chunk could not be determined.
""" """
@@ -184,6 +192,12 @@ class ChunkedRetriever(Retriever):
raise IncompleteResponse("Connection closed before receiving the complete payload!") raise IncompleteResponse("Connection closed before receiving the complete payload!")
def __get_chunk_size(self): def __get_chunk_size(self):
"""
Returns the next chunk size.
@return: The chunk size in bytes
@raise InvalidResponse: If an error occured when parsing the chunk size.
"""
line = self.client.read_line() line = self.client.read_line()
sep_pos = line.find(";") sep_pos = line.find(";")
if sep_pos >= 0: if sep_pos >= 0:
@@ -192,4 +206,4 @@ class ChunkedRetriever(Retriever):
try: try:
return int(line, 16) return int(line, 16)
except ValueError: except ValueError:
raise InvalidResponse() raise InvalidResponse("Failed to parse chunk size")

View File

@@ -1 +0,0 @@
lxml~=4.6.2

View File

@@ -14,7 +14,7 @@ def main():
parser.add_argument("--workers", "-w", parser.add_argument("--workers", "-w",
help="The amount of worker processes. This is by default based on the number of cpu threads.", help="The amount of worker processes. This is by default based on the number of cpu threads.",
type=int) type=int)
parser.add_argument("--port", "-p", help="The port to listen on", default=8000) parser.add_argument("--port", "-p", help="The port to listen on", default=5055)
arguments = parser.parse_args() arguments = parser.parse_args()
logging_level = logging.ERROR - (10 * arguments.verbose) logging_level = logging.ERROR - (10 * arguments.verbose)

View File

@@ -3,11 +3,9 @@ import os
import sys import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from datetime import datetime from datetime import datetime
from time import mktime
from wsgiref.handlers import format_date_time
from httplib import parser from httplib import parser
from httplib.exceptions import NotFound, Forbidden, NotModified from httplib.exceptions import NotFound, Forbidden, NotModified, BadRequest
from httplib.httpsocket import FORMAT from httplib.httpsocket import FORMAT
from httplib.message import RequestMessage as Message from httplib.message import RequestMessage as Message
@@ -60,28 +58,36 @@ class AbstractCommand(ABC):
@property @property
@abstractmethod @abstractmethod
def _conditional_headers(self): def _conditional_headers(self):
"""
The conditional headers specific to this command instance.
"""
pass pass
def _get_date(self):
"""
Returns a string representation of the current date according to RFC 1123.
"""
now = datetime.now()
stamp = mktime(now.timetuple())
return format_date_time(stamp)
@abstractmethod @abstractmethod
def execute(self): def execute(self):
"""
Execute the command
"""
pass pass
def _build_message(self, status: int, content_type: str, body: bytes, extra_headers=None): def _build_message(self, status: int, content_type: str, body: bytes, extra_headers=None):
"""
Build the response message.
@param status: The response status code
@param content_type: The response content-type header
@param body: The response body, may be empty.
@param extra_headers: Extra headers needed in the response message
@return: The encoded response message
"""
if extra_headers is None: if extra_headers is None:
extra_headers = {} extra_headers = {}
self._process_conditional_headers() self._process_conditional_headers()
message = f"HTTP/1.1 {status} {status_message[status]}\r\n" message = f"HTTP/1.1 {status} {status_message[status]}\r\n"
message += f"Date: {self._get_date()}\r\n" message += f"Date: {parser.get_date()}\r\n"
content_length = len(body) content_length = len(body)
message += f"Content-Length: {content_length}\r\n" message += f"Content-Length: {content_length}\r\n"
@@ -105,6 +111,13 @@ class AbstractCommand(ABC):
return message return message
def _get_path(self, check=True): def _get_path(self, check=True):
"""
Returns the absolute file system path of the resource in the request.
@param check: If True, throws an error if the file doesn't exist
@raise NotFound: if `check` is True and the path doesn't exist
"""
norm_path = os.path.normpath(self.msg.target.path) norm_path = os.path.normpath(self.msg.target.path)
if norm_path == "/": if norm_path == "/":
@@ -118,6 +131,9 @@ class AbstractCommand(ABC):
return path return path
def _process_conditional_headers(self): def _process_conditional_headers(self):
"""
Processes the conditional headers for this command instance.
"""
for header in self._conditional_headers: for header in self._conditional_headers:
tmp = self.msg.headers.get(header) tmp = self.msg.headers.get(header)
@@ -127,6 +143,13 @@ class AbstractCommand(ABC):
self._conditional_headers[header]() self._conditional_headers[header]()
def _if_modified_since(self): def _if_modified_since(self):
"""
Processes the if-modified-since header.
@return: True if the header is invalid, and thus shouldn't be taken into account, throws NotModified
if the content isn't modified since the given date.
@raise NotModified: If the date of if-modified-since greater than the modify-date of the resource.
"""
date_val = self.msg.headers.get("if-modified-since") date_val = self.msg.headers.get("if-modified-since")
if not date_val: if not date_val:
return True return True
@@ -141,7 +164,14 @@ class AbstractCommand(ABC):
return True return True
def get_mimetype(self, path): @staticmethod
def get_mimetype(path):
"""
Guess the type of file.
@param path: the path to the file to guess the type of
@return: The mimetype based on the extension, or if that fails, returns "text/plain" if the file is text,
otherwise returns "application/octet-stream"
"""
mime = mimetypes.guess_type(path)[0] mime = mimetypes.guess_type(path)[0]
if mime: if mime:
@@ -157,10 +187,16 @@ class AbstractCommand(ABC):
class AbstractModifyCommand(AbstractCommand, ABC): class AbstractModifyCommand(AbstractCommand, ABC):
"""
Base class for commands which modify a resource based on the request.
"""
@property @property
@abstractmethod @abstractmethod
def _file_mode(self): def _file_mode(self):
"""
The mode to open the target resource with. (e.a. 'a' or 'w')
"""
pass pass
@property @property
@@ -194,6 +230,10 @@ class AbstractModifyCommand(AbstractCommand, ABC):
class HeadCommand(AbstractCommand): class HeadCommand(AbstractCommand):
"""
A Command instance which represents an HEAD request
"""
@property @property
def command(self): def command(self):
return "HEAD" return "HEAD"
@@ -204,12 +244,16 @@ class HeadCommand(AbstractCommand):
def execute(self): def execute(self):
path = self._get_path() path = self._get_path()
mime = self.get_mimetype(path) mime = self.get_mimetype(path)
return self._build_message(200, mime, b"") return self._build_message(200, mime, b"")
class GetCommand(AbstractCommand): class GetCommand(AbstractCommand):
"""
A Command instance which represents a GET request
"""
@property @property
def command(self): def command(self):
return "GET" return "GET"
@@ -230,6 +274,10 @@ class GetCommand(AbstractCommand):
class PostCommand(AbstractModifyCommand): class PostCommand(AbstractModifyCommand):
"""
A Command instance which represents a POST request
"""
@property @property
def command(self): def command(self):
return "POST" return "POST"
@@ -240,6 +288,10 @@ class PostCommand(AbstractModifyCommand):
class PutCommand(AbstractModifyCommand): class PutCommand(AbstractModifyCommand):
"""
A Command instance which represents a PUT request
"""
@property @property
def command(self): def command(self):
return "PUT" return "PUT"
@@ -247,3 +299,9 @@ class PutCommand(AbstractModifyCommand):
@property @property
def _file_mode(self): def _file_mode(self):
return "w" return "w"
def execute(self):
if "content-range" in self.msg.headers:
raise BadRequest("PUT request contains a Content-Range header")
super().execute()

View File

@@ -67,7 +67,7 @@ class HTTPServer:
""" """
self.server.listen() self.server.listen()
logging.debug("Listening on %s:%d", self.address, self.port) logging.info("Listening on %s:%d", self.address, self.port)
while True: while True:
conn, addr = self.server.accept() conn, addr = self.server.accept()
@@ -86,6 +86,8 @@ class HTTPServer:
Notifies the worker processes to shut down and eventually closes the server socket Notifies the worker processes to shut down and eventually closes the server socket
""" """
logging.info("Shutting down server...")
# Set stop event # Set stop event
self._stop_event.set() self._stop_event.set()
@@ -111,7 +113,7 @@ class HTTPServer:
""" """
Create worker processes up to `self.worker_count`. Create worker processes up to `self.worker_count`.
A worker process is created with start method "spawn", target `worker.worker` and the `self.logging_level` A worker process is created with start method "spawn", target `worker.worker`, and the `self.logging_level`
is passed along with the `self.dispatch_queue` and `self._stop_event` is passed along with the `self.dispatch_queue` and `self._stop_event`
""" """
for i in range(self.worker_count): for i in range(self.worker_count):

View File

@@ -1,17 +1,12 @@
import logging import logging
import os
import sys
from datetime import datetime
from socket import socket from socket import socket
from time import mktime
from typing import Union from typing import Union
from urllib.parse import ParseResultBytes, ParseResult from urllib.parse import ParseResultBytes, ParseResult
from wsgiref.handlers import format_date_time
from httplib import parser from httplib import parser
from httplib.exceptions import MethodNotAllowed, BadRequest, UnsupportedEncoding, NotImplemented, NotFound, \ from httplib.exceptions import MethodNotAllowed, BadRequest, UnsupportedEncoding, NotImplemented, NotFound, \
HTTPVersionNotSupported HTTPVersionNotSupported
from httplib.httpsocket import HTTPSocket, FORMAT from httplib.httpsocket import FORMAT
from httplib.message import RequestMessage as Message from httplib.message import RequestMessage as Message
from httplib.retriever import Retriever, PreambleRetriever from httplib.retriever import Retriever, PreambleRetriever
from server import command from server import command
@@ -21,13 +16,24 @@ METHODS = ("GET", "HEAD", "PUT", "POST")
class RequestHandler: class RequestHandler:
conn: HTTPSocket """
root = os.path.join(os.path.dirname(sys.argv[0]), "public") A RequestHandler instance processes incoming HTTP requests messages from a single client.
RequestHandler instances are created everytime a client connects. They will read the incoming
messages, parse, verify them and send a response.
"""
conn: ServerSocket
host: str
def __init__(self, conn: socket, host): def __init__(self, conn: socket, host):
self.conn = ServerSocket(conn, host) self.conn = ServerSocket(conn)
self.host = host
def listen(self): def listen(self):
"""
Listen to incoming messages and process them.
"""
retriever = PreambleRetriever(self.conn) retriever = PreambleRetriever(self.conn)
@@ -41,16 +47,27 @@ class RequestHandler:
self._handle_message(retriever, line) self._handle_message(retriever, line)
def _handle_message(self, retriever, line): def _handle_message(self, retriever, line):
"""
Retrieves and processes the request message.
@param retriever: the retriever instance to retrieve the lines.
@param line: the first received line.
"""
lines = retriever.retrieve() lines = retriever.retrieve()
# Parse the request-line and headers
(method, target, version) = parser.parse_request_line(line) (method, target, version) = parser.parse_request_line(line)
headers = parser.parse_headers(lines) headers = parser.parse_headers(lines)
# Create the response message object
message = Message(version, method, target, headers, retriever.buffer) message = Message(version, method, target, headers, retriever.buffer)
logging.debug("---request begin---\r\n%s---request end---", "".join(message.raw)) logging.debug("---request begin---\r\n%s---request end---", "".join(message.raw))
# validate if the request is valid
self._validate_request(message) self._validate_request(message)
# The body (if available) hasn't been retrieved up till now.
body = b"" body = b""
if self._has_body(headers): if self._has_body(headers):
try: try:
@@ -64,14 +81,26 @@ class RequestHandler:
message.body = body message.body = body
# completed message # message completed
cmd = command.create(message) cmd = command.create(message)
msg = cmd.execute() msg = cmd.execute()
logging.debug("---response begin---\r\n%s\r\n---response end---", msg.split(b"\r\n\r\n", 1)[0].decode(FORMAT)) logging.debug("---response begin---\r\n%s\r\n---response end---", msg.split(b"\r\n\r\n", 1)[0].decode(FORMAT))
# Send the response message
self.conn.conn.sendall(msg) self.conn.conn.sendall(msg)
def _check_request_line(self, method: str, target: Union[ParseResultBytes, ParseResult], version): def _check_request_line(self, method: str, target: Union[ParseResultBytes, ParseResult], version):
"""
Checks if the request-line is valid. Throws an appropriate exception if not.
@param method: HTTP request method
@param target: The request target
@param version: The HTTP version
@raise MethodNotAllowed: if the method is not any of the allowed methods in `METHODS`
@raise HTTPVersionNotSupported: If the HTTP version is not supported by this server
@raise BadRequest: If the scheme of the target is not supported
@raise NotFound: If the target is not found on this server
"""
if method not in METHODS: if method not in METHODS:
raise MethodNotAllowed(METHODS) raise MethodNotAllowed(METHODS)
@@ -84,19 +113,33 @@ class RequestHandler:
# Only http is supported... # Only http is supported...
raise BadRequest(f"scheme={target.scheme}") raise BadRequest(f"scheme={target.scheme}")
if target.netloc != "" and target.netloc != self.conn.host and target.netloc != self.conn.host.split(":")[0]: if target.netloc != "" and target.netloc != self.host and target.netloc != self.host.split(":")[0]:
raise NotFound(str(target)) raise NotFound(str(target))
if target.path == "" or target.path[0] != "/": if target.path == "" or target.path[0] != "/":
raise NotFound(str(target)) raise NotFound(str(target))
def _validate_request(self, msg): def _validate_request(self, msg):
"""
Validates the message request-line and headers. Throws an error if the message is invalid.
@see: _check_request_line for exceptions raised when validating the request-line.
@param msg: the message to validate
@raise BadRequest: if HTTP 1.1, and the Host header is missing
"""
if msg.version == "1.1" and "host" not in msg.headers: if msg.version == "1.1" and "host" not in msg.headers:
raise BadRequest("Missing host header") raise BadRequest("Missing host header")
self._check_request_line(msg.method, msg.target, msg.version) self._check_request_line(msg.method, msg.target, msg.version)
def _has_body(self, headers): def _has_body(self, headers):
"""
Check if the headers notify the existing of a message body.
@param headers: the headers to check
@return: True if the message has a body. False otherwise.
"""
if "transfer-encoding" in headers: if "transfer-encoding" in headers:
return True return True
@@ -106,16 +149,18 @@ class RequestHandler:
return False return False
@staticmethod
def _get_date():
now = datetime.now()
stamp = mktime(now.timetuple())
return format_date_time(stamp)
@staticmethod @staticmethod
def send_error(client: socket, code, message): def send_error(client: socket, code, message):
"""
Send and HTTP error response to the client
@param client: the client to send the response to
@param code: the HTTP status code
@param message: the status code message
"""
message = f"HTTP/1.1 {code} {message}\r\n" message = f"HTTP/1.1 {code} {message}\r\n"
message += RequestHandler._get_date() + "\r\n" message += parser.get_date() + "\r\n"
message += "Content-Length: 0\r\n" message += "Content-Length: 0\r\n"
message += "\r\n" message += "\r\n"

View File

@@ -1,11 +1,18 @@
import socket
from httplib.exceptions import BadRequest from httplib.exceptions import BadRequest
from httplib.httpsocket import HTTPSocket from httplib.httpsocket import HTTPSocket
class ServerSocket(HTTPSocket): class ServerSocket(HTTPSocket):
"""
Wrapper class for a socket. Represents a client connected to this server.
"""
"""
Reads the next line decoded as `httpsocket.FORMAT`
@return: the decoded next line retrieved from the socket
@raise InvalidResponse: If the next line couldn't be decoded, but was expected to
"""
def read_line(self): def read_line(self):
try: try:
return super().read_line() return super().read_line()

View File

@@ -3,6 +3,7 @@ import multiprocessing as mp
import socket import socket
import threading import threading
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Dict
from httplib.exceptions import HTTPServerException, InternalServerError, HTTPServerCloseException from httplib.exceptions import HTTPServerException, InternalServerError, HTTPServerCloseException
from server.requesthandler import RequestHandler from server.requesthandler import RequestHandler
@@ -18,11 +19,19 @@ def worker(address, name, logging_level, queue: mp.Queue, stop_event: mp.Event):
try: try:
runner.run() runner.run()
except KeyboardInterrupt: except KeyboardInterrupt:
# Catch exit signals and close the threads appropriately.
logging.debug("Ctrl+C pressed, terminating") logging.debug("Ctrl+C pressed, terminating")
runner.shutdown() runner.shutdown()
class Worker: class Worker:
"""
A Worker instance represents a parallel execution process to handle incoming connections.
Worker instances are created when the HTTP server starts. They are used to handle many incoming connections
asynchronously.
"""
host: str host: str
name: str name: str
queue: mp.Queue queue: mp.Queue
@@ -30,24 +39,40 @@ class Worker:
stop_event: mp.Event stop_event: mp.Event
finished_queue: mp.Queue finished_queue: mp.Queue
dispatched_sockets: Dict[int, socket.socket]
def __init__(self, host, name, queue: mp.Queue, stop_event: mp.Event): def __init__(self, host, name, queue: mp.Queue, stop_event: mp.Event):
"""
Create a new Worker instance
@param host: The hostname of the HTTP server
@param name: The name of this Worker instance
@param queue: The dispatch queue for incoming socket connections
@param stop_event: The Event that signals when to shut down this worker.
"""
self.host = host self.host = host
self.name = name self.name = name
self.queue = queue self.queue = queue
self.executor = ThreadPoolExecutor(THREAD_LIMIT) self.executor = ThreadPoolExecutor(THREAD_LIMIT)
self.stop_event = stop_event self.stop_event = stop_event
self.finished_queue = mp.Queue() self.finished_queue = mp.Queue()
self.dispatched_sockets = {}
for i in range(THREAD_LIMIT): for i in range(THREAD_LIMIT):
self.finished_queue.put(i) self.finished_queue.put(i)
def run(self): def run(self):
"""
Run this worker.
The worker will start waiting for incoming clients being added to the queue and submit them to
the executor.
"""
while not self.stop_event.is_set(): while not self.stop_event.is_set():
# Blocks until thread is free # Blocks until the thread is free
self.finished_queue.get() self.finished_queue.get()
# Blocks until new client connects # Blocks until a new client connects
conn, addr = self.queue.get() conn, addr = self.queue.get()
if conn is None or addr is None: if conn is None or addr is None:
@@ -55,12 +80,32 @@ class Worker:
logging.debug("Processing new client: %s", addr) logging.debug("Processing new client: %s", addr)
# submit client to thread # submit the client to the executor
self.executor.submit(self._handle_client, conn, addr) self.executor.submit(self._handle_client, conn, addr)
self.shutdown() self.shutdown()
def _handle_client(self, conn: socket.socket, addr): def _handle_client(self, conn: socket.socket, addr):
"""
Target method for the worker threads.
Creates a RequestHandler and handles any exceptions which may occur.
@param conn: The client socket
@param addr: The address of the client.
"""
self.dispatched_sockets[threading.get_ident()] = conn
try:
self.__do_handle_client(conn, addr)
except Exception:
if not self.stop_event:
logging.debug("Internal error in thread:", exc_info=True)
self.dispatched_sockets.pop(threading.get_ident())
# Finished, put back into queue
self.finished_queue.put(threading.get_ident())
def __do_handle_client(self, conn: socket.socket, addr):
handler = RequestHandler(conn, self.host) handler = RequestHandler(conn, self.host)
@@ -68,28 +113,50 @@ class Worker:
try: try:
handler.listen() handler.listen()
except HTTPServerCloseException as e: except HTTPServerCloseException as e:
# Exception raised after which the client should be disconnected.
logging.warning("[HTTP: %s] %s. Reason: %s", e.status_code, e.message, e.arg) logging.warning("[HTTP: %s] %s. Reason: %s", e.status_code, e.message, e.arg)
RequestHandler.send_error(conn, e.status_code, e.message) RequestHandler.send_error(conn, e.status_code, e.message)
break break
except HTTPServerException as e: except HTTPServerException as e:
# Normal HTTP exception raised (e.a. 404) continue listening.
logging.debug("[HTTP: %s] %s. Reason: %s", e.status_code, e.message, e.arg) logging.debug("[HTTP: %s] %s. Reason: %s", e.status_code, e.message, e.arg)
RequestHandler.send_error(conn, e.status_code, e.message) RequestHandler.send_error(conn, e.status_code, e.message)
except socket.timeout: except socket.timeout:
# socket timed out, disconnect.
logging.info("Socket for client %s timed out.", addr) logging.info("Socket for client %s timed out.", addr)
break break
except ConnectionAbortedError: except ConnectionAbortedError:
# Client aborted connection
logging.info("Socket for client %s disconnected.", addr) logging.info("Socket for client %s disconnected.", addr)
break break
except Exception as e: except Exception as e:
# Unexpected exception raised. Send 500 and disconnect.
logging.error("Internal error", exc_info=e) logging.error("Internal error", exc_info=e)
RequestHandler.send_error(conn, InternalServerError.status_code, InternalServerError.message) RequestHandler.send_error(conn, InternalServerError.status_code, InternalServerError.message)
break break
conn.shutdown(socket.SHUT_RDWR) conn.shutdown(socket.SHUT_RDWR)
conn.close() conn.close()
# Finished, put back into queue
self.finished_queue.put(threading.get_ident())
def shutdown(self): def shutdown(self):
logging.info("shutting down") logging.info("shutting down")
# shutdown executor, but do not wait
self.executor.shutdown(False)
logging.info("Closing sockets")
# Copy dictionary to prevent issues with concurrency
clients = self.dispatched_sockets.copy().values()
for client in clients:
client: socket.socket
try:
client.shutdown(socket.SHUT_RDWR)
client.close()
except OSError:
# Ignore exception due to already closed sockets
pass
# Call shutdown again and wait this time
self.executor.shutdown() self.executor.shutdown()