client: cleanup

This commit is contained in:
2021-03-21 13:10:57 +01:00
parent d8a5765fd8
commit 638576f471
5 changed files with 77 additions and 128 deletions

View File

@@ -3,7 +3,7 @@ import argparse
import logging import logging
import sys import sys
from client.command import Command from client.command import AbstractCommand
def main(): def main():
@@ -18,7 +18,7 @@ def main():
logging.basicConfig(level=logging.ERROR - (10 * arguments.verbose)) logging.basicConfig(level=logging.ERROR - (10 * arguments.verbose))
logging.debug("Arguments: %s", arguments) logging.debug("Arguments: %s", arguments)
command = Command.create(arguments.command, arguments.URI, arguments.port) command = AbstractCommand.create(arguments.command, arguments.URI, arguments.port)
command.execute() command.execute()

View File

@@ -1,6 +1,7 @@
import logging import logging
import os import os
import re import re
from abc import ABC, abstractmethod
from typing import Dict from typing import Dict
from urllib.parse import urlparse from urllib.parse import urlparse
@@ -10,21 +11,7 @@ from client.Retriever import Retriever
from client.httpclient import HTTPClient, UnsupportedEncoding, FORMAT, InvalidResponse, InvalidStatusLine from client.httpclient import HTTPClient, UnsupportedEncoding, FORMAT, InvalidResponse, InvalidStatusLine
def parse_uri(uri: str): class ResponseHandler(ABC):
parsed = urlparse(uri)
# If there is no netloc, the url is invalid, so prepend `//` and try again
if parsed.netloc == "":
parsed = urlparse("//" + uri)
host = parsed.netloc
path = parsed.path
if len(path) == 0 or path[0] != '/':
path = "/" + path
return host, path
class ResponseHandler:
client: HTTPClient client: HTTPClient
headers: Dict[str, str] headers: Dict[str, str]
status_code: int status_code: int
@@ -38,6 +25,7 @@ class ResponseHandler:
self.retriever = retriever self.retriever = retriever
pass pass
@abstractmethod
def handle(self): def handle(self):
pass pass
@@ -133,8 +121,22 @@ class ResponseHandler:
logging.error("Invalid content-length value: %r", next_value) logging.error("Invalid content-length value: %r", next_value)
raise InvalidResponse() raise InvalidResponse()
@staticmethod
def parse_uri(uri: str):
parsed = urlparse(uri)
class DownloadHandler(ResponseHandler): # If there is no netloc, the url is invalid, so prepend `//` and try again
if parsed.netloc == "":
parsed = urlparse("//" + uri)
host = parsed.netloc
path = parsed.path
if len(path) == 0 or path[0] != '/':
path = "/" + path
return host, path
class DownloadHandler(ResponseHandler, ABC):
path: str path: str
def __init__(self, retriever: Retriever, client: HTTPClient, headers: Dict[str, str], url: str, dir=None): def __init__(self, retriever: Retriever, client: HTTPClient, headers: Dict[str, str], url: str, dir=None):
@@ -152,9 +154,6 @@ class DownloadHandler(ResponseHandler):
return HTMLDownloadHandler(retriever, client, headers, url, dir) return HTMLDownloadHandler(retriever, client, headers, url, dir)
return RawDownloadHandler(retriever, client, headers, url, dir) return RawDownloadHandler(retriever, client, headers, url, dir)
def handle(self) -> str:
pass
def _create_directory(self): def _create_directory(self):
path = self._get_duplicate_name(os.path.abspath(self.client.host)) path = self._get_duplicate_name(os.path.abspath(self.client.host))
os.mkdir(path) os.mkdir(path)
@@ -248,7 +247,7 @@ class HTMLDownloadHandler(DownloadHandler):
def __download_images(self, tmp_filename, target_filename): def __download_images(self, tmp_filename, target_filename):
(host, path) = parse_uri(self.url) (host, path) = ResponseHandler.parse_uri(self.url)
with open(tmp_filename, "rb") as fp: with open(tmp_filename, "rb") as fp:
soup = BeautifulSoup(fp, 'html.parser') soup = BeautifulSoup(fp, 'html.parser')
@@ -286,7 +285,7 @@ class HTMLDownloadHandler(DownloadHandler):
img_path = parsed.path img_path = parsed.path
else: else:
same_host = False same_host = False
(img_host, img_path) = parse_uri(img_src) (img_host, img_path) = ResponseHandler.parse_uri(img_src)
message = "GET {path} HTTP/1.1\r\n".format(path=img_path) message = "GET {path} HTTP/1.1\r\n".format(path=img_path)
message += "Accept: */*\r\nAccept-Encoding: identity\r\n" message += "Accept: */*\r\nAccept-Encoding: identity\r\n"

View File

@@ -1,16 +1,17 @@
import logging import logging
from abc import ABC, abstractmethod
from typing import Dict from typing import Dict
from client.httpclient import HTTPClient, BUFSIZE, IncompleteResponse, InvalidResponse, UnsupportedEncoding from client.httpclient import HTTPClient, BUFSIZE, IncompleteResponse, InvalidResponse, UnsupportedEncoding
class Retriever: class Retriever(ABC):
client: HTTPClient client: HTTPClient
headers: Dict[str, str]
def __init__(self, client: HTTPClient): def __init__(self, client: HTTPClient):
self.client = client self.client = client
@abstractmethod
def retrieve(self): def retrieve(self):
pass pass
@@ -95,7 +96,7 @@ class ChunkedRetriever(Retriever):
def retrieve(self): def retrieve(self):
while True: while True:
chunk_size = self._get_chunk_size() chunk_size = self.__get_chunk_size()
logging.debug("chunk-size: %s", chunk_size) logging.debug("chunk-size: %s", chunk_size)
if chunk_size == 0: if chunk_size == 0:
self.client.reset_request() self.client.reset_request()
@@ -108,7 +109,7 @@ class ChunkedRetriever(Retriever):
self.client.read_line() # remove CRLF self.client.read_line() # remove CRLF
return b"" return b""
def _get_chunk_size(self): def __get_chunk_size(self):
line = self.client.read_line() line = self.client.read_line()
sep_pos = line.find(";") sep_pos = line.find(";")
if sep_pos >= 0: if sep_pos >= 0:

View File

@@ -1,17 +1,22 @@
import logging import logging
from abc import ABC, abstractmethod
from urllib.parse import urlparse from urllib.parse import urlparse
from client.ResponseHandler import ResponseHandler from client.ResponseHandler import ResponseHandler
from client.httpclient import FORMAT, HTTPClient, InvalidResponse, InvalidStatusLine, UnsupportedEncoding from client.httpclient import FORMAT, HTTPClient, InvalidResponse, InvalidStatusLine, UnsupportedEncoding
class Command: class AbstractCommand(ABC):
command: str
def __init__(self, url: str, port: str): def __init__(self, url: str, port: str):
self.url = url self.url = url
self.port = port self.port = port
@property
@abstractmethod
def command(self):
pass
@staticmethod @staticmethod
def create(command: str, url: str, port: str): def create(command: str, url: str, port: str):
if command == "GET": if command == "GET":
@@ -56,8 +61,12 @@ class Command:
finally: finally:
client.close() client.close()
def _await_response(self, client: HTTPClient): def _await_response(self, client):
pass while True:
line = client.read_line()
print(line, end="")
if line in ("\r\n", "\n", ""):
break
def _build_message(self, message: str) -> bytes: def _build_message(self, message: str) -> bytes:
return (message + "\r\n").encode(FORMAT) return (message + "\r\n").encode(FORMAT)
@@ -81,35 +90,10 @@ class Command:
return host, path return host, path
class HeadCommand(Command): class AbstractWithBodyCommand(AbstractCommand, ABC):
command = "HEAD"
def _await_response(self, client):
while True:
line = client.read_line()
print(line, end="")
if line in ("\r\n", "\n", ""):
break
class GetCommand(Command):
command = "GET"
def _await_response(self, client):
(version, status, msg) = ResponseHandler.get_status_line(client)
logging.debug("Parsed status-line: version: %s, status: %s", version, status)
headers = ResponseHandler.get_headers(client)
logging.debug("Parsed headers: %r", headers)
handler = ResponseHandler.create(client, headers, status, self.url)
handler.handle()
class PostCommand(HeadCommand):
command = "POST"
def _build_message(self, message: str) -> bytes: def _build_message(self, message: str) -> bytes:
body = input("Enter POST data: ").encode(FORMAT) body = input(f"Enter {self.command} data: ").encode(FORMAT)
print() print()
message += "Content-Type: text/plain\r\n" message += "Content-Type: text/plain\r\n"
@@ -122,5 +106,34 @@ class PostCommand(HeadCommand):
return message return message
class PutCommand(PostCommand): class HeadCommand(AbstractCommand):
command = "PUT" @property
def command(self):
return "HEAD"
class GetCommand(AbstractCommand):
@property
def command(self):
return "GET"
def _await_response(self, client):
(version, status, msg) = ResponseHandler.get_status_line(client)
logging.debug("Parsed status-line: version: %s, status: %s", version, status)
headers = ResponseHandler.get_headers(client)
logging.debug("Parsed headers: %r", headers)
handler = ResponseHandler.create(client, headers, status, self.url)
handler.handle()
class PostCommand(AbstractWithBodyCommand):
@property
def command(self):
return "POST"
class PutCommand(AbstractWithBodyCommand):
@property
def command(self):
return "PUT"

View File

@@ -1,8 +1,6 @@
import logging import logging
import re
import socket import socket
from io import BufferedReader from io import BufferedReader
from typing import TextIO, IO
BUFSIZE = 4096 BUFSIZE = 4096
TIMEOUT = 3 TIMEOUT = 3
@@ -31,7 +29,7 @@ class HTTPClient(socket.socket):
self.file.close() self.file.close()
self.file = self.makefile("rb") self.file = self.makefile("rb")
def _do_receive(self): def __do_receive(self):
if self.fileno() == -1: if self.fileno() == -1:
raise Exception("Connection closed") raise Exception("Connection closed")
@@ -45,7 +43,7 @@ class HTTPClient(socket.socket):
while True: while True:
count += 1 count += 1
try: try:
return self._do_receive() return self.__do_receive()
except socket.timeout: except socket.timeout:
logging.debug("Socket receive timed out after %s seconds", TIMEOUT) logging.debug("Socket receive timed out after %s seconds", TIMEOUT)
if count == 3: if count == 3:
@@ -75,69 +73,6 @@ class HTTPClient(socket.socket):
return line return line
def validate_status_line(self, status_line: str):
split = list(filter(None, status_line.split(" ")))
if len(split) < 3:
return False
# Check HTTP version
http_version = split.pop(0)
if len(http_version) < 8 or http_version[4] != "/":
raise InvalidStatusLine(status_line)
(name, version) = http_version[:4], http_version[5:]
if name != "HTTP" or not re.match(r"1\.[0|1]", version):
return False
if not re.match(r"\d{3}", split[0]):
return False
return True
def get_crlf_chunk(self, buffer: bytes):
"""Finds the line break type (`CRLF` or `LF`) and splits the specified buffer
when encountering 2 consecutive linebreaks.
Returns a tuple with the first part and the remaining of the buffer.
:param buffer:
:return:
"""
lf_pos = buffer.find(b"\n\n")
crlf_pos = buffer.find(b"\r\n\r\n")
if lf_pos != -1 and lf_pos < crlf_pos:
split_start = lf_pos
split_end = lf_pos + 2
else:
split_start = crlf_pos
split_end = crlf_pos + 4
return buffer[:split_start], buffer[split_end:]
def parse_headers(self, data: bytes):
headers = {}
# decode bytes, split into lines and filter
header_split = list(
filter(lambda l: l is not "" and not l[0].isspace(), map(str.strip, data.decode("utf-8").split("\n"))))
if len(header_split) == 0:
raise InvalidResponse(data)
start_line = header_split.pop(0)
logging.debug("start-line: %r", start_line)
for line in header_split:
pos = line.find(":")
if pos <= 0 or pos >= len(line) - 1:
continue
(header, value) = map(str.strip, line.split(":", 1))
headers[header.lower()] = value.lower()
logging.debug("Parsed headers: %r", headers)
return start_line, headers
class HTTPException(Exception): class HTTPException(Exception):
""" Base class for HTTP exceptions """ """ Base class for HTTP exceptions """
@@ -164,6 +99,7 @@ class UnsupportedEncoding(HTTPException):
self.enc_type = enc_type self.enc_type = enc_type
self.encoding = encoding self.encoding = encoding
class IncompleteResponse(HTTPException): class IncompleteResponse(HTTPException):
def __init(self, cause): def __init(self, cause):
self.cause = cause self.cause = cause