import logging import re from urllib.parse import urlparse from httplib.exceptions import InvalidStatusLine, InvalidResponse, BadRequest from httplib.httpsocket import HTTPSocket def _get_start_line(client: HTTPSocket): line = client.read_line() split = list(filter(None, line.split(" "))) if len(split) < 3: raise InvalidStatusLine(line) # TODO fix exception return line, split def _is_valid_http_version(http_version: str): if len(http_version) < 8 or http_version[4] != "/": return False (name, version) = http_version[:4], http_version[5:] if name != "HTTP" or not re.match(r"1\.[0|1]", version): return False def get_status_line(client: HTTPSocket): line, (http_version, status, reason) = _get_start_line(client) if not _is_valid_http_version(http_version): raise InvalidStatusLine(line) version = http_version[:4] if not re.match(r"\d{3}", status): raise InvalidStatusLine(line) status = int(status) if status < 100 or status > 999: raise InvalidStatusLine(line) return version, status, reason def parse_request_line(client: HTTPSocket): line, (method, target, version) = _get_start_line(client) if method not in ("CONNECT", "DELETE", "GET", "HEAD", "OPTIONS", "POST", "PUT", "TRACE"): raise BadRequest() if not _is_valid_http_version(version): raise BadRequest() if len(target) == "": raise BadRequest() parsed_target = urlparse(target) return method, parsed_target, version def retrieve_headers(client: HTTPSocket): raw_headers = [] # first header after the status-line may not contain a space while True: line = client.read_line() if line[0].isspace(): continue else: break while True: if line in ("\r\n", "\n", " "): break if line[0].isspace(): raw_headers[-1] = raw_headers[-1].rstrip("\r\n") raw_headers.append(line.lstrip()) line = client.read_line() result = [] header_str = "".join(raw_headers) for line in header_str.splitlines(): pos = line.find(":") if pos <= 0 or pos >= len(line) - 1: continue (header, value) = line.split(":", 1) result.append((header.lower(), value.lower())) return result def parse_request_headers(client: HTTPSocket): raw_headers = retrieve_headers(client) headers = {} key: str for (key, value) in raw_headers: if any((c.isspace()) for c in key): raise BadRequest() if key == "content-length": if key in headers: logging.error("Multiple content-length headers specified") raise BadRequest() if not value.isnumeric() or int(value) <= 0: logging.error("Invalid content-length value: %r", value) raise BadRequest() elif key == "host": if value != client.host or key in headers: raise BadRequest() headers[key] = value return headers def get_headers(client: HTTPSocket): headers = [] # first header after the status-line may not contain a space while True: line = client.read_line() if line[0].isspace(): continue else: break while True: if line in ("\r\n", "\n", " "): break if line[0].isspace(): headers[-1] = headers[-1].rstrip("\r\n") headers.append(line.lstrip()) line = client.read_line() result = {} header_str = "".join(headers) for line in header_str.splitlines(): pos = line.find(":") if pos <= 0 or pos >= len(line) - 1: continue (header, value) = map(str.strip, line.split(":", 1)) check_next_header(result, header, value) result[header.lower()] = value.lower() return result def check_next_header(headers, next_header: str, next_value: str): if next_header == "content-length": if "content-length" in headers: logging.error("Multiple content-length headers specified") raise InvalidResponse() if not next_value.isnumeric() or int(next_value) <= 0: logging.error("Invalid content-length value: %r", next_value) raise InvalidResponse()