import logging import os import sys from socket import socket from typing import Union from urllib.parse import ParseResultBytes, ParseResult from httplib import parser from httplib.exceptions import MethodNotAllowed, BadRequest, UnsupportedEncoding, NotImplemented, NotFound, \ HTTPVersionNotSupported from httplib.httpsocket import HTTPSocket, FORMAT from httplib.message import RequestMessage as Message from httplib.retriever import Retriever, PreambleRetriever from server import command from server.serversocket import ServerSocket METHODS = ("GET", "HEAD", "PUT", "POST") class RequestHandler: """ Processes incoming HTTP request messages. """ conn: HTTPSocket def __init__(self, conn: socket, host): self.conn = ServerSocket(conn, host) def listen(self): retriever = PreambleRetriever(self.conn) while True: line = self.conn.read_line() if line in ("\r\n", "\r", "\n"): continue retriever.reset_buffer(line) self._handle_message(retriever, line) def _handle_message(self, retriever, line): lines = retriever.retrieve() # Parse the request-line and headers (method, target, version) = parser.parse_request_line(line) headers = parser.parse_headers(lines) # Create the response message object message = Message(version, method, target, headers, retriever.buffer) logging.debug("---request begin---\r\n%s---request end---", "".join(message.raw)) # validate if the request is valid self._validate_request(message) # The body (if available) hasn't been retrieved up till now. body = b"" if self._has_body(headers): try: retriever = Retriever.create(self.conn, headers) except UnsupportedEncoding as e: logging.error("Encoding not supported: %s=%s", e.enc_type, e.encoding) raise NotImplemented(f"{e.enc_type}={e.encoding}") for buffer in retriever.retrieve(): body += buffer message.body = body # message completed cmd = command.create(message) msg = cmd.execute() logging.debug("---response begin---\r\n%s\r\n---response end---", msg.split(b"\r\n\r\n", 1)[0].decode(FORMAT)) # Send the response message self.conn.conn.sendall(msg) def _check_request_line(self, method: str, target: Union[ParseResultBytes, ParseResult], version): """ Checks if the request-line is valid. Throws an appriopriate exception if not. @param method: HTTP request method @param target: The request target @param version: The HTTP version @raise MethodNotAllowed: if the method is not any of the allowed methods in `METHODS` @raise HTTPVersionNotSupported: If the HTTP version is not supported by this server @raise BadRequest: If the scheme of the target is not supported @raise NotFound: If the target is not found on this server """ if method not in METHODS: raise MethodNotAllowed(METHODS) if version not in ("1.0", "1.1"): raise HTTPVersionNotSupported(version) # only origin-form and absolute-form are allowed if target.scheme not in ("", "http"): # Only http is supported... raise BadRequest(f"scheme={target.scheme}") if target.netloc != "" and target.netloc != self.conn.host and target.netloc != self.conn.host.split(":")[0]: raise NotFound(str(target)) if target.path == "" or target.path[0] != "/": raise NotFound(str(target)) def _validate_request(self, msg): """ Validates the message request-line and headers. Throws an error if the message is invalid. @see: _check_request_line for exceptions raised when validating the request-line. @param msg: the message to validate @raise BadRequest: if HTTP 1.1 and the Host header is missing """ if msg.version == "1.1" and "host" not in msg.headers: raise BadRequest("Missing host header") self._check_request_line(msg.method, msg.target, msg.version) def _has_body(self, headers): """ Check if the headers notify the existing of a message body. @param headers: the headers to check @return: True if the message has a body. False otherwise. """ if "transfer-encoding" in headers: return True if "content-length" in headers and int(headers["content-length"]) > 0: return True return False @staticmethod def send_error(client: socket, code, message): message = f"HTTP/1.1 {code} {message}\r\n" message += parser.get_date() + "\r\n" message += "Content-Length: 0\r\n" message += "\r\n" logging.debug("---response begin---\r\n%s---response end---", message) client.sendall(message.encode(FORMAT))