client: cleanup
This commit is contained in:
@@ -3,7 +3,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from client.command import Command
|
from client.command import AbstractCommand
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -18,7 +18,7 @@ def main():
|
|||||||
logging.basicConfig(level=logging.ERROR - (10 * arguments.verbose))
|
logging.basicConfig(level=logging.ERROR - (10 * arguments.verbose))
|
||||||
logging.debug("Arguments: %s", arguments)
|
logging.debug("Arguments: %s", arguments)
|
||||||
|
|
||||||
command = Command.create(arguments.command, arguments.URI, arguments.port)
|
command = AbstractCommand.create(arguments.command, arguments.URI, arguments.port)
|
||||||
command.execute()
|
command.execute()
|
||||||
|
|
||||||
|
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
@@ -10,21 +11,7 @@ from client.Retriever import Retriever
|
|||||||
from client.httpclient import HTTPClient, UnsupportedEncoding, FORMAT, InvalidResponse, InvalidStatusLine
|
from client.httpclient import HTTPClient, UnsupportedEncoding, FORMAT, InvalidResponse, InvalidStatusLine
|
||||||
|
|
||||||
|
|
||||||
def parse_uri(uri: str):
|
class ResponseHandler(ABC):
|
||||||
parsed = urlparse(uri)
|
|
||||||
|
|
||||||
# If there is no netloc, the url is invalid, so prepend `//` and try again
|
|
||||||
if parsed.netloc == "":
|
|
||||||
parsed = urlparse("//" + uri)
|
|
||||||
|
|
||||||
host = parsed.netloc
|
|
||||||
path = parsed.path
|
|
||||||
if len(path) == 0 or path[0] != '/':
|
|
||||||
path = "/" + path
|
|
||||||
return host, path
|
|
||||||
|
|
||||||
|
|
||||||
class ResponseHandler:
|
|
||||||
client: HTTPClient
|
client: HTTPClient
|
||||||
headers: Dict[str, str]
|
headers: Dict[str, str]
|
||||||
status_code: int
|
status_code: int
|
||||||
@@ -38,6 +25,7 @@ class ResponseHandler:
|
|||||||
self.retriever = retriever
|
self.retriever = retriever
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def handle(self):
|
def handle(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -133,8 +121,22 @@ class ResponseHandler:
|
|||||||
logging.error("Invalid content-length value: %r", next_value)
|
logging.error("Invalid content-length value: %r", next_value)
|
||||||
raise InvalidResponse()
|
raise InvalidResponse()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_uri(uri: str):
|
||||||
|
parsed = urlparse(uri)
|
||||||
|
|
||||||
class DownloadHandler(ResponseHandler):
|
# If there is no netloc, the url is invalid, so prepend `//` and try again
|
||||||
|
if parsed.netloc == "":
|
||||||
|
parsed = urlparse("//" + uri)
|
||||||
|
|
||||||
|
host = parsed.netloc
|
||||||
|
path = parsed.path
|
||||||
|
if len(path) == 0 or path[0] != '/':
|
||||||
|
path = "/" + path
|
||||||
|
return host, path
|
||||||
|
|
||||||
|
|
||||||
|
class DownloadHandler(ResponseHandler, ABC):
|
||||||
path: str
|
path: str
|
||||||
|
|
||||||
def __init__(self, retriever: Retriever, client: HTTPClient, headers: Dict[str, str], url: str, dir=None):
|
def __init__(self, retriever: Retriever, client: HTTPClient, headers: Dict[str, str], url: str, dir=None):
|
||||||
@@ -152,9 +154,6 @@ class DownloadHandler(ResponseHandler):
|
|||||||
return HTMLDownloadHandler(retriever, client, headers, url, dir)
|
return HTMLDownloadHandler(retriever, client, headers, url, dir)
|
||||||
return RawDownloadHandler(retriever, client, headers, url, dir)
|
return RawDownloadHandler(retriever, client, headers, url, dir)
|
||||||
|
|
||||||
def handle(self) -> str:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _create_directory(self):
|
def _create_directory(self):
|
||||||
path = self._get_duplicate_name(os.path.abspath(self.client.host))
|
path = self._get_duplicate_name(os.path.abspath(self.client.host))
|
||||||
os.mkdir(path)
|
os.mkdir(path)
|
||||||
@@ -248,7 +247,7 @@ class HTMLDownloadHandler(DownloadHandler):
|
|||||||
|
|
||||||
def __download_images(self, tmp_filename, target_filename):
|
def __download_images(self, tmp_filename, target_filename):
|
||||||
|
|
||||||
(host, path) = parse_uri(self.url)
|
(host, path) = ResponseHandler.parse_uri(self.url)
|
||||||
with open(tmp_filename, "rb") as fp:
|
with open(tmp_filename, "rb") as fp:
|
||||||
soup = BeautifulSoup(fp, 'html.parser')
|
soup = BeautifulSoup(fp, 'html.parser')
|
||||||
|
|
||||||
@@ -286,7 +285,7 @@ class HTMLDownloadHandler(DownloadHandler):
|
|||||||
img_path = parsed.path
|
img_path = parsed.path
|
||||||
else:
|
else:
|
||||||
same_host = False
|
same_host = False
|
||||||
(img_host, img_path) = parse_uri(img_src)
|
(img_host, img_path) = ResponseHandler.parse_uri(img_src)
|
||||||
|
|
||||||
message = "GET {path} HTTP/1.1\r\n".format(path=img_path)
|
message = "GET {path} HTTP/1.1\r\n".format(path=img_path)
|
||||||
message += "Accept: */*\r\nAccept-Encoding: identity\r\n"
|
message += "Accept: */*\r\nAccept-Encoding: identity\r\n"
|
||||||
|
@@ -1,16 +1,17 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from client.httpclient import HTTPClient, BUFSIZE, IncompleteResponse, InvalidResponse, UnsupportedEncoding
|
from client.httpclient import HTTPClient, BUFSIZE, IncompleteResponse, InvalidResponse, UnsupportedEncoding
|
||||||
|
|
||||||
|
|
||||||
class Retriever:
|
class Retriever(ABC):
|
||||||
client: HTTPClient
|
client: HTTPClient
|
||||||
headers: Dict[str, str]
|
|
||||||
|
|
||||||
def __init__(self, client: HTTPClient):
|
def __init__(self, client: HTTPClient):
|
||||||
self.client = client
|
self.client = client
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def retrieve(self):
|
def retrieve(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -95,7 +96,7 @@ class ChunkedRetriever(Retriever):
|
|||||||
|
|
||||||
def retrieve(self):
|
def retrieve(self):
|
||||||
while True:
|
while True:
|
||||||
chunk_size = self._get_chunk_size()
|
chunk_size = self.__get_chunk_size()
|
||||||
logging.debug("chunk-size: %s", chunk_size)
|
logging.debug("chunk-size: %s", chunk_size)
|
||||||
if chunk_size == 0:
|
if chunk_size == 0:
|
||||||
self.client.reset_request()
|
self.client.reset_request()
|
||||||
@@ -108,7 +109,7 @@ class ChunkedRetriever(Retriever):
|
|||||||
self.client.read_line() # remove CRLF
|
self.client.read_line() # remove CRLF
|
||||||
return b""
|
return b""
|
||||||
|
|
||||||
def _get_chunk_size(self):
|
def __get_chunk_size(self):
|
||||||
line = self.client.read_line()
|
line = self.client.read_line()
|
||||||
sep_pos = line.find(";")
|
sep_pos = line.find(";")
|
||||||
if sep_pos >= 0:
|
if sep_pos >= 0:
|
||||||
|
@@ -1,17 +1,22 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from client.ResponseHandler import ResponseHandler
|
from client.ResponseHandler import ResponseHandler
|
||||||
from client.httpclient import FORMAT, HTTPClient, InvalidResponse, InvalidStatusLine, UnsupportedEncoding
|
from client.httpclient import FORMAT, HTTPClient, InvalidResponse, InvalidStatusLine, UnsupportedEncoding
|
||||||
|
|
||||||
|
|
||||||
class Command:
|
class AbstractCommand(ABC):
|
||||||
command: str
|
|
||||||
|
|
||||||
def __init__(self, url: str, port: str):
|
def __init__(self, url: str, port: str):
|
||||||
self.url = url
|
self.url = url
|
||||||
self.port = port
|
self.port = port
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def command(self):
|
||||||
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create(command: str, url: str, port: str):
|
def create(command: str, url: str, port: str):
|
||||||
if command == "GET":
|
if command == "GET":
|
||||||
@@ -56,8 +61,12 @@ class Command:
|
|||||||
finally:
|
finally:
|
||||||
client.close()
|
client.close()
|
||||||
|
|
||||||
def _await_response(self, client: HTTPClient):
|
def _await_response(self, client):
|
||||||
pass
|
while True:
|
||||||
|
line = client.read_line()
|
||||||
|
print(line, end="")
|
||||||
|
if line in ("\r\n", "\n", ""):
|
||||||
|
break
|
||||||
|
|
||||||
def _build_message(self, message: str) -> bytes:
|
def _build_message(self, message: str) -> bytes:
|
||||||
return (message + "\r\n").encode(FORMAT)
|
return (message + "\r\n").encode(FORMAT)
|
||||||
@@ -81,35 +90,10 @@ class Command:
|
|||||||
return host, path
|
return host, path
|
||||||
|
|
||||||
|
|
||||||
class HeadCommand(Command):
|
class AbstractWithBodyCommand(AbstractCommand, ABC):
|
||||||
command = "HEAD"
|
|
||||||
|
|
||||||
def _await_response(self, client):
|
|
||||||
while True:
|
|
||||||
line = client.read_line()
|
|
||||||
print(line, end="")
|
|
||||||
if line in ("\r\n", "\n", ""):
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
class GetCommand(Command):
|
|
||||||
command = "GET"
|
|
||||||
|
|
||||||
def _await_response(self, client):
|
|
||||||
(version, status, msg) = ResponseHandler.get_status_line(client)
|
|
||||||
logging.debug("Parsed status-line: version: %s, status: %s", version, status)
|
|
||||||
headers = ResponseHandler.get_headers(client)
|
|
||||||
logging.debug("Parsed headers: %r", headers)
|
|
||||||
|
|
||||||
handler = ResponseHandler.create(client, headers, status, self.url)
|
|
||||||
handler.handle()
|
|
||||||
|
|
||||||
|
|
||||||
class PostCommand(HeadCommand):
|
|
||||||
command = "POST"
|
|
||||||
|
|
||||||
def _build_message(self, message: str) -> bytes:
|
def _build_message(self, message: str) -> bytes:
|
||||||
body = input("Enter POST data: ").encode(FORMAT)
|
body = input(f"Enter {self.command} data: ").encode(FORMAT)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
message += "Content-Type: text/plain\r\n"
|
message += "Content-Type: text/plain\r\n"
|
||||||
@@ -122,5 +106,34 @@ class PostCommand(HeadCommand):
|
|||||||
return message
|
return message
|
||||||
|
|
||||||
|
|
||||||
class PutCommand(PostCommand):
|
class HeadCommand(AbstractCommand):
|
||||||
command = "PUT"
|
@property
|
||||||
|
def command(self):
|
||||||
|
return "HEAD"
|
||||||
|
|
||||||
|
|
||||||
|
class GetCommand(AbstractCommand):
|
||||||
|
@property
|
||||||
|
def command(self):
|
||||||
|
return "GET"
|
||||||
|
|
||||||
|
def _await_response(self, client):
|
||||||
|
(version, status, msg) = ResponseHandler.get_status_line(client)
|
||||||
|
logging.debug("Parsed status-line: version: %s, status: %s", version, status)
|
||||||
|
headers = ResponseHandler.get_headers(client)
|
||||||
|
logging.debug("Parsed headers: %r", headers)
|
||||||
|
|
||||||
|
handler = ResponseHandler.create(client, headers, status, self.url)
|
||||||
|
handler.handle()
|
||||||
|
|
||||||
|
|
||||||
|
class PostCommand(AbstractWithBodyCommand):
|
||||||
|
@property
|
||||||
|
def command(self):
|
||||||
|
return "POST"
|
||||||
|
|
||||||
|
|
||||||
|
class PutCommand(AbstractWithBodyCommand):
|
||||||
|
@property
|
||||||
|
def command(self):
|
||||||
|
return "PUT"
|
||||||
|
@@ -1,8 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import re
|
|
||||||
import socket
|
import socket
|
||||||
from io import BufferedReader
|
from io import BufferedReader
|
||||||
from typing import TextIO, IO
|
|
||||||
|
|
||||||
BUFSIZE = 4096
|
BUFSIZE = 4096
|
||||||
TIMEOUT = 3
|
TIMEOUT = 3
|
||||||
@@ -31,7 +29,7 @@ class HTTPClient(socket.socket):
|
|||||||
self.file.close()
|
self.file.close()
|
||||||
self.file = self.makefile("rb")
|
self.file = self.makefile("rb")
|
||||||
|
|
||||||
def _do_receive(self):
|
def __do_receive(self):
|
||||||
if self.fileno() == -1:
|
if self.fileno() == -1:
|
||||||
raise Exception("Connection closed")
|
raise Exception("Connection closed")
|
||||||
|
|
||||||
@@ -45,7 +43,7 @@ class HTTPClient(socket.socket):
|
|||||||
while True:
|
while True:
|
||||||
count += 1
|
count += 1
|
||||||
try:
|
try:
|
||||||
return self._do_receive()
|
return self.__do_receive()
|
||||||
except socket.timeout:
|
except socket.timeout:
|
||||||
logging.debug("Socket receive timed out after %s seconds", TIMEOUT)
|
logging.debug("Socket receive timed out after %s seconds", TIMEOUT)
|
||||||
if count == 3:
|
if count == 3:
|
||||||
@@ -75,69 +73,6 @@ class HTTPClient(socket.socket):
|
|||||||
|
|
||||||
return line
|
return line
|
||||||
|
|
||||||
def validate_status_line(self, status_line: str):
|
|
||||||
split = list(filter(None, status_line.split(" ")))
|
|
||||||
if len(split) < 3:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check HTTP version
|
|
||||||
http_version = split.pop(0)
|
|
||||||
if len(http_version) < 8 or http_version[4] != "/":
|
|
||||||
raise InvalidStatusLine(status_line)
|
|
||||||
(name, version) = http_version[:4], http_version[5:]
|
|
||||||
if name != "HTTP" or not re.match(r"1\.[0|1]", version):
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not re.match(r"\d{3}", split[0]):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_crlf_chunk(self, buffer: bytes):
|
|
||||||
"""Finds the line break type (`CRLF` or `LF`) and splits the specified buffer
|
|
||||||
when encountering 2 consecutive linebreaks.
|
|
||||||
Returns a tuple with the first part and the remaining of the buffer.
|
|
||||||
|
|
||||||
:param buffer:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
lf_pos = buffer.find(b"\n\n")
|
|
||||||
crlf_pos = buffer.find(b"\r\n\r\n")
|
|
||||||
if lf_pos != -1 and lf_pos < crlf_pos:
|
|
||||||
split_start = lf_pos
|
|
||||||
split_end = lf_pos + 2
|
|
||||||
else:
|
|
||||||
split_start = crlf_pos
|
|
||||||
split_end = crlf_pos + 4
|
|
||||||
|
|
||||||
return buffer[:split_start], buffer[split_end:]
|
|
||||||
|
|
||||||
def parse_headers(self, data: bytes):
|
|
||||||
headers = {}
|
|
||||||
|
|
||||||
# decode bytes, split into lines and filter
|
|
||||||
header_split = list(
|
|
||||||
filter(lambda l: l is not "" and not l[0].isspace(), map(str.strip, data.decode("utf-8").split("\n"))))
|
|
||||||
|
|
||||||
if len(header_split) == 0:
|
|
||||||
raise InvalidResponse(data)
|
|
||||||
|
|
||||||
start_line = header_split.pop(0)
|
|
||||||
logging.debug("start-line: %r", start_line)
|
|
||||||
|
|
||||||
for line in header_split:
|
|
||||||
pos = line.find(":")
|
|
||||||
|
|
||||||
if pos <= 0 or pos >= len(line) - 1:
|
|
||||||
continue
|
|
||||||
|
|
||||||
(header, value) = map(str.strip, line.split(":", 1))
|
|
||||||
headers[header.lower()] = value.lower()
|
|
||||||
|
|
||||||
logging.debug("Parsed headers: %r", headers)
|
|
||||||
|
|
||||||
return start_line, headers
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPException(Exception):
|
class HTTPException(Exception):
|
||||||
""" Base class for HTTP exceptions """
|
""" Base class for HTTP exceptions """
|
||||||
@@ -164,6 +99,7 @@ class UnsupportedEncoding(HTTPException):
|
|||||||
self.enc_type = enc_type
|
self.enc_type = enc_type
|
||||||
self.encoding = encoding
|
self.encoding = encoding
|
||||||
|
|
||||||
|
|
||||||
class IncompleteResponse(HTTPException):
|
class IncompleteResponse(HTTPException):
|
||||||
def __init(self, cause):
|
def __init(self, cause):
|
||||||
self.cause = cause
|
self.cause = cause
|
||||||
|
Reference in New Issue
Block a user