This commit is contained in:
2021-03-21 23:01:09 +01:00
parent 638576f471
commit d25d2ef993
14 changed files with 681 additions and 226 deletions

57
server/RequestHandler.py Normal file
View File

@@ -0,0 +1,57 @@
import logging
from logging import Logger
from socket import socket
from typing import Union
from urllib.parse import ParseResultBytes, ParseResult
from httplib import parser
from httplib.exceptions import MethodNotAllowed, BadRequest
from httplib.httpsocket import HTTPSocket
from httplib.retriever import Retriever
METHODS = ("GET", "HEAD", "PUT", "POST")
class RequestHandler:
conn: HTTPSocket
logger: Logger
def __init__(self, conn: socket, logger, host):
self.conn = HTTPSocket(conn, host)
self.logger = logger
def listen(self):
self.logger.debug("Parsing request line")
logging.debug("test logger")
(method, target, version) = parser.parse_request_line(self.conn)
headers = parser.parse_request_headers(self.conn)
self._validate_request(method, target, version, headers)
self.logger.debug("Parsed request-line: version: %s, target: %r", method, target)
headers = parser.get_headers(self.conn)
self.logger.debug("Parsed headers: %r", headers)
retriever = Retriever.create(self.conn, headers)
body = retriever.retrieve()
self.logger.debug("body: %r", body)
def _check_request_line(self, method: str, target: Union[ParseResultBytes, ParseResult], version):
if method not in METHODS:
raise MethodNotAllowed(METHODS)
# only origin-form and absolute-form are allowed
if len(target.path) < 1 or target.path[0] != "/" or \
target.netloc not in ("http", "https") and target.hostname == "":
raise BadRequest()
if version not in ("1.0", "1.1"):
raise BadRequest()
def _validate_request(self, method, target, version, headers):
self._check_request_line(method, target, version)
if version == "1.1" and "host" not in headers:
raise BadRequest()

94
server/httpserver.py Normal file
View File

@@ -0,0 +1,94 @@
import logging
import multiprocessing as mp
import socket
import time
from multiprocessing.context import Process
from multiprocessing.queues import Queue
from multiprocessing.synchronize import Event
from server import worker
class HTTPServer:
address: str
port: int
workers = []
worker_count: int
server: socket
_dispatch_queue: Queue
_stop_event: Event
def __init__(self, address: str, port: int, worker_count, logging_level):
self.address = address
self.port = port
self.worker_count = worker_count
self.logging_level = logging_level
mp.set_start_method("spawn")
self._dispatch_queue = mp.Queue()
self._stop_event = mp.Event()
def start(self):
try:
self.__do_start()
except KeyboardInterrupt:
self.__shutdown()
def __do_start(self):
# Create socket
self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.server.bind((self.address, self.port))
self.__create_workers()
self.__listen()
def __listen(self):
self.server.listen()
logging.debug("Listening for connections")
while True:
if self._dispatch_queue.qsize() > self.worker_count:
time.sleep(0.01)
continue
conn, addr = self.server.accept()
logging.info("New connection: %s", addr[0])
self._dispatch_queue.put((conn, addr))
logging.debug("Dispatched connection %s", addr)
def __shutdown(self):
# Set stop event
self._stop_event.set()
# Wake up workers
logging.debug("Waking up workers")
for p in self.workers:
self._dispatch_queue.put((None, None))
logging.debug("Closing dispatch queue")
self._dispatch_queue.close()
logging.debug("Waiting for workers to shutdown")
p: Process
for p in self.workers:
p.join()
p.terminate()
logging.debug("Shutting down socket")
self.server.shutdown(socket.SHUT_RDWR)
self.server.close()
def __create_workers(self):
for i in range(self.worker_count):
logging.debug("Creating worker: %d", i + 1)
p = mp.Process(target=worker.worker,
args=(self.address, i + 1, self.logging_level, self._dispatch_queue, self._stop_event))
p.start()
self.workers.append(p)
time.sleep(0.1)

83
server/worker.py Normal file
View File

@@ -0,0 +1,83 @@
import logging
import multiprocessing
import multiprocessing as mp
import threading
from concurrent.futures import ThreadPoolExecutor
from logging import Logger
from socket import socket
from server.RequestHandler import RequestHandler
THREAD_LIMIT = 20
def worker(address, name, log_level, queue: mp.Queue, stop_event: mp.Event):
logging.basicConfig(level=log_level)
logger = multiprocessing.log_to_stderr(level=log_level)
runner = Worker(address, name, logger, queue, stop_event)
runner.logger.debug("Worker %s started", name)
try:
runner.run()
except KeyboardInterrupt:
logger.debug("Ctrl+C pressed, terminating")
runner.shutdown()
class Worker:
host: str
name: str
logger: Logger
queue: mp.Queue
executor: ThreadPoolExecutor
stop_event: mp.Event
finished_queue: mp.Queue
def __init__(self, host, name, logger, queue: mp.Queue, stop_event: mp.Event):
self.host = host
self.name = name
self.logger = logger
self.queue = queue
self.executor = ThreadPoolExecutor(THREAD_LIMIT)
self.stop_event = stop_event
self.finished_queue = mp.Queue()
for i in range(THREAD_LIMIT):
self.finished_queue.put(i)
def run(self):
while not self.stop_event.is_set():
# Blocks until thread is free
self.finished_queue.get()
# Blocks until new client connects
conn, addr = self.queue.get()
if conn is None or addr is None:
break
self.logger.debug("Received new client: %s", addr)
# submit client to thread
print(threading.get_ident())
self.executor.submit(self._handle_client, conn, addr)
self.shutdown()
def _handle_client(self, conn: socket, addr):
try:
self.logger.debug("Handling client: %s", addr)
handler = RequestHandler(conn, self.logger, self.host)
handler.listen()
except Exception as e:
self.logger.debug("Internal error", exc_info=e)
# Finished, put back into queue
self.finished_queue.put(threading.get_ident())
def shutdown(self):
self.logger.info("shutting down")
self.executor.shutdown()