This commit is contained in:
2021-03-21 23:01:09 +01:00
parent 638576f471
commit d25d2ef993
14 changed files with 681 additions and 226 deletions

View File

@@ -2,9 +2,10 @@ import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from urllib.parse import urlparse from urllib.parse import urlparse
from client.ResponseHandler import ResponseHandler from client.response_handler import ResponseHandler
from client.httpclient import FORMAT, HTTPClient, InvalidResponse, InvalidStatusLine, UnsupportedEncoding from client.httpclient import FORMAT, HTTPClient
from httplib import parser
from httplib.exceptions import InvalidResponse, InvalidStatusLine, UnsupportedEncoding
class AbstractCommand(ABC): class AbstractCommand(ABC):
@@ -34,7 +35,7 @@ class AbstractCommand(ABC):
(host, path) = self.parse_uri() (host, path) = self.parse_uri()
client = HTTPClient(host) client = HTTPClient(host)
client.connect((host, int(self.port))) client.conn.connect((host, int(self.port)))
message = f"{self.command} {path} HTTP/1.1\r\n" message = f"{self.command} {path} HTTP/1.1\r\n"
message += f"Host: {host}\r\n" message += f"Host: {host}\r\n"
@@ -44,7 +45,7 @@ class AbstractCommand(ABC):
logging.info("---request begin---\r\n%s---request end---", encoded_msg.decode(FORMAT)) logging.info("---request begin---\r\n%s---request end---", encoded_msg.decode(FORMAT))
logging.debug("Sending HTTP message: %r", encoded_msg) logging.debug("Sending HTTP message: %r", encoded_msg)
client.sendall(encoded_msg) client.conn.sendall(encoded_msg)
logging.info("HTTP request sent, awaiting response...") logging.info("HTTP request sent, awaiting response...")
@@ -118,9 +119,9 @@ class GetCommand(AbstractCommand):
return "GET" return "GET"
def _await_response(self, client): def _await_response(self, client):
(version, status, msg) = ResponseHandler.get_status_line(client) (version, status, msg) = parser.get_status_line(client)
logging.debug("Parsed status-line: version: %s, status: %s", version, status) logging.debug("Parsed status-line: version: %s, status: %s", version, status)
headers = ResponseHandler.get_headers(client) headers = parser.get_headers(client)
logging.debug("Parsed headers: %r", headers) logging.debug("Parsed headers: %r", headers)
handler = ResponseHandler.create(client, headers, status, self.url) handler = ResponseHandler.create(client, headers, status, self.url)

View File

@@ -1,6 +1,6 @@
import logging
import socket import socket
from io import BufferedReader
from httplib.httpsocket import HTTPSocket
BUFSIZE = 4096 BUFSIZE = 4096
TIMEOUT = 3 TIMEOUT = 3
@@ -8,98 +8,8 @@ FORMAT = "UTF-8"
MAXLINE = 4096 MAXLINE = 4096
class HTTPClient(socket.socket): class HTTPClient(HTTPSocket):
host: str host: str
file: BufferedReader
def __init__(self, host: str): def __init__(self, host: str):
super().__init__(socket.socket(socket.AF_INET, socket.SOCK_STREAM), host)
super().__init__(socket.AF_INET, socket.SOCK_STREAM)
self.settimeout(TIMEOUT)
self.host = host
self.setblocking(True)
self.settimeout(3.0)
self.file = self.makefile("rb")
def close(self):
self.file.close()
super().close()
def reset_request(self):
self.file.close()
self.file = self.makefile("rb")
def __do_receive(self):
if self.fileno() == -1:
raise Exception("Connection closed")
result = self.recv(BUFSIZE)
return result
def receive(self):
"""Receive data from the client up to BUFSIZE
"""
count = 0
while True:
count += 1
try:
return self.__do_receive()
except socket.timeout:
logging.debug("Socket receive timed out after %s seconds", TIMEOUT)
if count == 3:
break
logging.debug("Retrying %s", count)
logging.debug("Timed out after waiting %s seconds for response", TIMEOUT * count)
raise TimeoutError("Request timed out")
def read(self, size=BUFSIZE, blocking=True) -> bytes:
if blocking:
return self.file.read(size)
return self.file.read1(size)
def read_line(self):
return str(self.read_bytes_line(), FORMAT)
def read_bytes_line(self):
"""
:rtype: bytes
"""
line = self.file.readline(MAXLINE + 1)
if len(line) > MAXLINE:
raise InvalidResponse("Line too long")
return line
class HTTPException(Exception):
""" Base class for HTTP exceptions """
class InvalidResponse(HTTPException):
""" Response message cannot be parsed """
def __init(self, message):
self.message = message
class InvalidStatusLine(HTTPException):
""" Response status line is invalid """
def __init(self, line):
self.line = line
class UnsupportedEncoding(HTTPException):
""" Reponse Encoding not support """
def __init(self, enc_type, encoding):
self.enc_type = enc_type
self.encoding = encoding
class IncompleteResponse(HTTPException):
def __init(self, cause):
self.cause = cause

View File

@@ -1,14 +1,15 @@
import logging import logging
import os import os
import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict from typing import Dict
from urllib.parse import urlparse from urllib.parse import urlparse
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from client.Retriever import Retriever from client.httpclient import HTTPClient, FORMAT
from client.httpclient import HTTPClient, UnsupportedEncoding, FORMAT, InvalidResponse, InvalidStatusLine from httplib.retriever import Retriever
from httplib import parser
from httplib.exceptions import InvalidResponse
class ResponseHandler(ABC): class ResponseHandler(ABC):
@@ -31,17 +32,6 @@ class ResponseHandler(ABC):
@staticmethod @staticmethod
def create(client: HTTPClient, headers, status_code, url): def create(client: HTTPClient, headers, status_code, url):
# only chunked transfer-encoding is supported
transfer_encoding = headers.get("transfer-encoding")
if transfer_encoding and transfer_encoding != "chunked":
raise UnsupportedEncoding("transfer-encoding", transfer_encoding)
chunked = transfer_encoding
# content-encoding is not supported
content_encoding = headers.get("content-encoding")
if content_encoding:
raise UnsupportedEncoding("content-encoding", content_encoding)
retriever = Retriever.create(client, headers) retriever = Retriever.create(client, headers)
content_type = headers.get("content-type") content_type = headers.get("content-type")
@@ -49,78 +39,6 @@ class ResponseHandler(ABC):
return HTMLDownloadHandler(retriever, client, headers, url) return HTMLDownloadHandler(retriever, client, headers, url)
return RawDownloadHandler(retriever, client, headers, url) return RawDownloadHandler(retriever, client, headers, url)
@staticmethod
def get_status_line(client: HTTPClient):
line = client.read_line()
split = list(filter(None, line.split(" ")))
if len(split) < 3:
raise InvalidStatusLine(line)
# Check HTTP version
http_version = split.pop(0)
if len(http_version) < 8 or http_version[4] != "/":
raise InvalidStatusLine(line)
(name, version) = http_version[:4], http_version[5:]
if name != "HTTP" or not re.match(r"1\.[0|1]", version):
raise InvalidStatusLine(line)
status = split.pop(0)
if not re.match(r"\d{3}", status):
raise InvalidStatusLine(line)
status = int(status)
if status < 100 or status > 999:
raise InvalidStatusLine(line)
reason = split.pop(0)
return version, status, reason
@staticmethod
def get_headers(client: HTTPClient):
headers = []
# first header after the status-line may not contain a space
while True:
line = client.read_line()
if line[0].isspace():
continue
else:
break
while True:
if line in ("\r\n", "\n", " "):
break
if line[0].isspace():
headers[-1] = headers[-1].rstrip("\r\n")
headers.append(line.lstrip())
line = client.read_line()
result = {}
header_str = "".join(headers)
for line in header_str.splitlines():
pos = line.find(":")
if pos <= 0 or pos >= len(line) - 1:
continue
(header, value) = map(str.strip, line.split(":", 1))
ResponseHandler.check_next_header(result, header, value)
result[header.lower()] = value.lower()
return result
@staticmethod
def check_next_header(headers, next_header: str, next_value: str):
if next_header == "content-length":
if "content-length" in headers:
logging.error("Multiple content-length headers specified")
raise InvalidResponse()
if not next_value.isnumeric() or int(next_value) <= 0:
logging.error("Invalid content-length value: %r", next_value)
raise InvalidResponse()
@staticmethod @staticmethod
def parse_uri(uri: str): def parse_uri(uri: str):
parsed = urlparse(uri) parsed = urlparse(uri)
@@ -196,9 +114,9 @@ class DownloadHandler(ResponseHandler, ABC):
def _handle_sub_request(self, client, url): def _handle_sub_request(self, client, url):
(version, status, _) = self.get_status_line(client) (version, status, _) = parser.get_status_line(client)
logging.debug("Parsed status-line: version: %s, status: %s", version, status) logging.debug("Parsed status-line: version: %s, status: %s", version, status)
headers = self.get_headers(client) headers = parser.get_headers(client)
logging.debug("Parsed headers: %r", headers) logging.debug("Parsed headers: %r", headers)
if status != 200: if status != 200:
@@ -297,8 +215,8 @@ class HTMLDownloadHandler(DownloadHandler):
client.reset_request() client.reset_request()
else: else:
client = HTTPClient(img_src) client = HTTPClient(img_src)
client.connect((img_host, 80)) client.conn.connect((img_host, 80))
client.sendall(message) client.conn.sendall(message)
filename = self._handle_sub_request(client, img_host + img_path) filename = self._handle_sub_request(client, img_host + img_path)
if not same_host: if not same_host:

41
httplib/exceptions.py Normal file
View File

@@ -0,0 +1,41 @@
class HTTPException(Exception):
""" Base class for HTTP exceptions """
class InvalidResponse(HTTPException):
""" Response message cannot be parsed """
def __init(self, message):
self.message = message
class InvalidStatusLine(HTTPException):
""" Response status line is invalid """
def __init(self, line):
self.line = line
class UnsupportedEncoding(HTTPException):
""" Reponse Encoding not support """
def __init(self, enc_type, encoding):
self.enc_type = enc_type
self.encoding = encoding
class IncompleteResponse(HTTPException):
def __init(self, cause):
self.cause = cause
class HTTPServerException(Exception):
""" Base class for HTTP Server exceptions """
class BadRequest(HTTPServerException):
""" Malformed HTTP request"""
class MethodNotAllowed(HTTPServerException):
""" Method is not allowed """
def __init(self, allowed_methods):
self.allowed_methods = allowed_methods

82
httplib/httpsocket.py Normal file
View File

@@ -0,0 +1,82 @@
import logging
import socket
from io import BufferedReader
BUFSIZE = 4096
TIMEOUT = 3
FORMAT = "UTF-8"
MAXLINE = 4096
class HTTPSocket:
host: str
conn: socket.socket
file: BufferedReader
def __init__(self, conn: socket.socket, host: str):
self.host = host
self.conn = conn
self.conn.settimeout(TIMEOUT)
self.conn.setblocking(True)
self.conn.settimeout(3.0)
self.file = self.conn.makefile("rb")
def close(self):
self.file.close()
self.conn.close()
def reset_request(self):
self.file.close()
self.file = self.conn.makefile("rb")
def __do_receive(self):
if self.conn.fileno() == -1:
raise Exception("Connection closed")
result = self.conn.recv(BUFSIZE)
return result
def receive(self):
"""Receive data from the client up to BUFSIZE
"""
count = 0
while True:
count += 1
try:
return self.__do_receive()
except socket.timeout:
logging.debug("Socket receive timed out after %s seconds", TIMEOUT)
if count == 3:
break
logging.debug("Retrying %s", count)
logging.debug("Timed out after waiting %s seconds for response", TIMEOUT * count)
raise TimeoutError("Request timed out")
def read(self, size=BUFSIZE, blocking=True) -> bytes:
if blocking:
return self.file.read(size)
return self.file.read1(size)
def read_line(self):
return str(self.read_bytes_line(), FORMAT)
def read_bytes_line(self) -> bytes:
line = self.file.readline(MAXLINE + 1)
if len(line) > MAXLINE:
raise InvalidResponse("Line too long")
return line
class HTTPException(Exception):
""" Base class for HTTP exceptions """
class InvalidResponse(HTTPException):
""" Response message cannot be parsed """
def __init(self, message):
self.message = message

160
httplib/parser.py Normal file
View File

@@ -0,0 +1,160 @@
import logging
import re
from urllib.parse import urlparse
from httplib.exceptions import InvalidStatusLine, InvalidResponse, BadRequest
from httplib.httpsocket import HTTPSocket
def _get_start_line(client: HTTPSocket):
line = client.read_line()
split = list(filter(None, line.split(" ")))
if len(split) < 3:
raise InvalidStatusLine(line) # TODO fix exception
return line, split
def _is_valid_http_version(http_version: str):
if len(http_version) < 8 or http_version[4] != "/":
return False
(name, version) = http_version[:4], http_version[5:]
if name != "HTTP" or not re.match(r"1\.[0|1]", version):
return False
def get_status_line(client: HTTPSocket):
line, (http_version, status, reason) = _get_start_line(client)
if not _is_valid_http_version(http_version):
raise InvalidStatusLine(line)
version = http_version[:4]
if not re.match(r"\d{3}", status):
raise InvalidStatusLine(line)
status = int(status)
if status < 100 or status > 999:
raise InvalidStatusLine(line)
return version, status, reason
def parse_request_line(client: HTTPSocket):
line, (method, target, version) = _get_start_line(client)
if method not in ("CONNECT", "DELETE", "GET", "HEAD", "OPTIONS", "POST", "PUT", "TRACE"):
raise BadRequest()
if not _is_valid_http_version(version):
raise BadRequest()
if len(target) == "":
raise BadRequest()
parsed_target = urlparse(target)
return method, parsed_target, version
def retrieve_headers(client: HTTPSocket):
raw_headers = []
# first header after the status-line may not contain a space
while True:
line = client.read_line()
if line[0].isspace():
continue
else:
break
while True:
if line in ("\r\n", "\n", " "):
break
if line[0].isspace():
raw_headers[-1] = raw_headers[-1].rstrip("\r\n")
raw_headers.append(line.lstrip())
line = client.read_line()
result = []
header_str = "".join(raw_headers)
for line in header_str.splitlines():
pos = line.find(":")
if pos <= 0 or pos >= len(line) - 1:
continue
(header, value) = line.split(":", 1)
result.append((header.lower(), value.lower()))
return result
def parse_request_headers(client: HTTPSocket):
raw_headers = retrieve_headers(client)
headers = {}
key: str
for (key, value) in raw_headers:
if any((c.isspace()) for c in key):
raise BadRequest()
if key == "content-length":
if key in headers:
logging.error("Multiple content-length headers specified")
raise BadRequest()
if not value.isnumeric() or int(value) <= 0:
logging.error("Invalid content-length value: %r", value)
raise BadRequest()
elif key == "host":
if value != client.host or key in headers:
raise BadRequest()
headers[key] = value
return headers
def get_headers(client: HTTPSocket):
headers = []
# first header after the status-line may not contain a space
while True:
line = client.read_line()
if line[0].isspace():
continue
else:
break
while True:
if line in ("\r\n", "\n", " "):
break
if line[0].isspace():
headers[-1] = headers[-1].rstrip("\r\n")
headers.append(line.lstrip())
line = client.read_line()
result = {}
header_str = "".join(headers)
for line in header_str.splitlines():
pos = line.find(":")
if pos <= 0 or pos >= len(line) - 1:
continue
(header, value) = map(str.strip, line.split(":", 1))
check_next_header(result, header, value)
result[header.lower()] = value.lower()
return result
def check_next_header(headers, next_header: str, next_value: str):
if next_header == "content-length":
if "content-length" in headers:
logging.error("Multiple content-length headers specified")
raise InvalidResponse()
if not next_value.isnumeric() or int(next_value) <= 0:
logging.error("Invalid content-length value: %r", next_value)
raise InvalidResponse()

View File

@@ -2,13 +2,14 @@ import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict from typing import Dict
from client.httpclient import HTTPClient, BUFSIZE, IncompleteResponse, InvalidResponse, UnsupportedEncoding from httplib.exceptions import IncompleteResponse, InvalidResponse, UnsupportedEncoding
from httplib.httpsocket import HTTPSocket, BUFSIZE
class Retriever(ABC): class Retriever(ABC):
client: HTTPClient client: HTTPSocket
def __init__(self, client: HTTPClient): def __init__(self, client: HTTPSocket):
self.client = client self.client = client
@abstractmethod @abstractmethod
@@ -16,7 +17,7 @@ class Retriever(ABC):
pass pass
@staticmethod @staticmethod
def create(client: HTTPClient, headers: Dict[str, str]): def create(client: HTTPSocket, headers: Dict[str, str]):
# only chunked transfer-encoding is supported # only chunked transfer-encoding is supported
transfer_encoding = headers.get("transfer-encoding") transfer_encoding = headers.get("transfer-encoding")
@@ -44,7 +45,7 @@ class Retriever(ABC):
class ContentLengthRetriever(Retriever): class ContentLengthRetriever(Retriever):
length: int length: int
def __init__(self, client: HTTPClient, length: int): def __init__(self, client: HTTPSocket, length: int):
super().__init__(client) super().__init__(client)
self.length = length self.length = length

53
public/index.html Normal file
View File

@@ -0,0 +1,53 @@
<!doctype html>
<html>
<head>
<title>Computer Networks example</title>
<meta charset="utf-8" />
<meta http-equiv="Content-type" content="text/html; charset=utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<style type="text/css">
body {
background-color: #f0f0f2;
margin: 0;
padding: 0;
font-family: -apple-system, system-ui, BlinkMacSystemFont, "Segoe UI", "Open Sans", "Helvetica Neue", Helvetica, Arial, sans-serif;
}
div {
width: 600px;
margin: 5em auto;
padding: 2em;
background-color: #fdfdff;
border-radius: 0.5em;
box-shadow: 2px 3px 7px 2px rgba(0,0,0,0.02);
}
a:link, a:visited {
color: #38488f;
text-decoration: none;
}
@media (max-width: 700px) {
div {
margin: 0 auto;
width: auto;
}
}
</style>
</head>
<body>
<div>
<h1>Example Domain</h1>
<p>This domain is for use in illustrative examples in documents. You may use this
domain in literature without prior coordination or asking for permission.</p>
</div>
<div>
<h2>Remote image</h2>
<img width="200px" src="http://archive.fabacademy.org/2018/labs/fablabrwanda/students/fred-rwema/media/week_4/fsf_logo.png">
</div>
<div>
<h2>Local image</h2>
<img width="200px" src="ulyssis.png">
</body>
</html>

BIN
public/ulyssis.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 21 KiB

111
server.py
View File

@@ -1,39 +1,90 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse
import logging
import multiprocessing
import socket import socket
import sys
# socket heeft een listening and accept method from server.httpserver import HTTPServer
import time
SERVER = "127.0.0.1" #dynamisch fixen in project
PORT = 5055
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
ADDR = (SERVER, PORT) # hier wordt de socket gebonden aan mijn IP adres, dit moet wel anders def main():
server.bind(ADDR) # in het project gebeuren parser = argparse.ArgumentParser(description='HTTP Server')
parser.add_argument("--verbose", "-v", action='count', default=0, help="Increase verbosity level of logging")
parser.add_argument("--workers", "-w",
help="The amount of worker processes. This is by default based on the number of cpu threads.",
type=int)
parser.add_argument("--port", "-p", help="The port to listen on", default=8000)
arguments = parser.parse_args()
HEADER = 64 # maximum size messages logging_level = logging.ERROR - (10 * arguments.verbose)
FORMAT = 'utf-8' # recieving images through this format does not work logging.basicConfig(level=logging_level)
DISCONNECT_MESSAGE = "DISCONNECT!" # special message for disconnecting client and server logging.debug("Arguments: %s", arguments)
# function for starting server # Set workers
def start(): if arguments.workers:
pass workers = int(arguments.workers)
server.listen() else:
while True: # infinite loop in which server accept incoming connections, we want to run it forever workers = multiprocessing.cpu_count()
conn, addr = server.accept() # Server blocks untill a client connects
print("new connection: ", addr[0], " connected.")
connected = True
while connected: # while client is connected, we want to recieve messages
msg = conn.recv(HEADER).decode(FORMAT).rstrip() # Argument is maximum size of msg (in project look into details of accp), decode is for converting bytes to strings, rstrip is for stripping messages for special hidden characters
print("message: ", msg)
for i in range(0,10):
conn.send(b"test")
time.sleep(1)
break # Set port
print("close connection ", addr[0], " disconnected.") if arguments.port:
conn.close() port = int(arguments.port)
else:
port = 8000
print("server is starting ... ") # Get hostname and address
start() hostname = socket.gethostname()
address = socket.gethostbyname(hostname)
server = HTTPServer(address, port, workers, logging_level)
server.start()
try:
if __name__ == '__main__':
main()
except Exception as e:
print("[ABRT] Internal error: " + str(e), file=sys.stderr)
logging.debug("Internal error", exc_info=e)
sys.exit(70)
# import socket
#
# # Get hostname and address
# hostname = socket.gethostname()
# address = socket.gethostbyname(hostname)
#
# # socket heeft een listening and accept method
#
# SERVER = "127.0.0.1" # dynamisch fixen in project
# PORT = 5055
# server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
#
# ADDR = (SERVER, PORT) # hier wordt de socket gebonden aan mijn IP adres, dit moet wel anders
# server.bind(ADDR) # in het project gebeuren
#
# HEADER = 64 # maximum size messages
# FORMAT = 'utf-8'
# DISCONNECT_MESSAGE = "DISCONNECT!" # special message for disconnecting client and server
#
#
# # function for starting server
# def start():
# pass
# server.listen()
# while True: # infinite loop in which server accept incoming connections, we want to run it forever
# conn, addr = server.accept() # Server blocks untill a client connects
# print("new connection: ", addr[0], " connected.")
# connected = True
# while connected: # while client is connected, we want to recieve messages
# msg = conn.recv(HEADER).decode(
# FORMAT).rstrip() # Argument is maximum size of msg (in project look into details of accp), decode is for converting bytes to strings, rstrip is for stripping messages for special hidden characters
# print("message: ", msg)
# if msg == DISCONNECT_MESSAGE:
# connected = False
# print("close connection ", addr[0], " disconnected.")
# conn.close()
#
#
# print("server is starting ... ")
# start()

57
server/RequestHandler.py Normal file
View File

@@ -0,0 +1,57 @@
import logging
from logging import Logger
from socket import socket
from typing import Union
from urllib.parse import ParseResultBytes, ParseResult
from httplib import parser
from httplib.exceptions import MethodNotAllowed, BadRequest
from httplib.httpsocket import HTTPSocket
from httplib.retriever import Retriever
METHODS = ("GET", "HEAD", "PUT", "POST")
class RequestHandler:
conn: HTTPSocket
logger: Logger
def __init__(self, conn: socket, logger, host):
self.conn = HTTPSocket(conn, host)
self.logger = logger
def listen(self):
self.logger.debug("Parsing request line")
logging.debug("test logger")
(method, target, version) = parser.parse_request_line(self.conn)
headers = parser.parse_request_headers(self.conn)
self._validate_request(method, target, version, headers)
self.logger.debug("Parsed request-line: version: %s, target: %r", method, target)
headers = parser.get_headers(self.conn)
self.logger.debug("Parsed headers: %r", headers)
retriever = Retriever.create(self.conn, headers)
body = retriever.retrieve()
self.logger.debug("body: %r", body)
def _check_request_line(self, method: str, target: Union[ParseResultBytes, ParseResult], version):
if method not in METHODS:
raise MethodNotAllowed(METHODS)
# only origin-form and absolute-form are allowed
if len(target.path) < 1 or target.path[0] != "/" or \
target.netloc not in ("http", "https") and target.hostname == "":
raise BadRequest()
if version not in ("1.0", "1.1"):
raise BadRequest()
def _validate_request(self, method, target, version, headers):
self._check_request_line(method, target, version)
if version == "1.1" and "host" not in headers:
raise BadRequest()

94
server/httpserver.py Normal file
View File

@@ -0,0 +1,94 @@
import logging
import multiprocessing as mp
import socket
import time
from multiprocessing.context import Process
from multiprocessing.queues import Queue
from multiprocessing.synchronize import Event
from server import worker
class HTTPServer:
address: str
port: int
workers = []
worker_count: int
server: socket
_dispatch_queue: Queue
_stop_event: Event
def __init__(self, address: str, port: int, worker_count, logging_level):
self.address = address
self.port = port
self.worker_count = worker_count
self.logging_level = logging_level
mp.set_start_method("spawn")
self._dispatch_queue = mp.Queue()
self._stop_event = mp.Event()
def start(self):
try:
self.__do_start()
except KeyboardInterrupt:
self.__shutdown()
def __do_start(self):
# Create socket
self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.server.bind((self.address, self.port))
self.__create_workers()
self.__listen()
def __listen(self):
self.server.listen()
logging.debug("Listening for connections")
while True:
if self._dispatch_queue.qsize() > self.worker_count:
time.sleep(0.01)
continue
conn, addr = self.server.accept()
logging.info("New connection: %s", addr[0])
self._dispatch_queue.put((conn, addr))
logging.debug("Dispatched connection %s", addr)
def __shutdown(self):
# Set stop event
self._stop_event.set()
# Wake up workers
logging.debug("Waking up workers")
for p in self.workers:
self._dispatch_queue.put((None, None))
logging.debug("Closing dispatch queue")
self._dispatch_queue.close()
logging.debug("Waiting for workers to shutdown")
p: Process
for p in self.workers:
p.join()
p.terminate()
logging.debug("Shutting down socket")
self.server.shutdown(socket.SHUT_RDWR)
self.server.close()
def __create_workers(self):
for i in range(self.worker_count):
logging.debug("Creating worker: %d", i + 1)
p = mp.Process(target=worker.worker,
args=(self.address, i + 1, self.logging_level, self._dispatch_queue, self._stop_event))
p.start()
self.workers.append(p)
time.sleep(0.1)

83
server/worker.py Normal file
View File

@@ -0,0 +1,83 @@
import logging
import multiprocessing
import multiprocessing as mp
import threading
from concurrent.futures import ThreadPoolExecutor
from logging import Logger
from socket import socket
from server.RequestHandler import RequestHandler
THREAD_LIMIT = 20
def worker(address, name, log_level, queue: mp.Queue, stop_event: mp.Event):
logging.basicConfig(level=log_level)
logger = multiprocessing.log_to_stderr(level=log_level)
runner = Worker(address, name, logger, queue, stop_event)
runner.logger.debug("Worker %s started", name)
try:
runner.run()
except KeyboardInterrupt:
logger.debug("Ctrl+C pressed, terminating")
runner.shutdown()
class Worker:
host: str
name: str
logger: Logger
queue: mp.Queue
executor: ThreadPoolExecutor
stop_event: mp.Event
finished_queue: mp.Queue
def __init__(self, host, name, logger, queue: mp.Queue, stop_event: mp.Event):
self.host = host
self.name = name
self.logger = logger
self.queue = queue
self.executor = ThreadPoolExecutor(THREAD_LIMIT)
self.stop_event = stop_event
self.finished_queue = mp.Queue()
for i in range(THREAD_LIMIT):
self.finished_queue.put(i)
def run(self):
while not self.stop_event.is_set():
# Blocks until thread is free
self.finished_queue.get()
# Blocks until new client connects
conn, addr = self.queue.get()
if conn is None or addr is None:
break
self.logger.debug("Received new client: %s", addr)
# submit client to thread
print(threading.get_ident())
self.executor.submit(self._handle_client, conn, addr)
self.shutdown()
def _handle_client(self, conn: socket, addr):
try:
self.logger.debug("Handling client: %s", addr)
handler = RequestHandler(conn, self.logger, self.host)
handler.listen()
except Exception as e:
self.logger.debug("Internal error", exc_info=e)
# Finished, put back into queue
self.finished_queue.put(threading.get_ident())
def shutdown(self):
self.logger.info("shutting down")
self.executor.shutdown()

4
server_flow.md Normal file
View File

@@ -0,0 +1,4 @@
# Flow
- listen
- dispatch asap
- throw error if too full