[core] Improve HTTP redirect handling (#7094)

Aligns HTTP redirect handling with what browsers commonly do and RFC standards. 

Fixes issues afac4caa7d missed.

Authored by: coletdjnz
This commit is contained in:
coletdjnz 2023-05-27 19:06:13 +12:00 committed by GitHub
parent 66468bbf49
commit 08916a49c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 281 additions and 72 deletions

View File

@ -10,7 +10,6 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import copy import copy
import json import json
import urllib.error
from test.helper import FakeYDL, assertRegexpMatches from test.helper import FakeYDL, assertRegexpMatches
from yt_dlp import YoutubeDL from yt_dlp import YoutubeDL
@ -1097,11 +1096,6 @@ class TestYoutubeDL(unittest.TestCase):
test_selection({'playlist_items': '-15::2'}, INDICES[1::2], True) test_selection({'playlist_items': '-15::2'}, INDICES[1::2], True)
test_selection({'playlist_items': '-15::15'}, [], True) test_selection({'playlist_items': '-15::15'}, [], True)
def test_urlopen_no_file_protocol(self):
# see https://github.com/ytdl-org/youtube-dl/issues/8227
ydl = YDL()
self.assertRaises(urllib.error.URLError, ydl.urlopen, 'file:///etc/passwd')
def test_do_not_override_ie_key_in_url_transparent(self): def test_do_not_override_ie_key_in_url_transparent(self):
ydl = YDL() ydl = YDL()

View File

@ -7,40 +7,163 @@ import unittest
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import gzip
import http.cookiejar
import http.server import http.server
import io
import pathlib
import ssl import ssl
import tempfile
import threading import threading
import urllib.error
import urllib.request import urllib.request
from test.helper import http_server_port from test.helper import http_server_port
from yt_dlp import YoutubeDL from yt_dlp import YoutubeDL
from yt_dlp.utils import sanitized_Request, urlencode_postdata
from .helper import FakeYDL
TEST_DIR = os.path.dirname(os.path.abspath(__file__)) TEST_DIR = os.path.dirname(os.path.abspath(__file__))
class HTTPTestRequestHandler(http.server.BaseHTTPRequestHandler): class HTTPTestRequestHandler(http.server.BaseHTTPRequestHandler):
protocol_version = 'HTTP/1.1'
def log_message(self, format, *args): def log_message(self, format, *args):
pass pass
def _headers(self):
payload = str(self.headers).encode('utf-8')
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.send_header('Content-Length', str(len(payload)))
self.end_headers()
self.wfile.write(payload)
def _redirect(self):
self.send_response(int(self.path[len('/redirect_'):]))
self.send_header('Location', '/method')
self.send_header('Content-Length', '0')
self.end_headers()
def _method(self, method, payload=None):
self.send_response(200)
self.send_header('Content-Length', str(len(payload or '')))
self.send_header('Method', method)
self.end_headers()
if payload:
self.wfile.write(payload)
def _status(self, status):
payload = f'<html>{status} NOT FOUND</html>'.encode()
self.send_response(int(status))
self.send_header('Content-Type', 'text/html; charset=utf-8')
self.send_header('Content-Length', str(len(payload)))
self.end_headers()
self.wfile.write(payload)
def _read_data(self):
if 'Content-Length' in self.headers:
return self.rfile.read(int(self.headers['Content-Length']))
def do_POST(self):
data = self._read_data()
if self.path.startswith('/redirect_'):
self._redirect()
elif self.path.startswith('/method'):
self._method('POST', data)
elif self.path.startswith('/headers'):
self._headers()
else:
self._status(404)
def do_HEAD(self):
if self.path.startswith('/redirect_'):
self._redirect()
elif self.path.startswith('/method'):
self._method('HEAD')
else:
self._status(404)
def do_PUT(self):
data = self._read_data()
if self.path.startswith('/redirect_'):
self._redirect()
elif self.path.startswith('/method'):
self._method('PUT', data)
else:
self._status(404)
def do_GET(self): def do_GET(self):
if self.path == '/video.html': if self.path == '/video.html':
payload = b'<html><video src="/vid.mp4" /></html>'
self.send_response(200) self.send_response(200)
self.send_header('Content-Type', 'text/html; charset=utf-8') self.send_header('Content-Type', 'text/html; charset=utf-8')
self.send_header('Content-Length', str(len(payload))) # required for persistent connections
self.end_headers() self.end_headers()
self.wfile.write(b'<html><video src="/vid.mp4" /></html>') self.wfile.write(payload)
elif self.path == '/vid.mp4': elif self.path == '/vid.mp4':
payload = b'\x00\x00\x00\x00\x20\x66\x74[video]'
self.send_response(200) self.send_response(200)
self.send_header('Content-Type', 'video/mp4') self.send_header('Content-Type', 'video/mp4')
self.send_header('Content-Length', str(len(payload)))
self.end_headers() self.end_headers()
self.wfile.write(b'\x00\x00\x00\x00\x20\x66\x74[video]') self.wfile.write(payload)
elif self.path == '/%E4%B8%AD%E6%96%87.html': elif self.path == '/%E4%B8%AD%E6%96%87.html':
payload = b'<html><video src="/vid.mp4" /></html>'
self.send_response(200) self.send_response(200)
self.send_header('Content-Type', 'text/html; charset=utf-8') self.send_header('Content-Type', 'text/html; charset=utf-8')
self.send_header('Content-Length', str(len(payload)))
self.end_headers()
self.wfile.write(payload)
elif self.path == '/%c7%9f':
payload = b'<html><video src="/vid.mp4" /></html>'
self.send_response(200)
self.send_header('Content-Type', 'text/html; charset=utf-8')
self.send_header('Content-Length', str(len(payload)))
self.end_headers()
self.wfile.write(payload)
elif self.path.startswith('/redirect_'):
self._redirect()
elif self.path.startswith('/method'):
self._method('GET')
elif self.path.startswith('/headers'):
self._headers()
elif self.path == '/trailing_garbage':
payload = b'<html><video src="/vid.mp4" /></html>'
self.send_response(200)
self.send_header('Content-Type', 'text/html; charset=utf-8')
self.send_header('Content-Encoding', 'gzip')
buf = io.BytesIO()
with gzip.GzipFile(fileobj=buf, mode='wb') as f:
f.write(payload)
compressed = buf.getvalue() + b'trailing garbage'
self.send_header('Content-Length', str(len(compressed)))
self.end_headers()
self.wfile.write(compressed)
elif self.path == '/302-non-ascii-redirect':
new_url = f'http://127.0.0.1:{http_server_port(self.server)}/中文.html'
self.send_response(301)
self.send_header('Location', new_url)
self.send_header('Content-Length', '0')
self.end_headers() self.end_headers()
self.wfile.write(b'<html><video src="/vid.mp4" /></html>')
else: else:
assert False self._status(404)
def send_header(self, keyword, value):
"""
Forcibly allow HTTP server to send non percent-encoded non-ASCII characters in headers.
This is against what is defined in RFC 3986, however we need to test we support this
since some sites incorrectly do this.
"""
if keyword.lower() == 'connection':
return super().send_header(keyword, value)
if not hasattr(self, '_headers_buffer'):
self._headers_buffer = []
self._headers_buffer.append(f'{keyword}: {value}\r\n'.encode())
class FakeLogger: class FakeLogger:
@ -56,36 +179,128 @@ class FakeLogger:
class TestHTTP(unittest.TestCase): class TestHTTP(unittest.TestCase):
def setUp(self): def setUp(self):
self.httpd = http.server.HTTPServer( # HTTP server
self.http_httpd = http.server.ThreadingHTTPServer(
('127.0.0.1', 0), HTTPTestRequestHandler) ('127.0.0.1', 0), HTTPTestRequestHandler)
self.port = http_server_port(self.httpd) self.http_port = http_server_port(self.http_httpd)
self.server_thread = threading.Thread(target=self.httpd.serve_forever) self.http_server_thread = threading.Thread(target=self.http_httpd.serve_forever)
self.server_thread.daemon = True # FIXME: we should probably stop the http server thread after each test
self.server_thread.start() # See: https://github.com/yt-dlp/yt-dlp/pull/7094#discussion_r1199746041
self.http_server_thread.daemon = True
self.http_server_thread.start()
# HTTPS server
class TestHTTPS(unittest.TestCase):
def setUp(self):
certfn = os.path.join(TEST_DIR, 'testcert.pem') certfn = os.path.join(TEST_DIR, 'testcert.pem')
self.httpd = http.server.HTTPServer( self.https_httpd = http.server.ThreadingHTTPServer(
('127.0.0.1', 0), HTTPTestRequestHandler) ('127.0.0.1', 0), HTTPTestRequestHandler)
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
sslctx.load_cert_chain(certfn, None) sslctx.load_cert_chain(certfn, None)
self.httpd.socket = sslctx.wrap_socket(self.httpd.socket, server_side=True) self.https_httpd.socket = sslctx.wrap_socket(self.https_httpd.socket, server_side=True)
self.port = http_server_port(self.httpd) self.https_port = http_server_port(self.https_httpd)
self.server_thread = threading.Thread(target=self.httpd.serve_forever) self.https_server_thread = threading.Thread(target=self.https_httpd.serve_forever)
self.server_thread.daemon = True self.https_server_thread.daemon = True
self.server_thread.start() self.https_server_thread.start()
def test_nocheckcertificate(self): def test_nocheckcertificate(self):
ydl = YoutubeDL({'logger': FakeLogger()}) with FakeYDL({'logger': FakeLogger()}) as ydl:
self.assertRaises( with self.assertRaises(urllib.error.URLError):
Exception, ydl.urlopen(sanitized_Request(f'https://127.0.0.1:{self.https_port}/headers'))
ydl.extract_info, 'https://127.0.0.1:%d/video.html' % self.port)
ydl = YoutubeDL({'logger': FakeLogger(), 'nocheckcertificate': True}) with FakeYDL({'logger': FakeLogger(), 'nocheckcertificate': True}) as ydl:
r = ydl.extract_info('https://127.0.0.1:%d/video.html' % self.port) r = ydl.urlopen(sanitized_Request(f'https://127.0.0.1:{self.https_port}/headers'))
self.assertEqual(r['url'], 'https://127.0.0.1:%d/vid.mp4' % self.port) self.assertEqual(r.status, 200)
r.close()
def test_percent_encode(self):
with FakeYDL() as ydl:
# Unicode characters should be encoded with uppercase percent-encoding
res = ydl.urlopen(sanitized_Request(f'http://127.0.0.1:{self.http_port}/中文.html'))
self.assertEqual(res.status, 200)
res.close()
# don't normalize existing percent encodings
res = ydl.urlopen(sanitized_Request(f'http://127.0.0.1:{self.http_port}/%c7%9f'))
self.assertEqual(res.status, 200)
res.close()
def test_unicode_path_redirection(self):
with FakeYDL() as ydl:
r = ydl.urlopen(sanitized_Request(f'http://127.0.0.1:{self.http_port}/302-non-ascii-redirect'))
self.assertEqual(r.url, f'http://127.0.0.1:{self.http_port}/%E4%B8%AD%E6%96%87.html')
r.close()
def test_redirect(self):
with FakeYDL() as ydl:
def do_req(redirect_status, method):
data = b'testdata' if method in ('POST', 'PUT') else None
res = ydl.urlopen(sanitized_Request(
f'http://127.0.0.1:{self.http_port}/redirect_{redirect_status}', method=method, data=data))
return res.read().decode('utf-8'), res.headers.get('method', '')
# A 303 must either use GET or HEAD for subsequent request
self.assertEqual(do_req(303, 'POST'), ('', 'GET'))
self.assertEqual(do_req(303, 'HEAD'), ('', 'HEAD'))
self.assertEqual(do_req(303, 'PUT'), ('', 'GET'))
# 301 and 302 turn POST only into a GET
self.assertEqual(do_req(301, 'POST'), ('', 'GET'))
self.assertEqual(do_req(301, 'HEAD'), ('', 'HEAD'))
self.assertEqual(do_req(302, 'POST'), ('', 'GET'))
self.assertEqual(do_req(302, 'HEAD'), ('', 'HEAD'))
self.assertEqual(do_req(301, 'PUT'), ('testdata', 'PUT'))
self.assertEqual(do_req(302, 'PUT'), ('testdata', 'PUT'))
# 307 and 308 should not change method
for m in ('POST', 'PUT'):
self.assertEqual(do_req(307, m), ('testdata', m))
self.assertEqual(do_req(308, m), ('testdata', m))
self.assertEqual(do_req(307, 'HEAD'), ('', 'HEAD'))
self.assertEqual(do_req(308, 'HEAD'), ('', 'HEAD'))
# These should not redirect and instead raise an HTTPError
for code in (300, 304, 305, 306):
with self.assertRaises(urllib.error.HTTPError):
do_req(code, 'GET')
def test_content_type(self):
# https://github.com/yt-dlp/yt-dlp/commit/379a4f161d4ad3e40932dcf5aca6e6fb9715ab28
with FakeYDL({'nocheckcertificate': True}) as ydl:
# method should be auto-detected as POST
r = sanitized_Request(f'https://localhost:{self.https_port}/headers', data=urlencode_postdata({'test': 'test'}))
headers = ydl.urlopen(r).read().decode('utf-8')
self.assertIn('Content-Type: application/x-www-form-urlencoded', headers)
# test http
r = sanitized_Request(f'http://localhost:{self.http_port}/headers', data=urlencode_postdata({'test': 'test'}))
headers = ydl.urlopen(r).read().decode('utf-8')
self.assertIn('Content-Type: application/x-www-form-urlencoded', headers)
def test_cookiejar(self):
with FakeYDL() as ydl:
ydl.cookiejar.set_cookie(http.cookiejar.Cookie(
0, 'test', 'ytdlp', None, False, '127.0.0.1', True,
False, '/headers', True, False, None, False, None, None, {}))
data = ydl.urlopen(sanitized_Request(f'http://127.0.0.1:{self.http_port}/headers')).read()
self.assertIn(b'Cookie: test=ytdlp', data)
def test_no_compression_compat_header(self):
with FakeYDL() as ydl:
data = ydl.urlopen(
sanitized_Request(
f'http://127.0.0.1:{self.http_port}/headers',
headers={'Youtubedl-no-compression': True})).read()
self.assertIn(b'Accept-Encoding: identity', data)
self.assertNotIn(b'youtubedl-no-compression', data.lower())
def test_gzip_trailing_garbage(self):
# https://github.com/ytdl-org/youtube-dl/commit/aa3e950764337ef9800c936f4de89b31c00dfcf5
# https://github.com/ytdl-org/youtube-dl/commit/6f2ec15cee79d35dba065677cad9da7491ec6e6f
with FakeYDL() as ydl:
data = ydl.urlopen(sanitized_Request(f'http://localhost:{self.http_port}/trailing_garbage')).read().decode('utf-8')
self.assertEqual(data, '<html><video src="/vid.mp4" /></html>')
class TestClientCert(unittest.TestCase): class TestClientCert(unittest.TestCase):
@ -112,8 +327,8 @@ class TestClientCert(unittest.TestCase):
'nocheckcertificate': True, 'nocheckcertificate': True,
**params, **params,
}) })
r = ydl.extract_info('https://127.0.0.1:%d/video.html' % self.port) r = ydl.extract_info(f'https://127.0.0.1:{self.port}/video.html')
self.assertEqual(r['url'], 'https://127.0.0.1:%d/vid.mp4' % self.port) self.assertEqual(r['url'], f'https://127.0.0.1:{self.port}/vid.mp4')
def test_certificate_combined_nopass(self): def test_certificate_combined_nopass(self):
self._run_test(client_certificate=os.path.join(self.certdir, 'clientwithkey.crt')) self._run_test(client_certificate=os.path.join(self.certdir, 'clientwithkey.crt'))
@ -188,5 +403,22 @@ class TestProxy(unittest.TestCase):
self.assertEqual(response, 'normal: http://xn--fiq228c.tw/') self.assertEqual(response, 'normal: http://xn--fiq228c.tw/')
class TestFileURL(unittest.TestCase):
# See https://github.com/ytdl-org/youtube-dl/issues/8227
def test_file_urls(self):
tf = tempfile.NamedTemporaryFile(delete=False)
tf.write(b'foobar')
tf.close()
url = pathlib.Path(tf.name).as_uri()
with FakeYDL() as ydl:
self.assertRaisesRegex(
urllib.error.URLError, 'file:// URLs are explicitly disabled in yt-dlp for security reasons', ydl.urlopen, url)
with FakeYDL({'enable_file_urls': True}) as ydl:
res = ydl.urlopen(url)
self.assertEqual(res.read(), b'foobar')
res.close()
os.unlink(tf.name)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -1664,61 +1664,44 @@ class YoutubeDLRedirectHandler(urllib.request.HTTPRedirectHandler):
The code is based on HTTPRedirectHandler implementation from CPython [1]. The code is based on HTTPRedirectHandler implementation from CPython [1].
This redirect handler solves two issues: This redirect handler fixes and improves the logic to better align with RFC7261
- ensures redirect URL is always unicode under python 2 and what browsers tend to do [2][3]
- introduces support for experimental HTTP response status code
308 Permanent Redirect [2] used by some sites [3]
1. https://github.com/python/cpython/blob/master/Lib/urllib/request.py 1. https://github.com/python/cpython/blob/master/Lib/urllib/request.py
2. https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/308 2. https://datatracker.ietf.org/doc/html/rfc7231
3. https://github.com/ytdl-org/youtube-dl/issues/28768 3. https://github.com/python/cpython/issues/91306
""" """
http_error_301 = http_error_303 = http_error_307 = http_error_308 = urllib.request.HTTPRedirectHandler.http_error_302 http_error_301 = http_error_303 = http_error_307 = http_error_308 = urllib.request.HTTPRedirectHandler.http_error_302
def redirect_request(self, req, fp, code, msg, headers, newurl): def redirect_request(self, req, fp, code, msg, headers, newurl):
"""Return a Request or None in response to a redirect. if code not in (301, 302, 303, 307, 308):
This is called by the http_error_30x methods when a
redirection response is received. If a redirection should
take place, return a new Request to allow http_error_30x to
perform the redirect. Otherwise, raise HTTPError if no-one
else should try to handle this url. Return None if you can't
but another Handler might.
"""
m = req.get_method()
if (not (code in (301, 302, 303, 307, 308) and m in ("GET", "HEAD")
or code in (301, 302, 303) and m == "POST")):
raise urllib.error.HTTPError(req.full_url, code, msg, headers, fp) raise urllib.error.HTTPError(req.full_url, code, msg, headers, fp)
# Strictly (according to RFC 2616), 301 or 302 in response to
# a POST MUST NOT cause a redirection without confirmation
# from the user (of urllib.request, in this case). In practice,
# essentially all clients do redirect in this case, so we do
# the same.
# Be conciliant with URIs containing a space. This is mainly
# redundant with the more complete encoding done in http_error_302(),
# but it is kept for compatibility with other callers.
newurl = newurl.replace(' ', '%20')
CONTENT_HEADERS = ("content-length", "content-type")
# NB: don't use dict comprehension for python 2.6 compatibility
newheaders = {k: v for k, v in req.headers.items() if k.lower() not in CONTENT_HEADERS}
new_method = req.get_method()
new_data = req.data
remove_headers = []
# A 303 must either use GET or HEAD for subsequent request # A 303 must either use GET or HEAD for subsequent request
# https://datatracker.ietf.org/doc/html/rfc7231#section-6.4.4 # https://datatracker.ietf.org/doc/html/rfc7231#section-6.4.4
if code == 303 and m != 'HEAD': if code == 303 and req.get_method() != 'HEAD':
m = 'GET' new_method = 'GET'
# 301 and 302 redirects are commonly turned into a GET from a POST # 301 and 302 redirects are commonly turned into a GET from a POST
# for subsequent requests by browsers, so we'll do the same. # for subsequent requests by browsers, so we'll do the same.
# https://datatracker.ietf.org/doc/html/rfc7231#section-6.4.2 # https://datatracker.ietf.org/doc/html/rfc7231#section-6.4.2
# https://datatracker.ietf.org/doc/html/rfc7231#section-6.4.3 # https://datatracker.ietf.org/doc/html/rfc7231#section-6.4.3
if code in (301, 302) and m == 'POST': elif code in (301, 302) and req.get_method() == 'POST':
m = 'GET' new_method = 'GET'
# only remove payload if method changed (e.g. POST to GET)
if new_method != req.get_method():
new_data = None
remove_headers.extend(['Content-Length', 'Content-Type'])
new_headers = {k: v for k, v in req.headers.items() if k.lower() not in remove_headers}
return urllib.request.Request( return urllib.request.Request(
newurl, headers=newheaders, origin_req_host=req.origin_req_host, newurl, headers=new_headers, origin_req_host=req.origin_req_host,
unverifiable=True, method=m) unverifiable=True, method=new_method, data=new_data)
def extract_timezone(date_str): def extract_timezone(date_str):