client: cleanup

This commit is contained in:
2021-03-21 13:10:57 +01:00
parent d8a5765fd8
commit 638576f471
5 changed files with 77 additions and 128 deletions

View File

@@ -1,6 +1,7 @@
import logging
import os
import re
from abc import ABC, abstractmethod
from typing import Dict
from urllib.parse import urlparse
@@ -10,21 +11,7 @@ from client.Retriever import Retriever
from client.httpclient import HTTPClient, UnsupportedEncoding, FORMAT, InvalidResponse, InvalidStatusLine
def parse_uri(uri: str):
parsed = urlparse(uri)
# If there is no netloc, the url is invalid, so prepend `//` and try again
if parsed.netloc == "":
parsed = urlparse("//" + uri)
host = parsed.netloc
path = parsed.path
if len(path) == 0 or path[0] != '/':
path = "/" + path
return host, path
class ResponseHandler:
class ResponseHandler(ABC):
client: HTTPClient
headers: Dict[str, str]
status_code: int
@@ -38,6 +25,7 @@ class ResponseHandler:
self.retriever = retriever
pass
@abstractmethod
def handle(self):
pass
@@ -133,8 +121,22 @@ class ResponseHandler:
logging.error("Invalid content-length value: %r", next_value)
raise InvalidResponse()
@staticmethod
def parse_uri(uri: str):
parsed = urlparse(uri)
class DownloadHandler(ResponseHandler):
# If there is no netloc, the url is invalid, so prepend `//` and try again
if parsed.netloc == "":
parsed = urlparse("//" + uri)
host = parsed.netloc
path = parsed.path
if len(path) == 0 or path[0] != '/':
path = "/" + path
return host, path
class DownloadHandler(ResponseHandler, ABC):
path: str
def __init__(self, retriever: Retriever, client: HTTPClient, headers: Dict[str, str], url: str, dir=None):
@@ -152,9 +154,6 @@ class DownloadHandler(ResponseHandler):
return HTMLDownloadHandler(retriever, client, headers, url, dir)
return RawDownloadHandler(retriever, client, headers, url, dir)
def handle(self) -> str:
pass
def _create_directory(self):
path = self._get_duplicate_name(os.path.abspath(self.client.host))
os.mkdir(path)
@@ -248,7 +247,7 @@ class HTMLDownloadHandler(DownloadHandler):
def __download_images(self, tmp_filename, target_filename):
(host, path) = parse_uri(self.url)
(host, path) = ResponseHandler.parse_uri(self.url)
with open(tmp_filename, "rb") as fp:
soup = BeautifulSoup(fp, 'html.parser')
@@ -286,7 +285,7 @@ class HTMLDownloadHandler(DownloadHandler):
img_path = parsed.path
else:
same_host = False
(img_host, img_path) = parse_uri(img_src)
(img_host, img_path) = ResponseHandler.parse_uri(img_src)
message = "GET {path} HTTP/1.1\r\n".format(path=img_path)
message += "Accept: */*\r\nAccept-Encoding: identity\r\n"

View File

@@ -1,16 +1,17 @@
import logging
from abc import ABC, abstractmethod
from typing import Dict
from client.httpclient import HTTPClient, BUFSIZE, IncompleteResponse, InvalidResponse, UnsupportedEncoding
class Retriever:
class Retriever(ABC):
client: HTTPClient
headers: Dict[str, str]
def __init__(self, client: HTTPClient):
self.client = client
@abstractmethod
def retrieve(self):
pass
@@ -95,7 +96,7 @@ class ChunkedRetriever(Retriever):
def retrieve(self):
while True:
chunk_size = self._get_chunk_size()
chunk_size = self.__get_chunk_size()
logging.debug("chunk-size: %s", chunk_size)
if chunk_size == 0:
self.client.reset_request()
@@ -108,7 +109,7 @@ class ChunkedRetriever(Retriever):
self.client.read_line() # remove CRLF
return b""
def _get_chunk_size(self):
def __get_chunk_size(self):
line = self.client.read_line()
sep_pos = line.find(";")
if sep_pos >= 0:

View File

@@ -1,17 +1,22 @@
import logging
from abc import ABC, abstractmethod
from urllib.parse import urlparse
from client.ResponseHandler import ResponseHandler
from client.httpclient import FORMAT, HTTPClient, InvalidResponse, InvalidStatusLine, UnsupportedEncoding
class Command:
command: str
class AbstractCommand(ABC):
def __init__(self, url: str, port: str):
self.url = url
self.port = port
@property
@abstractmethod
def command(self):
pass
@staticmethod
def create(command: str, url: str, port: str):
if command == "GET":
@@ -56,8 +61,12 @@ class Command:
finally:
client.close()
def _await_response(self, client: HTTPClient):
pass
def _await_response(self, client):
while True:
line = client.read_line()
print(line, end="")
if line in ("\r\n", "\n", ""):
break
def _build_message(self, message: str) -> bytes:
return (message + "\r\n").encode(FORMAT)
@@ -81,35 +90,10 @@ class Command:
return host, path
class HeadCommand(Command):
command = "HEAD"
def _await_response(self, client):
while True:
line = client.read_line()
print(line, end="")
if line in ("\r\n", "\n", ""):
break
class GetCommand(Command):
command = "GET"
def _await_response(self, client):
(version, status, msg) = ResponseHandler.get_status_line(client)
logging.debug("Parsed status-line: version: %s, status: %s", version, status)
headers = ResponseHandler.get_headers(client)
logging.debug("Parsed headers: %r", headers)
handler = ResponseHandler.create(client, headers, status, self.url)
handler.handle()
class PostCommand(HeadCommand):
command = "POST"
class AbstractWithBodyCommand(AbstractCommand, ABC):
def _build_message(self, message: str) -> bytes:
body = input("Enter POST data: ").encode(FORMAT)
body = input(f"Enter {self.command} data: ").encode(FORMAT)
print()
message += "Content-Type: text/plain\r\n"
@@ -122,5 +106,34 @@ class PostCommand(HeadCommand):
return message
class PutCommand(PostCommand):
command = "PUT"
class HeadCommand(AbstractCommand):
@property
def command(self):
return "HEAD"
class GetCommand(AbstractCommand):
@property
def command(self):
return "GET"
def _await_response(self, client):
(version, status, msg) = ResponseHandler.get_status_line(client)
logging.debug("Parsed status-line: version: %s, status: %s", version, status)
headers = ResponseHandler.get_headers(client)
logging.debug("Parsed headers: %r", headers)
handler = ResponseHandler.create(client, headers, status, self.url)
handler.handle()
class PostCommand(AbstractWithBodyCommand):
@property
def command(self):
return "POST"
class PutCommand(AbstractWithBodyCommand):
@property
def command(self):
return "PUT"

View File

@@ -1,8 +1,6 @@
import logging
import re
import socket
from io import BufferedReader
from typing import TextIO, IO
BUFSIZE = 4096
TIMEOUT = 3
@@ -31,7 +29,7 @@ class HTTPClient(socket.socket):
self.file.close()
self.file = self.makefile("rb")
def _do_receive(self):
def __do_receive(self):
if self.fileno() == -1:
raise Exception("Connection closed")
@@ -45,7 +43,7 @@ class HTTPClient(socket.socket):
while True:
count += 1
try:
return self._do_receive()
return self.__do_receive()
except socket.timeout:
logging.debug("Socket receive timed out after %s seconds", TIMEOUT)
if count == 3:
@@ -75,69 +73,6 @@ class HTTPClient(socket.socket):
return line
def validate_status_line(self, status_line: str):
split = list(filter(None, status_line.split(" ")))
if len(split) < 3:
return False
# Check HTTP version
http_version = split.pop(0)
if len(http_version) < 8 or http_version[4] != "/":
raise InvalidStatusLine(status_line)
(name, version) = http_version[:4], http_version[5:]
if name != "HTTP" or not re.match(r"1\.[0|1]", version):
return False
if not re.match(r"\d{3}", split[0]):
return False
return True
def get_crlf_chunk(self, buffer: bytes):
"""Finds the line break type (`CRLF` or `LF`) and splits the specified buffer
when encountering 2 consecutive linebreaks.
Returns a tuple with the first part and the remaining of the buffer.
:param buffer:
:return:
"""
lf_pos = buffer.find(b"\n\n")
crlf_pos = buffer.find(b"\r\n\r\n")
if lf_pos != -1 and lf_pos < crlf_pos:
split_start = lf_pos
split_end = lf_pos + 2
else:
split_start = crlf_pos
split_end = crlf_pos + 4
return buffer[:split_start], buffer[split_end:]
def parse_headers(self, data: bytes):
headers = {}
# decode bytes, split into lines and filter
header_split = list(
filter(lambda l: l is not "" and not l[0].isspace(), map(str.strip, data.decode("utf-8").split("\n"))))
if len(header_split) == 0:
raise InvalidResponse(data)
start_line = header_split.pop(0)
logging.debug("start-line: %r", start_line)
for line in header_split:
pos = line.find(":")
if pos <= 0 or pos >= len(line) - 1:
continue
(header, value) = map(str.strip, line.split(":", 1))
headers[header.lower()] = value.lower()
logging.debug("Parsed headers: %r", headers)
return start_line, headers
class HTTPException(Exception):
""" Base class for HTTP exceptions """
@@ -164,6 +99,7 @@ class UnsupportedEncoding(HTTPException):
self.enc_type = enc_type
self.encoding = encoding
class IncompleteResponse(HTTPException):
def __init(self, cause):
self.cause = cause
self.cause = cause