"Fix img url parsing"

This commit is contained in:
2021-03-24 17:20:40 +01:00
parent d14252f707
commit 7639383782
3 changed files with 16 additions and 10 deletions

View File

@@ -164,12 +164,11 @@ class GetCommand(AbstractCommand):
return Message(version, status, msg, headers)
def _await_response(self, client, retriever) -> str:
def _await_response(self, client, retriever):
msg = self._get_preamble(retriever)
from client import response_handler
self.filename = response_handler.handle(client, msg, self, self.dir)
return
class PostCommand(AbstractWithBodyCommand):

View File

@@ -247,25 +247,27 @@ class HTMLDownloadHandler(DownloadHandler):
file.write(str(soup))
def __download_image(self, img_src, host, base_url):
parsed = urlsplit(img_src)
logging.debug("Downloading image: %s", img_src)
parsed = urlsplit(img_src)
if parsed.scheme not in ("", "http", "https"):
# Not a valid url
return None
if parsed.hostname == host:
if parsed.hostname is None:
if img_src[0] == "/":
img_src = host + img_src
else:
img_src = os.path.join(os.path.dirname(base_url), img_src)
if parsed.hostname is None or parsed.hostname == host:
port = self.cmd.port
elif ":" in parsed.netloc:
port = parsed.netloc.split(":", 1)[1]
else:
port = 80
if len(parsed.netloc) == 0 and parsed.path != "/":
# relative url, append base_url
img_src = os.path.join(os.path.dirname(base_url), parsed.path)
command = GetCommand(img_src, port, os.path.dirname(self.path))
command.execute(True)

View File

@@ -177,6 +177,7 @@ def get_headers(client: HTTPSocket):
return result
def parse_headers(lines):
headers = []
# first header after the status-line may not contain a space
@@ -210,6 +211,7 @@ def parse_headers(lines):
return result
def check_next_header(headers, next_header: str, next_value: str):
if next_header == "content-length":
if "content-length" in headers:
@@ -229,8 +231,11 @@ def parse_uri(uri: str):
path = parsed.path
if parsed.query != '':
path = f"{path}?{parsed.query}"
else:
elif "/" in uri:
(host, path) = uri.split("/", 1)
else:
host = uri
path = "/"
if ":" in host:
host, port = host.split(":", 1)