Fix small issues, improve error handling and documentation

This commit is contained in:
2021-03-28 14:04:39 +02:00
parent 850535a060
commit 07b018d2ab
6 changed files with 112 additions and 52 deletions

View File

@@ -7,7 +7,7 @@ from urllib.parse import urlsplit, unquote
from client.command import AbstractCommand, GetCommand
from client.httpclient import HTTPClient
from httplib import parser
from httplib.exceptions import InvalidResponse
from httplib.exceptions import InvalidResponse, UnhandledHTTPCode, UnsupportedProtocol
from httplib.httpsocket import FORMAT
from httplib.message import ResponseMessage as Message
from httplib.retriever import Retriever
@@ -91,8 +91,7 @@ class BasicResponseHandler(ResponseHandler):
if self.msg.status == 101:
# Switching protocols is not supported
print("".join(self.msg.raw), end="")
return None
raise UnhandledHTTPCode(self.msg.status, "".join(self.msg.raw), "Switching protocols is not supported")
if 200 <= self.msg.status < 300:
return self.retriever
@@ -105,16 +104,13 @@ class BasicResponseHandler(ResponseHandler):
if 400 <= self.msg.status < 600:
self._skip_body()
# Dump headers and exit with error
if not self.cmd.sub_request:
print("".join(self.msg.raw), end="")
return None
raise UnhandledHTTPCode(self.msg.status, "".join(self.msg.raw), self.msg.msg)
return None
def _handle_redirect(self):
if self.msg.status == 304:
print("".join(self.msg.raw), end="")
return None
raise UnhandledHTTPCode(self.msg.status, "".join(self.msg.raw), self.msg.msg)
location = self.msg.headers.get("location")
if not location or len(location.strip()) == 0:
@@ -126,7 +122,7 @@ class BasicResponseHandler(ResponseHandler):
raise InvalidResponse("Invalid location")
if not parsed_location.scheme == "http":
raise InvalidResponse("Only http is supported")
raise UnsupportedProtocol(parsed_location.scheme)
self.cmd.uri = location
@@ -195,7 +191,7 @@ class RawDownloadHandler(DownloadHandler):
super().__init__(retriever, client, msg, cmd, directory)
def handle(self) -> str:
logging.debug("Retrieving payload")
logging.info("Saving to '%s'", parser.get_relative_save_path(self.path))
file = open(self.path, "wb")
for buffer in self.retriever.retrieve():
@@ -225,20 +221,20 @@ class HTMLDownloadHandler(DownloadHandler):
os.remove(tmp_path)
return self.path
def _download_images(self, tmp_filename, target_filename, charset=FORMAT):
def _download_images(self, tmp_path, target_path, charset=FORMAT):
"""
Downloads images referenced in the html of `tmp_filename` and replaces the references in the html
and writes it to `target_filename`.
@param tmp_filename: the path to the temporary html file
@param target_filename: the path for the final html file
@param tmp_path: the path to the temporary html file
@param target_path: the path for the final html file
@param charset: the charset to decode `tmp_filename`
"""
try:
fp = open(tmp_filename, "r", encoding=charset)
fp = open(tmp_path, "r", encoding="yeetus")
html = fp.read()
except UnicodeDecodeError:
fp = open(tmp_filename, "r", encoding=FORMAT, errors="replace")
except UnicodeDecodeError or LookupError:
fp = open(tmp_path, "r", encoding=FORMAT, errors="replace")
html = fp.read()
fp.close()
@@ -281,7 +277,8 @@ class HTMLDownloadHandler(DownloadHandler):
for (start, end, path) in to_replace:
html = html[:start] + path + html[end:]
with open(target_filename, 'w', encoding=FORMAT) as file:
logging.info("Saving to HTML '%s'", parser.get_relative_save_path(target_path))
with open(target_path, 'w', encoding=FORMAT) as file:
file.write(html)
def __download_image(self, img_src, base_url):
@@ -295,6 +292,7 @@ class HTMLDownloadHandler(DownloadHandler):
parsed = urlsplit(img_src)
img_src = parser.urljoin(base_url, img_src)
# Check if the port of the image sh
if parsed.hostname is None or parsed.hostname == self.cmd.host:
port = self.cmd.port
elif ":" in parsed.netloc: