Create custom async downloader for HLS streams

This commit is contained in:
Nathan Thomas 2021-09-11 10:49:27 -07:00
parent 5b2aaf5ad2
commit 9f5cd49aab
7 changed files with 364 additions and 184 deletions

1
.gitignore vendored
View File

@ -19,3 +19,4 @@ StreamripDownloads
*.pyc
*test.py
/.mypy_cache
.DS_Store

View File

@ -42,3 +42,12 @@ ignore_missing_imports = True
[mypy-appdirs.*]
ignore_missing_imports = True
[mypy-m3u8.*]
ignore_missing_imports = True
[mypy-aiohttp.*]
ignore_missing_imports = True
[mypy-aiofiles.*]
ignore_missing_imports = True

View File

@ -2,4 +2,4 @@
__version__ = "1.4"
from . import clients, constants, converter, media
from . import clients, constants, converter, downloadtools, media

241
streamrip/downloadtools.py Normal file
View File

@ -0,0 +1,241 @@
import asyncio
import functools
import hashlib
import logging
import os
import re
from tempfile import gettempdir
from typing import Callable, Dict, Generator, Iterator, List, Optional
import aiofiles
import aiohttp
from Cryptodome.Cipher import Blowfish
from .exceptions import NonStreamable
from .utils import gen_threadsafe_session
logger = logging.getLogger("streamrip")
class DownloadStream:
"""An iterator over chunks of a stream.
Usage:
>>> stream = DownloadStream('https://google.com', None)
>>> with open('google.html', 'wb') as file:
>>> for chunk in stream:
>>> file.write(chunk)
"""
is_encrypted = re.compile("/m(?:obile|edia)/")
def __init__(
self,
url: str,
source: str = None,
params: dict = None,
headers: dict = None,
item_id: str = None,
):
"""Create an iterable DownloadStream of a URL.
:param url: The url to download
:type url: str
:param source: Only applicable for Deezer
:type source: str
:param params: Parameters to pass in the request
:type params: dict
:param headers: Headers to pass in the request
:type headers: dict
:param item_id: (Only for Deezer) the ID of the track
:type item_id: str
"""
self.source = source
self.session = gen_threadsafe_session(headers=headers)
self.id = item_id
if isinstance(self.id, int):
self.id = str(self.id)
if params is None:
params = {}
self.request = self.session.get(
url, allow_redirects=True, stream=True, params=params
)
self.file_size = int(self.request.headers.get("Content-Length", 0))
if self.file_size < 20000 and not self.url.endswith(".jpg"):
import json
try:
info = self.request.json()
try:
# Usually happens with deezloader downloads
raise NonStreamable(
f"{info['error']} -- {info['message']}"
)
except KeyError:
raise NonStreamable(info)
except json.JSONDecodeError:
raise NonStreamable("File not found.")
def __iter__(self) -> Iterator:
"""Iterate through chunks of the stream.
:rtype: Iterator
"""
if (
self.source == "deezer"
and self.is_encrypted.search(self.url) is not None
):
assert isinstance(self.id, str), self.id
blowfish_key = self._generate_blowfish_key(self.id)
# decryptor = self._create_deezer_decryptor(blowfish_key)
CHUNK_SIZE = 2048 * 3
return (
# (decryptor.decrypt(chunk[:2048]) + chunk[2048:])
(
self._decrypt_chunk(blowfish_key, chunk[:2048])
+ chunk[2048:]
)
if len(chunk) >= 2048
else chunk
for chunk in self.request.iter_content(CHUNK_SIZE)
)
return self.request.iter_content(chunk_size=1024)
@property
def url(self):
"""Return the requested url."""
return self.request.url
def __len__(self) -> int:
"""Return the value of the "Content-Length" header.
:rtype: int
"""
return self.file_size
def _create_deezer_decryptor(self, key) -> Blowfish:
return Blowfish.new(
key, Blowfish.MODE_CBC, b"\x00\x01\x02\x03\x04\x05\x06\x07"
)
@staticmethod
def _generate_blowfish_key(track_id: str):
"""Generate the blowfish key for Deezer downloads.
:param track_id:
:type track_id: str
"""
SECRET = "g4el58wc0zvf9na1"
md5_hash = hashlib.md5(track_id.encode()).hexdigest()
# good luck :)
return "".join(
chr(functools.reduce(lambda x, y: x ^ y, map(ord, t)))
for t in zip(md5_hash[:16], md5_hash[16:], SECRET)
).encode()
@staticmethod
def _decrypt_chunk(key, data):
"""Decrypt a chunk of a Deezer stream.
:param key:
:param data:
"""
return Blowfish.new(
key,
Blowfish.MODE_CBC,
b"\x00\x01\x02\x03\x04\x05\x06\x07",
).decrypt(data)
class DownloadPool:
"""Asynchronously download a set of urls."""
def __init__(
self,
urls: Generator,
tempdir: str = None,
chunk_callback: Optional[Callable] = None,
):
self.finished: bool = False
# Enumerate urls to know the order
self.urls = dict(enumerate(urls))
self._downloaded_urls: List[str] = []
# {url: path}
self._paths: Dict[str, str] = {}
self.task: Optional[asyncio.Task] = None
self.callack = chunk_callback
if tempdir is None:
tempdir = gettempdir()
self.tempdir = tempdir
async def getfn(self, url):
path = os.path.join(
self.tempdir, f"__streamrip_partial_{abs(hash(url))}"
)
self._paths[url] = path
return path
async def _download_urls(self):
async with aiohttp.ClientSession() as session:
tasks = [
asyncio.ensure_future(self._download_url(session, url))
for url in self.urls.values()
]
await asyncio.gather(*tasks)
async def _download_url(self, session, url):
filename = await self.getfn(url)
logger.debug("Downloading %s", url)
async with session.get(url) as response, aiofiles.open(
filename, "wb"
) as f:
# without aiofiles 3.6632679780000004s
# with aiofiles 2.504482839s
await f.write(await response.content.read())
if self.callback:
self.callback()
logger.debug("Finished %s", url)
def download(self):
asyncio.run(self._download_urls())
@property
def files(self):
if len(self._paths) != len(self.urls):
# Not all of them have downloaded
raise Exception(
"Must run DownloadPool.download() before accessing files"
)
return [
os.path.join(self.tempdir, self._paths[self.urls[i]])
for i in range(len(self.urls))
]
def __len__(self):
return len(self.urls)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
logger.debug("Removing tempfiles %s", self._paths)
for file in self._paths.values():
try:
os.remove(file)
except FileNotFoundError:
pass
return False

View File

@ -39,6 +39,7 @@ from .constants import (
FOLDER_FORMAT,
TRACK_FORMAT,
)
from .downloadtools import DownloadPool, DownloadStream
from .exceptions import (
InvalidQuality,
InvalidSourceError,
@ -49,7 +50,6 @@ from .exceptions import (
)
from .metadata import TrackMetadata
from .utils import (
DownloadStream,
clean_filename,
clean_format,
decrypt_mqa_file,
@ -450,21 +450,44 @@ class Track(Media):
"""
logger.debug("dl_info: %s", dl_info)
if dl_info["type"] == "mp3":
import m3u8
import requests
parsed_m3u = m3u8.loads(requests.get(dl_info["url"]).text)
self.path += ".mp3"
# convert hls stream to mp3
subprocess.call(
[
"ffmpeg",
"-i",
dl_info["url"],
"-c",
"copy",
"-y",
self.path,
"-loglevel",
"fatal",
]
)
with DownloadPool(
segment.uri for segment in parsed_m3u.segments
) as pool:
pool.download()
subprocess.call(
[
"ffmpeg",
"-i",
f"concat:{'|'.join(pool.files)}",
"-acodec",
"copy",
"-loglevel",
"panic",
self.path,
]
)
# self.path += ".mp3"
# # convert hls stream to mp3
# subprocess.call(
# [
# "ffmpeg",
# "-i",
# dl_info["url"],
# "-c",
# "copy",
# "-y",
# self.path,
# "-loglevel",
# "fatal",
# ]
# )
elif dl_info["type"] == "original":
_quick_download(
dl_info["url"], self.path, desc=self._progress_desc
@ -857,6 +880,9 @@ class Video(Media):
:param kwargs:
"""
import m3u8
import requests
secho(
f"Downloading {self.title} (Video). This may take a while.",
fg="blue",
@ -864,19 +890,41 @@ class Video(Media):
self.parent_folder = kwargs.get("parent_folder", "StreamripDownloads")
url = self.client.get_file_url(self.id, video=True)
# it's more convenient to have ffmpeg download the hls
command = [
"ffmpeg",
"-i",
url,
"-c",
"copy",
"-loglevel",
"panic",
self.path,
]
p = subprocess.Popen(command)
p.wait() # remove this?
parsed_m3u = m3u8.loads(requests.get(url).text)
# Asynchronously download the streams
with DownloadPool(
segment.uri for segment in parsed_m3u.segments
) as pool:
pool.download()
# Put the filenames in a tempfile that ffmpeg
# can read from
file_list_path = os.path.join(
gettempdir(), "__streamrip_video_files"
)
with open(file_list_path, "w") as file_list:
text = "\n".join(f"file '{path}'" for path in pool.files)
file_list.write(text)
# Use ffmpeg to concat the files
p = subprocess.Popen(
[
"ffmpeg",
"-f",
"concat",
"-safe",
"0",
"-i",
file_list_path,
"-c",
"copy",
self.path,
]
)
p.wait()
os.remove(file_list_path)
def tag(self, *args, **kwargs):
"""Return False.
@ -1396,12 +1444,11 @@ class Tracklist(list):
class Album(Tracklist, Media):
"""Represents a downloadable album.
Usage:
>>> resp = client.get('fleetwood mac rumours', 'album')
>>> album = Album.from_api(resp['items'][0], client)
>>> album.load_meta()
>>> album.download()
"""
downloaded_ids: set = set()

View File

@ -3,166 +3,23 @@
from __future__ import annotations
import base64
import functools
import hashlib
import logging
import re
from string import Formatter
from typing import Dict, Hashable, Iterator, Optional, Tuple, Union
import requests
from click import secho, style
from Cryptodome.Cipher import Blowfish
from pathvalidate import sanitize_filename
from requests.packages import urllib3
from tqdm import tqdm
from .constants import COVER_SIZES, TIDAL_COVER_URL
from .exceptions import InvalidQuality, InvalidSourceError, NonStreamable
from .exceptions import InvalidQuality, InvalidSourceError
urllib3.disable_warnings()
logger = logging.getLogger("streamrip")
class DownloadStream:
"""An iterator over chunks of a stream.
Usage:
>>> stream = DownloadStream('https://google.com', None)
>>> with open('google.html', 'wb') as file:
>>> for chunk in stream:
>>> file.write(chunk)
"""
is_encrypted = re.compile("/m(?:obile|edia)/")
def __init__(
self,
url: str,
source: str = None,
params: dict = None,
headers: dict = None,
item_id: str = None,
):
"""Create an iterable DownloadStream of a URL.
:param url: The url to download
:type url: str
:param source: Only applicable for Deezer
:type source: str
:param params: Parameters to pass in the request
:type params: dict
:param headers: Headers to pass in the request
:type headers: dict
:param item_id: (Only for Deezer) the ID of the track
:type item_id: str
"""
self.source = source
self.session = gen_threadsafe_session(headers=headers)
self.id = item_id
if isinstance(self.id, int):
self.id = str(self.id)
if params is None:
params = {}
self.request = self.session.get(
url, allow_redirects=True, stream=True, params=params
)
self.file_size = int(self.request.headers.get("Content-Length", 0))
if self.file_size < 20000 and not self.url.endswith(".jpg"):
import json
try:
info = self.request.json()
try:
# Usually happens with deezloader downloads
raise NonStreamable(
f"{info['error']} -- {info['message']}"
)
except KeyError:
raise NonStreamable(info)
except json.JSONDecodeError:
raise NonStreamable("File not found.")
def __iter__(self) -> Iterator:
"""Iterate through chunks of the stream.
:rtype: Iterator
"""
if (
self.source == "deezer"
and self.is_encrypted.search(self.url) is not None
):
assert isinstance(self.id, str), self.id
blowfish_key = self._generate_blowfish_key(self.id)
# decryptor = self._create_deezer_decryptor(blowfish_key)
CHUNK_SIZE = 2048 * 3
return (
# (decryptor.decrypt(chunk[:2048]) + chunk[2048:])
(
self._decrypt_chunk(blowfish_key, chunk[:2048])
+ chunk[2048:]
)
if len(chunk) >= 2048
else chunk
for chunk in self.request.iter_content(CHUNK_SIZE)
)
return self.request.iter_content(chunk_size=1024)
@property
def url(self):
"""Return the requested url."""
return self.request.url
def __len__(self) -> int:
"""Return the value of the "Content-Length" header.
:rtype: int
"""
return self.file_size
def _create_deezer_decryptor(self, key) -> Blowfish:
return Blowfish.new(
key, Blowfish.MODE_CBC, b"\x00\x01\x02\x03\x04\x05\x06\x07"
)
@staticmethod
def _generate_blowfish_key(track_id: str):
"""Generate the blowfish key for Deezer downloads.
:param track_id:
:type track_id: str
"""
SECRET = "g4el58wc0zvf9na1"
md5_hash = hashlib.md5(track_id.encode()).hexdigest()
# good luck :)
return "".join(
chr(functools.reduce(lambda x, y: x ^ y, map(ord, t)))
for t in zip(md5_hash[:16], md5_hash[16:], SECRET)
).encode()
@staticmethod
def _decrypt_chunk(key, data):
"""Decrypt a chunk of a Deezer stream.
:param key:
:param data:
"""
return Blowfish.new(
key,
Blowfish.MODE_CBC,
b"\x00\x01\x02\x03\x04\x05\x06\x07",
).decrypt(data)
def safe_get(d: dict, *keys: Hashable, default=None):
"""Traverse dict layers safely.
@ -567,9 +424,7 @@ def set_progress_bar_theme(theme: str):
TQDM_BAR_FORMAT = TQDM_THEMES[theme]
def tqdm_stream(
iterator: DownloadStream, desc: Optional[str] = None
) -> Iterator[bytes]:
def tqdm_stream(iterator, desc: Optional[str] = None) -> Iterator[bytes]:
"""Return a tqdm bar with presets appropriate for downloading large files.
:param iterator:
@ -578,15 +433,19 @@ def tqdm_stream(
:type desc: Optional[str]
:rtype: Iterator
"""
with tqdm(
total=len(iterator),
with get_tqdm_bar(len(iterator), desc=desc) as bar:
for chunk in iterator:
bar.update(len(chunk))
yield chunk
def get_tqdm_bar(total, desc: Optional[str] = None):
return tqdm(
total=total,
unit="B",
unit_scale=True,
unit_divisor=1024,
desc=desc,
dynamic_ncols=True,
bar_format=TQDM_BAR_FORMAT,
) as bar:
for chunk in iterator:
bar.update(len(chunk))
yield chunk
)

23
tests/test_download.py Normal file
View File

@ -0,0 +1,23 @@
import os
import time
from pprint import pprint
from streamrip.downloadtools import DownloadPool
def test_downloadpool(tmpdir):
start = time.perf_counter()
with DownloadPool(
(
f"https://pokeapi.co/api/v2/pokemon/{number}"
for number in range(1, 151)
),
tempdir=tmpdir,
) as pool:
pool.download()
assert len(os.listdir(tmpdir)) == 151
# the tempfiles should be removed at this point
assert len(os.listdir(tmpdir)) == 0
print(f"Finished in {time.perf_counter() - start}s")