import logging import re from urllib.parse import urlparse, urlsplit from httplib.exceptions import InvalidStatusLine, InvalidResponse, BadRequest from httplib.httpsocket import HTTPSocket def _get_start_line(client: HTTPSocket): line = client.read_line().strip() split = list(filter(None, line.split(" ", 2))) if len(split) < 3: raise InvalidStatusLine(line) # TODO fix exception return line, split def _is_valid_http_version(http_version: str): if len(http_version) < 8 or http_version[4] != "/": return False (name, version) = http_version[:4], http_version[5:] if name != "HTTP" or not re.match(r"1\.[0|1]", version): return False return True def get_status_line(client: HTTPSocket): line, (http_version, status, reason) = _get_start_line(client) if not _is_valid_http_version(http_version): raise InvalidStatusLine(line) version = http_version[:4] if not re.match(r"\d{3}", status): raise InvalidStatusLine(line) status = int(status) if status < 100 or status > 999: raise InvalidStatusLine(line) return version, status, reason def parse_status_line(line: str): split = list(filter(None, line.strip().split(" ", 2))) if len(split) < 3: raise InvalidStatusLine(line) # TODO fix exception (http_version, status, reason) = split if not _is_valid_http_version(http_version): raise InvalidStatusLine(line) version = http_version[:4] if not re.match(r"\d{3}", status): raise InvalidStatusLine(line) status = int(status) if status < 100 or status > 999: raise InvalidStatusLine(line) return version, status, reason def parse_request_line(client: HTTPSocket): line, (method, target, version) = _get_start_line(client) logging.debug("Parsed request-line=%r, method=%r, target=%r, version=%r", line, method, target, version) if method not in ("CONNECT", "DELETE", "GET", "HEAD", "OPTIONS", "POST", "PUT", "TRACE"): raise BadRequest() if not _is_valid_http_version(version): logging.debug("[ABRT] request: invalid http-version=%r", version) raise BadRequest() if len(target) == "": raise BadRequest() parsed_target = urlparse(target) if len(parsed_target.path) > 0 and parsed_target.path[0] != "/" and parsed_target.netloc != "": parsed_target = urlparse(f"//{target}") return method, parsed_target, version.split("/")[1] def retrieve_headers(client: HTTPSocket): raw_headers = [] # first header after the status-line may not contain a space while True: line = client.read_line() if line[0].isspace(): continue else: break while True: if line in ("\r\n", "\n", " "): break if line[0].isspace(): raw_headers[-1] = raw_headers[-1].rstrip("\r\n") raw_headers.append(line.lstrip()) line = client.read_line() result = [] header_str = "".join(raw_headers) for line in header_str.splitlines(): pos = line.find(":") if pos <= 0 or pos >= len(line) - 1: continue (header, value) = line.split(":", 1) result.append((header.lower(), value.strip().lower())) return result def parse_request_headers(client: HTTPSocket): raw_headers = retrieve_headers(client) logging.debug("Received headers: %r", raw_headers) headers = {} key: str for (key, value) in raw_headers: if any((c.isspace()) for c in key): raise BadRequest() if key == "content-length": if key in headers: logging.error("Multiple content-length headers specified") raise BadRequest() if not value.isnumeric() or int(value) <= 0: logging.error("Invalid content-length value: %r", value) raise BadRequest() elif key == "host": if value != client.host and value != client.host.split(":")[0] or key in headers: raise BadRequest() headers[key] = value return headers def get_headers(client: HTTPSocket): headers = [] # first header after the status-line may not contain a space while True: line = client.read_line() if line[0].isspace(): continue else: break while True: if line in ("\r\n", "\n", " "): break if line[0].isspace(): headers[-1] = headers[-1].rstrip("\r\n") headers.append(line.lstrip()) line = client.read_line() result = {} header_str = "".join(headers) for line in header_str.splitlines(): pos = line.find(":") if pos <= 0 or pos >= len(line) - 1: continue (header, value) = map(str.strip, line.split(":", 1)) check_next_header(result, header, value) result[header.lower()] = value.lower() return result def parse_headers(lines): headers = [] # first header after the status-line may not contain a space for line in lines: line = next(lines) if line[0].isspace(): continue else: break for line in lines: if line in ("\r\n", "\n", " "): break if line[0].isspace(): headers[-1] = headers[-1].rstrip("\r\n") headers.append(line.lstrip()) result = {} header_str = "".join(headers) for line in header_str.splitlines(): pos = line.find(":") if pos <= 0 or pos >= len(line) - 1: continue (header, value) = map(str.strip, line.split(":", 1)) check_next_header(result, header, value) result[header.lower()] = value.lower() return result def check_next_header(headers, next_header: str, next_value: str): if next_header == "content-length": if "content-length" in headers: logging.error("Multiple content-length headers specified") raise InvalidResponse() if not next_value.isnumeric() or int(next_value) <= 0: logging.error("Invalid content-length value: %r", next_value) raise InvalidResponse() def parse_uri(uri: str): parsed = urlsplit(uri) # If there is no netloc, the given string is not a valid URI, so split on / if parsed.hostname: host = parsed.hostname path = parsed.path if parsed.query != '': path = f"{path}?{parsed.query}" elif "/" in uri: (host, path) = uri.split("/", 1) else: host = uri path = "/" if ":" in host: host, port = host.split(":", 1) elif parsed.scheme == "https": port = 443 else: port = 80 return host, port, path def base_url(uri: str): parsed = urlsplit(uri) path = parsed.path.rsplit("/", 1)[0] return f"{parsed.scheme}://{parsed.hostname}{path}/"