import logging import os import sys from datetime import datetime from socket import socket from time import mktime from typing import Union from urllib.parse import ParseResultBytes, ParseResult from wsgiref.handlers import format_date_time from httplib import parser from httplib.exceptions import MethodNotAllowed, BadRequest, UnsupportedEncoding, NotImplemented, NotFound, \ HTTPVersionNotSupported from httplib.httpsocket import HTTPSocket, FORMAT from httplib.message import ServerMessage as Message from httplib.retriever import Retriever, PreambleRetriever from server import command from server.serversocket import ServerSocket METHODS = ("GET", "HEAD", "PUT", "POST") class RequestHandler: conn: HTTPSocket root = os.path.join(os.path.dirname(sys.argv[0]), "public") def __init__(self, conn: socket, host): self.conn = ServerSocket(conn, host) def listen(self): retriever = PreambleRetriever(self.conn) while True: line = self.conn.read_line() if line in ("\r\n", "\r", "\n"): continue retriever.reset_buffer(line) self._handle_message(retriever, line) def _handle_message(self, retriever, line): lines = retriever.retrieve() (method, target, version) = parser.parse_request_line(line) headers = parser.parse_headers(lines) message = Message(version, method, target, headers, retriever.buffer) logging.debug("---request begin---\r\n%s---request end---", "".join(message.raw)) self._validate_request(message) body = b"" if self._has_body(headers): try: retriever = Retriever.create(self.conn, headers) except UnsupportedEncoding as e: logging.error("Encoding not supported: %s=%s", e.enc_type, e.encoding) raise NotImplemented(f"{e.enc_type}={e.encoding}") for buffer in retriever.retrieve(): body += buffer message.body = body # completed message cmd = command.create(message) msg = cmd.execute() logging.debug("---response begin---\r\n%s\r\n---response end---", msg.split(b"\r\n\r\n", 1)[0].decode(FORMAT)) self.conn.conn.sendall(msg) def _check_request_line(self, method: str, target: Union[ParseResultBytes, ParseResult], version): if method not in METHODS: raise MethodNotAllowed(METHODS) if version not in ("1.0", "1.1"): raise HTTPVersionNotSupported(version) # only origin-form and absolute-form are allowed if target.scheme not in ("", "http"): # Only http is supported... raise BadRequest(f"scheme={target.scheme}") if target.netloc != "" and target.netloc != self.conn.host and target.netloc != self.conn.host.split(":")[0]: raise NotFound(str(target)) if target.path == "" or target.path[0] != "/": raise NotFound(str(target)) def _validate_request(self, msg): if msg.version == "1.1" and "host" not in msg.headers: raise BadRequest("Missing host header") self._check_request_line(msg.method, msg.target, msg.version) def _has_body(self, headers): if "transfer-encoding" in headers: return True if "content-length" in headers and int(headers["content-length"]) > 0: return True return False @staticmethod def _get_date(): now = datetime.now() stamp = mktime(now.timetuple()) return format_date_time(stamp) @staticmethod def send_error(client: socket, code, message): message = f"HTTP/1.1 {code} {message}\r\n" message += RequestHandler._get_date() + "\r\n" message += "Content-Length: 0\r\n" message += "\r\n" logging.debug("---response begin---\r\n%s---response end---", message) client.sendall(message.encode(FORMAT))