Fix small issues, improve error handling and documentation
This commit is contained in:
@@ -7,7 +7,7 @@ from urllib.parse import urlsplit, unquote
|
||||
from client.command import AbstractCommand, GetCommand
|
||||
from client.httpclient import HTTPClient
|
||||
from httplib import parser
|
||||
from httplib.exceptions import InvalidResponse
|
||||
from httplib.exceptions import InvalidResponse, UnhandledHTTPCode, UnsupportedProtocol
|
||||
from httplib.httpsocket import FORMAT
|
||||
from httplib.message import ResponseMessage as Message
|
||||
from httplib.retriever import Retriever
|
||||
@@ -91,8 +91,7 @@ class BasicResponseHandler(ResponseHandler):
|
||||
|
||||
if self.msg.status == 101:
|
||||
# Switching protocols is not supported
|
||||
print("".join(self.msg.raw), end="")
|
||||
return None
|
||||
raise UnhandledHTTPCode(self.msg.status, "".join(self.msg.raw), "Switching protocols is not supported")
|
||||
|
||||
if 200 <= self.msg.status < 300:
|
||||
return self.retriever
|
||||
@@ -105,16 +104,13 @@ class BasicResponseHandler(ResponseHandler):
|
||||
if 400 <= self.msg.status < 600:
|
||||
self._skip_body()
|
||||
# Dump headers and exit with error
|
||||
if not self.cmd.sub_request:
|
||||
print("".join(self.msg.raw), end="")
|
||||
return None
|
||||
raise UnhandledHTTPCode(self.msg.status, "".join(self.msg.raw), self.msg.msg)
|
||||
|
||||
return None
|
||||
|
||||
def _handle_redirect(self):
|
||||
if self.msg.status == 304:
|
||||
print("".join(self.msg.raw), end="")
|
||||
return None
|
||||
raise UnhandledHTTPCode(self.msg.status, "".join(self.msg.raw), self.msg.msg)
|
||||
|
||||
location = self.msg.headers.get("location")
|
||||
if not location or len(location.strip()) == 0:
|
||||
@@ -126,7 +122,7 @@ class BasicResponseHandler(ResponseHandler):
|
||||
raise InvalidResponse("Invalid location")
|
||||
|
||||
if not parsed_location.scheme == "http":
|
||||
raise InvalidResponse("Only http is supported")
|
||||
raise UnsupportedProtocol(parsed_location.scheme)
|
||||
|
||||
self.cmd.uri = location
|
||||
|
||||
@@ -195,7 +191,7 @@ class RawDownloadHandler(DownloadHandler):
|
||||
super().__init__(retriever, client, msg, cmd, directory)
|
||||
|
||||
def handle(self) -> str:
|
||||
logging.debug("Retrieving payload")
|
||||
logging.info("Saving to '%s'", parser.get_relative_save_path(self.path))
|
||||
file = open(self.path, "wb")
|
||||
|
||||
for buffer in self.retriever.retrieve():
|
||||
@@ -225,20 +221,20 @@ class HTMLDownloadHandler(DownloadHandler):
|
||||
os.remove(tmp_path)
|
||||
return self.path
|
||||
|
||||
def _download_images(self, tmp_filename, target_filename, charset=FORMAT):
|
||||
def _download_images(self, tmp_path, target_path, charset=FORMAT):
|
||||
"""
|
||||
Downloads images referenced in the html of `tmp_filename` and replaces the references in the html
|
||||
and writes it to `target_filename`.
|
||||
@param tmp_filename: the path to the temporary html file
|
||||
@param target_filename: the path for the final html file
|
||||
@param tmp_path: the path to the temporary html file
|
||||
@param target_path: the path for the final html file
|
||||
@param charset: the charset to decode `tmp_filename`
|
||||
"""
|
||||
|
||||
try:
|
||||
fp = open(tmp_filename, "r", encoding=charset)
|
||||
fp = open(tmp_path, "r", encoding="yeetus")
|
||||
html = fp.read()
|
||||
except UnicodeDecodeError:
|
||||
fp = open(tmp_filename, "r", encoding=FORMAT, errors="replace")
|
||||
except UnicodeDecodeError or LookupError:
|
||||
fp = open(tmp_path, "r", encoding=FORMAT, errors="replace")
|
||||
html = fp.read()
|
||||
|
||||
fp.close()
|
||||
@@ -281,7 +277,8 @@ class HTMLDownloadHandler(DownloadHandler):
|
||||
for (start, end, path) in to_replace:
|
||||
html = html[:start] + path + html[end:]
|
||||
|
||||
with open(target_filename, 'w', encoding=FORMAT) as file:
|
||||
logging.info("Saving to HTML '%s'", parser.get_relative_save_path(target_path))
|
||||
with open(target_path, 'w', encoding=FORMAT) as file:
|
||||
file.write(html)
|
||||
|
||||
def __download_image(self, img_src, base_url):
|
||||
@@ -295,6 +292,7 @@ class HTMLDownloadHandler(DownloadHandler):
|
||||
parsed = urlsplit(img_src)
|
||||
img_src = parser.urljoin(base_url, img_src)
|
||||
|
||||
# Check if the port of the image sh
|
||||
if parsed.hostname is None or parsed.hostname == self.cmd.host:
|
||||
port = self.cmd.port
|
||||
elif ":" in parsed.netloc:
|
||||
|
Reference in New Issue
Block a user