This commit is contained in:
2021-03-26 18:25:03 +01:00
parent 7476870acc
commit fdbd865889
11 changed files with 297 additions and 136 deletions

View File

@@ -1,25 +1,39 @@
import logging
import mimetypes
import os
import sys
from abc import ABC, abstractmethod
from typing import Dict, Tuple
from urllib.parse import urlparse
from datetime import datetime
from time import mktime
from typing import Dict
from wsgiref.handlers import format_date_time
from client.httpclient import FORMAT, HTTPClient
from httplib import parser
from httplib.exceptions import InvalidResponse, InvalidStatusLine, UnsupportedEncoding
from httplib.message import Message
from httplib.retriever import PreambleRetriever
from client.httpclient import FORMAT
from httplib.exceptions import NotFound, Conflict, Forbidden
from httplib.message import ServerMessage as Message
root = os.path.join(os.path.dirname(sys.argv[0]), "public")
status_message = {
200: "OK",
201: "Created",
202: "Accepted",
304: "Not Modified",
400: "Bad Request",
404: "Not Found",
500: "Internal Server Error",
}
def create(method: str, message: Message):
if method == "GET":
return GetCommand(url, port)
elif method == "HEAD":
return HeadCommand(url, port)
elif method == "POST":
return PostCommand(url, port)
elif method == "PUT":
return PutCommand(url, port)
def create(message: Message):
if message.method == "GET":
return GetCommand(message)
elif message.method == "HEAD":
return HeadCommand(message)
elif message.method == "POST":
return PostCommand(message)
elif message.method == "PUT":
return PutCommand(message)
else:
raise ValueError()
@@ -27,8 +41,10 @@ def create(method: str, message: Message):
class AbstractCommand(ABC):
path: str
headers: Dict[str, str]
msg: Message
def __init(self):
def __init__(self, message: Message):
self.msg = message
pass
@property
@@ -36,63 +52,133 @@ class AbstractCommand(ABC):
def command(self):
pass
def _get_date(self):
now = datetime.now()
stamp = mktime(now.timetuple())
return format_date_time(stamp)
class AbstractWithBodyCommand(AbstractCommand, ABC):
@abstractmethod
def execute(self):
pass
def _build_message(self, message: str) -> bytes:
body = input(f"Enter {self.command} data: ").encode(FORMAT)
print()
def _build_message(self, status: int, content_type: str, body: bytes):
message = f"HTTP/1.1 {status} {status_message[status]}\r\n"
message += self._get_date() + "\r\n"
content_length = len(body)
message += f"Content-Length: {content_length}\r\n"
if content_type:
message += f"Content-Type: {content_type}"
if content_type.startswith("text"):
message += "; charset=UTF-8"
message += "\r\n"
elif content_length > 0:
message += f"Content-Type: application/octet-stream"
message += "Content-Type: text/plain\r\n"
message += f"Content-Length: {len(body)}\r\n"
message += "\r\n"
message = message.encode(FORMAT)
message += body
message += b"\r\n"
if content_length > 0:
message += body
message += b"\r\n"
return message
def _get_path(self, check=True):
norm_path = os.path.normpath(self.msg.target.path)
if norm_path == "/":
path = root + "/index.html"
else:
path = root + norm_path
if check and not os.path.exists(path):
raise NotFound()
return path
class AbstractModifyCommand(AbstractCommand, ABC):
@property
@abstractmethod
def _file_mode(self):
pass
def execute(self):
path = self._get_path(False)
dir = os.path.dirname(path)
if not os.path.exists(dir):
raise Forbidden("Target directory does not exists!")
if os.path.exists(dir) and not os.path.isdir(dir):
raise Forbidden("Target directory is an existing file!")
try:
with open(path, mode=f"{self._file_mode}b") as file:
file.write(self.msg.body)
except IsADirectoryError:
raise Forbidden("The target resource is a directory!")
class HeadCommand(AbstractCommand):
def execute(self):
path = self._get_path()
mime = mimetypes.guess_type(path)[0]
return self._build_message(200, mime, b"")
@property
def command(self):
return "HEAD"
class GetCommand(AbstractCommand):
def __init__(self, uri: str, port, dir=None):
super().__init__(uri, port)
self.dir = dir
self.filename = None
@property
def command(self):
return "GET"
def _get_preamble(self, retriever):
lines = retriever.retrieve()
(version, status, msg) = parser.parse_status_line(next(lines))
headers = parser.parse_headers(lines)
def get_mimetype(self, path):
mime = mimetypes.guess_type(path)[0]
logging.debug("---response begin---\r\n%s--- response end---", "".join(retriever.buffer))
if mime:
return mime
return Message(version, status, msg, headers, retriever.buffer)
try:
file = open(path, "r", encoding="utf-8")
file.readline()
file.close()
return "text/plain"
except UnicodeDecodeError:
return "application/octet-stream"
def _await_response(self, client, retriever):
msg = self._get_preamble(retriever)
def execute(self):
path = self._get_path()
mime = self.get_mimetype(path)
from client import response_handler
self.filename = response_handler.handle(client, msg, self, self.dir)
file = open(path, "rb")
buffer = file.read()
file.close()
return self._build_message(200, mime, buffer)
class PostCommand(AbstractWithBodyCommand):
class PostCommand(AbstractModifyCommand):
@property
def command(self):
return "POST"
@property
def _file_mode(self):
return "a"
class PutCommand(AbstractWithBodyCommand):
class PutCommand(AbstractModifyCommand):
@property
def command(self):
return "PUT"
@property
def _file_mode(self):
return "w"

View File

@@ -37,7 +37,6 @@ class HTTPServer:
def __do_start(self):
# Create socket
self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.server.bind((self.address, self.port))
@@ -56,6 +55,8 @@ class HTTPServer:
continue
conn, addr = self.server.accept()
conn.settimeout(5)
logging.info("New connection: %s", addr[0])
self._dispatch_queue.put((conn, addr))
logging.debug("Dispatched connection %s", addr)

View File

@@ -1,7 +1,7 @@
import logging
import mimetypes
import os
import sys
import time
from datetime import datetime
from socket import socket
from time import mktime
@@ -10,9 +10,12 @@ from urllib.parse import ParseResultBytes, ParseResult
from wsgiref.handlers import format_date_time
from httplib import parser
from httplib.exceptions import MethodNotAllowed, BadRequest, UnsupportedEncoding, NotImplemented, NotFound
from httplib.exceptions import MethodNotAllowed, BadRequest, UnsupportedEncoding, NotImplemented, NotFound, \
HTTPVersionNotSupported
from httplib.httpsocket import HTTPSocket, FORMAT
from httplib.retriever import Retriever
from httplib.message import ServerMessage as Message
from httplib.retriever import Retriever, PreambleRetriever
from server import command
METHODS = ("GET", "HEAD", "PUT", "POST")
@@ -25,13 +28,28 @@ class RequestHandler:
self.conn = HTTPSocket(conn, host)
def listen(self):
logging.debug("Parsing request line")
(method, target, version) = parser.parse_request_line(self.conn)
headers = parser.parse_request_headers(self.conn)
self._validate_request(method, target, version, headers)
retriever = PreambleRetriever(self.conn)
logging.debug("Parsed request-line: method: %s, target: %r", method, target)
while True:
line = self.conn.read_line()
if line in ("\r\n", "\r", "\n"):
continue
retriever.reset_buffer(line)
self._handle_message(retriever, line)
def _handle_message(self, retriever, line):
lines = retriever.retrieve()
(method, target, version) = parser.parse_request_line(line)
headers = parser.parse_headers(lines)
message = Message(version, method, target, headers, retriever.buffer)
logging.debug("---request begin---\r\n%s---request end---", "".join(message.raw))
self._validate_request(message)
body = b""
if self._has_body(headers):
@@ -44,8 +62,13 @@ class RequestHandler:
for buffer in retriever.retrieve():
body += buffer
message.body = body
# completed message
self._handle_message(method, target.path, body)
cmd = command.create(message)
msg = cmd.execute()
self.conn.conn.sendall(msg)
def _check_request_line(self, method: str, target: Union[ParseResultBytes, ParseResult], version):
@@ -53,30 +76,34 @@ class RequestHandler:
raise MethodNotAllowed(METHODS)
if version not in ("1.0", "1.1"):
raise BadRequest()
raise HTTPVersionNotSupported()
# only origin-form and absolute-form are allowed
if target.scheme not in ("", "http"):
# Only http is supported...
raise BadRequest()
if target.netloc != "" and target.netloc != self.conn.host and target.netloc != self.conn.host.split(":")[0]:
raise NotFound()
if target.path == "" or target.path[0] != "/":
raise NotFound()
norm_path = os.path.normpath(target.path)
if not os.path.exists(self.root + norm_path):
raise NotFound()
def _validate_request(self, method, target, version, headers):
if version == "1.1" and "host" not in headers:
def _validate_request(self, msg):
if msg.version == "1.1" and "host" not in msg.headers:
raise BadRequest()
self._check_request_line(method, target, version)
self._check_request_line(msg.method, msg.target, msg.version)
def _has_body(self, headers):
return "transfer-encoding" in headers or "content-encoding" in headers
if "transfer-encoding" in headers:
return True
if "content-length" in headers and int(headers["content-length"]) > 0:
return True
return False
@staticmethod
def _get_date():
@@ -84,38 +111,6 @@ class RequestHandler:
stamp = mktime(now.timetuple())
return format_date_time(stamp)
def _handle_message(self, method: str, target, body: bytes):
date = self._get_date()
if method == "GET":
if target == "/":
path = self.root + "/index.html"
else:
path = self.root + target
mime = mimetypes.guess_type(path)[0]
if mime.startswith("text"):
file = open(path, "rb", FORMAT)
else:
file = open(path, "rb")
buffer = file.read()
file.close()
message = "HTTP/1.1 200 OK\r\n"
message += date + "\r\n"
if mime:
message += f"Content-Type: {mime}"
if mime.startswith("text"):
message += "; charset=UTF-8"
message += "\r\n"
message += f"Content-Length: {len(buffer)}\r\n"
message += "\r\n"
message = message.encode(FORMAT)
message += buffer
message += b"\r\n"
logging.debug("Sending: %r", message)
self.conn.conn.sendall(message)
@staticmethod
def send_error(client: socket, code, message):
message = f"HTTP/1.1 {code} {message}\r\n"

View File

@@ -5,7 +5,7 @@ import threading
from concurrent.futures import ThreadPoolExecutor
from httplib.exceptions import HTTPServerException, InternalServerError
from server.RequestHandler import RequestHandler
from server.requesthandler import RequestHandler
THREAD_LIMIT = 128
@@ -62,15 +62,16 @@ class Worker:
def _handle_client(self, conn: socket.socket, addr):
try:
logging.debug("Handling client: %s", addr)
handler = RequestHandler(conn, self.host)
handler.listen()
except HTTPServerException as e:
logging.debug("HTTP Exception:", exc_info=e)
RequestHandler.send_error(conn, e.status_code, e.message)
except socket.timeout:
logging.debug("Socket for client %s timed out", addr)
except Exception as e:
RequestHandler.send_error(conn, InternalServerError.status_code, InternalServerError.message)
logging.debug("Internal error", exc_info=e)
RequestHandler.send_error(conn, InternalServerError.status_code, InternalServerError.message)
conn.shutdown(socket.SHUT_RDWR)
conn.close()