import logging import mimetypes import os import sys from datetime import datetime from socket import socket from time import mktime from typing import Union from urllib.parse import ParseResultBytes, ParseResult from wsgiref.handlers import format_date_time from httplib import parser from httplib.exceptions import MethodNotAllowed, BadRequest, UnsupportedEncoding, NotImplemented, NotFound from httplib.httpsocket import HTTPSocket, FORMAT from httplib.retriever import Retriever METHODS = ("GET", "HEAD", "PUT", "POST") class RequestHandler: conn: HTTPSocket root = os.path.join(os.path.dirname(sys.argv[0]), "public") def __init__(self, conn: socket, host): self.conn = HTTPSocket(conn, host) def listen(self): logging.debug("Parsing request line") (method, target, version) = parser.parse_request_line(self.conn) headers = parser.parse_request_headers(self.conn) self._validate_request(method, target, version, headers) logging.debug("Parsed request-line: method: %s, target: %r", method, target) body = b"" if self._has_body(headers): try: retriever = Retriever.create(self.conn, headers) except UnsupportedEncoding as e: logging.error("Encoding not supported: %s=%s", e.enc_type, e.encoding) raise NotImplemented() for buffer in retriever.retrieve(): body += buffer # completed message self._handle_message(method, target.path, body) def _check_request_line(self, method: str, target: Union[ParseResultBytes, ParseResult], version): if method not in METHODS: raise MethodNotAllowed(METHODS) if version not in ("1.0", "1.1"): raise BadRequest() # only origin-form and absolute-form are allowed if target.scheme not in ("", "http"): # Only http is supported... raise BadRequest() if target.netloc != "" and target.netloc != self.conn.host and target.netloc != self.conn.host.split(":")[0]: raise NotFound() if target.path == "" or target.path[0] != "/": raise NotFound() norm_path = os.path.normpath(target.path) if not os.path.exists(self.root + norm_path): raise NotFound() def _validate_request(self, method, target, version, headers): if version == "1.1" and "host" not in headers: raise BadRequest() self._check_request_line(method, target, version) def _has_body(self, headers): return "transfer-encoding" in headers or "content-encoding" in headers @staticmethod def _get_date(): now = datetime.now() stamp = mktime(now.timetuple()) return format_date_time(stamp) def _handle_message(self, method: str, target, body: bytes): date = self._get_date() if method == "GET": if target == "/": path = self.root + "/index.html" else: path = self.root + target mime = mimetypes.guess_type(path)[0] if mime.startswith("text"): file = open(path, "rb", FORMAT) else: file = open(path, "rb") buffer = file.read() file.close() message = "HTTP/1.1 200 OK\r\n" message += date + "\r\n" if mime: message += f"Content-Type: {mime}" if mime.startswith("text"): message += "; charset=UTF-8" message += "\r\n" message += f"Content-Length: {len(buffer)}\r\n" message += "\r\n" message = message.encode(FORMAT) message += buffer message += b"\r\n" logging.debug("Sending: %r", message) self.conn.conn.sendall(message) @staticmethod def send_error(client: socket, code, message): message = f"HTTP/1.1 {code} {message}\r\n" message += RequestHandler._get_date() + "\r\n" message += "Content-Length: 0\r\n" message += "\r\n" logging.debug("Sending: %r", message) client.sendall(message.encode(FORMAT))