import logging import os import re from abc import ABC, abstractmethod from urllib.parse import urlsplit, unquote from client.command import AbstractCommand, GetCommand from client.httpclient import HTTPClient from httplib import parser from httplib.exceptions import InvalidResponse, UnhandledHTTPCode, UnsupportedProtocol from httplib.httpsocket import FORMAT from httplib.message import ResponseMessage as Message from httplib.retriever import Retriever BASE_REGEX = re.compile(r"<\s*base[^>]*\shref\s*=\s*['\"]([^\"']+)['\"][^>]*>", re.M | re.I) IMG_REGEX = re.compile(r"<\s*img[^>]*\ssrc\s*=\s*['\"]([^\"']+)['\"][^>]*>", re.M | re.I) def handle(client: HTTPClient, msg: Message, command: AbstractCommand, directory=None): """ Handle the response of the request message @param client: the client which sent the request. @param msg: the response message @param command: the command of the sent request-message @param directory: the directory to download the response to (if available) """ handler = BasicResponseHandler(client, msg, command) retriever = handler.handle() if retriever is None: return content_type = msg.headers.get("content-type") if content_type and "text/html" in content_type: handler = HTMLDownloadHandler(retriever, client, msg, command, directory) else: handler = RawDownloadHandler(retriever, client, msg, command, directory) return handler.handle() class ResponseHandler(ABC): """ Helper class for handling response messages. """ client: HTTPClient retriever: Retriever msg: Message cmd: AbstractCommand def __init__(self, retriever: Retriever, client: HTTPClient, msg, cmd): self.client = client self.retriever = retriever self.msg = msg self.cmd = cmd @abstractmethod def handle(self): """ Handle the response. """ pass class BasicResponseHandler(ResponseHandler): """ Response handler which will handle redirects and other HTTP status codes. In case of a redirect, it will process it and pass it to the appropriate response handler. """ def __init__(self, client: HTTPClient, msg: Message, cmd: AbstractCommand): retriever = Retriever.create(client, msg.headers) super().__init__(retriever, client, msg, cmd) def handle(self): return self._handle_status() def _skip_body(self): logging.debug("Skipping body: [") for line in self.retriever.retrieve(): try: logging.debug("%s", line.decode(FORMAT)) except UnicodeDecodeError: logging.debug("%r", line) logging.debug("] done.") def _handle_status(self): logging.info("%d %s", self.msg.status, self.msg.msg) if self.msg.status == 101: # Switching protocols is not supported raise UnhandledHTTPCode(self.msg.status, "".join(self.msg.raw), "Switching protocols is not supported") if 200 <= self.msg.status < 300: return self.retriever if 300 <= self.msg.status < 400: # Redirect self._skip_body() return self._handle_redirect() if 400 <= self.msg.status < 600: self._skip_body() # Dump headers and exit with error raise UnhandledHTTPCode(self.msg.status, "".join(self.msg.raw), self.msg.msg) return None def _handle_redirect(self): if self.msg.status == 304: raise UnhandledHTTPCode(self.msg.status, "".join(self.msg.raw), self.msg.msg) location = self.msg.headers.get("location") if not location or len(location.strip()) == 0: raise InvalidResponse("No location in redirect") location = parser.urljoin(self.cmd.uri, location) parsed_location = urlsplit(location) if not parsed_location.hostname: raise InvalidResponse("Invalid location") if not parsed_location.scheme == "http": raise UnsupportedProtocol(parsed_location.scheme) self.cmd.uri = location if self.msg.status == 301: logging.info("Status 301. Closing socket [%s]", self.cmd.host) self.client.close() self.cmd.execute() return None class DownloadHandler(ResponseHandler, ABC): def __init__(self, retriever: Retriever, client: HTTPClient, msg, cmd, directory=None): super().__init__(retriever, client, msg, cmd) if not directory: directory = self._create_directory() self.path = self._get_duplicate_name(os.path.join(directory, self.get_filename())) @staticmethod def create(retriever: Retriever, client: HTTPClient, msg, cmd, directory=None): content_type = msg.headers.get("content-type") if content_type and "text/html" in content_type: return HTMLDownloadHandler(retriever, client, msg, cmd, directory) return RawDownloadHandler(retriever, client, msg, cmd, directory) def _create_directory(self): path = self._get_duplicate_name(os.path.abspath(self.client.host)) os.mkdir(path) return path def _get_duplicate_name(self, path): tmp_path = path i = 0 while os.path.exists(tmp_path): i += 1 tmp_path = "{path}.{counter}".format(path=path, counter=i) return tmp_path def get_filename(self): """ Returns the filename to download the payload to. """ filename = os.path.basename(self.cmd.path) if filename == '': return "index.html" while "%" in filename: filename = unquote(filename) filename = re.sub(r"[^\w.+-]+[.]*", '', filename) result = os.path.basename(filename).strip() if any(letter.isalnum() for letter in result): return result return "index.html" class RawDownloadHandler(DownloadHandler): def __init__(self, retriever: Retriever, client: HTTPClient, msg: Message, cmd: AbstractCommand, directory=None): super().__init__(retriever, client, msg, cmd, directory) def handle(self) -> str: logging.info("Saving to '%s'", parser.get_relative_save_path(self.path)) file = open(self.path, "wb") for buffer in self.retriever.retrieve(): file.write(buffer) file.close() return self.path class HTMLDownloadHandler(DownloadHandler): def __init__(self, retriever: Retriever, client: HTTPClient, msg: Message, cmd: AbstractCommand, directory=None): super().__init__(retriever, client, msg, cmd, directory) def handle(self) -> str: (directory, file) = os.path.split(self.path) tmp_filename = f".{file}.tmp" tmp_path = os.path.join(directory, tmp_filename) file = open(tmp_path, "wb") for buffer in self.retriever.retrieve(): file.write(buffer) file.close() charset = parser.get_charset(self.msg.headers) self._download_images(tmp_path, self.path, charset) os.remove(tmp_path) return self.path def _download_images(self, tmp_path, target_path, charset=FORMAT): """ Download images referenced in the html of `tmp_filename` and replaces the references in the html and writes it to `target_filename`. @param tmp_path: the path to the temporary html file @param target_path: the path for the final html file @param charset: the charset to decode `tmp_filename` """ try: fp = open(tmp_path, "r", encoding=charset) html = fp.read() except UnicodeDecodeError or LookupError: fp = open(tmp_path, "r", encoding=FORMAT, errors="replace") html = fp.read() fp.close() base_element = BASE_REGEX.search(html) base_url = self.cmd.uri if base_element: base_url = parser.urljoin(self.cmd.uri, base_element.group(1)) processed = {} to_replace = [] # Find all tags, and the urls from the corresponding `src` fields for m in IMG_REGEX.finditer(html): url_start = m.start(1) url_end = m.end(1) target = m.group(1) try: if len(target) == 0: continue if target in processed: # url is already processed new_url = processed.get(target) else: new_url = self.__download_image(target, base_url) processed[target] = new_url if new_url: local_path = os.path.basename(new_url) to_replace.append((url_start, url_end, local_path)) except Exception as e: logging.error("Failed to download image: %s, skipping...", target, exc_info=e) # reverse the list so urls at the bottom of the html file are processed first. # Otherwise, our start and end positions won't be correct. to_replace.reverse() for (start, end, path) in to_replace: html = html[:start] + path + html[end:] logging.info("Saving to HTML '%s'", parser.get_relative_save_path(target_path)) with open(target_path, 'w', encoding=FORMAT) as file: file.write(html) def __download_image(self, img_src, base_url): """ Download image from the specified `img_src` and `base_url`. If the image is available, it will be downloaded to the directory of `self.path` """ logging.info("Downloading image: %s", img_src) parsed = urlsplit(img_src) img_src = parser.urljoin(base_url, img_src) # Check if the port of the image sh if parsed.hostname is None or parsed.hostname == self.cmd.host: port = self.cmd.port elif ":" in parsed.netloc: port = parsed.netloc.split(":", 1)[1] else: port = 80 command = GetCommand(img_src, port, os.path.dirname(self.path)) command.execute(True) return command.filename