Fix small issues, improve error handling and documentation

This commit is contained in:
2021-03-28 14:04:39 +02:00
parent 850535a060
commit 07b018d2ab
6 changed files with 112 additions and 52 deletions

View File

@@ -5,7 +5,7 @@ from urllib.parse import urlparse
from client.httpclient import HTTPClient
from httplib import parser
from httplib.exceptions import InvalidResponse, InvalidStatusLine, UnsupportedEncoding
from httplib.exceptions import InvalidResponse, InvalidStatusLine, UnsupportedEncoding, UnsupportedProtocol
from httplib.httpsocket import FORMAT
from httplib.message import ResponseMessage as Message
from httplib.retriever import PreambleRetriever
@@ -42,12 +42,10 @@ class AbstractCommand(ABC):
_host: str
_path: str
_port: int
sub_request: bool
def __init__(self, uri: str, port):
self.uri = uri
self._port = int(port)
self.sub_request = False
@property
def uri(self):
@@ -81,23 +79,24 @@ class AbstractCommand(ABC):
@param sub_request: If this execution is in function of a prior command.
"""
self.uri = ""
self.sub_request = sub_request
(host, path) = self.parse_uri()
client = sockets.get(host)
client = sockets.get(self.host)
if client and client.is_closed():
sockets.pop(self.host)
client = None
if not client:
client = HTTPClient(host)
client.conn.connect((host, self.port))
sockets[host] = client
logging.info("Connecting to %s", self.host)
client = HTTPClient(self.host)
client.conn.connect((self.host, self.port))
logging.info("Connected.")
sockets[self.host] = client
else:
logging.info("Reusing socket for %s", self.host)
message = f"{self.method} {path} HTTP/1.1\r\n"
message += f"Host: {host}:{self.port}\r\n"
message = f"{self.method} {self.path} HTTP/1.1\r\n"
message += f"Host: {self.host}:{self.port}\r\n"
message += "Accept: */*\r\n"
message += "Accept-Encoding: identity\r\n"
encoded_msg = self._build_message(message)
@@ -111,13 +110,17 @@ class AbstractCommand(ABC):
try:
self._await_response(client)
except InvalidResponse as e:
logging.debug("Internal error: Response could not be parsed", exc_info=e)
return
logging.error("Response could not be parsed")
logging.debug("", exc_info=e)
except InvalidStatusLine as e:
logging.debug("Internal error: Invalid status-line in response", exc_info=e)
return
logging.error("Invalid status-line in response")
logging.debug("", exc_info=e)
except UnsupportedEncoding as e:
logging.debug("Internal error: Unsupported encoding in response", exc_info=e)
logging.error("Unsupported encoding in response")
logging.debug("", exc_info=e)
except UnsupportedProtocol as e:
logging.error("Unsupported protocol: %s", e.protocol)
logging.debug("", exc_info=e)
finally:
if not sub_request:
client.close()

View File

@@ -7,7 +7,7 @@ from urllib.parse import urlsplit, unquote
from client.command import AbstractCommand, GetCommand
from client.httpclient import HTTPClient
from httplib import parser
from httplib.exceptions import InvalidResponse
from httplib.exceptions import InvalidResponse, UnhandledHTTPCode, UnsupportedProtocol
from httplib.httpsocket import FORMAT
from httplib.message import ResponseMessage as Message
from httplib.retriever import Retriever
@@ -91,8 +91,7 @@ class BasicResponseHandler(ResponseHandler):
if self.msg.status == 101:
# Switching protocols is not supported
print("".join(self.msg.raw), end="")
return None
raise UnhandledHTTPCode(self.msg.status, "".join(self.msg.raw), "Switching protocols is not supported")
if 200 <= self.msg.status < 300:
return self.retriever
@@ -105,16 +104,13 @@ class BasicResponseHandler(ResponseHandler):
if 400 <= self.msg.status < 600:
self._skip_body()
# Dump headers and exit with error
if not self.cmd.sub_request:
print("".join(self.msg.raw), end="")
return None
raise UnhandledHTTPCode(self.msg.status, "".join(self.msg.raw), self.msg.msg)
return None
def _handle_redirect(self):
if self.msg.status == 304:
print("".join(self.msg.raw), end="")
return None
raise UnhandledHTTPCode(self.msg.status, "".join(self.msg.raw), self.msg.msg)
location = self.msg.headers.get("location")
if not location or len(location.strip()) == 0:
@@ -126,7 +122,7 @@ class BasicResponseHandler(ResponseHandler):
raise InvalidResponse("Invalid location")
if not parsed_location.scheme == "http":
raise InvalidResponse("Only http is supported")
raise UnsupportedProtocol(parsed_location.scheme)
self.cmd.uri = location
@@ -195,7 +191,7 @@ class RawDownloadHandler(DownloadHandler):
super().__init__(retriever, client, msg, cmd, directory)
def handle(self) -> str:
logging.debug("Retrieving payload")
logging.info("Saving to '%s'", parser.get_relative_save_path(self.path))
file = open(self.path, "wb")
for buffer in self.retriever.retrieve():
@@ -225,20 +221,20 @@ class HTMLDownloadHandler(DownloadHandler):
os.remove(tmp_path)
return self.path
def _download_images(self, tmp_filename, target_filename, charset=FORMAT):
def _download_images(self, tmp_path, target_path, charset=FORMAT):
"""
Downloads images referenced in the html of `tmp_filename` and replaces the references in the html
and writes it to `target_filename`.
@param tmp_filename: the path to the temporary html file
@param target_filename: the path for the final html file
@param tmp_path: the path to the temporary html file
@param target_path: the path for the final html file
@param charset: the charset to decode `tmp_filename`
"""
try:
fp = open(tmp_filename, "r", encoding=charset)
fp = open(tmp_path, "r", encoding="yeetus")
html = fp.read()
except UnicodeDecodeError:
fp = open(tmp_filename, "r", encoding=FORMAT, errors="replace")
except UnicodeDecodeError or LookupError:
fp = open(tmp_path, "r", encoding=FORMAT, errors="replace")
html = fp.read()
fp.close()
@@ -281,7 +277,8 @@ class HTMLDownloadHandler(DownloadHandler):
for (start, end, path) in to_replace:
html = html[:start] + path + html[end:]
with open(target_filename, 'w', encoding=FORMAT) as file:
logging.info("Saving to HTML '%s'", parser.get_relative_save_path(target_path))
with open(target_path, 'w', encoding=FORMAT) as file:
file.write(html)
def __download_image(self, img_src, base_url):
@@ -295,6 +292,7 @@ class HTMLDownloadHandler(DownloadHandler):
parsed = urlsplit(img_src)
img_src = parser.urljoin(base_url, img_src)
# Check if the port of the image sh
if parsed.hostname is None or parsed.hostname == self.cmd.host:
port = self.cmd.port
elif ":" in parsed.netloc: