Fix small issues, improve error handling and documentation

This commit is contained in:
2021-03-28 14:04:39 +02:00
parent 850535a060
commit 07b018d2ab
6 changed files with 112 additions and 52 deletions

View File

@@ -4,6 +4,7 @@ import logging
import sys
from client import command as cmd
from httplib.exceptions import UnhandledHTTPCode
def main():
@@ -15,7 +16,7 @@ def main():
arguments = parser.parse_args()
logging.basicConfig(level=logging.ERROR - (10 * arguments.verbose), format="[%(levelname)s] %(message)s")
logging.basicConfig(level=logging.INFO - (10 * arguments.verbose), format="[%(levelname)s] %(message)s")
logging.debug("Arguments: %s", arguments)
command = cmd.create(arguments.command, arguments.URI, arguments.port)
@@ -24,7 +25,10 @@ def main():
try:
main()
except UnhandledHTTPCode as e:
print(f"[{e.status_code}] {e.cause}:\r\n{e.headers}")
sys.exit(2)
except Exception as e:
print("[ABRT] Internal error: " + str(e), file=sys.stderr)
logging.debug("Internal error", exc_info=e)
sys.exit(70)
sys.exit(1)

View File

@@ -5,7 +5,7 @@ from urllib.parse import urlparse
from client.httpclient import HTTPClient
from httplib import parser
from httplib.exceptions import InvalidResponse, InvalidStatusLine, UnsupportedEncoding
from httplib.exceptions import InvalidResponse, InvalidStatusLine, UnsupportedEncoding, UnsupportedProtocol
from httplib.httpsocket import FORMAT
from httplib.message import ResponseMessage as Message
from httplib.retriever import PreambleRetriever
@@ -42,12 +42,10 @@ class AbstractCommand(ABC):
_host: str
_path: str
_port: int
sub_request: bool
def __init__(self, uri: str, port):
self.uri = uri
self._port = int(port)
self.sub_request = False
@property
def uri(self):
@@ -81,23 +79,24 @@ class AbstractCommand(ABC):
@param sub_request: If this execution is in function of a prior command.
"""
self.uri = ""
self.sub_request = sub_request
(host, path) = self.parse_uri()
client = sockets.get(host)
client = sockets.get(self.host)
if client and client.is_closed():
sockets.pop(self.host)
client = None
if not client:
client = HTTPClient(host)
client.conn.connect((host, self.port))
sockets[host] = client
logging.info("Connecting to %s", self.host)
client = HTTPClient(self.host)
client.conn.connect((self.host, self.port))
logging.info("Connected.")
sockets[self.host] = client
else:
logging.info("Reusing socket for %s", self.host)
message = f"{self.method} {path} HTTP/1.1\r\n"
message += f"Host: {host}:{self.port}\r\n"
message = f"{self.method} {self.path} HTTP/1.1\r\n"
message += f"Host: {self.host}:{self.port}\r\n"
message += "Accept: */*\r\n"
message += "Accept-Encoding: identity\r\n"
encoded_msg = self._build_message(message)
@@ -111,13 +110,17 @@ class AbstractCommand(ABC):
try:
self._await_response(client)
except InvalidResponse as e:
logging.debug("Internal error: Response could not be parsed", exc_info=e)
return
logging.error("Response could not be parsed")
logging.debug("", exc_info=e)
except InvalidStatusLine as e:
logging.debug("Internal error: Invalid status-line in response", exc_info=e)
return
logging.error("Invalid status-line in response")
logging.debug("", exc_info=e)
except UnsupportedEncoding as e:
logging.debug("Internal error: Unsupported encoding in response", exc_info=e)
logging.error("Unsupported encoding in response")
logging.debug("", exc_info=e)
except UnsupportedProtocol as e:
logging.error("Unsupported protocol: %s", e.protocol)
logging.debug("", exc_info=e)
finally:
if not sub_request:
client.close()

View File

@@ -7,7 +7,7 @@ from urllib.parse import urlsplit, unquote
from client.command import AbstractCommand, GetCommand
from client.httpclient import HTTPClient
from httplib import parser
from httplib.exceptions import InvalidResponse
from httplib.exceptions import InvalidResponse, UnhandledHTTPCode, UnsupportedProtocol
from httplib.httpsocket import FORMAT
from httplib.message import ResponseMessage as Message
from httplib.retriever import Retriever
@@ -91,8 +91,7 @@ class BasicResponseHandler(ResponseHandler):
if self.msg.status == 101:
# Switching protocols is not supported
print("".join(self.msg.raw), end="")
return None
raise UnhandledHTTPCode(self.msg.status, "".join(self.msg.raw), "Switching protocols is not supported")
if 200 <= self.msg.status < 300:
return self.retriever
@@ -105,16 +104,13 @@ class BasicResponseHandler(ResponseHandler):
if 400 <= self.msg.status < 600:
self._skip_body()
# Dump headers and exit with error
if not self.cmd.sub_request:
print("".join(self.msg.raw), end="")
return None
raise UnhandledHTTPCode(self.msg.status, "".join(self.msg.raw), self.msg.msg)
return None
def _handle_redirect(self):
if self.msg.status == 304:
print("".join(self.msg.raw), end="")
return None
raise UnhandledHTTPCode(self.msg.status, "".join(self.msg.raw), self.msg.msg)
location = self.msg.headers.get("location")
if not location or len(location.strip()) == 0:
@@ -126,7 +122,7 @@ class BasicResponseHandler(ResponseHandler):
raise InvalidResponse("Invalid location")
if not parsed_location.scheme == "http":
raise InvalidResponse("Only http is supported")
raise UnsupportedProtocol(parsed_location.scheme)
self.cmd.uri = location
@@ -195,7 +191,7 @@ class RawDownloadHandler(DownloadHandler):
super().__init__(retriever, client, msg, cmd, directory)
def handle(self) -> str:
logging.debug("Retrieving payload")
logging.info("Saving to '%s'", parser.get_relative_save_path(self.path))
file = open(self.path, "wb")
for buffer in self.retriever.retrieve():
@@ -225,20 +221,20 @@ class HTMLDownloadHandler(DownloadHandler):
os.remove(tmp_path)
return self.path
def _download_images(self, tmp_filename, target_filename, charset=FORMAT):
def _download_images(self, tmp_path, target_path, charset=FORMAT):
"""
Downloads images referenced in the html of `tmp_filename` and replaces the references in the html
and writes it to `target_filename`.
@param tmp_filename: the path to the temporary html file
@param target_filename: the path for the final html file
@param tmp_path: the path to the temporary html file
@param target_path: the path for the final html file
@param charset: the charset to decode `tmp_filename`
"""
try:
fp = open(tmp_filename, "r", encoding=charset)
fp = open(tmp_path, "r", encoding="yeetus")
html = fp.read()
except UnicodeDecodeError:
fp = open(tmp_filename, "r", encoding=FORMAT, errors="replace")
except UnicodeDecodeError or LookupError:
fp = open(tmp_path, "r", encoding=FORMAT, errors="replace")
html = fp.read()
fp.close()
@@ -281,7 +277,8 @@ class HTMLDownloadHandler(DownloadHandler):
for (start, end, path) in to_replace:
html = html[:start] + path + html[end:]
with open(target_filename, 'w', encoding=FORMAT) as file:
logging.info("Saving to HTML '%s'", parser.get_relative_save_path(target_path))
with open(target_path, 'w', encoding=FORMAT) as file:
file.write(html)
def __download_image(self, img_src, base_url):
@@ -295,6 +292,7 @@ class HTMLDownloadHandler(DownloadHandler):
parsed = urlsplit(img_src)
img_src = parser.urljoin(base_url, img_src)
# Check if the port of the image sh
if parsed.hostname is None or parsed.hostname == self.cmd.host:
port = self.cmd.port
elif ":" in parsed.netloc:

View File

@@ -2,34 +2,58 @@ class HTTPException(Exception):
""" Base class for HTTP exceptions """
class InvalidResponse(HTTPException):
""" Response message cannot be parsed """
class UnhandledHTTPCode(Exception):
status_code: str
headers: str
cause: str
def __init(self, message):
def __init__(self, status, headers, cause):
self.status_code = status
self.headers = headers
self.cause = cause
class InvalidResponse(HTTPException):
"""
Response message cannot be parsed
"""
def __init__(self, message):
self.message = message
class InvalidStatusLine(HTTPException):
""" Response status line is invalid """
"""
Response status line is invalid
"""
def __init(self, line):
def __init__(self, line):
self.line = line
class UnsupportedEncoding(HTTPException):
""" Encoding not supported """
"""
Encoding not supported
"""
def __init(self, enc_type, encoding):
def __init__(self, enc_type, encoding):
self.enc_type = enc_type
self.encoding = encoding
class UnsupportedProtocol(HTTPException):
"""
Protocol is not supported
"""
def __init__(self, protocol):
self.protocol = protocol
class IncompleteResponse(HTTPException):
def __init(self, cause):
def __init__(self, cause):
self.cause = cause
class HTTPServerException(Exception):
class HTTPServerException(HTTPException):
""" Base class for HTTP Server exceptions """
status_code: str
message: str
@@ -68,7 +92,7 @@ class MethodNotAllowed(HTTPServerException):
status_code = 405
message = "Method Not Allowed"
def __init(self, allowed_methods):
def __init__(self, allowed_methods):
self.allowed_methods = allowed_methods

View File

@@ -6,7 +6,7 @@ from urllib.parse import SplitResult
class Message(ABC):
version: str
headers: Dict[str, str]
raw: str
raw: [str]
body: bytes
def __init__(self, version: str, headers: Dict[str, str], raw=None, body: bytes = None):

View File

@@ -1,4 +1,6 @@
import logging
import os
import pathlib
import re
import urllib
from typing import Dict
@@ -85,6 +87,11 @@ def parse_request_line(line: str):
def parse_headers(lines):
"""
Parses the lines from the `lines` iterator as headers.
@param lines: iterator to retrieve the lines from.
@return: A dictionary with header as key and value as value.
"""
headers = []
try:
@@ -127,17 +134,21 @@ def parse_headers(lines):
def check_next_header(headers, next_header: str, next_value: str):
if next_header == "content-length":
if "content-length" in headers:
logging.error("Multiple content-length headers specified")
raise InvalidResponse()
raise InvalidResponse("Multiple content-length headers specified")
if not next_value.isnumeric() or int(next_value) <= 0:
logging.error("Invalid content-length value: %r", next_value)
raise InvalidResponse()
raise InvalidResponse(f"Invalid content-length value: {next_value}")
def parse_uri(uri: str):
"""
Parse the specified URI into the host, port and path.
If the URI is invalid, this method will try to create one.
@param uri: the URI to be parsed
@return: A tuple with the host, port and path
"""
parsed = urlsplit(uri)
# If there is no netloc, the given string is not a valid URI, so split on /
# If there is no hostname, the given string is not a valid URI, so split on /
if parsed.hostname:
host = parsed.hostname
path = parsed.path
@@ -180,6 +191,12 @@ def urljoin(base, url):
def get_charset(headers: Dict[str, str]):
"""
Returns the charset of the content from the headers if found. Otherwise returns `FORMAT`
@param headers: the headers to retrieve the charset from
@return: A charset
"""
if "content-type" in headers:
content_type = headers["content-type"]
match = re.search(r"charset\s*=\s*([a-z\-0-9]*)", content_type, re.I)
@@ -187,3 +204,17 @@ def get_charset(headers: Dict[str, str]):
return match.group(1)
return FORMAT
def get_relative_save_path(path: str):
"""
Returns the specified path relative to the working directory.
@param path: the path to compute
@return: the relative path
"""
path_obj = pathlib.PurePath(path)
root = pathlib.PurePath(os.getcwd())
rel = path_obj.relative_to(root)
return str(rel)