diff --git a/client/command.py b/client/command.py index e3d1c9c..19b8ae6 100644 --- a/client/command.py +++ b/client/command.py @@ -2,9 +2,10 @@ import logging from abc import ABC, abstractmethod from urllib.parse import urlparse -from client.ResponseHandler import ResponseHandler -from client.httpclient import FORMAT, HTTPClient, InvalidResponse, InvalidStatusLine, UnsupportedEncoding - +from client.response_handler import ResponseHandler +from client.httpclient import FORMAT, HTTPClient +from httplib import parser +from httplib.exceptions import InvalidResponse, InvalidStatusLine, UnsupportedEncoding class AbstractCommand(ABC): @@ -34,7 +35,7 @@ class AbstractCommand(ABC): (host, path) = self.parse_uri() client = HTTPClient(host) - client.connect((host, int(self.port))) + client.conn.connect((host, int(self.port))) message = f"{self.command} {path} HTTP/1.1\r\n" message += f"Host: {host}\r\n" @@ -44,7 +45,7 @@ class AbstractCommand(ABC): logging.info("---request begin---\r\n%s---request end---", encoded_msg.decode(FORMAT)) logging.debug("Sending HTTP message: %r", encoded_msg) - client.sendall(encoded_msg) + client.conn.sendall(encoded_msg) logging.info("HTTP request sent, awaiting response...") @@ -118,9 +119,9 @@ class GetCommand(AbstractCommand): return "GET" def _await_response(self, client): - (version, status, msg) = ResponseHandler.get_status_line(client) + (version, status, msg) = parser.get_status_line(client) logging.debug("Parsed status-line: version: %s, status: %s", version, status) - headers = ResponseHandler.get_headers(client) + headers = parser.get_headers(client) logging.debug("Parsed headers: %r", headers) handler = ResponseHandler.create(client, headers, status, self.url) diff --git a/client/httpclient.py b/client/httpclient.py index c3e4376..e0f23bc 100644 --- a/client/httpclient.py +++ b/client/httpclient.py @@ -1,6 +1,6 @@ -import logging import socket -from io import BufferedReader + +from httplib.httpsocket import HTTPSocket BUFSIZE = 4096 TIMEOUT = 3 @@ -8,98 +8,8 @@ FORMAT = "UTF-8" MAXLINE = 4096 -class HTTPClient(socket.socket): +class HTTPClient(HTTPSocket): host: str - file: BufferedReader def __init__(self, host: str): - - super().__init__(socket.AF_INET, socket.SOCK_STREAM) - self.settimeout(TIMEOUT) - self.host = host - self.setblocking(True) - self.settimeout(3.0) - self.file = self.makefile("rb") - - def close(self): - self.file.close() - super().close() - - def reset_request(self): - self.file.close() - self.file = self.makefile("rb") - - def __do_receive(self): - if self.fileno() == -1: - raise Exception("Connection closed") - - result = self.recv(BUFSIZE) - return result - - def receive(self): - """Receive data from the client up to BUFSIZE - """ - count = 0 - while True: - count += 1 - try: - return self.__do_receive() - except socket.timeout: - logging.debug("Socket receive timed out after %s seconds", TIMEOUT) - if count == 3: - break - logging.debug("Retrying %s", count) - - logging.debug("Timed out after waiting %s seconds for response", TIMEOUT * count) - raise TimeoutError("Request timed out") - - def read(self, size=BUFSIZE, blocking=True) -> bytes: - if blocking: - return self.file.read(size) - - return self.file.read1(size) - - def read_line(self): - return str(self.read_bytes_line(), FORMAT) - - def read_bytes_line(self): - """ - - :rtype: bytes - """ - line = self.file.readline(MAXLINE + 1) - if len(line) > MAXLINE: - raise InvalidResponse("Line too long") - - return line - - -class HTTPException(Exception): - """ Base class for HTTP exceptions """ - - -class InvalidResponse(HTTPException): - """ Response message cannot be parsed """ - - def __init(self, message): - self.message = message - - -class InvalidStatusLine(HTTPException): - """ Response status line is invalid """ - - def __init(self, line): - self.line = line - - -class UnsupportedEncoding(HTTPException): - """ Reponse Encoding not support """ - - def __init(self, enc_type, encoding): - self.enc_type = enc_type - self.encoding = encoding - - -class IncompleteResponse(HTTPException): - def __init(self, cause): - self.cause = cause + super().__init__(socket.socket(socket.AF_INET, socket.SOCK_STREAM), host) diff --git a/client/ResponseHandler.py b/client/response_handler.py similarity index 68% rename from client/ResponseHandler.py rename to client/response_handler.py index 6b71282..824f383 100644 --- a/client/ResponseHandler.py +++ b/client/response_handler.py @@ -1,14 +1,15 @@ import logging import os -import re from abc import ABC, abstractmethod from typing import Dict from urllib.parse import urlparse from bs4 import BeautifulSoup -from client.Retriever import Retriever -from client.httpclient import HTTPClient, UnsupportedEncoding, FORMAT, InvalidResponse, InvalidStatusLine +from client.httpclient import HTTPClient, FORMAT +from httplib.retriever import Retriever +from httplib import parser +from httplib.exceptions import InvalidResponse class ResponseHandler(ABC): @@ -31,17 +32,6 @@ class ResponseHandler(ABC): @staticmethod def create(client: HTTPClient, headers, status_code, url): - # only chunked transfer-encoding is supported - transfer_encoding = headers.get("transfer-encoding") - if transfer_encoding and transfer_encoding != "chunked": - raise UnsupportedEncoding("transfer-encoding", transfer_encoding) - chunked = transfer_encoding - - # content-encoding is not supported - content_encoding = headers.get("content-encoding") - if content_encoding: - raise UnsupportedEncoding("content-encoding", content_encoding) - retriever = Retriever.create(client, headers) content_type = headers.get("content-type") @@ -49,78 +39,6 @@ class ResponseHandler(ABC): return HTMLDownloadHandler(retriever, client, headers, url) return RawDownloadHandler(retriever, client, headers, url) - @staticmethod - def get_status_line(client: HTTPClient): - line = client.read_line() - - split = list(filter(None, line.split(" "))) - if len(split) < 3: - raise InvalidStatusLine(line) - - # Check HTTP version - http_version = split.pop(0) - if len(http_version) < 8 or http_version[4] != "/": - raise InvalidStatusLine(line) - - (name, version) = http_version[:4], http_version[5:] - if name != "HTTP" or not re.match(r"1\.[0|1]", version): - raise InvalidStatusLine(line) - - status = split.pop(0) - if not re.match(r"\d{3}", status): - raise InvalidStatusLine(line) - status = int(status) - if status < 100 or status > 999: - raise InvalidStatusLine(line) - - reason = split.pop(0) - return version, status, reason - - @staticmethod - def get_headers(client: HTTPClient): - headers = [] - # first header after the status-line may not contain a space - while True: - line = client.read_line() - if line[0].isspace(): - continue - else: - break - - while True: - if line in ("\r\n", "\n", " "): - break - - if line[0].isspace(): - headers[-1] = headers[-1].rstrip("\r\n") - - headers.append(line.lstrip()) - line = client.read_line() - - result = {} - header_str = "".join(headers) - for line in header_str.splitlines(): - pos = line.find(":") - - if pos <= 0 or pos >= len(line) - 1: - continue - - (header, value) = map(str.strip, line.split(":", 1)) - ResponseHandler.check_next_header(result, header, value) - result[header.lower()] = value.lower() - - return result - - @staticmethod - def check_next_header(headers, next_header: str, next_value: str): - if next_header == "content-length": - if "content-length" in headers: - logging.error("Multiple content-length headers specified") - raise InvalidResponse() - if not next_value.isnumeric() or int(next_value) <= 0: - logging.error("Invalid content-length value: %r", next_value) - raise InvalidResponse() - @staticmethod def parse_uri(uri: str): parsed = urlparse(uri) @@ -196,9 +114,9 @@ class DownloadHandler(ResponseHandler, ABC): def _handle_sub_request(self, client, url): - (version, status, _) = self.get_status_line(client) + (version, status, _) = parser.get_status_line(client) logging.debug("Parsed status-line: version: %s, status: %s", version, status) - headers = self.get_headers(client) + headers = parser.get_headers(client) logging.debug("Parsed headers: %r", headers) if status != 200: @@ -297,8 +215,8 @@ class HTMLDownloadHandler(DownloadHandler): client.reset_request() else: client = HTTPClient(img_src) - client.connect((img_host, 80)) - client.sendall(message) + client.conn.connect((img_host, 80)) + client.conn.sendall(message) filename = self._handle_sub_request(client, img_host + img_path) if not same_host: diff --git a/httplib/exceptions.py b/httplib/exceptions.py new file mode 100644 index 0000000..930ae7a --- /dev/null +++ b/httplib/exceptions.py @@ -0,0 +1,41 @@ +class HTTPException(Exception): + """ Base class for HTTP exceptions """ + + +class InvalidResponse(HTTPException): + """ Response message cannot be parsed """ + + def __init(self, message): + self.message = message + + +class InvalidStatusLine(HTTPException): + """ Response status line is invalid """ + + def __init(self, line): + self.line = line + + +class UnsupportedEncoding(HTTPException): + """ Reponse Encoding not support """ + + def __init(self, enc_type, encoding): + self.enc_type = enc_type + self.encoding = encoding + + +class IncompleteResponse(HTTPException): + def __init(self, cause): + self.cause = cause + +class HTTPServerException(Exception): + """ Base class for HTTP Server exceptions """ + + +class BadRequest(HTTPServerException): + """ Malformed HTTP request""" + +class MethodNotAllowed(HTTPServerException): + """ Method is not allowed """ + def __init(self, allowed_methods): + self.allowed_methods = allowed_methods \ No newline at end of file diff --git a/httplib/httpsocket.py b/httplib/httpsocket.py new file mode 100644 index 0000000..a894d09 --- /dev/null +++ b/httplib/httpsocket.py @@ -0,0 +1,82 @@ +import logging +import socket +from io import BufferedReader + +BUFSIZE = 4096 +TIMEOUT = 3 +FORMAT = "UTF-8" +MAXLINE = 4096 + + +class HTTPSocket: + host: str + conn: socket.socket + file: BufferedReader + + def __init__(self, conn: socket.socket, host: str): + + self.host = host + self.conn = conn + self.conn.settimeout(TIMEOUT) + self.conn.setblocking(True) + self.conn.settimeout(3.0) + self.file = self.conn.makefile("rb") + + def close(self): + self.file.close() + self.conn.close() + + def reset_request(self): + self.file.close() + self.file = self.conn.makefile("rb") + + def __do_receive(self): + if self.conn.fileno() == -1: + raise Exception("Connection closed") + + result = self.conn.recv(BUFSIZE) + return result + + def receive(self): + """Receive data from the client up to BUFSIZE + """ + count = 0 + while True: + count += 1 + try: + return self.__do_receive() + except socket.timeout: + logging.debug("Socket receive timed out after %s seconds", TIMEOUT) + if count == 3: + break + logging.debug("Retrying %s", count) + + logging.debug("Timed out after waiting %s seconds for response", TIMEOUT * count) + raise TimeoutError("Request timed out") + + def read(self, size=BUFSIZE, blocking=True) -> bytes: + if blocking: + return self.file.read(size) + + return self.file.read1(size) + + def read_line(self): + return str(self.read_bytes_line(), FORMAT) + + def read_bytes_line(self) -> bytes: + line = self.file.readline(MAXLINE + 1) + if len(line) > MAXLINE: + raise InvalidResponse("Line too long") + + return line + + +class HTTPException(Exception): + """ Base class for HTTP exceptions """ + + +class InvalidResponse(HTTPException): + """ Response message cannot be parsed """ + + def __init(self, message): + self.message = message diff --git a/httplib/parser.py b/httplib/parser.py new file mode 100644 index 0000000..771b7e1 --- /dev/null +++ b/httplib/parser.py @@ -0,0 +1,160 @@ +import logging +import re +from urllib.parse import urlparse + +from httplib.exceptions import InvalidStatusLine, InvalidResponse, BadRequest +from httplib.httpsocket import HTTPSocket + + +def _get_start_line(client: HTTPSocket): + line = client.read_line() + split = list(filter(None, line.split(" "))) + if len(split) < 3: + raise InvalidStatusLine(line) # TODO fix exception + + return line, split + + +def _is_valid_http_version(http_version: str): + if len(http_version) < 8 or http_version[4] != "/": + return False + + (name, version) = http_version[:4], http_version[5:] + if name != "HTTP" or not re.match(r"1\.[0|1]", version): + return False + + +def get_status_line(client: HTTPSocket): + line, (http_version, status, reason) = _get_start_line(client) + + if not _is_valid_http_version(http_version): + raise InvalidStatusLine(line) + version = http_version[:4] + + if not re.match(r"\d{3}", status): + raise InvalidStatusLine(line) + status = int(status) + if status < 100 or status > 999: + raise InvalidStatusLine(line) + + return version, status, reason + + +def parse_request_line(client: HTTPSocket): + line, (method, target, version) = _get_start_line(client) + + if method not in ("CONNECT", "DELETE", "GET", "HEAD", "OPTIONS", "POST", "PUT", "TRACE"): + raise BadRequest() + + if not _is_valid_http_version(version): + raise BadRequest() + + if len(target) == "": + raise BadRequest() + parsed_target = urlparse(target) + + return method, parsed_target, version + + +def retrieve_headers(client: HTTPSocket): + raw_headers = [] + # first header after the status-line may not contain a space + while True: + line = client.read_line() + if line[0].isspace(): + continue + else: + break + + while True: + if line in ("\r\n", "\n", " "): + break + + if line[0].isspace(): + raw_headers[-1] = raw_headers[-1].rstrip("\r\n") + + raw_headers.append(line.lstrip()) + line = client.read_line() + + result = [] + header_str = "".join(raw_headers) + for line in header_str.splitlines(): + pos = line.find(":") + + if pos <= 0 or pos >= len(line) - 1: + continue + + (header, value) = line.split(":", 1) + result.append((header.lower(), value.lower())) + + return result + + +def parse_request_headers(client: HTTPSocket): + raw_headers = retrieve_headers(client) + headers = {} + + key: str + for (key, value) in raw_headers: + if any((c.isspace()) for c in key): + raise BadRequest() + + if key == "content-length": + if key in headers: + logging.error("Multiple content-length headers specified") + raise BadRequest() + if not value.isnumeric() or int(value) <= 0: + logging.error("Invalid content-length value: %r", value) + raise BadRequest() + elif key == "host": + if value != client.host or key in headers: + raise BadRequest() + + headers[key] = value + + return headers + + +def get_headers(client: HTTPSocket): + headers = [] + # first header after the status-line may not contain a space + while True: + line = client.read_line() + if line[0].isspace(): + continue + else: + break + + while True: + if line in ("\r\n", "\n", " "): + break + + if line[0].isspace(): + headers[-1] = headers[-1].rstrip("\r\n") + + headers.append(line.lstrip()) + line = client.read_line() + + result = {} + header_str = "".join(headers) + for line in header_str.splitlines(): + pos = line.find(":") + + if pos <= 0 or pos >= len(line) - 1: + continue + + (header, value) = map(str.strip, line.split(":", 1)) + check_next_header(result, header, value) + result[header.lower()] = value.lower() + + return result + + +def check_next_header(headers, next_header: str, next_value: str): + if next_header == "content-length": + if "content-length" in headers: + logging.error("Multiple content-length headers specified") + raise InvalidResponse() + if not next_value.isnumeric() or int(next_value) <= 0: + logging.error("Invalid content-length value: %r", next_value) + raise InvalidResponse() diff --git a/client/Retriever.py b/httplib/retriever.py similarity index 91% rename from client/Retriever.py rename to httplib/retriever.py index 5a66cf9..280a3d6 100644 --- a/client/Retriever.py +++ b/httplib/retriever.py @@ -2,13 +2,14 @@ import logging from abc import ABC, abstractmethod from typing import Dict -from client.httpclient import HTTPClient, BUFSIZE, IncompleteResponse, InvalidResponse, UnsupportedEncoding +from httplib.exceptions import IncompleteResponse, InvalidResponse, UnsupportedEncoding +from httplib.httpsocket import HTTPSocket, BUFSIZE class Retriever(ABC): - client: HTTPClient + client: HTTPSocket - def __init__(self, client: HTTPClient): + def __init__(self, client: HTTPSocket): self.client = client @abstractmethod @@ -16,7 +17,7 @@ class Retriever(ABC): pass @staticmethod - def create(client: HTTPClient, headers: Dict[str, str]): + def create(client: HTTPSocket, headers: Dict[str, str]): # only chunked transfer-encoding is supported transfer_encoding = headers.get("transfer-encoding") @@ -44,7 +45,7 @@ class Retriever(ABC): class ContentLengthRetriever(Retriever): length: int - def __init__(self, client: HTTPClient, length: int): + def __init__(self, client: HTTPSocket, length: int): super().__init__(client) self.length = length diff --git a/public/index.html b/public/index.html new file mode 100644 index 0000000..225dbac --- /dev/null +++ b/public/index.html @@ -0,0 +1,53 @@ + + + + Computer Networks example + + + + + + + + +
+

Example Domain

+

This domain is for use in illustrative examples in documents. You may use this + domain in literature without prior coordination or asking for permission.

+
+
+

Remote image

+ +
+
+

Local image

+ + + + diff --git a/public/ulyssis.png b/public/ulyssis.png new file mode 100644 index 0000000..14819f5 Binary files /dev/null and b/public/ulyssis.png differ diff --git a/server.py b/server.py index 0d5c7ce..fea876d 100644 --- a/server.py +++ b/server.py @@ -1,39 +1,90 @@ #!/usr/bin/env python3 - +import argparse +import logging +import multiprocessing import socket +import sys -# socket heeft een listening and accept method -import time +from server.httpserver import HTTPServer -SERVER = "127.0.0.1" #dynamisch fixen in project -PORT = 5055 -server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) -ADDR = (SERVER, PORT) # hier wordt de socket gebonden aan mijn IP adres, dit moet wel anders -server.bind(ADDR) # in het project gebeuren +def main(): + parser = argparse.ArgumentParser(description='HTTP Server') + parser.add_argument("--verbose", "-v", action='count', default=0, help="Increase verbosity level of logging") + parser.add_argument("--workers", "-w", + help="The amount of worker processes. This is by default based on the number of cpu threads.", + type=int) + parser.add_argument("--port", "-p", help="The port to listen on", default=8000) + arguments = parser.parse_args() -HEADER = 64 # maximum size messages -FORMAT = 'utf-8' # recieving images through this format does not work -DISCONNECT_MESSAGE = "DISCONNECT!" # special message for disconnecting client and server + logging_level = logging.ERROR - (10 * arguments.verbose) + logging.basicConfig(level=logging_level) + logging.debug("Arguments: %s", arguments) -# function for starting server -def start(): - pass - server.listen() - while True: # infinite loop in which server accept incoming connections, we want to run it forever - conn, addr = server.accept() # Server blocks untill a client connects - print("new connection: ", addr[0], " connected.") - connected = True - while connected: # while client is connected, we want to recieve messages - msg = conn.recv(HEADER).decode(FORMAT).rstrip() # Argument is maximum size of msg (in project look into details of accp), decode is for converting bytes to strings, rstrip is for stripping messages for special hidden characters - print("message: ", msg) - for i in range(0,10): - conn.send(b"test") - time.sleep(1) + # Set workers + if arguments.workers: + workers = int(arguments.workers) + else: + workers = multiprocessing.cpu_count() - break - print("close connection ", addr[0], " disconnected.") - conn.close() + # Set port + if arguments.port: + port = int(arguments.port) + else: + port = 8000 -print("server is starting ... ") -start() \ No newline at end of file + # Get hostname and address + hostname = socket.gethostname() + address = socket.gethostbyname(hostname) + server = HTTPServer(address, port, workers, logging_level) + server.start() + + +try: + if __name__ == '__main__': + main() +except Exception as e: + print("[ABRT] Internal error: " + str(e), file=sys.stderr) + logging.debug("Internal error", exc_info=e) + sys.exit(70) + +# import socket +# +# # Get hostname and address +# hostname = socket.gethostname() +# address = socket.gethostbyname(hostname) +# +# # socket heeft een listening and accept method +# +# SERVER = "127.0.0.1" # dynamisch fixen in project +# PORT = 5055 +# server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) +# +# ADDR = (SERVER, PORT) # hier wordt de socket gebonden aan mijn IP adres, dit moet wel anders +# server.bind(ADDR) # in het project gebeuren +# +# HEADER = 64 # maximum size messages +# FORMAT = 'utf-8' +# DISCONNECT_MESSAGE = "DISCONNECT!" # special message for disconnecting client and server +# +# +# # function for starting server +# def start(): +# pass +# server.listen() +# while True: # infinite loop in which server accept incoming connections, we want to run it forever +# conn, addr = server.accept() # Server blocks untill a client connects +# print("new connection: ", addr[0], " connected.") +# connected = True +# while connected: # while client is connected, we want to recieve messages +# msg = conn.recv(HEADER).decode( +# FORMAT).rstrip() # Argument is maximum size of msg (in project look into details of accp), decode is for converting bytes to strings, rstrip is for stripping messages for special hidden characters +# print("message: ", msg) +# if msg == DISCONNECT_MESSAGE: +# connected = False +# print("close connection ", addr[0], " disconnected.") +# conn.close() +# +# +# print("server is starting ... ") +# start() diff --git a/server/RequestHandler.py b/server/RequestHandler.py new file mode 100644 index 0000000..d3ce433 --- /dev/null +++ b/server/RequestHandler.py @@ -0,0 +1,57 @@ +import logging +from logging import Logger +from socket import socket +from typing import Union +from urllib.parse import ParseResultBytes, ParseResult + +from httplib import parser +from httplib.exceptions import MethodNotAllowed, BadRequest +from httplib.httpsocket import HTTPSocket +from httplib.retriever import Retriever + +METHODS = ("GET", "HEAD", "PUT", "POST") + + +class RequestHandler: + conn: HTTPSocket + logger: Logger + + def __init__(self, conn: socket, logger, host): + self.conn = HTTPSocket(conn, host) + self.logger = logger + + def listen(self): + self.logger.debug("Parsing request line") + logging.debug("test logger") + (method, target, version) = parser.parse_request_line(self.conn) + headers = parser.parse_request_headers(self.conn) + + self._validate_request(method, target, version, headers) + + self.logger.debug("Parsed request-line: version: %s, target: %r", method, target) + headers = parser.get_headers(self.conn) + self.logger.debug("Parsed headers: %r", headers) + retriever = Retriever.create(self.conn, headers) + body = retriever.retrieve() + + self.logger.debug("body: %r", body) + + def _check_request_line(self, method: str, target: Union[ParseResultBytes, ParseResult], version): + + if method not in METHODS: + raise MethodNotAllowed(METHODS) + + # only origin-form and absolute-form are allowed + if len(target.path) < 1 or target.path[0] != "/" or \ + target.netloc not in ("http", "https") and target.hostname == "": + raise BadRequest() + + if version not in ("1.0", "1.1"): + raise BadRequest() + + def _validate_request(self, method, target, version, headers): + + self._check_request_line(method, target, version) + + if version == "1.1" and "host" not in headers: + raise BadRequest() diff --git a/server/httpserver.py b/server/httpserver.py new file mode 100644 index 0000000..f4b3ebd --- /dev/null +++ b/server/httpserver.py @@ -0,0 +1,94 @@ +import logging +import multiprocessing as mp +import socket +import time +from multiprocessing.context import Process +from multiprocessing.queues import Queue +from multiprocessing.synchronize import Event + +from server import worker + + +class HTTPServer: + address: str + port: int + workers = [] + worker_count: int + server: socket + + _dispatch_queue: Queue + _stop_event: Event + + def __init__(self, address: str, port: int, worker_count, logging_level): + self.address = address + self.port = port + self.worker_count = worker_count + self.logging_level = logging_level + + mp.set_start_method("spawn") + self._dispatch_queue = mp.Queue() + self._stop_event = mp.Event() + + def start(self): + try: + self.__do_start() + except KeyboardInterrupt: + self.__shutdown() + + def __do_start(self): + # Create socket + + self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.server.bind((self.address, self.port)) + + self.__create_workers() + + self.__listen() + + def __listen(self): + + self.server.listen() + logging.debug("Listening for connections") + + while True: + if self._dispatch_queue.qsize() > self.worker_count: + time.sleep(0.01) + continue + + conn, addr = self.server.accept() + logging.info("New connection: %s", addr[0]) + self._dispatch_queue.put((conn, addr)) + logging.debug("Dispatched connection %s", addr) + + def __shutdown(self): + + # Set stop event + self._stop_event.set() + + # Wake up workers + logging.debug("Waking up workers") + for p in self.workers: + self._dispatch_queue.put((None, None)) + + logging.debug("Closing dispatch queue") + self._dispatch_queue.close() + + logging.debug("Waiting for workers to shutdown") + p: Process + for p in self.workers: + p.join() + p.terminate() + + logging.debug("Shutting down socket") + self.server.shutdown(socket.SHUT_RDWR) + self.server.close() + + def __create_workers(self): + for i in range(self.worker_count): + logging.debug("Creating worker: %d", i + 1) + p = mp.Process(target=worker.worker, + args=(self.address, i + 1, self.logging_level, self._dispatch_queue, self._stop_event)) + p.start() + self.workers.append(p) + + time.sleep(0.1) diff --git a/server/worker.py b/server/worker.py new file mode 100644 index 0000000..c63b1d1 --- /dev/null +++ b/server/worker.py @@ -0,0 +1,83 @@ +import logging +import multiprocessing +import multiprocessing as mp +import threading +from concurrent.futures import ThreadPoolExecutor +from logging import Logger +from socket import socket + +from server.RequestHandler import RequestHandler + +THREAD_LIMIT = 20 + + +def worker(address, name, log_level, queue: mp.Queue, stop_event: mp.Event): + logging.basicConfig(level=log_level) + logger = multiprocessing.log_to_stderr(level=log_level) + runner = Worker(address, name, logger, queue, stop_event) + runner.logger.debug("Worker %s started", name) + + try: + runner.run() + except KeyboardInterrupt: + logger.debug("Ctrl+C pressed, terminating") + runner.shutdown() + + +class Worker: + + host: str + name: str + logger: Logger + queue: mp.Queue + executor: ThreadPoolExecutor + stop_event: mp.Event + + finished_queue: mp.Queue + + def __init__(self, host, name, logger, queue: mp.Queue, stop_event: mp.Event): + self.host = host + self.name = name + self.logger = logger + self.queue = queue + self.executor = ThreadPoolExecutor(THREAD_LIMIT) + self.stop_event = stop_event + self.finished_queue = mp.Queue() + + for i in range(THREAD_LIMIT): + self.finished_queue.put(i) + + def run(self): + while not self.stop_event.is_set(): + + # Blocks until thread is free + self.finished_queue.get() + # Blocks until new client connects + conn, addr = self.queue.get() + + if conn is None or addr is None: + break + + self.logger.debug("Received new client: %s", addr) + + # submit client to thread + print(threading.get_ident()) + self.executor.submit(self._handle_client, conn, addr) + + self.shutdown() + + def _handle_client(self, conn: socket, addr): + try: + self.logger.debug("Handling client: %s", addr) + + handler = RequestHandler(conn, self.logger, self.host) + handler.listen() + except Exception as e: + self.logger.debug("Internal error", exc_info=e) + + # Finished, put back into queue + self.finished_queue.put(threading.get_ident()) + + def shutdown(self): + self.logger.info("shutting down") + self.executor.shutdown() diff --git a/server_flow.md b/server_flow.md new file mode 100644 index 0000000..f486721 --- /dev/null +++ b/server_flow.md @@ -0,0 +1,4 @@ +# Flow +- listen +- dispatch asap +- throw error if too full \ No newline at end of file