import logging from typing import Dict from client.httpclient import HTTPClient, BUFSIZE, IncompleteResponse, InvalidResponse, UnsupportedEncoding class Retriever: client: HTTPClient headers: Dict[str, str] def __init__(self, client: HTTPClient): self.client = client def retrieve(self): pass @staticmethod def create(client: HTTPClient, headers: Dict[str, str]): # only chunked transfer-encoding is supported transfer_encoding = headers.get("transfer-encoding") if transfer_encoding and transfer_encoding != "chunked": raise UnsupportedEncoding("transfer-encoding", transfer_encoding) chunked = transfer_encoding # content-encoding is not supported content_encoding = headers.get("content-encoding") if content_encoding: raise UnsupportedEncoding("content-encoding", content_encoding) if chunked: return ChunkedRetriever(client) else: content_length = headers.get("content-length") if not content_length: logging.warning("Transfer-encoding and content-length not specified, trying without") return RawRetriever(client) return ContentLengthRetriever(client, int(content_length)) class ContentLengthRetriever(Retriever): length: int def __init__(self, client: HTTPClient, length: int): super().__init__(client) self.length = length def retrieve(self): cur_payload_size = 0 read_size = BUFSIZE while cur_payload_size < self.length: remaining = self.length - cur_payload_size if remaining < read_size: read_size = remaining try: buffer = self.client.read(remaining) except TimeoutError: logging.error("Timed out before receiving complete payload") self.client.close() raise IncompleteResponse("Timed out before receiving complete payload") except ConnectionError: logging.error("Timed out before receiving complete payload") self.client.close() raise IncompleteResponse("Connection closed before receiving complete payload") logging.debug("Received payload length: %s", len(buffer)) if len(buffer) == 0: logging.warning("Received payload length %s less than expected %s", cur_payload_size, self.length) break cur_payload_size += len(buffer) logging.debug("Processed payload: %r", cur_payload_size) yield buffer return b"" class RawRetriever(Retriever): def retrieve(self): while True: try: yield self.client.read() except TimeoutError or ConnectionError: return b"" class ChunkedRetriever(Retriever): def retrieve(self): while True: chunk_size = self._get_chunk_size() logging.debug("chunk-size: %s", chunk_size) if chunk_size == 0: self.client.reset_request() break buffer = self.client.read(chunk_size) logging.debug("chunk: %r", buffer) yield buffer self.client.read_line() # remove CRLF return b"" def _get_chunk_size(self): line = self.client.read_line() sep_pos = line.find(";") if sep_pos >= 0: line = line[:sep_pos] try: return int(line, 16) except ValueError: raise InvalidResponse()