import logging from socket import socket from typing import Union from urllib.parse import ParseResultBytes, ParseResult from httplib import parser from httplib.exceptions import MethodNotAllowed, BadRequest, UnsupportedEncoding, NotImplemented, NotFound, \ HTTPVersionNotSupported from httplib.httpsocket import HTTPSocket, FORMAT from httplib.message import RequestMessage as Message from httplib.retriever import Retriever, PreambleRetriever from server import command from server.serversocket import ServerSocket METHODS = ("GET", "HEAD", "PUT", "POST") class RequestHandler: """ Processes incoming HTTP request messages. """ conn: HTTPSocket def __init__(self, conn: socket, host): self.conn = ServerSocket(conn, host) def listen(self): retriever = PreambleRetriever(self.conn) while True: line = self.conn.read_line() if line in ("\r\n", "\r", "\n"): continue retriever.reset_buffer(line) self._handle_message(retriever, line) def _handle_message(self, retriever, line): lines = retriever.retrieve() # Parse the request-line and headers (method, target, version) = parser.parse_request_line(line) headers = parser.parse_headers(lines) # Create the response message object message = Message(version, method, target, headers, retriever.buffer) logging.debug("---request begin---\r\n%s---request end---", "".join(message.raw)) # validate if the request is valid self._validate_request(message) # The body (if available) hasn't been retrieved up till now. body = b"" if self._has_body(headers): try: retriever = Retriever.create(self.conn, headers) except UnsupportedEncoding as e: logging.error("Encoding not supported: %s=%s", e.enc_type, e.encoding) raise NotImplemented(f"{e.enc_type}={e.encoding}") for buffer in retriever.retrieve(): body += buffer message.body = body # message completed cmd = command.create(message) msg = cmd.execute() logging.debug("---response begin---\r\n%s\r\n---response end---", msg.split(b"\r\n\r\n", 1)[0].decode(FORMAT)) # Send the response message self.conn.conn.sendall(msg) def _check_request_line(self, method: str, target: Union[ParseResultBytes, ParseResult], version): """ Checks if the request-line is valid. Throws an appriopriate exception if not. @param method: HTTP request method @param target: The request target @param version: The HTTP version @raise MethodNotAllowed: if the method is not any of the allowed methods in `METHODS` @raise HTTPVersionNotSupported: If the HTTP version is not supported by this server @raise BadRequest: If the scheme of the target is not supported @raise NotFound: If the target is not found on this server """ if method not in METHODS: raise MethodNotAllowed(METHODS) if version not in ("1.0", "1.1"): raise HTTPVersionNotSupported(version) # only origin-form and absolute-form are allowed if target.scheme not in ("", "http"): # Only http is supported... raise BadRequest(f"scheme={target.scheme}") if target.netloc != "" and target.netloc != self.conn.host and target.netloc != self.conn.host.split(":")[0]: raise NotFound(str(target)) if target.path == "" or target.path[0] != "/": raise NotFound(str(target)) def _validate_request(self, msg): """ Validates the message request-line and headers. Throws an error if the message is invalid. @see: _check_request_line for exceptions raised when validating the request-line. @param msg: the message to validate @raise BadRequest: if HTTP 1.1 and the Host header is missing """ if msg.version == "1.1" and "host" not in msg.headers: raise BadRequest("Missing host header") self._check_request_line(msg.method, msg.target, msg.version) def _has_body(self, headers): """ Check if the headers notify the existing of a message body. @param headers: the headers to check @return: True if the message has a body. False otherwise. """ if "transfer-encoding" in headers: return True if "content-length" in headers and int(headers["content-length"]) > 0: return True return False @staticmethod def send_error(client: socket, code, message): message = f"HTTP/1.1 {code} {message}\r\n" message += parser.get_date() + "\r\n" message += "Content-Length: 0\r\n" message += "\r\n" logging.debug("---response begin---\r\n%s---response end---", message) client.sendall(message.encode(FORMAT))