Fix small issues, improve error handling and documentation

This commit is contained in:
2021-03-28 14:04:39 +02:00
parent 850535a060
commit 07b018d2ab
6 changed files with 112 additions and 52 deletions

View File

@@ -4,6 +4,7 @@ import logging
import sys import sys
from client import command as cmd from client import command as cmd
from httplib.exceptions import UnhandledHTTPCode
def main(): def main():
@@ -15,7 +16,7 @@ def main():
arguments = parser.parse_args() arguments = parser.parse_args()
logging.basicConfig(level=logging.ERROR - (10 * arguments.verbose), format="[%(levelname)s] %(message)s") logging.basicConfig(level=logging.INFO - (10 * arguments.verbose), format="[%(levelname)s] %(message)s")
logging.debug("Arguments: %s", arguments) logging.debug("Arguments: %s", arguments)
command = cmd.create(arguments.command, arguments.URI, arguments.port) command = cmd.create(arguments.command, arguments.URI, arguments.port)
@@ -24,7 +25,10 @@ def main():
try: try:
main() main()
except UnhandledHTTPCode as e:
print(f"[{e.status_code}] {e.cause}:\r\n{e.headers}")
sys.exit(2)
except Exception as e: except Exception as e:
print("[ABRT] Internal error: " + str(e), file=sys.stderr) print("[ABRT] Internal error: " + str(e), file=sys.stderr)
logging.debug("Internal error", exc_info=e) logging.debug("Internal error", exc_info=e)
sys.exit(70) sys.exit(1)

View File

@@ -5,7 +5,7 @@ from urllib.parse import urlparse
from client.httpclient import HTTPClient from client.httpclient import HTTPClient
from httplib import parser from httplib import parser
from httplib.exceptions import InvalidResponse, InvalidStatusLine, UnsupportedEncoding from httplib.exceptions import InvalidResponse, InvalidStatusLine, UnsupportedEncoding, UnsupportedProtocol
from httplib.httpsocket import FORMAT from httplib.httpsocket import FORMAT
from httplib.message import ResponseMessage as Message from httplib.message import ResponseMessage as Message
from httplib.retriever import PreambleRetriever from httplib.retriever import PreambleRetriever
@@ -42,12 +42,10 @@ class AbstractCommand(ABC):
_host: str _host: str
_path: str _path: str
_port: int _port: int
sub_request: bool
def __init__(self, uri: str, port): def __init__(self, uri: str, port):
self.uri = uri self.uri = uri
self._port = int(port) self._port = int(port)
self.sub_request = False
@property @property
def uri(self): def uri(self):
@@ -81,23 +79,24 @@ class AbstractCommand(ABC):
@param sub_request: If this execution is in function of a prior command. @param sub_request: If this execution is in function of a prior command.
""" """
self.uri = ""
self.sub_request = sub_request
(host, path) = self.parse_uri()
client = sockets.get(host) client = sockets.get(self.host)
if client and client.is_closed(): if client and client.is_closed():
sockets.pop(self.host) sockets.pop(self.host)
client = None client = None
if not client: if not client:
client = HTTPClient(host) logging.info("Connecting to %s", self.host)
client.conn.connect((host, self.port)) client = HTTPClient(self.host)
sockets[host] = client client.conn.connect((self.host, self.port))
logging.info("Connected.")
sockets[self.host] = client
else:
logging.info("Reusing socket for %s", self.host)
message = f"{self.method} {path} HTTP/1.1\r\n" message = f"{self.method} {self.path} HTTP/1.1\r\n"
message += f"Host: {host}:{self.port}\r\n" message += f"Host: {self.host}:{self.port}\r\n"
message += "Accept: */*\r\n" message += "Accept: */*\r\n"
message += "Accept-Encoding: identity\r\n" message += "Accept-Encoding: identity\r\n"
encoded_msg = self._build_message(message) encoded_msg = self._build_message(message)
@@ -111,13 +110,17 @@ class AbstractCommand(ABC):
try: try:
self._await_response(client) self._await_response(client)
except InvalidResponse as e: except InvalidResponse as e:
logging.debug("Internal error: Response could not be parsed", exc_info=e) logging.error("Response could not be parsed")
return logging.debug("", exc_info=e)
except InvalidStatusLine as e: except InvalidStatusLine as e:
logging.debug("Internal error: Invalid status-line in response", exc_info=e) logging.error("Invalid status-line in response")
return logging.debug("", exc_info=e)
except UnsupportedEncoding as e: except UnsupportedEncoding as e:
logging.debug("Internal error: Unsupported encoding in response", exc_info=e) logging.error("Unsupported encoding in response")
logging.debug("", exc_info=e)
except UnsupportedProtocol as e:
logging.error("Unsupported protocol: %s", e.protocol)
logging.debug("", exc_info=e)
finally: finally:
if not sub_request: if not sub_request:
client.close() client.close()

View File

@@ -7,7 +7,7 @@ from urllib.parse import urlsplit, unquote
from client.command import AbstractCommand, GetCommand from client.command import AbstractCommand, GetCommand
from client.httpclient import HTTPClient from client.httpclient import HTTPClient
from httplib import parser from httplib import parser
from httplib.exceptions import InvalidResponse from httplib.exceptions import InvalidResponse, UnhandledHTTPCode, UnsupportedProtocol
from httplib.httpsocket import FORMAT from httplib.httpsocket import FORMAT
from httplib.message import ResponseMessage as Message from httplib.message import ResponseMessage as Message
from httplib.retriever import Retriever from httplib.retriever import Retriever
@@ -91,8 +91,7 @@ class BasicResponseHandler(ResponseHandler):
if self.msg.status == 101: if self.msg.status == 101:
# Switching protocols is not supported # Switching protocols is not supported
print("".join(self.msg.raw), end="") raise UnhandledHTTPCode(self.msg.status, "".join(self.msg.raw), "Switching protocols is not supported")
return None
if 200 <= self.msg.status < 300: if 200 <= self.msg.status < 300:
return self.retriever return self.retriever
@@ -105,16 +104,13 @@ class BasicResponseHandler(ResponseHandler):
if 400 <= self.msg.status < 600: if 400 <= self.msg.status < 600:
self._skip_body() self._skip_body()
# Dump headers and exit with error # Dump headers and exit with error
if not self.cmd.sub_request: raise UnhandledHTTPCode(self.msg.status, "".join(self.msg.raw), self.msg.msg)
print("".join(self.msg.raw), end="")
return None
return None return None
def _handle_redirect(self): def _handle_redirect(self):
if self.msg.status == 304: if self.msg.status == 304:
print("".join(self.msg.raw), end="") raise UnhandledHTTPCode(self.msg.status, "".join(self.msg.raw), self.msg.msg)
return None
location = self.msg.headers.get("location") location = self.msg.headers.get("location")
if not location or len(location.strip()) == 0: if not location or len(location.strip()) == 0:
@@ -126,7 +122,7 @@ class BasicResponseHandler(ResponseHandler):
raise InvalidResponse("Invalid location") raise InvalidResponse("Invalid location")
if not parsed_location.scheme == "http": if not parsed_location.scheme == "http":
raise InvalidResponse("Only http is supported") raise UnsupportedProtocol(parsed_location.scheme)
self.cmd.uri = location self.cmd.uri = location
@@ -195,7 +191,7 @@ class RawDownloadHandler(DownloadHandler):
super().__init__(retriever, client, msg, cmd, directory) super().__init__(retriever, client, msg, cmd, directory)
def handle(self) -> str: def handle(self) -> str:
logging.debug("Retrieving payload") logging.info("Saving to '%s'", parser.get_relative_save_path(self.path))
file = open(self.path, "wb") file = open(self.path, "wb")
for buffer in self.retriever.retrieve(): for buffer in self.retriever.retrieve():
@@ -225,20 +221,20 @@ class HTMLDownloadHandler(DownloadHandler):
os.remove(tmp_path) os.remove(tmp_path)
return self.path return self.path
def _download_images(self, tmp_filename, target_filename, charset=FORMAT): def _download_images(self, tmp_path, target_path, charset=FORMAT):
""" """
Downloads images referenced in the html of `tmp_filename` and replaces the references in the html Downloads images referenced in the html of `tmp_filename` and replaces the references in the html
and writes it to `target_filename`. and writes it to `target_filename`.
@param tmp_filename: the path to the temporary html file @param tmp_path: the path to the temporary html file
@param target_filename: the path for the final html file @param target_path: the path for the final html file
@param charset: the charset to decode `tmp_filename` @param charset: the charset to decode `tmp_filename`
""" """
try: try:
fp = open(tmp_filename, "r", encoding=charset) fp = open(tmp_path, "r", encoding="yeetus")
html = fp.read() html = fp.read()
except UnicodeDecodeError: except UnicodeDecodeError or LookupError:
fp = open(tmp_filename, "r", encoding=FORMAT, errors="replace") fp = open(tmp_path, "r", encoding=FORMAT, errors="replace")
html = fp.read() html = fp.read()
fp.close() fp.close()
@@ -281,7 +277,8 @@ class HTMLDownloadHandler(DownloadHandler):
for (start, end, path) in to_replace: for (start, end, path) in to_replace:
html = html[:start] + path + html[end:] html = html[:start] + path + html[end:]
with open(target_filename, 'w', encoding=FORMAT) as file: logging.info("Saving to HTML '%s'", parser.get_relative_save_path(target_path))
with open(target_path, 'w', encoding=FORMAT) as file:
file.write(html) file.write(html)
def __download_image(self, img_src, base_url): def __download_image(self, img_src, base_url):
@@ -295,6 +292,7 @@ class HTMLDownloadHandler(DownloadHandler):
parsed = urlsplit(img_src) parsed = urlsplit(img_src)
img_src = parser.urljoin(base_url, img_src) img_src = parser.urljoin(base_url, img_src)
# Check if the port of the image sh
if parsed.hostname is None or parsed.hostname == self.cmd.host: if parsed.hostname is None or parsed.hostname == self.cmd.host:
port = self.cmd.port port = self.cmd.port
elif ":" in parsed.netloc: elif ":" in parsed.netloc:

View File

@@ -2,34 +2,58 @@ class HTTPException(Exception):
""" Base class for HTTP exceptions """ """ Base class for HTTP exceptions """
class InvalidResponse(HTTPException): class UnhandledHTTPCode(Exception):
""" Response message cannot be parsed """ status_code: str
headers: str
cause: str
def __init(self, message): def __init__(self, status, headers, cause):
self.status_code = status
self.headers = headers
self.cause = cause
class InvalidResponse(HTTPException):
"""
Response message cannot be parsed
"""
def __init__(self, message):
self.message = message self.message = message
class InvalidStatusLine(HTTPException): class InvalidStatusLine(HTTPException):
""" Response status line is invalid """ """
Response status line is invalid
"""
def __init(self, line): def __init__(self, line):
self.line = line self.line = line
class UnsupportedEncoding(HTTPException): class UnsupportedEncoding(HTTPException):
""" Encoding not supported """ """
Encoding not supported
"""
def __init(self, enc_type, encoding): def __init__(self, enc_type, encoding):
self.enc_type = enc_type self.enc_type = enc_type
self.encoding = encoding self.encoding = encoding
class UnsupportedProtocol(HTTPException):
"""
Protocol is not supported
"""
def __init__(self, protocol):
self.protocol = protocol
class IncompleteResponse(HTTPException): class IncompleteResponse(HTTPException):
def __init(self, cause): def __init__(self, cause):
self.cause = cause self.cause = cause
class HTTPServerException(Exception): class HTTPServerException(HTTPException):
""" Base class for HTTP Server exceptions """ """ Base class for HTTP Server exceptions """
status_code: str status_code: str
message: str message: str
@@ -68,7 +92,7 @@ class MethodNotAllowed(HTTPServerException):
status_code = 405 status_code = 405
message = "Method Not Allowed" message = "Method Not Allowed"
def __init(self, allowed_methods): def __init__(self, allowed_methods):
self.allowed_methods = allowed_methods self.allowed_methods = allowed_methods

View File

@@ -6,7 +6,7 @@ from urllib.parse import SplitResult
class Message(ABC): class Message(ABC):
version: str version: str
headers: Dict[str, str] headers: Dict[str, str]
raw: str raw: [str]
body: bytes body: bytes
def __init__(self, version: str, headers: Dict[str, str], raw=None, body: bytes = None): def __init__(self, version: str, headers: Dict[str, str], raw=None, body: bytes = None):

View File

@@ -1,4 +1,6 @@
import logging import logging
import os
import pathlib
import re import re
import urllib import urllib
from typing import Dict from typing import Dict
@@ -85,6 +87,11 @@ def parse_request_line(line: str):
def parse_headers(lines): def parse_headers(lines):
"""
Parses the lines from the `lines` iterator as headers.
@param lines: iterator to retrieve the lines from.
@return: A dictionary with header as key and value as value.
"""
headers = [] headers = []
try: try:
@@ -127,17 +134,21 @@ def parse_headers(lines):
def check_next_header(headers, next_header: str, next_value: str): def check_next_header(headers, next_header: str, next_value: str):
if next_header == "content-length": if next_header == "content-length":
if "content-length" in headers: if "content-length" in headers:
logging.error("Multiple content-length headers specified") raise InvalidResponse("Multiple content-length headers specified")
raise InvalidResponse()
if not next_value.isnumeric() or int(next_value) <= 0: if not next_value.isnumeric() or int(next_value) <= 0:
logging.error("Invalid content-length value: %r", next_value) raise InvalidResponse(f"Invalid content-length value: {next_value}")
raise InvalidResponse()
def parse_uri(uri: str): def parse_uri(uri: str):
"""
Parse the specified URI into the host, port and path.
If the URI is invalid, this method will try to create one.
@param uri: the URI to be parsed
@return: A tuple with the host, port and path
"""
parsed = urlsplit(uri) parsed = urlsplit(uri)
# If there is no netloc, the given string is not a valid URI, so split on / # If there is no hostname, the given string is not a valid URI, so split on /
if parsed.hostname: if parsed.hostname:
host = parsed.hostname host = parsed.hostname
path = parsed.path path = parsed.path
@@ -180,6 +191,12 @@ def urljoin(base, url):
def get_charset(headers: Dict[str, str]): def get_charset(headers: Dict[str, str]):
"""
Returns the charset of the content from the headers if found. Otherwise returns `FORMAT`
@param headers: the headers to retrieve the charset from
@return: A charset
"""
if "content-type" in headers: if "content-type" in headers:
content_type = headers["content-type"] content_type = headers["content-type"]
match = re.search(r"charset\s*=\s*([a-z\-0-9]*)", content_type, re.I) match = re.search(r"charset\s*=\s*([a-z\-0-9]*)", content_type, re.I)
@@ -187,3 +204,17 @@ def get_charset(headers: Dict[str, str]):
return match.group(1) return match.group(1)
return FORMAT return FORMAT
def get_relative_save_path(path: str):
"""
Returns the specified path relative to the working directory.
@param path: the path to compute
@return: the relative path
"""
path_obj = pathlib.PurePath(path)
root = pathlib.PurePath(os.getcwd())
rel = path_obj.relative_to(root)
return str(rel)