update
This commit is contained in:
@@ -40,6 +40,7 @@ class AbstractCommand(ABC):
|
||||
message = f"{self.command} {path} HTTP/1.1\r\n"
|
||||
message += f"Host: {host}\r\n"
|
||||
message += "Accept: */*\r\nAccept-Encoding: identity\r\n"
|
||||
message += "User-Agent: Mozilla/5.0 (X11; Linux x86_64; rv:88.0) Gecko/20100101 Firefox/88.0\r\n"
|
||||
encoded_msg = self._build_message(message)
|
||||
|
||||
logging.info("---request begin---\r\n%s---request end---", encoded_msg.decode(FORMAT))
|
||||
|
@@ -4,12 +4,13 @@ from abc import ABC, abstractmethod
|
||||
from typing import Dict
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
import cssutils
|
||||
from bs4 import BeautifulSoup, Tag
|
||||
|
||||
from client.httpclient import HTTPClient, FORMAT
|
||||
from httplib.retriever import Retriever
|
||||
from httplib import parser
|
||||
from httplib.exceptions import InvalidResponse
|
||||
from httplib.retriever import Retriever
|
||||
|
||||
|
||||
class ResponseHandler(ABC):
|
||||
@@ -159,15 +160,15 @@ class HTMLDownloadHandler(DownloadHandler):
|
||||
file.write(buffer)
|
||||
file.close()
|
||||
|
||||
self.__download_images(tmp_path, self.path)
|
||||
self._download_images(tmp_path, self.path)
|
||||
os.remove(tmp_path)
|
||||
return self.path
|
||||
|
||||
def __download_images(self, tmp_filename, target_filename):
|
||||
def _download_images(self, tmp_filename, target_filename):
|
||||
|
||||
(host, path) = ResponseHandler.parse_uri(self.url)
|
||||
with open(tmp_filename, "rb") as fp:
|
||||
soup = BeautifulSoup(fp, 'html.parser')
|
||||
soup = BeautifulSoup(fp, 'lxml')
|
||||
|
||||
base_url = self.url
|
||||
base_element = soup.find("base")
|
||||
@@ -175,13 +176,51 @@ class HTMLDownloadHandler(DownloadHandler):
|
||||
if base_element:
|
||||
base_url = base_element["href"]
|
||||
|
||||
processed = {}
|
||||
tag: Tag
|
||||
for tag in soup.find_all("img"):
|
||||
try:
|
||||
tag["src"] = self.__download_image(tag["src"], host, base_url)
|
||||
if tag["src"] in processed:
|
||||
new_url = processed.get(tag["src"])
|
||||
else:
|
||||
new_url = self.__download_image(tag["src"], host, base_url)
|
||||
processed[tag["src"]] = new_url
|
||||
if new_url:
|
||||
tag["src"] = new_url
|
||||
except Exception as e:
|
||||
logging.debug(e)
|
||||
logging.error("Failed to download image: %s, skipping...", tag["src"])
|
||||
|
||||
for tag in soup.find_all("div"):
|
||||
if not tag.has_attr("style"):
|
||||
continue
|
||||
style = cssutils.parseStyle(tag["style"])
|
||||
|
||||
if "background" in style and "url(" in style["background"]:
|
||||
el_name = "background"
|
||||
elif "background-image" in style and "url(" in style["background-image"]:
|
||||
el_name = "background-image"
|
||||
else:
|
||||
continue
|
||||
el = style[el_name]
|
||||
start = el.find("url(") + 4
|
||||
end = el.find(")", start)
|
||||
url = el[start:end].strip()
|
||||
|
||||
try:
|
||||
if url in processed:
|
||||
new_url = url
|
||||
else:
|
||||
new_url = self.__download_image(url, host, base_url)
|
||||
processed[url] = new_url
|
||||
if new_url:
|
||||
el = el[:start] + new_url + el[end:]
|
||||
style[el_name] = el
|
||||
tag["style"] = style.cssText
|
||||
except Exception as e:
|
||||
logging.debug("Internal error", exc_info=e)
|
||||
logging.error("Failed to download image: %s, skipping...", tag["src"])
|
||||
|
||||
with open(target_filename, 'w') as file:
|
||||
file.write(str(soup))
|
||||
|
||||
@@ -190,6 +229,10 @@ class HTMLDownloadHandler(DownloadHandler):
|
||||
|
||||
logging.debug("Downloading image: %s", img_src)
|
||||
|
||||
if parsed.scheme not in ("", "http"):
|
||||
# Not a valid url
|
||||
return None
|
||||
|
||||
if len(parsed.netloc) == 0 and parsed.path != "/":
|
||||
# relative url, append base_url
|
||||
img_src = os.path.join(os.path.dirname(base_url), parsed.path)
|
||||
|
@@ -38,4 +38,10 @@ class BadRequest(HTTPServerException):
|
||||
class MethodNotAllowed(HTTPServerException):
|
||||
""" Method is not allowed """
|
||||
def __init(self, allowed_methods):
|
||||
self.allowed_methods = allowed_methods
|
||||
self.allowed_methods = allowed_methods
|
||||
|
||||
class NotImplemented(HTTPServerException):
|
||||
""" Functionality not implemented """
|
||||
|
||||
class NotFound(HTTPServerException):
|
||||
""" Resource not found """
|
@@ -19,7 +19,7 @@ class HTTPSocket:
|
||||
self.conn = conn
|
||||
self.conn.settimeout(TIMEOUT)
|
||||
self.conn.setblocking(True)
|
||||
self.conn.settimeout(3.0)
|
||||
self.conn.settimeout(60)
|
||||
self.file = self.conn.makefile("rb")
|
||||
|
||||
def close(self):
|
||||
|
@@ -7,7 +7,7 @@ from httplib.httpsocket import HTTPSocket
|
||||
|
||||
|
||||
def _get_start_line(client: HTTPSocket):
|
||||
line = client.read_line()
|
||||
line = client.read_line().strip()
|
||||
split = list(filter(None, line.split(" ")))
|
||||
if len(split) < 3:
|
||||
raise InvalidStatusLine(line) # TODO fix exception
|
||||
@@ -23,6 +23,8 @@ def _is_valid_http_version(http_version: str):
|
||||
if name != "HTTP" or not re.match(r"1\.[0|1]", version):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_status_line(client: HTTPSocket):
|
||||
line, (http_version, status, reason) = _get_start_line(client)
|
||||
@@ -43,17 +45,22 @@ def get_status_line(client: HTTPSocket):
|
||||
def parse_request_line(client: HTTPSocket):
|
||||
line, (method, target, version) = _get_start_line(client)
|
||||
|
||||
logging.debug("Parsed request-line=%r, method=%r, target=%r, version=%r", line, method, target, version)
|
||||
|
||||
if method not in ("CONNECT", "DELETE", "GET", "HEAD", "OPTIONS", "POST", "PUT", "TRACE"):
|
||||
raise BadRequest()
|
||||
|
||||
if not _is_valid_http_version(version):
|
||||
logging.debug("[ABRT] request: invalid http-version=%r", version)
|
||||
raise BadRequest()
|
||||
|
||||
if len(target) == "":
|
||||
raise BadRequest()
|
||||
parsed_target = urlparse(target)
|
||||
if len(parsed_target.path) > 0 and parsed_target.path[0] != "/" and parsed_target.netloc != "":
|
||||
parsed_target = urlparse(f"//{target}")
|
||||
|
||||
return method, parsed_target, version
|
||||
return method, parsed_target, version.split("/")[1]
|
||||
|
||||
|
||||
def retrieve_headers(client: HTTPSocket):
|
||||
@@ -85,13 +92,14 @@ def retrieve_headers(client: HTTPSocket):
|
||||
continue
|
||||
|
||||
(header, value) = line.split(":", 1)
|
||||
result.append((header.lower(), value.lower()))
|
||||
result.append((header.lower(), value.strip().lower()))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def parse_request_headers(client: HTTPSocket):
|
||||
raw_headers = retrieve_headers(client)
|
||||
logging.debug("Received headers: %r", raw_headers)
|
||||
headers = {}
|
||||
|
||||
key: str
|
||||
@@ -107,7 +115,7 @@ def parse_request_headers(client: HTTPSocket):
|
||||
logging.error("Invalid content-length value: %r", value)
|
||||
raise BadRequest()
|
||||
elif key == "host":
|
||||
if value != client.host or key in headers:
|
||||
if value != client.host and value != client.host.split(":")[0] or key in headers:
|
||||
raise BadRequest()
|
||||
|
||||
headers[key] = value
|
||||
|
@@ -1,2 +1,3 @@
|
||||
beautifulsoup4~=4.9.3
|
||||
lxml==4.6.2
|
||||
lxml~=4.6.2
|
||||
cssutils~=2.2.0
|
@@ -18,7 +18,7 @@ def main():
|
||||
arguments = parser.parse_args()
|
||||
|
||||
logging_level = logging.ERROR - (10 * arguments.verbose)
|
||||
logging.basicConfig(level=logging_level)
|
||||
logging.basicConfig(level=logging_level, format="%(levelname)s:[SERVER] %(message)s")
|
||||
logging.debug("Arguments: %s", arguments)
|
||||
|
||||
# Set workers
|
||||
|
@@ -1,12 +1,17 @@
|
||||
import logging
|
||||
from logging import Logger
|
||||
import mimetypes
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from socket import socket
|
||||
from time import mktime
|
||||
from typing import Union
|
||||
from urllib.parse import ParseResultBytes, ParseResult
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
from httplib import parser
|
||||
from httplib.exceptions import MethodNotAllowed, BadRequest
|
||||
from httplib.httpsocket import HTTPSocket
|
||||
from httplib.exceptions import MethodNotAllowed, BadRequest, UnsupportedEncoding, NotImplemented, NotFound
|
||||
from httplib.httpsocket import HTTPSocket, FORMAT
|
||||
from httplib.retriever import Retriever
|
||||
|
||||
METHODS = ("GET", "HEAD", "PUT", "POST")
|
||||
@@ -14,44 +19,98 @@ METHODS = ("GET", "HEAD", "PUT", "POST")
|
||||
|
||||
class RequestHandler:
|
||||
conn: HTTPSocket
|
||||
logger: Logger
|
||||
root = os.path.join(os.path.dirname(sys.argv[0]), "public")
|
||||
|
||||
def __init__(self, conn: socket, logger, host):
|
||||
def __init__(self, conn: socket, host):
|
||||
self.conn = HTTPSocket(conn, host)
|
||||
self.logger = logger
|
||||
|
||||
def listen(self):
|
||||
self.logger.debug("Parsing request line")
|
||||
logging.debug("test logger")
|
||||
logging.debug("Parsing request line")
|
||||
(method, target, version) = parser.parse_request_line(self.conn)
|
||||
headers = parser.parse_request_headers(self.conn)
|
||||
|
||||
self._validate_request(method, target, version, headers)
|
||||
|
||||
self.logger.debug("Parsed request-line: version: %s, target: %r", method, target)
|
||||
headers = parser.get_headers(self.conn)
|
||||
self.logger.debug("Parsed headers: %r", headers)
|
||||
retriever = Retriever.create(self.conn, headers)
|
||||
body = retriever.retrieve()
|
||||
logging.debug("Parsed request-line: method: %s, target: %r", method, target)
|
||||
|
||||
self.logger.debug("body: %r", body)
|
||||
body = b""
|
||||
if self._has_body(headers):
|
||||
try:
|
||||
retriever = Retriever.create(self.conn, headers)
|
||||
except UnsupportedEncoding as e:
|
||||
logging.error("Encoding not supported: %s=%s", e.enc_type, e.encoding)
|
||||
raise NotImplemented()
|
||||
|
||||
for buffer in retriever.retrieve():
|
||||
body += buffer
|
||||
|
||||
# completed message
|
||||
self._handle_message(method, target.path, body)
|
||||
|
||||
def _check_request_line(self, method: str, target: Union[ParseResultBytes, ParseResult], version):
|
||||
|
||||
if method not in METHODS:
|
||||
raise MethodNotAllowed(METHODS)
|
||||
|
||||
# only origin-form and absolute-form are allowed
|
||||
if len(target.path) < 1 or target.path[0] != "/" or \
|
||||
target.netloc not in ("http", "https") and target.hostname == "":
|
||||
raise BadRequest()
|
||||
|
||||
if version not in ("1.0", "1.1"):
|
||||
raise BadRequest()
|
||||
|
||||
# only origin-form and absolute-form are allowed
|
||||
if target.scheme not in ("", "http"):
|
||||
# Only http is supported...
|
||||
raise BadRequest()
|
||||
if target.netloc != "" and target.netloc != self.conn.host and target.netloc != self.conn.host.split(":")[0]:
|
||||
raise NotFound()
|
||||
|
||||
if target.path == "" or target.path[0] != "/":
|
||||
raise NotFound()
|
||||
|
||||
norm_path = os.path.normpath(target.path)
|
||||
if not os.path.exists(self.root + norm_path):
|
||||
raise NotFound()
|
||||
|
||||
def _validate_request(self, method, target, version, headers):
|
||||
if version == "1.1" and "host" not in headers:
|
||||
raise BadRequest()
|
||||
|
||||
self._check_request_line(method, target, version)
|
||||
|
||||
if version == "1.1" and "host" not in headers:
|
||||
raise BadRequest()
|
||||
def _has_body(self, headers):
|
||||
return "transfer-encoding" in headers or "content-encoding" in headers
|
||||
|
||||
def _get_date(self):
|
||||
now = datetime.now()
|
||||
stamp = mktime(now.timetuple())
|
||||
return format_date_time(stamp)
|
||||
|
||||
def _handle_message(self, method: str, target, body: bytes):
|
||||
date = self._get_date()
|
||||
|
||||
if method == "GET":
|
||||
if target == "/":
|
||||
path = self.root + "/index.html"
|
||||
else:
|
||||
path = self.root + target
|
||||
mime = mimetypes.guess_type(path)[0]
|
||||
if mime.startswith("test"):
|
||||
file = open(path, "rb", FORMAT)
|
||||
else:
|
||||
file = open(path, "rb")
|
||||
buffer = file.read()
|
||||
file.close()
|
||||
|
||||
message = "HTTP/1.1 200 OK\r\n"
|
||||
message += date + "\r\n"
|
||||
if mime:
|
||||
message += f"Content-Type: {mime}"
|
||||
if mime.startswith("test"):
|
||||
message += "; charset=UTF-8"
|
||||
message += "\r\n"
|
||||
message += f"Content-Length: {len(buffer)}\r\n"
|
||||
message += "\r\n"
|
||||
message = message.encode(FORMAT)
|
||||
message += buffer
|
||||
message += b"\r\n"
|
||||
|
||||
logging.debug("Sending: %r", message)
|
||||
self.conn.conn.sendall(message)
|
||||
|
@@ -87,8 +87,9 @@ class HTTPServer:
|
||||
for i in range(self.worker_count):
|
||||
logging.debug("Creating worker: %d", i + 1)
|
||||
p = mp.Process(target=worker.worker,
|
||||
args=(self.address, i + 1, self.logging_level, self._dispatch_queue, self._stop_event))
|
||||
args=(f"{self.address}:{self.port}", i + 1, self.logging_level, self._dispatch_queue, self._stop_event))
|
||||
p.start()
|
||||
self.workers.append(p)
|
||||
|
||||
time.sleep(0.1)
|
||||
time.sleep(0.2)
|
||||
time.sleep(1)
|
||||
|
@@ -4,23 +4,22 @@ import multiprocessing as mp
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from logging import Logger
|
||||
from socket import socket
|
||||
import socket
|
||||
|
||||
from server.RequestHandler import RequestHandler
|
||||
|
||||
THREAD_LIMIT = 20
|
||||
THREAD_LIMIT = 128
|
||||
|
||||
|
||||
def worker(address, name, log_level, queue: mp.Queue, stop_event: mp.Event):
|
||||
logging.basicConfig(level=log_level)
|
||||
logger = multiprocessing.log_to_stderr(level=log_level)
|
||||
runner = Worker(address, name, logger, queue, stop_event)
|
||||
runner.logger.debug("Worker %s started", name)
|
||||
def worker(address, name, logging_level, queue: mp.Queue, stop_event: mp.Event):
|
||||
logging.basicConfig(level=logging_level, format="%(levelname)s:[WORKER " + str(name) + "] %(message)s")
|
||||
runner = Worker(address, name, queue, stop_event)
|
||||
logging.debug("started")
|
||||
|
||||
try:
|
||||
runner.run()
|
||||
except KeyboardInterrupt:
|
||||
logger.debug("Ctrl+C pressed, terminating")
|
||||
logging.debug("Ctrl+C pressed, terminating")
|
||||
runner.shutdown()
|
||||
|
||||
|
||||
@@ -35,10 +34,9 @@ class Worker:
|
||||
|
||||
finished_queue: mp.Queue
|
||||
|
||||
def __init__(self, host, name, logger, queue: mp.Queue, stop_event: mp.Event):
|
||||
def __init__(self, host, name, queue: mp.Queue, stop_event: mp.Event):
|
||||
self.host = host
|
||||
self.name = name
|
||||
self.logger = logger
|
||||
self.queue = queue
|
||||
self.executor = ThreadPoolExecutor(THREAD_LIMIT)
|
||||
self.stop_event = stop_event
|
||||
@@ -58,26 +56,27 @@ class Worker:
|
||||
if conn is None or addr is None:
|
||||
break
|
||||
|
||||
self.logger.debug("Received new client: %s", addr)
|
||||
logging.debug("Processing new client: %s", addr)
|
||||
|
||||
# submit client to thread
|
||||
print(threading.get_ident())
|
||||
self.executor.submit(self._handle_client, conn, addr)
|
||||
|
||||
self.shutdown()
|
||||
|
||||
def _handle_client(self, conn: socket, addr):
|
||||
def _handle_client(self, conn: socket.socket, addr):
|
||||
try:
|
||||
self.logger.debug("Handling client: %s", addr)
|
||||
logging.debug("Handling client: %s", addr)
|
||||
|
||||
handler = RequestHandler(conn, self.logger, self.host)
|
||||
handler = RequestHandler(conn, self.host)
|
||||
handler.listen()
|
||||
except Exception as e:
|
||||
self.logger.debug("Internal error", exc_info=e)
|
||||
logging.debug("Internal error")
|
||||
|
||||
conn.shutdown(socket.SHUT_RDWR)
|
||||
conn.close()
|
||||
# Finished, put back into queue
|
||||
self.finished_queue.put(threading.get_ident())
|
||||
|
||||
def shutdown(self):
|
||||
self.logger.info("shutting down")
|
||||
logging.info("shutting down")
|
||||
self.executor.shutdown()
|
||||
|
Reference in New Issue
Block a user