Fix small issues, improve error handling and documentation
This commit is contained in:
@@ -4,6 +4,7 @@ import logging
|
||||
import sys
|
||||
|
||||
from client import command as cmd
|
||||
from httplib.exceptions import UnhandledHTTPCode
|
||||
|
||||
|
||||
def main():
|
||||
@@ -15,7 +16,7 @@ def main():
|
||||
|
||||
arguments = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.ERROR - (10 * arguments.verbose), format="[%(levelname)s] %(message)s")
|
||||
logging.basicConfig(level=logging.INFO - (10 * arguments.verbose), format="[%(levelname)s] %(message)s")
|
||||
logging.debug("Arguments: %s", arguments)
|
||||
|
||||
command = cmd.create(arguments.command, arguments.URI, arguments.port)
|
||||
@@ -24,7 +25,10 @@ def main():
|
||||
|
||||
try:
|
||||
main()
|
||||
except UnhandledHTTPCode as e:
|
||||
print(f"[{e.status_code}] {e.cause}:\r\n{e.headers}")
|
||||
sys.exit(2)
|
||||
except Exception as e:
|
||||
print("[ABRT] Internal error: " + str(e), file=sys.stderr)
|
||||
logging.debug("Internal error", exc_info=e)
|
||||
sys.exit(70)
|
||||
sys.exit(1)
|
||||
|
@@ -5,7 +5,7 @@ from urllib.parse import urlparse
|
||||
|
||||
from client.httpclient import HTTPClient
|
||||
from httplib import parser
|
||||
from httplib.exceptions import InvalidResponse, InvalidStatusLine, UnsupportedEncoding
|
||||
from httplib.exceptions import InvalidResponse, InvalidStatusLine, UnsupportedEncoding, UnsupportedProtocol
|
||||
from httplib.httpsocket import FORMAT
|
||||
from httplib.message import ResponseMessage as Message
|
||||
from httplib.retriever import PreambleRetriever
|
||||
@@ -42,12 +42,10 @@ class AbstractCommand(ABC):
|
||||
_host: str
|
||||
_path: str
|
||||
_port: int
|
||||
sub_request: bool
|
||||
|
||||
def __init__(self, uri: str, port):
|
||||
self.uri = uri
|
||||
self._port = int(port)
|
||||
self.sub_request = False
|
||||
|
||||
@property
|
||||
def uri(self):
|
||||
@@ -81,23 +79,24 @@ class AbstractCommand(ABC):
|
||||
|
||||
@param sub_request: If this execution is in function of a prior command.
|
||||
"""
|
||||
self.uri = ""
|
||||
self.sub_request = sub_request
|
||||
(host, path) = self.parse_uri()
|
||||
|
||||
client = sockets.get(host)
|
||||
client = sockets.get(self.host)
|
||||
|
||||
if client and client.is_closed():
|
||||
sockets.pop(self.host)
|
||||
client = None
|
||||
|
||||
if not client:
|
||||
client = HTTPClient(host)
|
||||
client.conn.connect((host, self.port))
|
||||
sockets[host] = client
|
||||
logging.info("Connecting to %s", self.host)
|
||||
client = HTTPClient(self.host)
|
||||
client.conn.connect((self.host, self.port))
|
||||
logging.info("Connected.")
|
||||
sockets[self.host] = client
|
||||
else:
|
||||
logging.info("Reusing socket for %s", self.host)
|
||||
|
||||
message = f"{self.method} {path} HTTP/1.1\r\n"
|
||||
message += f"Host: {host}:{self.port}\r\n"
|
||||
message = f"{self.method} {self.path} HTTP/1.1\r\n"
|
||||
message += f"Host: {self.host}:{self.port}\r\n"
|
||||
message += "Accept: */*\r\n"
|
||||
message += "Accept-Encoding: identity\r\n"
|
||||
encoded_msg = self._build_message(message)
|
||||
@@ -111,13 +110,17 @@ class AbstractCommand(ABC):
|
||||
try:
|
||||
self._await_response(client)
|
||||
except InvalidResponse as e:
|
||||
logging.debug("Internal error: Response could not be parsed", exc_info=e)
|
||||
return
|
||||
logging.error("Response could not be parsed")
|
||||
logging.debug("", exc_info=e)
|
||||
except InvalidStatusLine as e:
|
||||
logging.debug("Internal error: Invalid status-line in response", exc_info=e)
|
||||
return
|
||||
logging.error("Invalid status-line in response")
|
||||
logging.debug("", exc_info=e)
|
||||
except UnsupportedEncoding as e:
|
||||
logging.debug("Internal error: Unsupported encoding in response", exc_info=e)
|
||||
logging.error("Unsupported encoding in response")
|
||||
logging.debug("", exc_info=e)
|
||||
except UnsupportedProtocol as e:
|
||||
logging.error("Unsupported protocol: %s", e.protocol)
|
||||
logging.debug("", exc_info=e)
|
||||
finally:
|
||||
if not sub_request:
|
||||
client.close()
|
||||
|
@@ -7,7 +7,7 @@ from urllib.parse import urlsplit, unquote
|
||||
from client.command import AbstractCommand, GetCommand
|
||||
from client.httpclient import HTTPClient
|
||||
from httplib import parser
|
||||
from httplib.exceptions import InvalidResponse
|
||||
from httplib.exceptions import InvalidResponse, UnhandledHTTPCode, UnsupportedProtocol
|
||||
from httplib.httpsocket import FORMAT
|
||||
from httplib.message import ResponseMessage as Message
|
||||
from httplib.retriever import Retriever
|
||||
@@ -91,8 +91,7 @@ class BasicResponseHandler(ResponseHandler):
|
||||
|
||||
if self.msg.status == 101:
|
||||
# Switching protocols is not supported
|
||||
print("".join(self.msg.raw), end="")
|
||||
return None
|
||||
raise UnhandledHTTPCode(self.msg.status, "".join(self.msg.raw), "Switching protocols is not supported")
|
||||
|
||||
if 200 <= self.msg.status < 300:
|
||||
return self.retriever
|
||||
@@ -105,16 +104,13 @@ class BasicResponseHandler(ResponseHandler):
|
||||
if 400 <= self.msg.status < 600:
|
||||
self._skip_body()
|
||||
# Dump headers and exit with error
|
||||
if not self.cmd.sub_request:
|
||||
print("".join(self.msg.raw), end="")
|
||||
return None
|
||||
raise UnhandledHTTPCode(self.msg.status, "".join(self.msg.raw), self.msg.msg)
|
||||
|
||||
return None
|
||||
|
||||
def _handle_redirect(self):
|
||||
if self.msg.status == 304:
|
||||
print("".join(self.msg.raw), end="")
|
||||
return None
|
||||
raise UnhandledHTTPCode(self.msg.status, "".join(self.msg.raw), self.msg.msg)
|
||||
|
||||
location = self.msg.headers.get("location")
|
||||
if not location or len(location.strip()) == 0:
|
||||
@@ -126,7 +122,7 @@ class BasicResponseHandler(ResponseHandler):
|
||||
raise InvalidResponse("Invalid location")
|
||||
|
||||
if not parsed_location.scheme == "http":
|
||||
raise InvalidResponse("Only http is supported")
|
||||
raise UnsupportedProtocol(parsed_location.scheme)
|
||||
|
||||
self.cmd.uri = location
|
||||
|
||||
@@ -195,7 +191,7 @@ class RawDownloadHandler(DownloadHandler):
|
||||
super().__init__(retriever, client, msg, cmd, directory)
|
||||
|
||||
def handle(self) -> str:
|
||||
logging.debug("Retrieving payload")
|
||||
logging.info("Saving to '%s'", parser.get_relative_save_path(self.path))
|
||||
file = open(self.path, "wb")
|
||||
|
||||
for buffer in self.retriever.retrieve():
|
||||
@@ -225,20 +221,20 @@ class HTMLDownloadHandler(DownloadHandler):
|
||||
os.remove(tmp_path)
|
||||
return self.path
|
||||
|
||||
def _download_images(self, tmp_filename, target_filename, charset=FORMAT):
|
||||
def _download_images(self, tmp_path, target_path, charset=FORMAT):
|
||||
"""
|
||||
Downloads images referenced in the html of `tmp_filename` and replaces the references in the html
|
||||
and writes it to `target_filename`.
|
||||
@param tmp_filename: the path to the temporary html file
|
||||
@param target_filename: the path for the final html file
|
||||
@param tmp_path: the path to the temporary html file
|
||||
@param target_path: the path for the final html file
|
||||
@param charset: the charset to decode `tmp_filename`
|
||||
"""
|
||||
|
||||
try:
|
||||
fp = open(tmp_filename, "r", encoding=charset)
|
||||
fp = open(tmp_path, "r", encoding="yeetus")
|
||||
html = fp.read()
|
||||
except UnicodeDecodeError:
|
||||
fp = open(tmp_filename, "r", encoding=FORMAT, errors="replace")
|
||||
except UnicodeDecodeError or LookupError:
|
||||
fp = open(tmp_path, "r", encoding=FORMAT, errors="replace")
|
||||
html = fp.read()
|
||||
|
||||
fp.close()
|
||||
@@ -281,7 +277,8 @@ class HTMLDownloadHandler(DownloadHandler):
|
||||
for (start, end, path) in to_replace:
|
||||
html = html[:start] + path + html[end:]
|
||||
|
||||
with open(target_filename, 'w', encoding=FORMAT) as file:
|
||||
logging.info("Saving to HTML '%s'", parser.get_relative_save_path(target_path))
|
||||
with open(target_path, 'w', encoding=FORMAT) as file:
|
||||
file.write(html)
|
||||
|
||||
def __download_image(self, img_src, base_url):
|
||||
@@ -295,6 +292,7 @@ class HTMLDownloadHandler(DownloadHandler):
|
||||
parsed = urlsplit(img_src)
|
||||
img_src = parser.urljoin(base_url, img_src)
|
||||
|
||||
# Check if the port of the image sh
|
||||
if parsed.hostname is None or parsed.hostname == self.cmd.host:
|
||||
port = self.cmd.port
|
||||
elif ":" in parsed.netloc:
|
||||
|
@@ -2,34 +2,58 @@ class HTTPException(Exception):
|
||||
""" Base class for HTTP exceptions """
|
||||
|
||||
|
||||
class InvalidResponse(HTTPException):
|
||||
""" Response message cannot be parsed """
|
||||
class UnhandledHTTPCode(Exception):
|
||||
status_code: str
|
||||
headers: str
|
||||
cause: str
|
||||
|
||||
def __init(self, message):
|
||||
def __init__(self, status, headers, cause):
|
||||
self.status_code = status
|
||||
self.headers = headers
|
||||
self.cause = cause
|
||||
|
||||
|
||||
class InvalidResponse(HTTPException):
|
||||
"""
|
||||
Response message cannot be parsed
|
||||
"""
|
||||
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
|
||||
|
||||
class InvalidStatusLine(HTTPException):
|
||||
""" Response status line is invalid """
|
||||
"""
|
||||
Response status line is invalid
|
||||
"""
|
||||
|
||||
def __init(self, line):
|
||||
def __init__(self, line):
|
||||
self.line = line
|
||||
|
||||
|
||||
class UnsupportedEncoding(HTTPException):
|
||||
""" Encoding not supported """
|
||||
"""
|
||||
Encoding not supported
|
||||
"""
|
||||
|
||||
def __init(self, enc_type, encoding):
|
||||
def __init__(self, enc_type, encoding):
|
||||
self.enc_type = enc_type
|
||||
self.encoding = encoding
|
||||
|
||||
class UnsupportedProtocol(HTTPException):
|
||||
"""
|
||||
Protocol is not supported
|
||||
"""
|
||||
def __init__(self, protocol):
|
||||
self.protocol = protocol
|
||||
|
||||
|
||||
class IncompleteResponse(HTTPException):
|
||||
def __init(self, cause):
|
||||
def __init__(self, cause):
|
||||
self.cause = cause
|
||||
|
||||
|
||||
class HTTPServerException(Exception):
|
||||
class HTTPServerException(HTTPException):
|
||||
""" Base class for HTTP Server exceptions """
|
||||
status_code: str
|
||||
message: str
|
||||
@@ -68,7 +92,7 @@ class MethodNotAllowed(HTTPServerException):
|
||||
status_code = 405
|
||||
message = "Method Not Allowed"
|
||||
|
||||
def __init(self, allowed_methods):
|
||||
def __init__(self, allowed_methods):
|
||||
self.allowed_methods = allowed_methods
|
||||
|
||||
|
||||
|
@@ -6,7 +6,7 @@ from urllib.parse import SplitResult
|
||||
class Message(ABC):
|
||||
version: str
|
||||
headers: Dict[str, str]
|
||||
raw: str
|
||||
raw: [str]
|
||||
body: bytes
|
||||
|
||||
def __init__(self, version: str, headers: Dict[str, str], raw=None, body: bytes = None):
|
||||
|
@@ -1,4 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
import urllib
|
||||
from typing import Dict
|
||||
@@ -85,6 +87,11 @@ def parse_request_line(line: str):
|
||||
|
||||
|
||||
def parse_headers(lines):
|
||||
"""
|
||||
Parses the lines from the `lines` iterator as headers.
|
||||
@param lines: iterator to retrieve the lines from.
|
||||
@return: A dictionary with header as key and value as value.
|
||||
"""
|
||||
headers = []
|
||||
|
||||
try:
|
||||
@@ -127,17 +134,21 @@ def parse_headers(lines):
|
||||
def check_next_header(headers, next_header: str, next_value: str):
|
||||
if next_header == "content-length":
|
||||
if "content-length" in headers:
|
||||
logging.error("Multiple content-length headers specified")
|
||||
raise InvalidResponse()
|
||||
raise InvalidResponse("Multiple content-length headers specified")
|
||||
if not next_value.isnumeric() or int(next_value) <= 0:
|
||||
logging.error("Invalid content-length value: %r", next_value)
|
||||
raise InvalidResponse()
|
||||
raise InvalidResponse(f"Invalid content-length value: {next_value}")
|
||||
|
||||
|
||||
def parse_uri(uri: str):
|
||||
"""
|
||||
Parse the specified URI into the host, port and path.
|
||||
If the URI is invalid, this method will try to create one.
|
||||
@param uri: the URI to be parsed
|
||||
@return: A tuple with the host, port and path
|
||||
"""
|
||||
parsed = urlsplit(uri)
|
||||
|
||||
# If there is no netloc, the given string is not a valid URI, so split on /
|
||||
# If there is no hostname, the given string is not a valid URI, so split on /
|
||||
if parsed.hostname:
|
||||
host = parsed.hostname
|
||||
path = parsed.path
|
||||
@@ -180,6 +191,12 @@ def urljoin(base, url):
|
||||
|
||||
|
||||
def get_charset(headers: Dict[str, str]):
|
||||
"""
|
||||
Returns the charset of the content from the headers if found. Otherwise returns `FORMAT`
|
||||
|
||||
@param headers: the headers to retrieve the charset from
|
||||
@return: A charset
|
||||
"""
|
||||
if "content-type" in headers:
|
||||
content_type = headers["content-type"]
|
||||
match = re.search(r"charset\s*=\s*([a-z\-0-9]*)", content_type, re.I)
|
||||
@@ -187,3 +204,17 @@ def get_charset(headers: Dict[str, str]):
|
||||
return match.group(1)
|
||||
|
||||
return FORMAT
|
||||
|
||||
|
||||
def get_relative_save_path(path: str):
|
||||
"""
|
||||
Returns the specified path relative to the working directory.
|
||||
|
||||
@param path: the path to compute
|
||||
@return: the relative path
|
||||
"""
|
||||
|
||||
path_obj = pathlib.PurePath(path)
|
||||
root = pathlib.PurePath(os.getcwd())
|
||||
rel = path_obj.relative_to(root)
|
||||
return str(rel)
|
||||
|
Reference in New Issue
Block a user