Compare commits

...

29 Commits

Author SHA1 Message Date
baaa3941d6 fix buffer in PreambleRetriever 2021-03-28 20:27:40 +02:00
6fd015c770 Log more as info 2021-03-28 20:21:53 +02:00
032c71144d change default server port 2021-03-28 20:11:16 +02:00
8eae777265 Fix issues 2021-03-28 19:53:14 +02:00
b7315c2348 Improve documentation 2021-03-28 18:55:00 +02:00
c748387b48 Improve documentation 2021-03-28 17:57:08 +02:00
0f2b039e71 Fix issues with shutdown, improve documentation 2021-03-28 17:12:07 +02:00
210c03b73f Improve logging, fix small issues 2021-03-28 15:58:07 +02:00
cd053bc74e Improve documentation, cleanup duplicated code 2021-03-28 15:20:28 +02:00
07b018d2ab Fix small issues, improve error handling and documentation 2021-03-28 14:04:39 +02:00
850535a060 Improve documentation 2021-03-28 03:33:00 +02:00
b42c17c420 Rename response_handler to responsehandler 2021-03-28 03:00:53 +02:00
7ecfedbec7 Command add properties for fields 2021-03-28 03:00:04 +02:00
48c4f207a8 Rename Server- and ClientMessage
Renamed ServerMessage and ClientMessage to respectively ResponseMessage
and RequestMessage to make it more clear.
2021-03-28 01:59:08 +01:00
1f0ade0f09 Parse port argument as int 2021-03-28 01:58:16 +01:00
07023f2837 Remove server_flow.md 2021-03-28 01:17:30 +01:00
7329a2b9a5 Fix recreating ServerSocket on error 2021-03-28 01:16:40 +01:00
0ffdc73a6d server: use queue blocking instead of sleep 2021-03-28 01:09:11 +01:00
a3ce68330f small update 2021-03-28 00:35:58 +01:00
5b5a821522 fix base regex 2021-03-27 23:52:16 +01:00
4473d1bec9 Parse html with regex, fix small issues 2021-03-27 23:41:28 +01:00
bbca6f603b Improve ChunkedRetriever error handling and documentation 2021-03-27 19:08:06 +01:00
9036755a62 Fix some issues, improve documentation 2021-03-27 19:05:09 +01:00
0f7d67c98d small fixes 2021-03-27 17:32:16 +01:00
ff32ce9b39 Cleanup parser and add documentation 2021-03-27 16:58:48 +01:00
3615c56152 update 2021-03-27 16:30:53 +01:00
fdbd865889 Update 2021-03-26 18:25:03 +01:00
7476870acc client: fix relative paths 2021-03-25 18:26:50 +01:00
f15ff38f69 client: fix image url parsing 2021-03-25 17:56:21 +01:00
21 changed files with 1497 additions and 868 deletions

View File

@@ -4,18 +4,20 @@ import logging
import sys import sys
from client import command as cmd from client import command as cmd
from httplib.exceptions import UnhandledHTTPCode
def main(): def main():
parser = argparse.ArgumentParser(description='HTTP Client') parser = argparse.ArgumentParser(description='HTTP Client')
parser.add_argument("--verbose", "-v", action='count', default=0, help="Increase verbosity level of logging") parser.add_argument("--verbose", "-v", action='count', default=0, help="Increase verbosity level of logging")
parser.add_argument("--command", "-c", help="HEAD, GET, PUT or POST", default="GET") parser.add_argument("--command", "-c", help="HEAD, GET, PUT or POST", default="GET")
parser.add_argument("--port", "-p", help="The port used to connect with the server", default=80) parser.add_argument("--port", "-p", help="The port used to connect with the server", default=80, type=int)
parser.add_argument("URI", help="The URI to connect to") parser.add_argument("URI", help="The URI to connect to")
arguments = parser.parse_args() arguments = parser.parse_args()
logging.basicConfig(level=logging.ERROR - (10 * arguments.verbose)) # Setup logging
logging.basicConfig(level=logging.INFO - (10 * arguments.verbose), format="[%(levelname)s] %(message)s")
logging.debug("Arguments: %s", arguments) logging.debug("Arguments: %s", arguments)
command = cmd.create(arguments.command, arguments.URI, arguments.port) command = cmd.create(arguments.command, arguments.URI, arguments.port)
@@ -24,7 +26,10 @@ def main():
try: try:
main() main()
except UnhandledHTTPCode as e:
logging.info(f"[{e.status_code}] {e.cause}:\r\n{e.headers}")
sys.exit(2)
except Exception as e: except Exception as e:
print("[ABRT] Internal error: " + str(e), file=sys.stderr) logging.info("[ABRT] Internal error: %s", e)
logging.debug("Internal error", exc_info=e) logging.debug("Internal error", exc_info=e)
sys.exit(70) sys.exit(1)

View File

@@ -1,63 +1,103 @@
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, Tuple from typing import Dict
from urllib.parse import urlparse
from client.httpclient import FORMAT, HTTPClient from client.httpclient import HTTPClient
from httplib import parser from httplib import parser
from httplib.exceptions import InvalidResponse, InvalidStatusLine, UnsupportedEncoding from httplib.exceptions import InvalidResponse, InvalidStatusLine, UnsupportedEncoding, UnsupportedProtocol
from httplib.message import Message from httplib.httpsocket import FORMAT
from httplib.message import ResponseMessage as Message
from httplib.retriever import PreambleRetriever from httplib.retriever import PreambleRetriever
sockets: Dict[str, HTTPClient] = {} sockets: Dict[str, HTTPClient] = {}
def create(command: str, url: str, port): def create(method: str, url: str, port):
if command == "GET": """
return GetCommand(url, port) Create a corresponding Command instance of the specified HTTP `method` with the specified `url` and `port`.
elif command == "HEAD": @param method: The command type to create
return HeadCommand(url, port) @param url: The url for the command
elif command == "POST": @param port: The port for the command
return PostCommand(url, port) """
elif command == "PUT":
return PutCommand(url, port) uri = parser.uri_from_url(url)
if method == "GET":
return GetCommand(uri, port)
elif method == "HEAD":
return HeadCommand(uri, port)
elif method == "POST":
return PostCommand(uri, port)
elif method == "PUT":
return PutCommand(uri, port)
else: else:
raise ValueError() raise ValueError("Unknown HTTP method")
class AbstractCommand(ABC): class AbstractCommand(ABC):
uri: str """
host: str A class representing the command for sending an HTTP request.
path: str """
port: Tuple[str, int] _uri: str
_host: str
_path: str
_port: int
def __init__(self, uri: str, port): def __init__(self, uri: str, port):
self.uri = uri self.uri = uri
self.host, _, self.path = parser.parse_uri(uri) self._port = int(port)
self.port = port
@property
def uri(self):
return self._uri
@uri.setter
def uri(self, value):
self._uri = value
self._host, self._port, self._path = parser.parse_uri(value)
@property
def host(self):
return self._host
@property
def path(self):
return self._path
@property
def port(self):
return self._port
@property @property
@abstractmethod @abstractmethod
def command(self): def method(self):
pass pass
def execute(self, sub_request=False): def execute(self, sub_request=False):
(host, path) = self.parse_uri() """
Creates and sends the HTTP message for this Command.
client = sockets.get(host) @param sub_request: If this execution is in function of a prior command.
"""
client = sockets.get(self.host)
if client and client.is_closed(): if client and client.is_closed():
sockets.pop(self.host) sockets.pop(self.host)
client = None client = None
if not client: if not client:
client = HTTPClient(host) logging.info("Connecting to %s", self.host)
client.conn.connect((host, self.port)) client = HTTPClient(self.host)
sockets[host] = client client.conn.connect((self.host, self.port))
logging.info("Connected.")
sockets[self.host] = client
else:
logging.info("Reusing socket for %s", self.host)
message = f"{self.command} {path} HTTP/1.1\r\n" message = f"{self.method} {self.path} HTTP/1.1\r\n"
message += f"Host: {host}:{self.port}\r\n" message += f"Host: {self.host}:{self.port}\r\n"
message += "Accept: */*\r\nAccept-Encoding: identity\r\n" message += "Accept: */*\r\n"
message += "Accept-Encoding: identity\r\n"
encoded_msg = self._build_message(message) encoded_msg = self._build_message(message)
logging.debug("---request begin---\r\n%s---request end---", encoded_msg.decode(FORMAT)) logging.debug("---request begin---\r\n%s---request end---", encoded_msg.decode(FORMAT))
@@ -67,53 +107,98 @@ class AbstractCommand(ABC):
logging.info("HTTP request sent, awaiting response...") logging.info("HTTP request sent, awaiting response...")
try: try:
retriever = PreambleRetriever(client) self._await_response(client)
self._await_response(client, retriever)
except InvalidResponse as e: except InvalidResponse as e:
logging.debug("Internal error: Response could not be parsed", exc_info=e) logging.error("Response could not be parsed")
return logging.debug("", exc_info=e)
except InvalidStatusLine as e: except InvalidStatusLine as e:
logging.debug("Internal error: Invalid status-line in response", exc_info=e) logging.error("Invalid status-line in response")
return logging.debug("", exc_info=e)
except UnsupportedEncoding as e: except UnsupportedEncoding as e:
logging.debug("Internal error: Unsupported encoding in response", exc_info=e) logging.error("Unsupported encoding in response")
logging.debug("", exc_info=e)
except UnsupportedProtocol as e:
logging.error("Unsupported protocol: %s", e.protocol)
logging.debug("", exc_info=e)
finally: finally:
if not sub_request: if not sub_request:
client.close() client.close()
def _await_response(self, client, retriever): def _get_preamble(self, client):
while True: """
line = client.read_line() Returns the preamble (start-line and headers) of the response of this command.
print(line, end="") @param client: the client object to retrieve from
if line in ("\r\n", "\n", ""): @return: A Message object containing the HTTP-version, status code, status message, headers and buffer
break """
retriever = PreambleRetriever(client)
lines = retriever.retrieve()
(version, status, msg) = parser.parse_status_line(next(lines))
headers = parser.parse_headers(lines)
buffer = retriever.buffer
logging.debug("---response begin---\r\n%s---response end---", "".join(buffer))
return Message(version, status, msg, headers, buffer)
def _await_response(self, client):
"""
Simple response method.
Receives the response and prints to stdout.
"""
msg = self._get_preamble(client)
print("".join(msg.raw))
def _build_message(self, message: str) -> bytes: def _build_message(self, message: str) -> bytes:
return (message + "\r\n").encode(FORMAT) return (message + "\r\n").encode(FORMAT)
def parse_uri(self):
parsed = urlparse(self.uri)
# If there is no netloc, the url is invalid, so prepend `//` and try again class HeadCommand(AbstractCommand):
if parsed.netloc == "": """
parsed = urlparse("//" + self.uri) A Command for sending a `HEAD` request.
"""
host = parsed.netloc @property
path = parsed.path def method(self):
if len(path) == 0 or path[0] != '/': return "HEAD"
path = "/" + path
port_pos = host.find(":")
if port_pos >= 0:
host = host[:port_pos]
return host, path class GetCommand(AbstractCommand):
"""
A Command for sending a `GET` request.
"""
dir: str
def __init__(self, uri: str, port, directory=None):
super().__init__(uri, port)
self.dir = directory
self.filename = None
@property
def method(self):
return "GET"
def _await_response(self, client):
"""
Handles the response of this command.
"""
msg = self._get_preamble(client)
from client import responsehandler
self.filename = responsehandler.handle(client, msg, self, self.dir)
class AbstractWithBodyCommand(AbstractCommand, ABC): class AbstractWithBodyCommand(AbstractCommand, ABC):
"""
The building block for creating an HTTP message for an HTTP method with a body (POST and PUT).
"""
def _build_message(self, message: str) -> bytes: def _build_message(self, message: str) -> bytes:
body = input(f"Enter {self.command} data: ").encode(FORMAT) input_line = input(f"Enter {self.method} data: ")
input_line += "\r\n"
body = input_line.encode(FORMAT)
print() print()
message += "Content-Type: text/plain\r\n" message += "Content-Type: text/plain\r\n"
@@ -126,46 +211,21 @@ class AbstractWithBodyCommand(AbstractCommand, ABC):
return message return message
class HeadCommand(AbstractCommand):
@property
def command(self):
return "HEAD"
class GetCommand(AbstractCommand):
def __init__(self, uri: str, port, dir=None):
super().__init__(uri, port)
self.dir = dir
self.filename = None
@property
def command(self):
return "GET"
def _get_preamble(self, retriever):
lines = retriever.retrieve()
(version, status, msg) = parser.parse_status_line(next(lines))
headers = parser.parse_headers(lines)
logging.debug("---response begin---\r\n%s--- response end---", "".join(retriever.buffer))
return Message(version, status, msg, headers)
def _await_response(self, client, retriever):
msg = self._get_preamble(retriever)
from client import response_handler
self.filename = response_handler.handle(client, msg, self, self.dir)
class PostCommand(AbstractWithBodyCommand): class PostCommand(AbstractWithBodyCommand):
"""
A command for sending a `POST` request.
"""
@property @property
def command(self): def method(self):
return "POST" return "POST"
class PutCommand(AbstractWithBodyCommand): class PutCommand(AbstractWithBodyCommand):
"""
A command for sending a `PUT` request.
"""
@property @property
def command(self): def method(self):
return "PUT" return "PUT"

View File

@@ -1,6 +0,0 @@
from bs4 import BeautifulSoup
class HTMLParser:
def __init__(self, soup: BeautifulSoup):
pass

View File

@@ -1,15 +1,27 @@
import socket import socket
from httplib.httpsocket import HTTPSocket from httplib.httpsocket import HTTPSocket, InvalidResponse
BUFSIZE = 4096
TIMEOUT = 3
FORMAT = "UTF-8"
MAXLINE = 4096
class HTTPClient(HTTPSocket): class HTTPClient(HTTPSocket):
"""
Wrapper class for a socket. Represents a client which connects to a server.
"""
host: str host: str
def __init__(self, host: str): def __init__(self, host: str):
super().__init__(socket.socket(socket.AF_INET, socket.SOCK_STREAM), host) super().__init__(socket.socket(socket.AF_INET, socket.SOCK_STREAM))
self.host = host
def read_line(self):
"""
Reads the next line decoded as `httpsocket.FORMAT`
@return: the decoded next line retrieved from the socket
@raise InvalidResponse: If the next line couldn't be decoded, but was expected to
"""
try:
return super().read_line()
except UnicodeDecodeError:
raise InvalidResponse("Unexpected decoding error")

View File

@@ -1,274 +0,0 @@
import logging
import os
import re
from abc import ABC, abstractmethod
from urllib.parse import urlsplit, unquote
from bs4 import BeautifulSoup, Tag
from client.command import AbstractCommand, GetCommand
from client.httpclient import HTTPClient, FORMAT
from httplib import parser
from httplib.exceptions import InvalidResponse
from httplib.message import Message
from httplib.retriever import Retriever
def handle(client: HTTPClient, msg: Message, command: AbstractCommand, dir=None):
handler = BasicResponseHandler(client, msg, command)
retriever = handler.handle()
if retriever is None:
return
content_type = msg.headers.get("content-type")
if content_type and "text/html" in content_type:
handler = HTMLDownloadHandler(retriever, client, msg, command, dir)
else:
handler = RawDownloadHandler(retriever, client, msg, command, dir)
return handler.handle()
class ResponseHandler(ABC):
client: HTTPClient
retriever: Retriever
msg: Message
cmd: AbstractCommand
def __init__(self, retriever: Retriever, client: HTTPClient, msg, cmd):
self.client = client
self.retriever = retriever
self.msg = msg
self.cmd = cmd
@abstractmethod
def handle(self):
pass
@staticmethod
def parse_uri(uri: str):
parsed = urlsplit(uri)
# If there is no netloc, the url is invalid, so prepend `//` and try again
if parsed.netloc == "":
parsed = urlsplit("//" + uri)
host = parsed.netloc
path = parsed.path
if len(path) == 0 or path[0] != '/':
path = "/" + path
return host, path
class BasicResponseHandler(ResponseHandler):
""" Response handler which throws away the body and only shows the headers.
In case of a redirect, it will process it and pass it to the appropriate response handler.
"""
def __init__(self, client: HTTPClient, msg: Message, cmd: AbstractCommand):
retriever = Retriever.create(client, msg.headers)
super().__init__(retriever, client, msg, cmd)
def handle(self):
return self._handle_status()
def _skip_body(self):
logging.debug("Skipping body: [")
for line in self.retriever.retrieve():
try:
logging.debug("%s", line.decode(FORMAT))
except Exception:
logging.debug("%r", line)
logging.debug("] done.")
def _handle_status(self):
logging.info("%d %s", self.msg.status, self.msg.msg)
if self.msg.status == 101:
# Switching protocols is not supported
print(f"{self.msg.version} {self.msg.status} {self.msg.msg}")
print(self.msg.headers)
return
if 200 <= self.msg.status < 300:
return self.retriever
if 300 <= self.msg.status < 400:
# Redirect
return self._do_handle_redirect()
if 400 <= self.msg.status < 500:
# Dump headers and exit with error
print(f"{self.msg.version} {self.msg.status} {self.msg.msg}")
print(self.msg.headers)
return None
def _do_handle_redirect(self):
self._skip_body()
location = self.msg.headers.get("location")
if not location:
raise InvalidResponse("No location in redirect")
parsed_location = urlsplit(location)
if not parsed_location.hostname:
raise InvalidResponse("Invalid location")
if not parsed_location.scheme == "http":
raise InvalidResponse("Only http is supported")
self.cmd.uri = location
self.cmd.host, self.cmd.port, self.cmd.path = parser.parse_uri(location)
if self.msg.status == 301:
logging.info("Status 301. Closing socket [%s]", self.cmd.host)
self.client.close()
self.cmd.execute()
return None
class DownloadHandler(ResponseHandler, ABC):
def __init__(self, retriever: Retriever, client: HTTPClient, msg, cmd, dir=None):
super().__init__(retriever, client, msg, cmd)
if not dir:
dir = self._create_directory()
self.path = self._get_duplicate_name(os.path.join(dir, self.get_filename()))
@staticmethod
def create(retriever: Retriever, client: HTTPClient, msg, cmd, dir=None):
content_type = msg.headers.get("content-type")
if content_type and "text/html" in content_type:
return HTMLDownloadHandler(retriever, client, msg, cmd, dir)
return RawDownloadHandler(retriever, client, msg, cmd, dir)
def _create_directory(self):
path = self._get_duplicate_name(os.path.abspath(self.client.host))
os.mkdir(path)
return path
def _get_duplicate_name(self, path):
tmp_path = path
i = 0
while os.path.exists(tmp_path):
i += 1
tmp_path = "{path}.{counter}".format(path=path, counter=i)
return tmp_path
def get_filename(self):
"""Returns the filename to download the payload to.
"""
filename = os.path.basename(self.cmd.path)
if filename == '':
return "index.html"
while "%" in filename:
filename = unquote(filename)
filename = re.sub(r"[^\w.+-]+[.]*", '', filename)
result = os.path.basename(filename).strip()
if any(letter.isalnum() for letter in result):
return result
return "index.html"
class RawDownloadHandler(DownloadHandler):
def __init__(self, retriever: Retriever, client: HTTPClient, msg: Message, cmd: AbstractCommand, dir=None):
super().__init__(retriever, client, msg, cmd, dir)
def handle(self) -> str:
logging.debug("Retrieving payload")
file = open(self.path, "wb")
for buffer in self.retriever.retrieve():
file.write(buffer)
file.close()
return self.path
class HTMLDownloadHandler(DownloadHandler):
def __init__(self, retriever: Retriever, client: HTTPClient, msg: Message, cmd: AbstractCommand, dir=None):
super().__init__(retriever, client, msg, cmd, dir)
def handle(self) -> str:
(dir, file) = os.path.split(self.path)
tmp_filename = f".{file}.tmp"
tmp_path = os.path.join(dir, tmp_filename)
file = open(tmp_path, "wb")
for buffer in self.retriever.retrieve():
file.write(buffer)
file.close()
self._download_images(tmp_path, self.path)
os.remove(tmp_path)
return self.path
def _download_images(self, tmp_filename, target_filename):
(host, path) = ResponseHandler.parse_uri(self.cmd.uri)
with open(tmp_filename, "rb") as fp:
soup = BeautifulSoup(fp, 'lxml')
base_url = self.cmd.uri
base_element = soup.find("base")
if base_element:
base_url = base_element["href"]
processed = {}
tag: Tag
for tag in soup.find_all("img"):
try:
if not tag.has_attr("src"):
continue
if tag["src"] in processed:
new_url = processed.get(tag["src"])
else:
new_url = self.__download_image(tag["src"], host, base_url)
processed[tag["src"]] = new_url
if new_url:
tag["src"] = os.path.basename(new_url)
except Exception as e:
logging.error("Failed to download image: %s, skipping...", tag["src"], exc_info=e)
with open(target_filename, 'w') as file:
file.write(str(soup))
def __download_image(self, img_src, host, base_url):
logging.debug("Downloading image: %s", img_src)
parsed = urlsplit(img_src)
if parsed.scheme not in ("", "http", "https"):
# Not a valid url
return None
if parsed.hostname is None:
if img_src[0] == "/":
img_src = host + img_src
else:
img_src = os.path.join(os.path.dirname(base_url), img_src)
if parsed.hostname is None or parsed.hostname == host:
port = self.cmd.port
elif ":" in parsed.netloc:
port = parsed.netloc.split(":", 1)[1]
else:
port = 80
command = GetCommand(img_src, port, os.path.dirname(self.path))
command.execute(True)
return command.filename

308
client/responsehandler.py Normal file
View File

@@ -0,0 +1,308 @@
import logging
import os
import re
from abc import ABC, abstractmethod
from urllib.parse import urlsplit, unquote
from client.command import AbstractCommand, GetCommand
from client.httpclient import HTTPClient
from httplib import parser
from httplib.exceptions import InvalidResponse, UnhandledHTTPCode, UnsupportedProtocol
from httplib.httpsocket import FORMAT
from httplib.message import ResponseMessage as Message
from httplib.retriever import Retriever
BASE_REGEX = re.compile(r"<\s*base[^>]*\shref\s*=\s*['\"]([^\"']+)['\"][^>]*>", re.M | re.I)
IMG_REGEX = re.compile(r"<\s*img[^>]*\ssrc\s*=\s*['\"]([^\"']+)['\"][^>]*>", re.M | re.I)
def handle(client: HTTPClient, msg: Message, command: AbstractCommand, directory=None):
"""
Handle the response of the request message
@param client: the client which sent the request.
@param msg: the response message
@param command: the command of the sent request-message
@param directory: the directory to download the response to (if available)
"""
handler = BasicResponseHandler(client, msg, command)
retriever = handler.handle()
if retriever is None:
return
content_type = msg.headers.get("content-type")
if content_type and "text/html" in content_type:
handler = HTMLDownloadHandler(retriever, client, msg, command, directory)
else:
handler = RawDownloadHandler(retriever, client, msg, command, directory)
return handler.handle()
class ResponseHandler(ABC):
"""
Helper class for handling response messages.
"""
client: HTTPClient
retriever: Retriever
msg: Message
cmd: AbstractCommand
def __init__(self, retriever: Retriever, client: HTTPClient, msg, cmd):
self.client = client
self.retriever = retriever
self.msg = msg
self.cmd = cmd
@abstractmethod
def handle(self):
"""
Handle the response.
"""
pass
class BasicResponseHandler(ResponseHandler):
"""
Response handler which will handle redirects and other HTTP status codes.
In case of a redirect, it will process it and pass it to the appropriate response handler.
"""
def __init__(self, client: HTTPClient, msg: Message, cmd: AbstractCommand):
retriever = Retriever.create(client, msg.headers)
super().__init__(retriever, client, msg, cmd)
def handle(self):
return self._handle_status()
def _skip_body(self):
logging.debug("Skipping body: [")
for line in self.retriever.retrieve():
try:
logging.debug("%s", line.decode(FORMAT))
except UnicodeDecodeError:
logging.debug("%r", line)
logging.debug("] done.")
def _handle_status(self):
logging.info("%d %s", self.msg.status, self.msg.msg)
if self.msg.status == 101:
# Switching protocols is not supported
raise UnhandledHTTPCode(self.msg.status, "".join(self.msg.raw), "Switching protocols is not supported")
if 200 <= self.msg.status < 300:
return self.retriever
if 300 <= self.msg.status < 400:
# Redirect
self._skip_body()
return self._handle_redirect()
if 400 <= self.msg.status < 600:
self._skip_body()
# Dump headers and exit with error
raise UnhandledHTTPCode(self.msg.status, "".join(self.msg.raw), self.msg.msg)
return None
def _handle_redirect(self):
if self.msg.status == 304:
raise UnhandledHTTPCode(self.msg.status, "".join(self.msg.raw), self.msg.msg)
location = self.msg.headers.get("location")
if not location or len(location.strip()) == 0:
raise InvalidResponse("No location in redirect")
location = parser.urljoin(self.cmd.uri, location)
parsed_location = urlsplit(location)
if not parsed_location.hostname:
raise InvalidResponse("Invalid location")
if not parsed_location.scheme == "http":
raise UnsupportedProtocol(parsed_location.scheme)
self.cmd.uri = location
if self.msg.status == 301:
logging.info("Status 301. Closing socket [%s]", self.cmd.host)
self.client.close()
self.cmd.execute()
return None
class DownloadHandler(ResponseHandler, ABC):
def __init__(self, retriever: Retriever, client: HTTPClient, msg, cmd, directory=None):
super().__init__(retriever, client, msg, cmd)
if not directory:
directory = self._create_directory()
self.path = self._get_duplicate_name(os.path.join(directory, self.get_filename()))
@staticmethod
def create(retriever: Retriever, client: HTTPClient, msg, cmd, directory=None):
content_type = msg.headers.get("content-type")
if content_type and "text/html" in content_type:
return HTMLDownloadHandler(retriever, client, msg, cmd, directory)
return RawDownloadHandler(retriever, client, msg, cmd, directory)
def _create_directory(self):
path = self._get_duplicate_name(os.path.abspath(self.client.host))
os.mkdir(path)
return path
def _get_duplicate_name(self, path):
tmp_path = path
i = 0
while os.path.exists(tmp_path):
i += 1
tmp_path = "{path}.{counter}".format(path=path, counter=i)
return tmp_path
def get_filename(self):
"""
Returns the filename to download the payload to.
"""
filename = os.path.basename(self.cmd.path)
if filename == '':
return "index.html"
while "%" in filename:
filename = unquote(filename)
filename = re.sub(r"[^\w.+-]+[.]*", '', filename)
result = os.path.basename(filename).strip()
if any(letter.isalnum() for letter in result):
return result
return "index.html"
class RawDownloadHandler(DownloadHandler):
def __init__(self, retriever: Retriever, client: HTTPClient, msg: Message, cmd: AbstractCommand, directory=None):
super().__init__(retriever, client, msg, cmd, directory)
def handle(self) -> str:
logging.info("Saving to '%s'", parser.get_relative_save_path(self.path))
file = open(self.path, "wb")
for buffer in self.retriever.retrieve():
file.write(buffer)
file.close()
return self.path
class HTMLDownloadHandler(DownloadHandler):
def __init__(self, retriever: Retriever, client: HTTPClient, msg: Message, cmd: AbstractCommand, directory=None):
super().__init__(retriever, client, msg, cmd, directory)
def handle(self) -> str:
(directory, file) = os.path.split(self.path)
tmp_filename = f".{file}.tmp"
tmp_path = os.path.join(directory, tmp_filename)
file = open(tmp_path, "wb")
for buffer in self.retriever.retrieve():
file.write(buffer)
file.close()
charset = parser.get_charset(self.msg.headers)
self._download_images(tmp_path, self.path, charset)
os.remove(tmp_path)
return self.path
def _download_images(self, tmp_path, target_path, charset=FORMAT):
"""
Download images referenced in the html of `tmp_filename` and replaces the references in the html
and writes it to `target_filename`.
@param tmp_path: the path to the temporary html file
@param target_path: the path for the final html file
@param charset: the charset to decode `tmp_filename`
"""
try:
fp = open(tmp_path, "r", encoding=charset)
html = fp.read()
except UnicodeDecodeError or LookupError:
fp = open(tmp_path, "r", encoding=FORMAT, errors="replace")
html = fp.read()
fp.close()
base_element = BASE_REGEX.search(html)
base_url = self.cmd.uri
if base_element:
base_url = parser.urljoin(self.cmd.uri, base_element.group(1))
processed = {}
to_replace = []
# Find all <img> tags, and the urls from the corresponding `src` fields
for m in IMG_REGEX.finditer(html):
url_start = m.start(1)
url_end = m.end(1)
target = m.group(1)
try:
if len(target) == 0:
continue
if target in processed:
# url is already processed
new_url = processed.get(target)
else:
new_url = self.__download_image(target, base_url)
processed[target] = new_url
if new_url:
local_path = os.path.basename(new_url)
to_replace.append((url_start, url_end, local_path))
except Exception as e:
logging.error("Failed to download image: %s, skipping...", target)
logging.debug("", exc_info=e)
processed[target] = None
# reverse the list so urls at the bottom of the html file are processed first.
# Otherwise, our start and end positions won't be correct.
to_replace.reverse()
for (start, end, path) in to_replace:
html = html[:start] + path + html[end:]
logging.info("Saving HTML to '%s'", parser.get_relative_save_path(target_path))
with open(target_path, 'w', encoding=FORMAT) as file:
file.write(html)
def __download_image(self, img_src, base_url):
"""
Download image from the specified `img_src` and `base_url`.
If the image is available, it will be downloaded to the directory of `self.path`
"""
logging.info("Downloading image: %s", img_src)
parsed = urlsplit(img_src)
img_src = parser.urljoin(base_url, img_src)
# Check if the port of the image sh
if parsed.hostname is None or parsed.hostname == self.cmd.host:
port = self.cmd.port
elif ":" in parsed.netloc:
port = parsed.netloc.split(":", 1)[1]
else:
port = 80
command = GetCommand(img_src, port, os.path.dirname(self.path))
command.execute(True)
return command.filename

View File

@@ -1,68 +1,163 @@
class HTTPException(Exception): class HTTPException(Exception):
""" Base class for HTTP exceptions """ """
Base class for HTTP exceptions
"""
class UnhandledHTTPCode(Exception):
"""
Exception thrown if HTTP codes are not further processed.
"""
status_code: str
headers: str
cause: str
def __init__(self, status, headers, cause):
self.status_code = status
self.headers = headers
self.cause = cause
class InvalidResponse(HTTPException): class InvalidResponse(HTTPException):
""" Response message cannot be parsed """ """
Response message cannot be parsed
"""
def __init(self, message): def __init__(self, message):
self.message = message self.message = message
class InvalidStatusLine(HTTPException): class InvalidStatusLine(HTTPException):
""" Response status line is invalid """ """
Response status line is invalid
"""
def __init(self, line): def __init__(self, line):
self.line = line self.line = line
class UnsupportedEncoding(HTTPException): class UnsupportedEncoding(HTTPException):
""" Reponse Encoding not support """ """
Encoding not supported
"""
def __init(self, enc_type, encoding): def __init__(self, enc_type, encoding):
self.enc_type = enc_type self.enc_type = enc_type
self.encoding = encoding self.encoding = encoding
class UnsupportedProtocol(HTTPException):
"""
Protocol is not supported
"""
def __init__(self, protocol):
self.protocol = protocol
class IncompleteResponse(HTTPException): class IncompleteResponse(HTTPException):
def __init(self, cause): def __init__(self, cause):
self.cause = cause self.cause = cause
class HTTPServerException(Exception): class HTTPServerException(HTTPException):
""" Base class for HTTP Server exceptions """ """
Base class for HTTP Server exceptions
"""
status_code: str status_code: str
message: str message: str
arg: str
def __init__(self, arg):
self.arg = arg
class BadRequest(HTTPServerException): class HTTPServerCloseException(HTTPServerException):
""" Malformed HTTP request""" """
When raised, the connection should be closed
"""
class BadRequest(HTTPServerCloseException):
"""
Malformed HTTP request
"""
status_code = 400 status_code = 400
message = "Bad Request" message = "Bad Request"
class Forbidden(HTTPServerException):
"""
Request not allowed
"""
status_code = 403
message = "Forbidden"
class NotFound(HTTPServerException):
"""
Resource not found
"""
status_code = 404
message = "Not Found"
class MethodNotAllowed(HTTPServerException): class MethodNotAllowed(HTTPServerException):
""" Method is not allowed """ """
Method is not allowed
"""
status_code = 405 status_code = 405
message = "Method Not Allowed" message = "Method Not Allowed"
def __init(self, allowed_methods): def __init__(self, allowed_methods):
self.allowed_methods = allowed_methods self.allowed_methods = allowed_methods
class InternalServerError(HTTPServerException): class InternalServerError(HTTPServerCloseException):
""" Internal Server Error """ """
Internal Server Error
"""
status_code = 500 status_code = 500
message = "Internal Server Error" message = "Internal Server Error"
class NotImplemented(HTTPServerException): class NotImplemented(HTTPServerException):
""" Functionality not implemented """ """
Functionality not implemented
"""
status_code = 501 status_code = 501
message = "Not Implemented" message = "Not Implemented"
class NotFound(HTTPServerException): class HTTPVersionNotSupported(HTTPServerCloseException):
""" Resource not found """ """
status_code = 404 The server does not support the major version HTTP used in the request message
message = "Not Found" """
status_code = 505
message = "HTTP Version Not Supported"
class Conflict(HTTPServerException):
"""
Conflict in the current state of the target resource
"""
status_code = 409
message = "Conflict"
class NotModified(HTTPServerException):
"""
Requested resource was not modified
"""
status_code = 304
message = "Not Modified"
class InvalidRequestLine(BadRequest):
"""
Request start-line is invalid
"""
def __init__(self, line, arg):
super().__init__(arg)
self.request_line = line

View File

@@ -1,4 +1,3 @@
import logging
import socket import socket
from io import BufferedReader from io import BufferedReader
from typing import Tuple from typing import Tuple
@@ -10,13 +9,20 @@ MAXLINE = 4096
class HTTPSocket: class HTTPSocket:
host: str """
Wrapper class for a socket. Represents an HTTP connection.
This class adds helper methods to read the underlying socket as a file.
"""
conn: socket.socket conn: socket.socket
file: Tuple[BufferedReader, None] file: BufferedReader
def __init__(self, conn: socket.socket, host: str): def __init__(self, conn: socket.socket):
"""
Initialize an HTTPSocket with the given socket and host.
@param conn: the socket object
"""
self.host = host
self.conn = conn self.conn = conn
self.conn.settimeout(TIMEOUT) self.conn.settimeout(TIMEOUT)
self.conn.setblocking(True) self.conn.setblocking(True)
@@ -24,64 +30,67 @@ class HTTPSocket:
self.file = self.conn.makefile("rb") self.file = self.conn.makefile("rb")
def close(self): def close(self):
"""
Close this socket
"""
self.file.close() self.file.close()
# self.conn.shutdown(socket.SHUT_RDWR)
self.conn.close() self.conn.close()
def is_closed(self): def is_closed(self):
return self.file is None return self.file is None
def reset_request(self): def reset_request(self):
"""
Close the file handle of this socket and create a new one.
"""
self.file.close() self.file.close()
self.file = self.conn.makefile("rb") self.file = self.conn.makefile("rb")
def __do_receive(self):
if self.conn.fileno() == -1:
raise Exception("Connection closed")
result = self.conn.recv(BUFSIZE)
return result
def receive(self):
"""Receive data from the client up to BUFSIZE
"""
count = 0
while True:
count += 1
try:
return self.__do_receive()
except socket.timeout:
logging.debug("Socket receive timed out after %s seconds", TIMEOUT)
if count == 3:
break
logging.debug("Retrying %s", count)
logging.debug("Timed out after waiting %s seconds for response", TIMEOUT * count)
raise TimeoutError("Request timed out")
def read(self, size=BUFSIZE, blocking=True) -> bytes: def read(self, size=BUFSIZE, blocking=True) -> bytes:
"""
Read bytes up to the specified buffer size. This method will block when `blocking` is set to True (Default).
"""
if blocking: if blocking:
return self.file.read(size) buffer = self.file.read(size)
else:
buffer = self.file.read1(size)
return self.file.read1(size) if len(buffer) == 0:
raise ConnectionAbortedError
return buffer
def read_line(self): def read_line(self):
"""
Read a line decoded as `httpsocket.FORMAT`.
@return: the decoded line
@raise: UnicodeDecodeError
"""
return str(self.read_bytes_line(), FORMAT) return str(self.read_bytes_line(), FORMAT)
def read_bytes_line(self) -> bytes: def read_bytes_line(self) -> bytes:
"""
Read a line as bytes.
"""
line = self.file.readline(MAXLINE + 1) line = self.file.readline(MAXLINE + 1)
if len(line) > MAXLINE: if len(line) > MAXLINE:
raise InvalidResponse("Line too long") raise InvalidResponse("Line too long")
elif len(line) == 0:
raise ConnectionAbortedError
return line return line
class HTTPException(Exception): class HTTPException(Exception):
""" Base class for HTTP exceptions """ """
Base class for HTTP exceptions
"""
class InvalidResponse(HTTPException): class InvalidResponse(HTTPException):
""" Response message cannot be parsed """ """
Response message cannot be parsed
"""
def __init(self, message): def __init(self, message):
self.message = message self.message = message

View File

@@ -1,16 +1,36 @@
from abc import ABC
from typing import Dict from typing import Dict
from urllib.parse import SplitResult
class Message: class Message(ABC):
version: str version: str
status: int
msg: str
headers: Dict[str, str] headers: Dict[str, str]
raw: [str]
body: bytes body: bytes
def __init__(self, version: str, status: int, msg: str, headers: Dict[str, str], body: bytes = None): def __init__(self, version: str, headers: Dict[str, str], raw=None, body: bytes = None):
self.version = version self.version = version
self.headers = headers
self.raw = raw
self.body = body
class ResponseMessage(Message):
status: int
msg: str
def __init__(self, version: str, status: int, msg: str, headers: Dict[str, str], raw=None, body: bytes = None):
super().__init__(version, headers, raw, body)
self.status = status self.status = status
self.msg = msg self.msg = msg
self.headers = headers
self.body = body
class RequestMessage(Message):
method: str
target: SplitResult
def __init__(self, version: str, method: str, target, headers: Dict[str, str], raw=None, body: bytes = None):
super().__init__(version, headers, raw, body)
self.method = method
self.target = target

View File

@@ -1,21 +1,25 @@
import logging import logging
import os
import pathlib
import re import re
import urllib
from datetime import datetime
from time import mktime
from typing import Dict
from urllib.parse import urlparse, urlsplit from urllib.parse import urlparse, urlsplit
from wsgiref.handlers import format_date_time
from httplib.exceptions import InvalidStatusLine, InvalidResponse, BadRequest from httplib.exceptions import InvalidStatusLine, InvalidResponse, BadRequest, InvalidRequestLine
from httplib.httpsocket import HTTPSocket from httplib.httpsocket import FORMAT
def _get_start_line(client: HTTPSocket):
line = client.read_line().strip()
split = list(filter(None, line.split(" ", 2)))
if len(split) < 3:
raise InvalidStatusLine(line) # TODO fix exception
return line, split
def _is_valid_http_version(http_version: str): def _is_valid_http_version(http_version: str):
"""
Returns True if the specified HTTP-version is valid.
@param http_version: the string to be checked
@return: True if the specified HTTP-version is valid.
"""
if len(http_version) < 8 or http_version[4] != "/": if len(http_version) < 8 or http_version[4] != "/":
return False return False
@@ -26,28 +30,21 @@ def _is_valid_http_version(http_version: str):
return True return True
def get_status_line(client: HTTPSocket):
line, (http_version, status, reason) = _get_start_line(client)
if not _is_valid_http_version(http_version):
raise InvalidStatusLine(line)
version = http_version[:4]
if not re.match(r"\d{3}", status):
raise InvalidStatusLine(line)
status = int(status)
if status < 100 or status > 999:
raise InvalidStatusLine(line)
return version, status, reason
def parse_status_line(line: str): def parse_status_line(line: str):
"""
Parses the specified line as an HTTP status-line.
@param line: the status-line to be parsed
@raise InvalidStatusLine: if the line couldn't be parsed, if the HTTP-version is invalid or if the status code
is invalid
@return: tuple of the HTTP-version, status and reason
"""
split = list(filter(None, line.strip().split(" ", 2))) split = list(filter(None, line.strip().split(" ", 2)))
if len(split) < 3: if len(split) < 3:
raise InvalidStatusLine(line) # TODO fix exception raise InvalidStatusLine(line)
(http_version, status, reason) = split http_version, status, reason = split
if not _is_valid_http_version(http_version): if not _is_valid_http_version(http_version):
raise InvalidStatusLine(line) raise InvalidStatusLine(line)
@@ -62,140 +59,66 @@ def parse_status_line(line: str):
return version, status, reason return version, status, reason
def parse_request_line(client: HTTPSocket): def parse_request_line(line: str):
line, (method, target, version) = _get_start_line(client) """
Parses the specified line as an HTTP request-line.
Returns the method, target as ParseResult and HTTP version from the request-line.
logging.debug("Parsed request-line=%r, method=%r, target=%r, version=%r", line, method, target, version) @param line: the request-line to be parsed
@raise InvalidRequestLine: if the line couldn't be parsed.
@raise BadRequest: Invalid HTTP method, Invalid HTTP-version or Invalid target
@return: tuple of the method, target and HTTP-version
"""
split = list(filter(None, line.rstrip().split(" ", 2)))
if len(split) < 3:
raise InvalidRequestLine(line, "missing argument in request-line")
method, target, version = split
if method not in ("CONNECT", "DELETE", "GET", "HEAD", "OPTIONS", "POST", "PUT", "TRACE"): if method not in ("CONNECT", "DELETE", "GET", "HEAD", "OPTIONS", "POST", "PUT", "TRACE"):
raise BadRequest() raise BadRequest(f"Invalid method: {method}")
if not _is_valid_http_version(version): if not _is_valid_http_version(version):
logging.debug("[ABRT] request: invalid http-version=%r", version) logging.debug("[ABRT] request: invalid http-version=%r", version)
raise BadRequest() raise BadRequest(f"Invalid HTTP-version: {version}")
if len(target) == "": if len(target) == "":
raise BadRequest() raise BadRequest("request-target not specified")
parsed_target = urlparse(target) parsed_target = urlsplit(target)
if len(parsed_target.path) > 0 and parsed_target.path[0] != "/" and parsed_target.netloc != "":
parsed_target = urlparse(f"//{target}")
return method, parsed_target, version.split("/")[1] return method, parsed_target, version.split("/")[1]
def retrieve_headers(client: HTTPSocket):
raw_headers = []
# first header after the status-line may not contain a space
while True:
line = client.read_line()
if line[0].isspace():
continue
else:
break
while True:
if line in ("\r\n", "\n", " "):
break
if line[0].isspace():
raw_headers[-1] = raw_headers[-1].rstrip("\r\n")
raw_headers.append(line.lstrip())
line = client.read_line()
result = []
header_str = "".join(raw_headers)
for line in header_str.splitlines():
pos = line.find(":")
if pos <= 0 or pos >= len(line) - 1:
continue
(header, value) = line.split(":", 1)
result.append((header.lower(), value.strip().lower()))
return result
def parse_request_headers(client: HTTPSocket):
raw_headers = retrieve_headers(client)
logging.debug("Received headers: %r", raw_headers)
headers = {}
key: str
for (key, value) in raw_headers:
if any((c.isspace()) for c in key):
raise BadRequest()
if key == "content-length":
if key in headers:
logging.error("Multiple content-length headers specified")
raise BadRequest()
if not value.isnumeric() or int(value) <= 0:
logging.error("Invalid content-length value: %r", value)
raise BadRequest()
elif key == "host":
if value != client.host and value != client.host.split(":")[0] or key in headers:
raise BadRequest()
headers[key] = value
return headers
def get_headers(client: HTTPSocket):
headers = []
# first header after the status-line may not contain a space
while True:
line = client.read_line()
if line[0].isspace():
continue
else:
break
while True:
if line in ("\r\n", "\n", " "):
break
if line[0].isspace():
headers[-1] = headers[-1].rstrip("\r\n")
headers.append(line.lstrip())
line = client.read_line()
result = {}
header_str = "".join(headers)
for line in header_str.splitlines():
pos = line.find(":")
if pos <= 0 or pos >= len(line) - 1:
continue
(header, value) = map(str.strip, line.split(":", 1))
check_next_header(result, header, value)
result[header.lower()] = value.lower()
return result
def parse_headers(lines): def parse_headers(lines):
"""
Parses the lines from the `lines` iterator as headers.
@param lines: iterator to retrieve the lines from.
@return: A dictionary with header as key and value as value.
"""
headers = [] headers = []
# first header after the status-line may not contain a space
for line in lines: try:
# first header after the start-line may not start with a space
line = next(lines) line = next(lines)
while True:
if line[0].isspace(): if line[0].isspace():
continue continue
else: else:
break break
for line in lines: while True:
if line in ("\r\n", "\n", " "): if line in ("\r\n", "\r", "\n", ""):
break break
if line[0].isspace(): if line[0].isspace():
headers[-1] = headers[-1].rstrip("\r\n") headers[-1] = headers[-1].rstrip("\r\n")
headers.append(line.lstrip()) headers.append(line.lstrip())
line = next(lines)
except StopIteration:
# No more lines to be parsed
pass
result = {} result = {}
header_str = "".join(headers) header_str = "".join(headers)
@@ -215,17 +138,21 @@ def parse_headers(lines):
def check_next_header(headers, next_header: str, next_value: str): def check_next_header(headers, next_header: str, next_value: str):
if next_header == "content-length": if next_header == "content-length":
if "content-length" in headers: if "content-length" in headers:
logging.error("Multiple content-length headers specified") raise InvalidResponse("Multiple content-length headers specified")
raise InvalidResponse()
if not next_value.isnumeric() or int(next_value) <= 0: if not next_value.isnumeric() or int(next_value) <= 0:
logging.error("Invalid content-length value: %r", next_value) raise InvalidResponse(f"Invalid content-length value: {next_value}")
raise InvalidResponse()
def parse_uri(uri: str): def parse_uri(uri: str):
"""
Parse the specified URI into the host, port and path.
If the URI is invalid, this method will try to create one.
@param uri: the URI to be parsed
@return: A tuple with the host, port and path
"""
parsed = urlsplit(uri) parsed = urlsplit(uri)
# If there is no netloc, the given string is not a valid URI, so split on / # If there is no hostname, the given string is not a valid URI, so split on /
if parsed.hostname: if parsed.hostname:
host = parsed.hostname host = parsed.hostname
path = parsed.path path = parsed.path
@@ -245,3 +172,70 @@ def parse_uri(uri: str):
port = 80 port = 80
return host, port, path return host, port, path
def uri_from_url(url: str):
"""
Returns a valid URI of the specified URL.
"""
parsed = urlsplit(url)
if parsed.hostname is None:
url = f"http://{url}"
parsed = urlsplit(url)
path = parsed.path
if path == "":
path = "/"
result = f"http://{parsed.netloc}{path}"
if parsed.query != "":
result = f"{result}?{parsed.query}"
return result
def urljoin(base, url):
"""
Join a base url, and a URL to form an absolute url.
"""
return urllib.parse.urljoin(base, url)
def get_charset(headers: Dict[str, str]):
"""
Returns the charset of the content from the headers if found. Otherwise, returns `FORMAT`
@param headers: the headers to retrieve the charset from
@return: A charset
"""
if "content-type" in headers:
content_type = headers["content-type"]
match = re.search(r"charset\s*=\s*([a-z\-0-9]*)", content_type, re.I)
if match:
return match.group(1)
return FORMAT
def get_relative_save_path(path: str):
"""
Returns the specified path relative to the working directory.
@param path: the path to compute
@return: the relative path
"""
path_obj = pathlib.PurePath(path)
root = pathlib.PurePath(os.getcwd())
rel = path_obj.relative_to(root)
return str(rel)
def get_date():
"""
Returns a string representation of the current date according to RFC 1123.
"""
now = datetime.now()
stamp = mktime(now.timetuple())
return format_date_time(stamp)

View File

@@ -7,6 +7,9 @@ from httplib.httpsocket import HTTPSocket, BUFSIZE
class Retriever(ABC): class Retriever(ABC):
"""
This is a helper class for retrieving HTTP messages.
"""
client: HTTPSocket client: HTTPSocket
def __init__(self, client: HTTPSocket): def __init__(self, client: HTTPSocket):
@@ -14,10 +17,23 @@ class Retriever(ABC):
@abstractmethod @abstractmethod
def retrieve(self): def retrieve(self):
"""
Creates an iterator of the retrieved message content.
"""
pass pass
@staticmethod @staticmethod
def create(client: HTTPSocket, headers: Dict[str, str]): def create(client: HTTPSocket, headers: Dict[str, str]):
"""
Creates a Retriever instance depending on the give headers.
@param client: the socket to retrieve from
@param headers: the message headers for choosing the retriever instance
@return: ChunkedRetriever if the message uses chunked encoding, ContentLengthRetriever if the message
specifies a content-length, RawRetriever if none of the above is True.
@raise UnsupportedEncoding: if the `transfer-encoding` is not supported or if the `content-encoding` is not
supported.
"""
# only chunked transfer-encoding is supported # only chunked transfer-encoding is supported
transfer_encoding = headers.get("transfer-encoding") transfer_encoding = headers.get("transfer-encoding")
@@ -32,7 +48,7 @@ class Retriever(ABC):
if chunked: if chunked:
return ChunkedRetriever(client) return ChunkedRetriever(client)
else:
content_length = headers.get("content-length") content_length = headers.get("content-length")
if not content_length: if not content_length:
@@ -43,28 +59,56 @@ class Retriever(ABC):
class PreambleRetriever(Retriever): class PreambleRetriever(Retriever):
"""
Retriever instance for retrieving the start-line and headers of an HTTP message.
"""
client: HTTPSocket client: HTTPSocket
buffer: [] _buffer: []
@property
def buffer(self):
"""
Returns a copy of the internal buffer.
Clears the internal buffer afterwards.
@return: A list of the buffered lines.
"""
tmp_buffer = self._buffer
self._buffer = []
return tmp_buffer
def __init__(self, client: HTTPSocket): def __init__(self, client: HTTPSocket):
super().__init__(client) super().__init__(client)
self.client = client self.client = client
self.buffer = [] self._buffer = []
def retrieve(self): def retrieve(self):
"""
Returns an iterator of the retrieved lines.
@return:
"""
line = self.client.read_line() line = self.client.read_line()
while True: while True:
self.buffer.append(line) self._buffer.append(line)
if line in ("\r\n", "\n", " "): if line in ("\r\n", "\r", "\n", ""):
break return line
yield line yield line
line = self.client.read_line() line = self.client.read_line()
def reset_buffer(self, line):
self._buffer.clear()
self._buffer.append(line)
class ContentLengthRetriever(Retriever): class ContentLengthRetriever(Retriever):
"""
Retriever instance for retrieving a message body with a given content-length.
"""
length: int length: int
def __init__(self, client: HTTPSocket, length: int): def __init__(self, client: HTTPSocket, length: int):
@@ -72,6 +116,11 @@ class ContentLengthRetriever(Retriever):
self.length = length self.length = length
def retrieve(self): def retrieve(self):
"""
Returns an iterator of the received message bytes.
The size of each iteration is not necessarily constant.
@raise IncompleteResponse: if the connection is closed or timed out before receiving the complete payload.
"""
cur_payload_size = 0 cur_payload_size = 0
read_size = BUFSIZE read_size = BUFSIZE
@@ -84,11 +133,9 @@ class ContentLengthRetriever(Retriever):
try: try:
buffer = self.client.read(remaining) buffer = self.client.read(remaining)
except TimeoutError: except TimeoutError:
logging.error("Timed out before receiving complete payload")
raise IncompleteResponse("Timed out before receiving complete payload") raise IncompleteResponse("Timed out before receiving complete payload")
except ConnectionError: except ConnectionError:
logging.error("Timed out before receiving complete payload") raise IncompleteResponse("Connection closed before receiving the complete payload")
raise IncompleteResponse("Connection closed before receiving complete payload")
if len(buffer) == 0: if len(buffer) == 0:
logging.warning("Received payload length %s less than expected %s", cur_payload_size, self.length) logging.warning("Received payload length %s less than expected %s", cur_payload_size, self.length)
@@ -99,6 +146,10 @@ class ContentLengthRetriever(Retriever):
class RawRetriever(Retriever): class RawRetriever(Retriever):
"""
Retriever instance for retrieving a message body without any length specifier or encoding.
This retriever will keep waiting until a timeout occurs, or the connection is disconnected.
"""
def retrieve(self): def retrieve(self):
while True: while True:
@@ -109,22 +160,44 @@ class RawRetriever(Retriever):
class ChunkedRetriever(Retriever): class ChunkedRetriever(Retriever):
"""
Retriever instance for retrieving a message body with chunked encoding.
"""
def retrieve(self): def retrieve(self):
"""
Returns an iterator of the received message bytes.
The size of each iteration is not necessarily constant.
@raise IncompleteResponse: if the connection is closed or timed out before receiving the complete payload.
@raise InvalidResponse: if the length of a chunk could not be determined.
"""
try:
while True: while True:
chunk_size = self.__get_chunk_size() chunk_size = self.__get_chunk_size()
logging.debug("chunk-size: %s", chunk_size) logging.debug("chunk-size: %s", chunk_size)
if chunk_size == 0: if chunk_size == 0:
# remove all trailing lines
self.client.reset_request() self.client.reset_request()
break break
buffer = self.client.read(chunk_size) buffer = self.client.read(chunk_size)
logging.debug("chunk: %r", buffer)
yield buffer yield buffer
self.client.read_line() # remove CRLF self.client.read_line() # remove trailing CRLF
except TimeoutError:
raise IncompleteResponse("Timed out before receiving the complete payload!")
except ConnectionError:
raise IncompleteResponse("Connection closed before receiving the complete payload!")
def __get_chunk_size(self): def __get_chunk_size(self):
"""
Returns the next chunk size.
@return: The chunk size in bytes
@raise InvalidResponse: If an error occured when parsing the chunk size.
"""
line = self.client.read_line() line = self.client.read_line()
sep_pos = line.find(";") sep_pos = line.find(";")
if sep_pos >= 0: if sep_pos >= 0:
@@ -133,4 +206,4 @@ class ChunkedRetriever(Retriever):
try: try:
return int(line, 16) return int(line, 16)
except ValueError: except ValueError:
raise InvalidResponse() raise InvalidResponse("Failed to parse chunk size")

View File

@@ -48,6 +48,7 @@
<div> <div>
<h2>Local image</h2> <h2>Local image</h2>
<img width="200px" src="ulyssis.png"> <img width="200px" src="ulyssis.png">
</div>
</body> </body>
</html> </html>

View File

@@ -1,3 +0,0 @@
beautifulsoup4~=4.9.3
lxml~=4.6.2
cssutils~=2.2.0

View File

@@ -14,7 +14,7 @@ def main():
parser.add_argument("--workers", "-w", parser.add_argument("--workers", "-w",
help="The amount of worker processes. This is by default based on the number of cpu threads.", help="The amount of worker processes. This is by default based on the number of cpu threads.",
type=int) type=int)
parser.add_argument("--port", "-p", help="The port to listen on", default=8000) parser.add_argument("--port", "-p", help="The port to listen on", default=5055)
arguments = parser.parse_args() arguments = parser.parse_args()
logging_level = logging.ERROR - (10 * arguments.verbose) logging_level = logging.ERROR - (10 * arguments.verbose)
@@ -47,44 +47,3 @@ except Exception as e:
print("[ABRT] Internal error: " + str(e), file=sys.stderr) print("[ABRT] Internal error: " + str(e), file=sys.stderr)
logging.debug("Internal error", exc_info=e) logging.debug("Internal error", exc_info=e)
sys.exit(70) sys.exit(70)
# import socket
#
# # Get hostname and address
# hostname = socket.gethostname()
# address = socket.gethostbyname(hostname)
#
# # socket heeft een listening and accept method
#
# SERVER = "127.0.0.1" # dynamisch fixen in project
# PORT = 5055
# server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
#
# ADDR = (SERVER, PORT) # hier wordt de socket gebonden aan mijn IP adres, dit moet wel anders
# server.bind(ADDR) # in het project gebeuren
#
# HEADER = 64 # maximum size messages
# FORMAT = 'utf-8'
# DISCONNECT_MESSAGE = "DISCONNECT!" # special message for disconnecting client and server
#
#
# # function for starting server
# def start():
# pass
# server.listen()
# while True: # infinite loop in which server accept incoming connections, we want to run it forever
# conn, addr = server.accept() # Server blocks untill a client connects
# print("new connection: ", addr[0], " connected.")
# connected = True
# while connected: # while client is connected, we want to recieve messages
# msg = conn.recv(HEADER).decode(
# FORMAT).rstrip() # Argument is maximum size of msg (in project look into details of accp), decode is for converting bytes to strings, rstrip is for stripping messages for special hidden characters
# print("message: ", msg)
# if msg == DISCONNECT_MESSAGE:
# connected = False
# print("close connection ", addr[0], " disconnected.")
# conn.close()
#
#
# print("server is starting ... ")
# start()

View File

@@ -1,127 +0,0 @@
import logging
import mimetypes
import os
import sys
from datetime import datetime
from socket import socket
from time import mktime
from typing import Union
from urllib.parse import ParseResultBytes, ParseResult
from wsgiref.handlers import format_date_time
from httplib import parser
from httplib.exceptions import MethodNotAllowed, BadRequest, UnsupportedEncoding, NotImplemented, NotFound
from httplib.httpsocket import HTTPSocket, FORMAT
from httplib.retriever import Retriever
METHODS = ("GET", "HEAD", "PUT", "POST")
class RequestHandler:
conn: HTTPSocket
root = os.path.join(os.path.dirname(sys.argv[0]), "public")
def __init__(self, conn: socket, host):
self.conn = HTTPSocket(conn, host)
def listen(self):
logging.debug("Parsing request line")
(method, target, version) = parser.parse_request_line(self.conn)
headers = parser.parse_request_headers(self.conn)
self._validate_request(method, target, version, headers)
logging.debug("Parsed request-line: method: %s, target: %r", method, target)
body = b""
if self._has_body(headers):
try:
retriever = Retriever.create(self.conn, headers)
except UnsupportedEncoding as e:
logging.error("Encoding not supported: %s=%s", e.enc_type, e.encoding)
raise NotImplemented()
for buffer in retriever.retrieve():
body += buffer
# completed message
self._handle_message(method, target.path, body)
def _check_request_line(self, method: str, target: Union[ParseResultBytes, ParseResult], version):
if method not in METHODS:
raise MethodNotAllowed(METHODS)
if version not in ("1.0", "1.1"):
raise BadRequest()
# only origin-form and absolute-form are allowed
if target.scheme not in ("", "http"):
# Only http is supported...
raise BadRequest()
if target.netloc != "" and target.netloc != self.conn.host and target.netloc != self.conn.host.split(":")[0]:
raise NotFound()
if target.path == "" or target.path[0] != "/":
raise NotFound()
norm_path = os.path.normpath(target.path)
if not os.path.exists(self.root + norm_path):
raise NotFound()
def _validate_request(self, method, target, version, headers):
if version == "1.1" and "host" not in headers:
raise BadRequest()
self._check_request_line(method, target, version)
def _has_body(self, headers):
return "transfer-encoding" in headers or "content-encoding" in headers
@staticmethod
def _get_date():
now = datetime.now()
stamp = mktime(now.timetuple())
return format_date_time(stamp)
def _handle_message(self, method: str, target, body: bytes):
date = self._get_date()
if method == "GET":
if target == "/":
path = self.root + "/index.html"
else:
path = self.root + target
mime = mimetypes.guess_type(path)[0]
if mime.startswith("text"):
file = open(path, "rb", FORMAT)
else:
file = open(path, "rb")
buffer = file.read()
file.close()
message = "HTTP/1.1 200 OK\r\n"
message += date + "\r\n"
if mime:
message += f"Content-Type: {mime}"
if mime.startswith("text"):
message += "; charset=UTF-8"
message += "\r\n"
message += f"Content-Length: {len(buffer)}\r\n"
message += "\r\n"
message = message.encode(FORMAT)
message += buffer
message += b"\r\n"
logging.debug("Sending: %r", message)
self.conn.conn.sendall(message)
@staticmethod
def send_error(client: socket, code, message):
message = f"HTTP/1.1 {code} {message}\r\n"
message += RequestHandler._get_date() + "\r\n"
message += "Content-Length: 0\r\n"
message += "\r\n"
logging.debug("Sending: %r", message)
client.sendall(message.encode(FORMAT))

View File

@@ -1,34 +1,53 @@
import logging import mimetypes
import os
import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, Tuple from datetime import datetime
from urllib.parse import urlparse
from client.httpclient import FORMAT, HTTPClient
from httplib import parser from httplib import parser
from httplib.exceptions import InvalidResponse, InvalidStatusLine, UnsupportedEncoding from httplib.exceptions import NotFound, Forbidden, NotModified, BadRequest
from httplib.message import Message from httplib.httpsocket import FORMAT
from httplib.retriever import PreambleRetriever from httplib.message import RequestMessage as Message
CONTENT_ROOT = os.path.join(os.path.dirname(sys.argv[0]), "public")
status_message = {
200: "OK",
201: "Created",
202: "Accepted",
204: "No Content",
304: "Not Modified",
400: "Bad Request",
404: "Not Found",
500: "Internal Server Error",
}
def create(method: str, message: Message): def create(message: Message):
"""
Creates a Command based on the specified message
@param message: the message to create the Command with.
@return: An instance of `AbstractCommand`
"""
if method == "GET": if message.method == "GET":
return GetCommand(url, port) return GetCommand(message)
elif method == "HEAD": elif message.method == "HEAD":
return HeadCommand(url, port) return HeadCommand(message)
elif method == "POST": elif message.method == "POST":
return PostCommand(url, port) return PostCommand(message)
elif method == "PUT": elif message.method == "PUT":
return PutCommand(url, port) return PutCommand(message)
else: else:
raise ValueError() raise ValueError()
class AbstractCommand(ABC): class AbstractCommand(ABC):
path: str path: str
headers: Dict[str, str] msg: Message
def __init(self): def __init__(self, message: Message):
self.msg = message
pass pass
@property @property
@@ -36,63 +55,253 @@ class AbstractCommand(ABC):
def command(self): def command(self):
pass pass
@property
@abstractmethod
def _conditional_headers(self):
"""
The conditional headers specific to this command instance.
"""
pass
class AbstractWithBodyCommand(AbstractCommand, ABC): @abstractmethod
def execute(self):
"""
Execute the command
"""
pass
def _build_message(self, message: str) -> bytes: def _build_message(self, status: int, content_type: str, body: bytes, extra_headers=None):
body = input(f"Enter {self.command} data: ").encode(FORMAT) """
print() Build the response message.
@param status: The response status code
@param content_type: The response content-type header
@param body: The response body, may be empty.
@param extra_headers: Extra headers needed in the response message
@return: The encoded response message
"""
if extra_headers is None:
extra_headers = {}
self._process_conditional_headers()
message = f"HTTP/1.1 {status} {status_message[status]}\r\n"
message += f"Date: {parser.get_date()}\r\n"
content_length = len(body)
message += f"Content-Length: {content_length}\r\n"
if content_type:
message += f"Content-Type: {content_type}"
if content_type.startswith("text"):
message += f"; charset={FORMAT}"
message += "\r\n"
elif content_length > 0:
message += f"Content-Type: application/octet-stream\r\n"
for header in extra_headers:
message += f"{header}: {extra_headers[header]}\r\n"
message += "Content-Type: text/plain\r\n"
message += f"Content-Length: {len(body)}\r\n"
message += "\r\n" message += "\r\n"
message = message.encode(FORMAT) message = message.encode(FORMAT)
if content_length > 0:
message += body message += body
message += b"\r\n"
return message return message
def _get_path(self, check=True):
"""
Returns the absolute file system path of the resource in the request.
@param check: If True, throws an error if the file doesn't exist
@raise NotFound: if `check` is True and the path doesn't exist
"""
norm_path = os.path.normpath(self.msg.target.path)
if norm_path == "/":
path = CONTENT_ROOT + "/index.html"
else:
path = CONTENT_ROOT + norm_path
if check and not os.path.exists(path):
raise NotFound(path)
return path
def _process_conditional_headers(self):
"""
Processes the conditional headers for this command instance.
"""
for header in self._conditional_headers:
tmp = self.msg.headers.get(header)
if not tmp:
continue
self._conditional_headers[header]()
def _if_modified_since(self):
"""
Processes the if-modified-since header.
@return: True if the header is invalid, and thus shouldn't be taken into account, throws NotModified
if the content isn't modified since the given date.
@raise NotModified: If the date of if-modified-since greater than the modify-date of the resource.
"""
date_val = self.msg.headers.get("if-modified-since")
if not date_val:
return True
modified = datetime.utcfromtimestamp(os.path.getmtime(self._get_path(False)))
try:
min_date = datetime.strptime(date_val, '%a, %d %b %Y %H:%M:%S GMT')
except ValueError:
return True
if modified <= min_date:
raise NotModified(f"{modified} <= {min_date}")
return True
@staticmethod
def get_mimetype(path):
"""
Guess the type of file.
@param path: the path to the file to guess the type of
@return: The mimetype based on the extension, or if that fails, returns "text/plain" if the file is text,
otherwise returns "application/octet-stream"
"""
mime = mimetypes.guess_type(path)[0]
if mime:
return mime
try:
file = open(path, "r", encoding=FORMAT)
file.readline()
file.close()
return "text/plain"
except UnicodeDecodeError:
return "application/octet-stream"
class AbstractModifyCommand(AbstractCommand, ABC):
"""
Base class for commands which modify a resource based on the request.
"""
@property
@abstractmethod
def _file_mode(self):
"""
The mode to open the target resource with. (e.a. 'a' or 'w')
"""
pass
@property
def _conditional_headers(self):
return {}
def execute(self):
path = self._get_path(False)
directory = os.path.dirname(path)
if not os.path.exists(directory):
raise Forbidden("Target directory does not exists!")
if os.path.exists(directory) and not os.path.isdir(directory):
raise Forbidden("Target directory is an existing file!")
exists = os.path.exists(path)
try:
with open(path, mode=f"{self._file_mode}b") as file:
file.write(self.msg.body)
except IsADirectoryError:
raise Forbidden("The target resource is a directory!")
if exists:
status = 204
else:
status = 201
location = parser.urljoin("/", os.path.relpath(path, CONTENT_ROOT))
return self._build_message(status, "text/plain", b"", {"Location": location})
class HeadCommand(AbstractCommand): class HeadCommand(AbstractCommand):
"""
A Command instance which represents an HEAD request
"""
@property @property
def command(self): def command(self):
return "HEAD" return "HEAD"
@property
def _conditional_headers(self):
return {'if-modified-since': self._if_modified_since}
def execute(self):
path = self._get_path()
mime = self.get_mimetype(path)
return self._build_message(200, mime, b"")
class GetCommand(AbstractCommand): class GetCommand(AbstractCommand):
"""
def __init__(self, uri: str, port, dir=None): A Command instance which represents a GET request
super().__init__(uri, port) """
self.dir = dir
self.filename = None
@property @property
def command(self): def command(self):
return "GET" return "GET"
def _get_preamble(self, retriever): @property
lines = retriever.retrieve() def _conditional_headers(self):
(version, status, msg) = parser.parse_status_line(next(lines)) return {'if-modified-since': self._if_modified_since}
headers = parser.parse_headers(lines)
logging.debug("---response begin---\r\n%s--- response end---", "".join(retriever.buffer)) def execute(self):
path = self._get_path()
mime = self.get_mimetype(path)
return Message(version, status, msg, headers) file = open(path, "rb")
buffer = file.read()
file.close()
def _await_response(self, client, retriever): return self._build_message(200, mime, buffer)
msg = self._get_preamble(retriever)
from client import response_handler
self.filename = response_handler.handle(client, msg, self, self.dir)
class PostCommand(AbstractWithBodyCommand): class PostCommand(AbstractModifyCommand):
"""
A Command instance which represents a POST request
"""
@property @property
def command(self): def command(self):
return "POST" return "POST"
@property
def _file_mode(self):
return "a"
class PutCommand(AbstractModifyCommand):
"""
A Command instance which represents a PUT request
"""
class PutCommand(AbstractWithBodyCommand):
@property @property
def command(self): def command(self):
return "PUT" return "PUT"
@property
def _file_mode(self):
return "w"
def execute(self):
if "content-range" in self.msg.headers:
raise BadRequest("PUT request contains a Content-Range header")
super().execute()

View File

@@ -1,7 +1,6 @@
import logging import logging
import multiprocessing as mp import multiprocessing as mp
import socket import socket
import time
from multiprocessing.context import Process from multiprocessing.context import Process
from multiprocessing.queues import Queue from multiprocessing.queues import Queue
from multiprocessing.synchronize import Event from multiprocessing.synchronize import Event
@@ -20,54 +19,81 @@ class HTTPServer:
_stop_event: Event _stop_event: Event
def __init__(self, address: str, port: int, worker_count, logging_level): def __init__(self, address: str, port: int, worker_count, logging_level):
"""
Initialize an HTTP server with the specified address, port, worker_count and logging_level
@param address: the address to listen on for connections
@param port: the port to listen on for connections
@param worker_count: The amount of worker processes to create
@param logging_level: verbosity level for the logger
"""
self.address = address self.address = address
self.port = port self.port = port
self.worker_count = worker_count self.worker_count = worker_count
self.logging_level = logging_level self.logging_level = logging_level
mp.set_start_method("spawn") mp.set_start_method("spawn")
self._dispatch_queue = mp.Queue()
# Create a queue with maximum size of worker_count.
self._dispatch_queue = mp.Queue(worker_count)
self._stop_event = mp.Event() self._stop_event = mp.Event()
def start(self): def start(self):
"""
Start the HTTP server.
"""
try: try:
self.__do_start() self.__do_start()
except KeyboardInterrupt: except KeyboardInterrupt:
self.__shutdown() self.__shutdown()
def __do_start(self): def __do_start(self):
"""
Internal method to start the server.
"""
# Create socket # Create socket
self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.server.bind((self.address, self.port)) self.server.bind((self.address, self.port))
# Create workers processes to handle requests
self.__create_workers() self.__create_workers()
self.__listen() self.__listen()
def __listen(self): def __listen(self):
"""
Start listening for new connections
If a connection is received, it will be dispatched to the worker queue, and picked up by a worker process.
"""
self.server.listen() self.server.listen()
logging.debug("Listening for connections") logging.info("Listening on %s:%d", self.address, self.port)
while True: while True:
if self._dispatch_queue.qsize() > self.worker_count:
time.sleep(0.01)
continue
conn, addr = self.server.accept() conn, addr = self.server.accept()
conn.settimeout(5)
logging.info("New connection: %s", addr[0]) logging.info("New connection: %s", addr[0])
# blocks when the queue is full (contains self.worker_count items).
self._dispatch_queue.put((conn, addr)) self._dispatch_queue.put((conn, addr))
logging.debug("Dispatched connection %s", addr) logging.debug("Dispatched connection %s", addr)
def __shutdown(self): def __shutdown(self):
"""
Cleanly shutdown the server
Notifies the worker processes to shut down and eventually closes the server socket
"""
logging.info("Shutting down server...")
# Set stop event # Set stop event
self._stop_event.set() self._stop_event.set()
# Wake up workers # Wake up workers
logging.debug("Waking up workers") logging.debug("Waking up workers")
for p in self.workers: for _ in self.workers:
self._dispatch_queue.put((None, None)) self._dispatch_queue.put((None, None))
logging.debug("Closing dispatch queue") logging.debug("Closing dispatch queue")
@@ -84,12 +110,16 @@ class HTTPServer:
self.server.close() self.server.close()
def __create_workers(self): def __create_workers(self):
"""
Create worker processes up to `self.worker_count`.
A worker process is created with start method "spawn", target `worker.worker`, and the `self.logging_level`
is passed along with the `self.dispatch_queue` and `self._stop_event`
"""
for i in range(self.worker_count): for i in range(self.worker_count):
logging.debug("Creating worker: %d", i + 1) logging.debug("Creating worker: %d", i + 1)
p = mp.Process(target=worker.worker, p = mp.Process(target=worker.worker,
args=(f"{self.address}:{self.port}", i + 1, self.logging_level, self._dispatch_queue, self._stop_event)) args=(f"{self.address}:{self.port}", i + 1, self.logging_level, self._dispatch_queue,
self._stop_event))
p.start() p.start()
self.workers.append(p) self.workers.append(p)
time.sleep(0.2)
time.sleep(1)

168
server/requesthandler.py Normal file
View File

@@ -0,0 +1,168 @@
import logging
from socket import socket
from typing import Union
from urllib.parse import ParseResultBytes, ParseResult
from httplib import parser
from httplib.exceptions import MethodNotAllowed, BadRequest, UnsupportedEncoding, NotImplemented, NotFound, \
HTTPVersionNotSupported
from httplib.httpsocket import FORMAT
from httplib.message import RequestMessage as Message
from httplib.retriever import Retriever, PreambleRetriever
from server import command
from server.serversocket import ServerSocket
METHODS = ("GET", "HEAD", "PUT", "POST")
class RequestHandler:
"""
A RequestHandler instance processes incoming HTTP requests messages from a single client.
RequestHandler instances are created everytime a client connects. They will read the incoming
messages, parse, verify them and send a response.
"""
conn: ServerSocket
host: str
def __init__(self, conn: socket, host):
self.conn = ServerSocket(conn)
self.host = host
def listen(self):
"""
Listen to incoming messages and process them.
"""
retriever = PreambleRetriever(self.conn)
while True:
line = self.conn.read_line()
if line in ("\r\n", "\r", "\n"):
continue
retriever.reset_buffer(line)
self._handle_message(retriever, line)
def _handle_message(self, retriever, line):
"""
Retrieves and processes the request message.
@param retriever: the retriever instance to retrieve the lines.
@param line: the first received line.
"""
lines = retriever.retrieve()
# Parse the request-line and headers
(method, target, version) = parser.parse_request_line(line)
headers = parser.parse_headers(lines)
# Create the response message object
message = Message(version, method, target, headers, retriever.buffer)
logging.debug("---request begin---\r\n%s---request end---", "".join(message.raw))
# validate if the request is valid
self._validate_request(message)
# The body (if available) hasn't been retrieved up till now.
body = b""
if self._has_body(headers):
try:
retriever = Retriever.create(self.conn, headers)
except UnsupportedEncoding as e:
logging.error("Encoding not supported: %s=%s", e.enc_type, e.encoding)
raise NotImplemented(f"{e.enc_type}={e.encoding}")
for buffer in retriever.retrieve():
body += buffer
message.body = body
# message completed
cmd = command.create(message)
msg = cmd.execute()
logging.debug("---response begin---\r\n%s\r\n---response end---", msg.split(b"\r\n\r\n", 1)[0].decode(FORMAT))
# Send the response message
self.conn.conn.sendall(msg)
def _check_request_line(self, method: str, target: Union[ParseResultBytes, ParseResult], version):
"""
Checks if the request-line is valid. Throws an appropriate exception if not.
@param method: HTTP request method
@param target: The request target
@param version: The HTTP version
@raise MethodNotAllowed: if the method is not any of the allowed methods in `METHODS`
@raise HTTPVersionNotSupported: If the HTTP version is not supported by this server
@raise BadRequest: If the scheme of the target is not supported
@raise NotFound: If the target is not found on this server
"""
if method not in METHODS:
raise MethodNotAllowed(METHODS)
if version not in ("1.0", "1.1"):
raise HTTPVersionNotSupported(version)
# only origin-form and absolute-form are allowed
if target.scheme not in ("", "http"):
# Only http is supported...
raise BadRequest(f"scheme={target.scheme}")
if target.netloc != "" and target.netloc != self.host and target.netloc != self.host.split(":")[0]:
raise NotFound(str(target))
if target.path == "" or target.path[0] != "/":
raise NotFound(str(target))
def _validate_request(self, msg):
"""
Validates the message request-line and headers. Throws an error if the message is invalid.
@see: _check_request_line for exceptions raised when validating the request-line.
@param msg: the message to validate
@raise BadRequest: if HTTP 1.1, and the Host header is missing
"""
if msg.version == "1.1" and "host" not in msg.headers:
raise BadRequest("Missing host header")
self._check_request_line(msg.method, msg.target, msg.version)
def _has_body(self, headers):
"""
Check if the headers notify the existing of a message body.
@param headers: the headers to check
@return: True if the message has a body. False otherwise.
"""
if "transfer-encoding" in headers:
return True
if "content-length" in headers and int(headers["content-length"]) > 0:
return True
return False
@staticmethod
def send_error(client: socket, code, message):
"""
Send and HTTP error response to the client
@param client: the client to send the response to
@param code: the HTTP status code
@param message: the status code message
"""
message = f"HTTP/1.1 {code} {message}\r\n"
message += parser.get_date() + "\r\n"
message += "Content-Length: 0\r\n"
message += "\r\n"
logging.debug("---response begin---\r\n%s---response end---", message)
client.sendall(message.encode(FORMAT))

20
server/serversocket.py Normal file
View File

@@ -0,0 +1,20 @@
from httplib.exceptions import BadRequest
from httplib.httpsocket import HTTPSocket
class ServerSocket(HTTPSocket):
"""
Wrapper class for a socket. Represents a client connected to this server.
"""
"""
Reads the next line decoded as `httpsocket.FORMAT`
@return: the decoded next line retrieved from the socket
@raise InvalidResponse: If the next line couldn't be decoded, but was expected to
"""
def read_line(self):
try:
return super().read_line()
except UnicodeDecodeError:
raise BadRequest("UnicodeDecodeError")

View File

@@ -3,9 +3,10 @@ import multiprocessing as mp
import socket import socket
import threading import threading
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Dict
from httplib.exceptions import HTTPServerException, InternalServerError from httplib.exceptions import HTTPServerException, InternalServerError, HTTPServerCloseException
from server.RequestHandler import RequestHandler from server.requesthandler import RequestHandler
THREAD_LIMIT = 128 THREAD_LIMIT = 128
@@ -18,11 +19,19 @@ def worker(address, name, logging_level, queue: mp.Queue, stop_event: mp.Event):
try: try:
runner.run() runner.run()
except KeyboardInterrupt: except KeyboardInterrupt:
# Catch exit signals and close the threads appropriately.
logging.debug("Ctrl+C pressed, terminating") logging.debug("Ctrl+C pressed, terminating")
runner.shutdown() runner.shutdown()
class Worker: class Worker:
"""
A Worker instance represents a parallel execution process to handle incoming connections.
Worker instances are created when the HTTP server starts. They are used to handle many incoming connections
asynchronously.
"""
host: str host: str
name: str name: str
queue: mp.Queue queue: mp.Queue
@@ -30,24 +39,40 @@ class Worker:
stop_event: mp.Event stop_event: mp.Event
finished_queue: mp.Queue finished_queue: mp.Queue
dispatched_sockets: Dict[int, socket.socket]
def __init__(self, host, name, queue: mp.Queue, stop_event: mp.Event): def __init__(self, host, name, queue: mp.Queue, stop_event: mp.Event):
"""
Create a new Worker instance
@param host: The hostname of the HTTP server
@param name: The name of this Worker instance
@param queue: The dispatch queue for incoming socket connections
@param stop_event: The Event that signals when to shut down this worker.
"""
self.host = host self.host = host
self.name = name self.name = name
self.queue = queue self.queue = queue
self.executor = ThreadPoolExecutor(THREAD_LIMIT) self.executor = ThreadPoolExecutor(THREAD_LIMIT)
self.stop_event = stop_event self.stop_event = stop_event
self.finished_queue = mp.Queue() self.finished_queue = mp.Queue()
self.dispatched_sockets = {}
for i in range(THREAD_LIMIT): for i in range(THREAD_LIMIT):
self.finished_queue.put(i) self.finished_queue.put(i)
def run(self): def run(self):
"""
Run this worker.
The worker will start waiting for incoming clients being added to the queue and submit them to
the executor.
"""
while not self.stop_event.is_set(): while not self.stop_event.is_set():
# Blocks until thread is free # Blocks until the thread is free
self.finished_queue.get() self.finished_queue.get()
# Blocks until new client connects # Blocks until a new client connects
conn, addr = self.queue.get() conn, addr = self.queue.get()
if conn is None or addr is None: if conn is None or addr is None:
@@ -55,28 +80,83 @@ class Worker:
logging.debug("Processing new client: %s", addr) logging.debug("Processing new client: %s", addr)
# submit client to thread # submit the client to the executor
self.executor.submit(self._handle_client, conn, addr) self.executor.submit(self._handle_client, conn, addr)
self.shutdown() self.shutdown()
def _handle_client(self, conn: socket.socket, addr): def _handle_client(self, conn: socket.socket, addr):
"""
Target method for the worker threads.
Creates a RequestHandler and handles any exceptions which may occur.
@param conn: The client socket
@param addr: The address of the client.
"""
self.dispatched_sockets[threading.get_ident()] = conn
try: try:
logging.debug("Handling client: %s", addr) self.__do_handle_client(conn, addr)
except Exception:
if not self.stop_event:
logging.debug("Internal error in thread:", exc_info=True)
handler = RequestHandler(conn, self.host) self.dispatched_sockets.pop(threading.get_ident())
handler.listen()
except HTTPServerException as e:
RequestHandler.send_error(conn, e.status_code, e.message)
except Exception as e:
RequestHandler.send_error(conn, InternalServerError.status_code, InternalServerError.message)
logging.debug("Internal error", exc_info=e)
conn.shutdown(socket.SHUT_RDWR)
conn.close()
# Finished, put back into queue # Finished, put back into queue
self.finished_queue.put(threading.get_ident()) self.finished_queue.put(threading.get_ident())
def __do_handle_client(self, conn: socket.socket, addr):
handler = RequestHandler(conn, self.host)
while True:
try:
handler.listen()
except HTTPServerCloseException as e:
# Exception raised after which the client should be disconnected.
logging.warning("[HTTP: %s] %s. Reason: %s", e.status_code, e.message, e.arg)
RequestHandler.send_error(conn, e.status_code, e.message)
break
except HTTPServerException as e:
# Normal HTTP exception raised (e.a. 404) continue listening.
logging.debug("[HTTP: %s] %s. Reason: %s", e.status_code, e.message, e.arg)
RequestHandler.send_error(conn, e.status_code, e.message)
except socket.timeout:
# socket timed out, disconnect.
logging.info("Socket for client %s timed out.", addr)
break
except ConnectionAbortedError:
# Client aborted connection
logging.info("Socket for client %s disconnected.", addr)
break
except Exception as e:
# Unexpected exception raised. Send 500 and disconnect.
logging.error("Internal error", exc_info=e)
RequestHandler.send_error(conn, InternalServerError.status_code, InternalServerError.message)
break
conn.shutdown(socket.SHUT_RDWR)
conn.close()
def shutdown(self): def shutdown(self):
logging.info("shutting down") logging.info("shutting down")
# shutdown executor, but do not wait
self.executor.shutdown(False)
logging.info("Closing sockets")
# Copy dictionary to prevent issues with concurrency
clients = self.dispatched_sockets.copy().values()
for client in clients:
client: socket.socket
try:
client.shutdown(socket.SHUT_RDWR)
client.close()
except OSError:
# Ignore exception due to already closed sockets
pass
# Call shutdown again and wait this time
self.executor.shutdown() self.executor.shutdown()

View File

@@ -1,4 +0,0 @@
# Flow
- listen
- dispatch asap
- throw error if too full