Fix small issues, improve error handling and documentation
This commit is contained in:
@@ -5,7 +5,7 @@ from urllib.parse import urlparse
|
||||
|
||||
from client.httpclient import HTTPClient
|
||||
from httplib import parser
|
||||
from httplib.exceptions import InvalidResponse, InvalidStatusLine, UnsupportedEncoding
|
||||
from httplib.exceptions import InvalidResponse, InvalidStatusLine, UnsupportedEncoding, UnsupportedProtocol
|
||||
from httplib.httpsocket import FORMAT
|
||||
from httplib.message import ResponseMessage as Message
|
||||
from httplib.retriever import PreambleRetriever
|
||||
@@ -42,12 +42,10 @@ class AbstractCommand(ABC):
|
||||
_host: str
|
||||
_path: str
|
||||
_port: int
|
||||
sub_request: bool
|
||||
|
||||
def __init__(self, uri: str, port):
|
||||
self.uri = uri
|
||||
self._port = int(port)
|
||||
self.sub_request = False
|
||||
|
||||
@property
|
||||
def uri(self):
|
||||
@@ -81,23 +79,24 @@ class AbstractCommand(ABC):
|
||||
|
||||
@param sub_request: If this execution is in function of a prior command.
|
||||
"""
|
||||
self.uri = ""
|
||||
self.sub_request = sub_request
|
||||
(host, path) = self.parse_uri()
|
||||
|
||||
client = sockets.get(host)
|
||||
client = sockets.get(self.host)
|
||||
|
||||
if client and client.is_closed():
|
||||
sockets.pop(self.host)
|
||||
client = None
|
||||
|
||||
if not client:
|
||||
client = HTTPClient(host)
|
||||
client.conn.connect((host, self.port))
|
||||
sockets[host] = client
|
||||
logging.info("Connecting to %s", self.host)
|
||||
client = HTTPClient(self.host)
|
||||
client.conn.connect((self.host, self.port))
|
||||
logging.info("Connected.")
|
||||
sockets[self.host] = client
|
||||
else:
|
||||
logging.info("Reusing socket for %s", self.host)
|
||||
|
||||
message = f"{self.method} {path} HTTP/1.1\r\n"
|
||||
message += f"Host: {host}:{self.port}\r\n"
|
||||
message = f"{self.method} {self.path} HTTP/1.1\r\n"
|
||||
message += f"Host: {self.host}:{self.port}\r\n"
|
||||
message += "Accept: */*\r\n"
|
||||
message += "Accept-Encoding: identity\r\n"
|
||||
encoded_msg = self._build_message(message)
|
||||
@@ -111,13 +110,17 @@ class AbstractCommand(ABC):
|
||||
try:
|
||||
self._await_response(client)
|
||||
except InvalidResponse as e:
|
||||
logging.debug("Internal error: Response could not be parsed", exc_info=e)
|
||||
return
|
||||
logging.error("Response could not be parsed")
|
||||
logging.debug("", exc_info=e)
|
||||
except InvalidStatusLine as e:
|
||||
logging.debug("Internal error: Invalid status-line in response", exc_info=e)
|
||||
return
|
||||
logging.error("Invalid status-line in response")
|
||||
logging.debug("", exc_info=e)
|
||||
except UnsupportedEncoding as e:
|
||||
logging.debug("Internal error: Unsupported encoding in response", exc_info=e)
|
||||
logging.error("Unsupported encoding in response")
|
||||
logging.debug("", exc_info=e)
|
||||
except UnsupportedProtocol as e:
|
||||
logging.error("Unsupported protocol: %s", e.protocol)
|
||||
logging.debug("", exc_info=e)
|
||||
finally:
|
||||
if not sub_request:
|
||||
client.close()
|
||||
|
@@ -7,7 +7,7 @@ from urllib.parse import urlsplit, unquote
|
||||
from client.command import AbstractCommand, GetCommand
|
||||
from client.httpclient import HTTPClient
|
||||
from httplib import parser
|
||||
from httplib.exceptions import InvalidResponse
|
||||
from httplib.exceptions import InvalidResponse, UnhandledHTTPCode, UnsupportedProtocol
|
||||
from httplib.httpsocket import FORMAT
|
||||
from httplib.message import ResponseMessage as Message
|
||||
from httplib.retriever import Retriever
|
||||
@@ -91,8 +91,7 @@ class BasicResponseHandler(ResponseHandler):
|
||||
|
||||
if self.msg.status == 101:
|
||||
# Switching protocols is not supported
|
||||
print("".join(self.msg.raw), end="")
|
||||
return None
|
||||
raise UnhandledHTTPCode(self.msg.status, "".join(self.msg.raw), "Switching protocols is not supported")
|
||||
|
||||
if 200 <= self.msg.status < 300:
|
||||
return self.retriever
|
||||
@@ -105,16 +104,13 @@ class BasicResponseHandler(ResponseHandler):
|
||||
if 400 <= self.msg.status < 600:
|
||||
self._skip_body()
|
||||
# Dump headers and exit with error
|
||||
if not self.cmd.sub_request:
|
||||
print("".join(self.msg.raw), end="")
|
||||
return None
|
||||
raise UnhandledHTTPCode(self.msg.status, "".join(self.msg.raw), self.msg.msg)
|
||||
|
||||
return None
|
||||
|
||||
def _handle_redirect(self):
|
||||
if self.msg.status == 304:
|
||||
print("".join(self.msg.raw), end="")
|
||||
return None
|
||||
raise UnhandledHTTPCode(self.msg.status, "".join(self.msg.raw), self.msg.msg)
|
||||
|
||||
location = self.msg.headers.get("location")
|
||||
if not location or len(location.strip()) == 0:
|
||||
@@ -126,7 +122,7 @@ class BasicResponseHandler(ResponseHandler):
|
||||
raise InvalidResponse("Invalid location")
|
||||
|
||||
if not parsed_location.scheme == "http":
|
||||
raise InvalidResponse("Only http is supported")
|
||||
raise UnsupportedProtocol(parsed_location.scheme)
|
||||
|
||||
self.cmd.uri = location
|
||||
|
||||
@@ -195,7 +191,7 @@ class RawDownloadHandler(DownloadHandler):
|
||||
super().__init__(retriever, client, msg, cmd, directory)
|
||||
|
||||
def handle(self) -> str:
|
||||
logging.debug("Retrieving payload")
|
||||
logging.info("Saving to '%s'", parser.get_relative_save_path(self.path))
|
||||
file = open(self.path, "wb")
|
||||
|
||||
for buffer in self.retriever.retrieve():
|
||||
@@ -225,20 +221,20 @@ class HTMLDownloadHandler(DownloadHandler):
|
||||
os.remove(tmp_path)
|
||||
return self.path
|
||||
|
||||
def _download_images(self, tmp_filename, target_filename, charset=FORMAT):
|
||||
def _download_images(self, tmp_path, target_path, charset=FORMAT):
|
||||
"""
|
||||
Downloads images referenced in the html of `tmp_filename` and replaces the references in the html
|
||||
and writes it to `target_filename`.
|
||||
@param tmp_filename: the path to the temporary html file
|
||||
@param target_filename: the path for the final html file
|
||||
@param tmp_path: the path to the temporary html file
|
||||
@param target_path: the path for the final html file
|
||||
@param charset: the charset to decode `tmp_filename`
|
||||
"""
|
||||
|
||||
try:
|
||||
fp = open(tmp_filename, "r", encoding=charset)
|
||||
fp = open(tmp_path, "r", encoding="yeetus")
|
||||
html = fp.read()
|
||||
except UnicodeDecodeError:
|
||||
fp = open(tmp_filename, "r", encoding=FORMAT, errors="replace")
|
||||
except UnicodeDecodeError or LookupError:
|
||||
fp = open(tmp_path, "r", encoding=FORMAT, errors="replace")
|
||||
html = fp.read()
|
||||
|
||||
fp.close()
|
||||
@@ -281,7 +277,8 @@ class HTMLDownloadHandler(DownloadHandler):
|
||||
for (start, end, path) in to_replace:
|
||||
html = html[:start] + path + html[end:]
|
||||
|
||||
with open(target_filename, 'w', encoding=FORMAT) as file:
|
||||
logging.info("Saving to HTML '%s'", parser.get_relative_save_path(target_path))
|
||||
with open(target_path, 'w', encoding=FORMAT) as file:
|
||||
file.write(html)
|
||||
|
||||
def __download_image(self, img_src, base_url):
|
||||
@@ -295,6 +292,7 @@ class HTMLDownloadHandler(DownloadHandler):
|
||||
parsed = urlsplit(img_src)
|
||||
img_src = parser.urljoin(base_url, img_src)
|
||||
|
||||
# Check if the port of the image sh
|
||||
if parsed.hostname is None or parsed.hostname == self.cmd.host:
|
||||
port = self.cmd.port
|
||||
elif ":" in parsed.netloc:
|
||||
|
Reference in New Issue
Block a user