This commit is contained in:
2021-03-24 16:35:12 +01:00
parent 9ba7a030a7
commit d14252f707
10 changed files with 325 additions and 185 deletions

View File

@@ -1,6 +1,7 @@
import logging
import socket
from io import BufferedReader
from typing import Tuple
BUFSIZE = 4096
TIMEOUT = 3
@@ -11,7 +12,7 @@ MAXLINE = 4096
class HTTPSocket:
host: str
conn: socket.socket
file: BufferedReader
file: Tuple[BufferedReader, None]
def __init__(self, conn: socket.socket, host: str):
@@ -24,8 +25,12 @@ class HTTPSocket:
def close(self):
self.file.close()
# self.conn.shutdown(socket.SHUT_RDWR)
self.conn.close()
def is_closed(self):
return self.file is None
def reset_request(self):
self.file.close()
self.file = self.conn.makefile("rb")

16
httplib/message.py Normal file
View File

@@ -0,0 +1,16 @@
from typing import Dict
class Message:
version: str
status: int
msg: str
headers: Dict[str, str]
body: bytes
def __init__(self, version: str, status: int, msg: str, headers: Dict[str, str], body: bytes = None):
self.version = version
self.status = status
self.msg = msg
self.headers = headers
self.body = body

View File

@@ -1,6 +1,6 @@
import logging
import re
from urllib.parse import urlparse
from urllib.parse import urlparse, urlsplit
from httplib.exceptions import InvalidStatusLine, InvalidResponse, BadRequest
from httplib.httpsocket import HTTPSocket
@@ -42,6 +42,26 @@ def get_status_line(client: HTTPSocket):
return version, status, reason
def parse_status_line(line: str):
split = list(filter(None, line.strip().split(" ", 2)))
if len(split) < 3:
raise InvalidStatusLine(line) # TODO fix exception
(http_version, status, reason) = split
if not _is_valid_http_version(http_version):
raise InvalidStatusLine(line)
version = http_version[:4]
if not re.match(r"\d{3}", status):
raise InvalidStatusLine(line)
status = int(status)
if status < 100 or status > 999:
raise InvalidStatusLine(line)
return version, status, reason
def parse_request_line(client: HTTPSocket):
line, (method, target, version) = _get_start_line(client)
@@ -119,7 +139,7 @@ def parse_request_headers(client: HTTPSocket):
raise BadRequest()
headers[key] = value
return headers
@@ -157,6 +177,38 @@ def get_headers(client: HTTPSocket):
return result
def parse_headers(lines):
headers = []
# first header after the status-line may not contain a space
for line in lines:
line = next(lines)
if line[0].isspace():
continue
else:
break
for line in lines:
if line in ("\r\n", "\n", " "):
break
if line[0].isspace():
headers[-1] = headers[-1].rstrip("\r\n")
headers.append(line.lstrip())
result = {}
header_str = "".join(headers)
for line in header_str.splitlines():
pos = line.find(":")
if pos <= 0 or pos >= len(line) - 1:
continue
(header, value) = map(str.strip, line.split(":", 1))
check_next_header(result, header, value)
result[header.lower()] = value.lower()
return result
def check_next_header(headers, next_header: str, next_value: str):
if next_header == "content-length":
@@ -166,3 +218,25 @@ def check_next_header(headers, next_header: str, next_value: str):
if not next_value.isnumeric() or int(next_value) <= 0:
logging.error("Invalid content-length value: %r", next_value)
raise InvalidResponse()
def parse_uri(uri: str):
parsed = urlsplit(uri)
# If there is no netloc, the given string is not a valid URI, so split on /
if parsed.hostname:
host = parsed.hostname
path = parsed.path
if parsed.query != '':
path = f"{path}?{parsed.query}"
else:
(host, path) = uri.split("/", 1)
if ":" in host:
host, port = host.split(":", 1)
elif parsed.scheme == "https":
port = 443
else:
port = 80
return host, port, path

View File

@@ -42,6 +42,28 @@ class Retriever(ABC):
return ContentLengthRetriever(client, int(content_length))
class PreambleRetriever(Retriever):
client: HTTPSocket
buffer: []
def __init__(self, client: HTTPSocket):
super().__init__(client)
self.client = client
self.buffer = []
def retrieve(self):
line = self.client.read_line()
while True:
self.buffer.append(line)
if line in ("\r\n", "\n", " "):
break
yield line
line = self.client.read_line()
class ContentLengthRetriever(Retriever):
length: int
@@ -63,21 +85,16 @@ class ContentLengthRetriever(Retriever):
buffer = self.client.read(remaining)
except TimeoutError:
logging.error("Timed out before receiving complete payload")
self.client.close()
raise IncompleteResponse("Timed out before receiving complete payload")
except ConnectionError:
logging.error("Timed out before receiving complete payload")
self.client.close()
raise IncompleteResponse("Connection closed before receiving complete payload")
logging.debug("Received payload length: %s", len(buffer))
if len(buffer) == 0:
logging.warning("Received payload length %s less than expected %s", cur_payload_size, self.length)
break
cur_payload_size += len(buffer)
logging.debug("Processed payload: %r", cur_payload_size)
yield buffer
return b""
@@ -108,7 +125,6 @@ class ChunkedRetriever(Retriever):
yield buffer
self.client.read_line() # remove CRLF
return b""
def __get_chunk_size(self):
line = self.client.read_line()