Files
CN2021/server/requesthandler.py

145 lines
4.9 KiB
Python

import logging
from socket import socket
from typing import Union
from urllib.parse import ParseResultBytes, ParseResult
from httplib import parser
from httplib.exceptions import MethodNotAllowed, BadRequest, UnsupportedEncoding, NotImplemented, NotFound, \
HTTPVersionNotSupported
from httplib.httpsocket import HTTPSocket, FORMAT
from httplib.message import RequestMessage as Message
from httplib.retriever import Retriever, PreambleRetriever
from server import command
from server.serversocket import ServerSocket
METHODS = ("GET", "HEAD", "PUT", "POST")
class RequestHandler:
"""
Processes incoming HTTP request messages.
"""
conn: HTTPSocket
def __init__(self, conn: socket, host):
self.conn = ServerSocket(conn, host)
def listen(self):
retriever = PreambleRetriever(self.conn)
while True:
line = self.conn.read_line()
if line in ("\r\n", "\r", "\n"):
continue
retriever.reset_buffer(line)
self._handle_message(retriever, line)
def _handle_message(self, retriever, line):
lines = retriever.retrieve()
# Parse the request-line and headers
(method, target, version) = parser.parse_request_line(line)
headers = parser.parse_headers(lines)
# Create the response message object
message = Message(version, method, target, headers, retriever.buffer)
logging.debug("---request begin---\r\n%s---request end---", "".join(message.raw))
# validate if the request is valid
self._validate_request(message)
# The body (if available) hasn't been retrieved up till now.
body = b""
if self._has_body(headers):
try:
retriever = Retriever.create(self.conn, headers)
except UnsupportedEncoding as e:
logging.error("Encoding not supported: %s=%s", e.enc_type, e.encoding)
raise NotImplemented(f"{e.enc_type}={e.encoding}")
for buffer in retriever.retrieve():
body += buffer
message.body = body
# message completed
cmd = command.create(message)
msg = cmd.execute()
logging.debug("---response begin---\r\n%s\r\n---response end---", msg.split(b"\r\n\r\n", 1)[0].decode(FORMAT))
# Send the response message
self.conn.conn.sendall(msg)
def _check_request_line(self, method: str, target: Union[ParseResultBytes, ParseResult], version):
"""
Checks if the request-line is valid. Throws an appriopriate exception if not.
@param method: HTTP request method
@param target: The request target
@param version: The HTTP version
@raise MethodNotAllowed: if the method is not any of the allowed methods in `METHODS`
@raise HTTPVersionNotSupported: If the HTTP version is not supported by this server
@raise BadRequest: If the scheme of the target is not supported
@raise NotFound: If the target is not found on this server
"""
if method not in METHODS:
raise MethodNotAllowed(METHODS)
if version not in ("1.0", "1.1"):
raise HTTPVersionNotSupported(version)
# only origin-form and absolute-form are allowed
if target.scheme not in ("", "http"):
# Only http is supported...
raise BadRequest(f"scheme={target.scheme}")
if target.netloc != "" and target.netloc != self.conn.host and target.netloc != self.conn.host.split(":")[0]:
raise NotFound(str(target))
if target.path == "" or target.path[0] != "/":
raise NotFound(str(target))
def _validate_request(self, msg):
"""
Validates the message request-line and headers. Throws an error if the message is invalid.
@see: _check_request_line for exceptions raised when validating the request-line.
@param msg: the message to validate
@raise BadRequest: if HTTP 1.1 and the Host header is missing
"""
if msg.version == "1.1" and "host" not in msg.headers:
raise BadRequest("Missing host header")
self._check_request_line(msg.method, msg.target, msg.version)
def _has_body(self, headers):
"""
Check if the headers notify the existing of a message body.
@param headers: the headers to check
@return: True if the message has a body. False otherwise.
"""
if "transfer-encoding" in headers:
return True
if "content-length" in headers and int(headers["content-length"]) > 0:
return True
return False
@staticmethod
def send_error(client: socket, code, message):
message = f"HTTP/1.1 {code} {message}\r\n"
message += parser.get_date() + "\r\n"
message += "Content-Length: 0\r\n"
message += "\r\n"
logging.debug("---response begin---\r\n%s---response end---", message)
client.sendall(message.encode(FORMAT))