Update
This commit is contained in:
122
server/requesthandler.py
Normal file
122
server/requesthandler.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
from socket import socket
|
||||
from time import mktime
|
||||
from typing import Union
|
||||
from urllib.parse import ParseResultBytes, ParseResult
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
from httplib import parser
|
||||
from httplib.exceptions import MethodNotAllowed, BadRequest, UnsupportedEncoding, NotImplemented, NotFound, \
|
||||
HTTPVersionNotSupported
|
||||
from httplib.httpsocket import HTTPSocket, FORMAT
|
||||
from httplib.message import ServerMessage as Message
|
||||
from httplib.retriever import Retriever, PreambleRetriever
|
||||
from server import command
|
||||
|
||||
METHODS = ("GET", "HEAD", "PUT", "POST")
|
||||
|
||||
|
||||
class RequestHandler:
|
||||
conn: HTTPSocket
|
||||
root = os.path.join(os.path.dirname(sys.argv[0]), "public")
|
||||
|
||||
def __init__(self, conn: socket, host):
|
||||
self.conn = HTTPSocket(conn, host)
|
||||
|
||||
def listen(self):
|
||||
|
||||
retriever = PreambleRetriever(self.conn)
|
||||
|
||||
while True:
|
||||
line = self.conn.read_line()
|
||||
|
||||
if line in ("\r\n", "\r", "\n"):
|
||||
continue
|
||||
|
||||
retriever.reset_buffer(line)
|
||||
self._handle_message(retriever, line)
|
||||
|
||||
def _handle_message(self, retriever, line):
|
||||
lines = retriever.retrieve()
|
||||
(method, target, version) = parser.parse_request_line(line)
|
||||
headers = parser.parse_headers(lines)
|
||||
|
||||
message = Message(version, method, target, headers, retriever.buffer)
|
||||
|
||||
logging.debug("---request begin---\r\n%s---request end---", "".join(message.raw))
|
||||
|
||||
self._validate_request(message)
|
||||
|
||||
body = b""
|
||||
if self._has_body(headers):
|
||||
try:
|
||||
retriever = Retriever.create(self.conn, headers)
|
||||
except UnsupportedEncoding as e:
|
||||
logging.error("Encoding not supported: %s=%s", e.enc_type, e.encoding)
|
||||
raise NotImplemented()
|
||||
|
||||
for buffer in retriever.retrieve():
|
||||
body += buffer
|
||||
|
||||
message.body = body
|
||||
|
||||
# completed message
|
||||
|
||||
cmd = command.create(message)
|
||||
msg = cmd.execute()
|
||||
self.conn.conn.sendall(msg)
|
||||
|
||||
def _check_request_line(self, method: str, target: Union[ParseResultBytes, ParseResult], version):
|
||||
|
||||
if method not in METHODS:
|
||||
raise MethodNotAllowed(METHODS)
|
||||
|
||||
if version not in ("1.0", "1.1"):
|
||||
raise HTTPVersionNotSupported()
|
||||
|
||||
# only origin-form and absolute-form are allowed
|
||||
if target.scheme not in ("", "http"):
|
||||
# Only http is supported...
|
||||
raise BadRequest()
|
||||
|
||||
if target.netloc != "" and target.netloc != self.conn.host and target.netloc != self.conn.host.split(":")[0]:
|
||||
raise NotFound()
|
||||
|
||||
if target.path == "" or target.path[0] != "/":
|
||||
raise NotFound()
|
||||
|
||||
def _validate_request(self, msg):
|
||||
if msg.version == "1.1" and "host" not in msg.headers:
|
||||
raise BadRequest()
|
||||
|
||||
self._check_request_line(msg.method, msg.target, msg.version)
|
||||
|
||||
def _has_body(self, headers):
|
||||
|
||||
if "transfer-encoding" in headers:
|
||||
return True
|
||||
|
||||
if "content-length" in headers and int(headers["content-length"]) > 0:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _get_date():
|
||||
now = datetime.now()
|
||||
stamp = mktime(now.timetuple())
|
||||
return format_date_time(stamp)
|
||||
|
||||
@staticmethod
|
||||
def send_error(client: socket, code, message):
|
||||
message = f"HTTP/1.1 {code} {message}\r\n"
|
||||
message += RequestHandler._get_date() + "\r\n"
|
||||
message += "Content-Length: 0\r\n"
|
||||
message += "\r\n"
|
||||
|
||||
logging.debug("Sending: %r", message)
|
||||
client.sendall(message.encode(FORMAT))
|
Reference in New Issue
Block a user