client: cleanup

This commit is contained in:
2021-03-21 00:01:31 +01:00
parent fa8d08d63d
commit d8a5765fd8
4 changed files with 242 additions and 374 deletions

249
client.py
View File

@@ -1,207 +1,9 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse import argparse
import logging import logging
import re
import socket
import sys import sys
import time
from urllib.parse import urlparse
from client import ResponseHandler from client.command import Command
from client.httpclient import HTTPClient
FORMAT = 'utf-8'
BUFSIZE = 4096
def receive_bytes_chunk(client: socket.socket):
buffering = True
buffer = b''
while buffering:
received = client.recv(BUFSIZE)
received_size = len(received)
logging.debug("Received size: %s", received_size)
logging.debug("Received: %r", received)
def receive_bytes(client: socket.socket):
buffering = True
buffer = b''
while buffering:
received = client.recv(BUFSIZE)
received_size = len(received)
logging.debug("Received size: %s", received_size)
logging.debug("Received: %r", received)
if received_size < BUFSIZE:
buffering = False
buffer += received
lf_pos = buffer.find(b"\n\n")
crlf_pos = buffer.find(b"\r\n\r\n")
if lf_pos != -1 and lf_pos < crlf_pos:
buffer_split = buffer.split(b"\n\n")
else:
buffer_split = buffer.split(b"\r\n\r\n")
buffer = buffer_split[-1]
for part in buffer_split[:-1]:
yield part + b"\r\n\r\n"
if buffer:
buffering = True
if buffer:
yield buffer
def receive(client: socket.socket):
if client.fileno() == -1:
raise Exception("Connection closed")
result = client.recv(BUFSIZE)
if len(result) == 0:
time.sleep(0.1)
result = client.recv(BUFSIZE)
return result
def parse_header(data: bytes):
headers = {}
# decode bytes, split into lines and filter
header_split = list(
filter(lambda l: l is not "" and not l[0].isspace(), map(str.strip, data.decode("utf-8").split("\n"))))
if len(header_split) == 0:
raise Exception("No start-line")
start_line = header_split.pop(0)
logging.debug("start-line: %r", start_line)
for line in header_split:
pos = line.find(":")
if pos <= 0 or pos >= len(line) - 1:
continue
(header, value) = map(str.strip, line.split(":", 1))
headers[header.upper()] = value
logging.debug("Parsed headers: %r", headers)
return start_line, headers
def validate_status_line(status_line: str):
split = list(filter(None, status_line.split(" ")))
if len(split) < 3:
return False
# Check HTTP version
http_version = split.pop(0)
if len(http_version) < 8 or http_version[4] != "/":
return False
(name, version) = http_version[:4], http_version[5:]
if name != "HTTP" or not re.match(r"1\.[0|1]", version):
return False
if not re.match(r"\d{3}", split[0]):
return False
return True
def get_chunk(buffer: bytes):
lf_pos = buffer.find(b"\n\n")
crlf_pos = buffer.find(b"\r\n\r\n")
if lf_pos != -1 and lf_pos < crlf_pos:
split_start = lf_pos
split_end = lf_pos + 2
else:
split_start = crlf_pos
split_end = crlf_pos + 4
return buffer[:split_start], buffer[split_end:]
def response_parser(client: socket.socket):
try:
buffer = client.recv(BUFSIZE)
except TimeoutError as err:
# TODO handler error appropriately
logging.debug("[ERR] Socket timeout: %r", exc_info=err)
return
(header_chunk, buffer) = get_chunk(buffer)
(status_line, headers) = parse_header(header_chunk)
if not validate_status_line(status_line):
raise Exception("Invalid status-line")
logging.debug("valid status-line: %r", status_line)
encoding = "plain"
if "TRANSFER-ENCODING" in headers:
encoding = headers["TRANSFER-ENCODING"]
if encoding == "plain" and "CONTENT-LENGTH" in headers:
payload_size = int(headers["CONTENT-LENGTH"])
if payload_size == 0:
return
filename = util.get_html_filename(headers)
f = open(filename, "wb")
f.write(buffer)
cur_payload_size = len(buffer)
while cur_payload_size < payload_size:
buffer = receive(client)
logging.debug("Received payload: %r", buffer)
if len(buffer) == 0:
logging.warning("Received payload length %s less than expected %s", payload_size, cur_payload_size)
break
cur_payload_size += len(buffer)
f.write(buffer)
f.close()
def http_parser(client: socket.socket):
headers = {}
start_line = ""
receiver = receive_bytes(client)
(status_line, headers) = parse_header(next(receiver))
if not validate_status_line(status_line):
raise Exception("Invalid header")
logging.debug("valid status-line: %r", status_line)
for chunk in receiver:
logging.debug("chunk: %r", chunk)
def parse_uri(uri: str):
parsed = urlparse(uri)
# If there is no netloc, the url is invalid, so prepend `//` and try again
if parsed.netloc == "":
parsed = urlparse("//" + uri)
host = parsed.netloc
path = parsed.path
if len(path) == 0 or path[0] != '/':
path = "/" + path
port_pos = host.find(":")
if port_pos >= 0:
host = host[:port_pos]
return host, path
def main(): def main():
@@ -216,53 +18,8 @@ def main():
logging.basicConfig(level=logging.ERROR - (10 * arguments.verbose)) logging.basicConfig(level=logging.ERROR - (10 * arguments.verbose))
logging.debug("Arguments: %s", arguments) logging.debug("Arguments: %s", arguments)
(host, path) = parse_uri(arguments.URI) command = Command.create(arguments.command, arguments.URI, arguments.port)
client = HTTPClient(host) command.execute()
client.connect((host, int(arguments.port)))
message = "GET {path} HTTP/1.1\r\n".format(path=path)
message += "Accept: */*\r\nAccept-Encoding: identity\r\n"
message += "Host: {host}\r\n\r\n".format(host=host)
message = message.encode(FORMAT)
logging.debug("Sending HTTP message: %r", message)
client.sendall(message)
ResponseHandler.handle(client, arguments.URI)
# response_parser(client)
# http_parser(client)
# tmp = b''
# keep = False
# count = 0
# for line in receive_bytes(client):
#
# if count > 0:
# tmp += line.rstrip(b"\r\n")
# if keep:
# count += 1
#
# if line == b'\r\n':
# keep = True
#
# logging.debug('end of part 1')
#
# logging.debug("attempt 2")
# while True:
# logging.debug("attempt")
# keep = False
# for line in receive_bytes(client):
# if line == b"0\r\n":
# break
# if keep:
# tmp += line.rstrip(b"\r\n")
# keep = True
#
# if b"0\r\n" == line:
# break
# logging.debug("content: %s", tmp)
# # logging.debug("content: %r", tmp.replace(b"\r\n", b"").decode("utf-8"))
#
# f = open("test.jpeg", "wb")
# f.write(tmp)
try: try:

View File

@@ -10,121 +10,6 @@ from client.Retriever import Retriever
from client.httpclient import HTTPClient, UnsupportedEncoding, FORMAT, InvalidResponse, InvalidStatusLine from client.httpclient import HTTPClient, UnsupportedEncoding, FORMAT, InvalidResponse, InvalidStatusLine
def handle(client: HTTPClient, url: str):
logging.debug("Waiting for response")
try:
(version, status, _) = get_status_line(client)
logging.debug("Parsed status-line: version: %s, status: %s", version, status)
headers = get_headers(client)
logging.debug("Parsed headers: %r", headers)
response_handler = construct(client, headers, status, url)
response_handler.handle()
except InvalidResponse as e:
logging.debug("Internal error: Response could not be parsed", exc_info=e)
return
except InvalidStatusLine as e:
logging.debug("Internal error: Invalid status-line in response", exc_info=e)
return
except UnsupportedEncoding as e:
logging.debug("Internal error: Unsupported encoding in response", exc_info=e)
return
def get_status_line(client: HTTPClient):
line = client.read_line()
split = list(filter(None, line.split(" ")))
if len(split) < 3:
raise InvalidStatusLine(line)
# Check HTTP version
http_version = split.pop(0)
if len(http_version) < 8 or http_version[4] != "/":
raise InvalidStatusLine(line)
(name, version) = http_version[:4], http_version[5:]
if name != "HTTP" or not re.match(r"1\.[0|1]", version):
raise InvalidStatusLine(line)
status = split.pop(0)
if not re.match(r"\d{3}", status):
raise InvalidStatusLine(line)
status = int(status)
if status < 100 or status > 999:
raise InvalidStatusLine(line)
reason = split.pop(0)
return version, status, reason
def get_headers(client: HTTPClient):
headers = []
# first header after the status-line may not contain a space
while True:
line = client.read_line()
if line[0].isspace():
continue
else:
break
while True:
if line in ("\r\n", "\n", " "):
break
if line[0].isspace():
headers[-1] = headers[-1].rstrip("\r\n")
headers.append(line.lstrip())
line = client.read_line()
result = {}
header_str = "".join(headers)
for line in header_str.splitlines():
pos = line.find(":")
if pos <= 0 or pos >= len(line) - 1:
continue
(header, value) = map(str.strip, line.split(":", 1))
check_next_header(result, header, value)
result[header.lower()] = value.lower()
return result
def check_next_header(headers, next_header: str, next_value: str):
if next_header == "content-length":
if "content-length" in headers:
logging.error("Multiple content-length headers specified")
raise InvalidResponse()
if not next_value.isnumeric() or int(next_value) <= 0:
logging.error("Invalid content-length value: %r", next_value)
raise InvalidResponse()
def construct(client: HTTPClient, headers, status_code, url):
# only chunked transfer-encoding is supported
transfer_encoding = headers.get("transfer-encoding")
if transfer_encoding and transfer_encoding != "chunked":
raise UnsupportedEncoding("transfer-encoding", transfer_encoding)
chunked = transfer_encoding
# content-encoding is not supported
content_encoding = headers.get("content-encoding")
if content_encoding:
raise UnsupportedEncoding("content-encoding", content_encoding)
retriever = Retriever.create(client, headers)
content_type = headers.get("content-type")
if content_type and "text/html" in content_type:
return HTMLDownloadHandler(retriever, client, headers, url)
return RawDownloadHandler(retriever, client, headers, url)
def parse_uri(uri: str): def parse_uri(uri: str):
parsed = urlparse(uri) parsed = urlparse(uri)
@@ -156,6 +41,98 @@ class ResponseHandler:
def handle(self): def handle(self):
pass pass
@staticmethod
def create(client: HTTPClient, headers, status_code, url):
# only chunked transfer-encoding is supported
transfer_encoding = headers.get("transfer-encoding")
if transfer_encoding and transfer_encoding != "chunked":
raise UnsupportedEncoding("transfer-encoding", transfer_encoding)
chunked = transfer_encoding
# content-encoding is not supported
content_encoding = headers.get("content-encoding")
if content_encoding:
raise UnsupportedEncoding("content-encoding", content_encoding)
retriever = Retriever.create(client, headers)
content_type = headers.get("content-type")
if content_type and "text/html" in content_type:
return HTMLDownloadHandler(retriever, client, headers, url)
return RawDownloadHandler(retriever, client, headers, url)
@staticmethod
def get_status_line(client: HTTPClient):
line = client.read_line()
split = list(filter(None, line.split(" ")))
if len(split) < 3:
raise InvalidStatusLine(line)
# Check HTTP version
http_version = split.pop(0)
if len(http_version) < 8 or http_version[4] != "/":
raise InvalidStatusLine(line)
(name, version) = http_version[:4], http_version[5:]
if name != "HTTP" or not re.match(r"1\.[0|1]", version):
raise InvalidStatusLine(line)
status = split.pop(0)
if not re.match(r"\d{3}", status):
raise InvalidStatusLine(line)
status = int(status)
if status < 100 or status > 999:
raise InvalidStatusLine(line)
reason = split.pop(0)
return version, status, reason
@staticmethod
def get_headers(client: HTTPClient):
headers = []
# first header after the status-line may not contain a space
while True:
line = client.read_line()
if line[0].isspace():
continue
else:
break
while True:
if line in ("\r\n", "\n", " "):
break
if line[0].isspace():
headers[-1] = headers[-1].rstrip("\r\n")
headers.append(line.lstrip())
line = client.read_line()
result = {}
header_str = "".join(headers)
for line in header_str.splitlines():
pos = line.find(":")
if pos <= 0 or pos >= len(line) - 1:
continue
(header, value) = map(str.strip, line.split(":", 1))
ResponseHandler.check_next_header(result, header, value)
result[header.lower()] = value.lower()
return result
@staticmethod
def check_next_header(headers, next_header: str, next_value: str):
if next_header == "content-length":
if "content-length" in headers:
logging.error("Multiple content-length headers specified")
raise InvalidResponse()
if not next_value.isnumeric() or int(next_value) <= 0:
logging.error("Invalid content-length value: %r", next_value)
raise InvalidResponse()
class DownloadHandler(ResponseHandler): class DownloadHandler(ResponseHandler):
path: str path: str
@@ -220,9 +197,9 @@ class DownloadHandler(ResponseHandler):
def _handle_sub_request(self, client, url): def _handle_sub_request(self, client, url):
(version, status, _) = get_status_line(client) (version, status, _) = self.get_status_line(client)
logging.debug("Parsed status-line: version: %s, status: %s", version, status) logging.debug("Parsed status-line: version: %s, status: %s", version, status)
headers = get_headers(client) headers = self.get_headers(client)
logging.debug("Parsed headers: %r", headers) logging.debug("Parsed headers: %r", headers)
if status != 200: if status != 200:
@@ -275,29 +252,37 @@ class HTMLDownloadHandler(DownloadHandler):
with open(tmp_filename, "rb") as fp: with open(tmp_filename, "rb") as fp:
soup = BeautifulSoup(fp, 'html.parser') soup = BeautifulSoup(fp, 'html.parser')
base_url = self.url
base_element = soup.find("base")
if base_element:
base_url = base_element["href"]
for tag in soup.find_all("img"): for tag in soup.find_all("img"):
try: try:
tag["src"] = self.__download_image(tag["src"], host, path) tag["src"] = self.__download_image(tag["src"], host, base_url)
except Exception as e: except Exception as e:
logging.error("Failed to download image: %s, skipping...", tag["src"], exc_info=e) logging.debug(e)
logging.error("Failed to download image: %s, skipping...", tag["src"])
with open(target_filename, 'w') as file: with open(target_filename, 'w') as file:
file.write(str(soup)) file.write(str(soup))
def __download_image(self, img_src, host, path): def __download_image(self, img_src, host, base_url):
parsed = urlparse(img_src) parsed = urlparse(img_src)
logging.debug("Downloading image: %s", img_src) logging.debug("Downloading image: %s", img_src)
same_host = True if len(parsed.netloc) == 0 and parsed.path != "/":
# relative url, append base_url
img_src = os.path.join(os.path.dirname(base_url), parsed.path)
parsed = urlparse(img_src)
# Check if the image is located on the same server
if len(parsed.netloc) == 0 or parsed.netloc == host: if len(parsed.netloc) == 0 or parsed.netloc == host:
same_host = True
img_host = host img_host = host
if parsed.path[0] != "/":
base = os.path.split(path)[0]
if base[-1] != '/':
base += "/"
img_path = base + parsed.path
else:
img_path = parsed.path img_path = parsed.path
else: else:
same_host = False same_host = False

126
client/command.py Normal file
View File

@@ -0,0 +1,126 @@
import logging
from urllib.parse import urlparse
from client.ResponseHandler import ResponseHandler
from client.httpclient import FORMAT, HTTPClient, InvalidResponse, InvalidStatusLine, UnsupportedEncoding
class Command:
command: str
def __init__(self, url: str, port: str):
self.url = url
self.port = port
@staticmethod
def create(command: str, url: str, port: str):
if command == "GET":
return GetCommand(url, port)
elif command == "HEAD":
return HeadCommand(url, port)
elif command == "POST":
return PostCommand(url, port)
elif command == "PUT":
return PutCommand(url, port)
else:
raise ValueError()
def execute(self):
(host, path) = self.parse_uri()
client = HTTPClient(host)
client.connect((host, int(self.port)))
message = f"{self.command} {path} HTTP/1.1\r\n"
message += f"Host: {host}\r\n"
message += "Accept: */*\r\nAccept-Encoding: identity\r\n"
encoded_msg = self._build_message(message)
logging.info("---request begin---\r\n%s---request end---", encoded_msg.decode(FORMAT))
logging.debug("Sending HTTP message: %r", encoded_msg)
client.sendall(encoded_msg)
logging.info("HTTP request sent, awaiting response...")
try:
self._await_response(client)
except InvalidResponse as e:
logging.debug("Internal error: Response could not be parsed", exc_info=e)
return
except InvalidStatusLine as e:
logging.debug("Internal error: Invalid status-line in response", exc_info=e)
return
except UnsupportedEncoding as e:
logging.debug("Internal error: Unsupported encoding in response", exc_info=e)
finally:
client.close()
def _await_response(self, client: HTTPClient):
pass
def _build_message(self, message: str) -> bytes:
return (message + "\r\n").encode(FORMAT)
def parse_uri(self):
parsed = urlparse(self.url)
# If there is no netloc, the url is invalid, so prepend `//` and try again
if parsed.netloc == "":
parsed = urlparse("//" + self.url)
host = parsed.netloc
path = parsed.path
if len(path) == 0 or path[0] != '/':
path = "/" + path
port_pos = host.find(":")
if port_pos >= 0:
host = host[:port_pos]
return host, path
class HeadCommand(Command):
command = "HEAD"
def _await_response(self, client):
while True:
line = client.read_line()
print(line, end="")
if line in ("\r\n", "\n", ""):
break
class GetCommand(Command):
command = "GET"
def _await_response(self, client):
(version, status, msg) = ResponseHandler.get_status_line(client)
logging.debug("Parsed status-line: version: %s, status: %s", version, status)
headers = ResponseHandler.get_headers(client)
logging.debug("Parsed headers: %r", headers)
handler = ResponseHandler.create(client, headers, status, self.url)
handler.handle()
class PostCommand(HeadCommand):
command = "POST"
def _build_message(self, message: str) -> bytes:
body = input("Enter POST data: ").encode(FORMAT)
print()
message += "Content-Type: text/plain\r\n"
message += f"Content-Length: {len(body)}\r\n"
message += "\r\n"
message = message.encode(FORMAT)
message += body
message += b"\r\n"
return message
class PutCommand(PostCommand):
command = "PUT"

0
server/__init__.py Normal file
View File