Databases working, add no progress flag

This commit is contained in:
Nathan Thomas 2023-11-24 18:22:50 -08:00
parent 1964a0e488
commit 3e6284b04d
12 changed files with 197 additions and 112 deletions

View File

@ -7,6 +7,7 @@ from . import progress
from .artwork import download_artwork
from .client import Client
from .config import Config
from .db import Database
from .exceptions import NonStreamable
from .media import Media, Pending
from .metadata import AlbumMetadata
@ -23,6 +24,7 @@ class Album(Media):
config: Config
# folder where the tracks will be downloaded
folder: str
db: Database
async def preprocess(self):
progress.add_title(self.meta.album)
@ -45,6 +47,7 @@ class PendingAlbum(Pending):
id: str
client: Client
config: Config
db: Database
async def resolve(self) -> Album | None:
resp = await self.client.get_metadata(self.id, "album")
@ -75,12 +78,13 @@ class PendingAlbum(Pending):
client=self.client,
config=self.config,
folder=album_folder,
db=self.db,
cover_path=embed_cover,
)
for id in tracklist
]
logger.debug("Pending tracks: %s", pending_tracks)
return Album(meta, pending_tracks, self.config, album_folder)
return Album(meta, pending_tracks, self.config, album_folder, self.db)
def _album_folder(self, parent: str, meta: AlbumMetadata) -> str:
formatter = self.config.session.filepaths.folder_format

View File

@ -4,6 +4,7 @@ from .album import PendingAlbum
from .album_list import AlbumList
from .client import Client
from .config import Config
from .db import Database
from .media import Pending
from .metadata import ArtistMetadata
@ -17,12 +18,13 @@ class PendingArtist(Pending):
id: str
client: Client
config: Config
db: Database
async def resolve(self) -> Artist:
resp = await self.client.get_metadata(self.id, "artist")
meta = ArtistMetadata.from_resp(resp, self.client.source)
albums = [
PendingAlbum(album_id, self.client, self.config)
PendingAlbum(album_id, self.client, self.config, self.db)
for album_id in meta.album_ids()
]
return Artist(meta.name, albums, self.client, self.config)

View File

@ -6,7 +6,6 @@ import subprocess
from functools import wraps
import click
from click import secho
from click_help_colors import HelpColorsGroup # type: ignore
from rich.logging import RichHandler
from rich.prompt import Confirm
@ -15,19 +14,7 @@ from rich.traceback import install
from .config import Config, set_user_defaults
from .console import console
from .main import Main
from .user_paths import BLANK_CONFIG_PATH, CONFIG_PATH
def echo_i(msg, **kwargs):
secho(msg, fg="green", **kwargs)
def echo_w(msg, **kwargs):
secho(msg, fg="yellow", **kwargs)
def echo_e(msg, **kwargs):
secho(msg, fg="yellow", **kwargs)
from .user_paths import BLANK_CONFIG_PATH, DEFAULT_CONFIG_PATH
def coro(f):
@ -45,7 +32,7 @@ def coro(f):
)
@click.version_option(version="2.0")
@click.option(
"--config-path", default=CONFIG_PATH, help="Path to the configuration file"
"--config-path", default=DEFAULT_CONFIG_PATH, help="Path to the configuration file"
)
@click.option("-f", "--folder", help="The folder to download items into.")
@click.option(
@ -61,18 +48,20 @@ def coro(f):
"--convert",
help="Convert the downloaded files to an audio codec (ALAC, FLAC, MP3, AAC, or OGG)",
)
@click.option(
"--no-progress", help="Do not show progress bars", is_flag=True, default=False
)
@click.option(
"-v", "--verbose", help="Enable verbose output (debug mode)", is_flag=True
)
@click.pass_context
def rip(ctx, config_path, folder, no_db, quality, convert, verbose):
def rip(ctx, config_path, folder, no_db, quality, convert, no_progress, verbose):
"""
Streamrip: the all in one music downloader.
"""
print(ctx, config_path, folder, no_db, quality, convert, verbose)
global logger
logging.basicConfig(
level="WARNING", format="%(message)s", datefmt="[%X]", handlers=[RichHandler()]
level="INFO", format="%(message)s", datefmt="[%X]", handlers=[RichHandler()]
)
logger = logging.getLogger("streamrip")
if verbose:
@ -89,7 +78,9 @@ def rip(ctx, config_path, folder, no_db, quality, convert, verbose):
logger.setLevel(logging.WARNING)
if not os.path.isfile(config_path):
echo_i(f"No file found at {config_path}, creating default config.")
console.print(
f"No file found at [bold cyan]{config_path}[/bold cyan], creating default config."
)
shutil.copy(BLANK_CONFIG_PATH, config_path)
set_user_defaults(config_path)
@ -98,18 +89,24 @@ def rip(ctx, config_path, folder, no_db, quality, convert, verbose):
c = Config(config_path)
# set session config values to command line args
c.session.database.downloads_enabled = not no_db
if folder is not None:
c.session.downloads.folder = folder
c.session.database.downloads_enabled = not no_db
c.session.qobuz.quality = quality
c.session.tidal.quality = quality
c.session.deezer.quality = quality
c.session.soundcloud.quality = quality
if quality is not None:
c.session.qobuz.quality = quality
c.session.tidal.quality = quality
c.session.deezer.quality = quality
c.session.soundcloud.quality = quality
if convert is not None:
c.session.conversion.enabled = True
assert convert.upper() in ("ALAC", "FLAC", "OGG", "MP3", "AAC")
c.session.conversion.codec = convert.upper()
if no_progress:
c.session.cli.progress_bars = False
ctx.obj["config"] = c
@ -118,16 +115,10 @@ def rip(ctx, config_path, folder, no_db, quality, convert, verbose):
@click.pass_context
@coro
async def url(ctx, urls):
"""Download content from URLs.
Example usage:
rip url TODO: find url
"""
"""Download content from URLs."""
with ctx.obj["config"] as cfg:
main = Main(cfg)
for u in urls:
await main.add(u)
await main.add_all(urls)
await main.resolve()
await main.rip()
@ -146,8 +137,7 @@ async def file(ctx, path):
with ctx.obj["config"] as cfg:
main = Main(cfg)
with open(path) as f:
for url in f:
await main.add(url)
await main.add_all([line for line in f])
await main.resolve()
await main.rip()
@ -164,7 +154,7 @@ def config():
def config_open(ctx, vim):
"""Open the config file in a text editor."""
config_path = ctx.obj["config_path"]
echo_i(f"Opening file at {config_path}")
console.log(f"Opening file at [bold cyan]{config_path}")
if vim:
if shutil.which("nvim") is not None:
subprocess.run(["nvim", config_path])
@ -189,7 +179,7 @@ def config_reset(ctx, yes):
shutil.copy(BLANK_CONFIG_PATH, config_path)
set_user_defaults(config_path)
echo_i(f"Reset the config file at {config_path}!")
console.print(f"Reset the config file at [bold cyan]{config_path}!")
@rip.command()
@ -199,14 +189,15 @@ def config_reset(ctx, yes):
async def search(query, source):
"""
Search for content using a specific source.
"""
echo_i(f'Searching for "{query}" in source: {source}')
raise NotImplementedError
@rip.command()
@click.argument("url", required=True)
def lastfm(url):
pass
raise NotImplementedError
if __name__ == "__main__":

View File

@ -2,13 +2,13 @@
import copy
import logging
import os
from dataclasses import dataclass, fields
from tomlkit.api import dumps, parse
from tomlkit.toml_document import TOMLDocument
from .user_paths import (
DEFAULT_CONFIG_PATH,
DEFAULT_DOWNLOADS_DB_PATH,
DEFAULT_DOWNLOADS_FOLDER,
DEFAULT_FAILED_DOWNLOADS_DB_PATH,
@ -19,8 +19,6 @@ logger = logging.getLogger("streamrip")
CURRENT_CONFIG_VERSION = "2.0"
DEFAULT_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "config.toml")
@dataclass(slots=True)
class QobuzConfig:

View File

@ -4,13 +4,12 @@ import logging
import os
import sqlite3
from abc import ABC, abstractmethod
from dataclasses import dataclass
logger = logging.getLogger("streamrip")
# apologies to anyone reading this file
class Database(ABC):
class DatabaseInterface(ABC):
@abstractmethod
def create(self):
pass
@ -27,8 +26,31 @@ class Database(ABC):
def remove(self, kvs):
pass
@abstractmethod
def all(self) -> list:
pass
class DatabaseBase(Database):
class Dummy(DatabaseInterface):
"""This exists as a mock to use in case databases are disabled."""
def create(self):
pass
def contains(self, **_):
return False
def add(self, *_):
pass
def remove(self, *_):
pass
def all(self):
return []
class DatabaseBase(DatabaseInterface):
"""A wrapper for an sqlite database."""
structure: dict
@ -41,6 +63,7 @@ class DatabaseBase(Database):
"""
assert self.structure != {}
assert self.name
assert path
self.path = path
@ -124,10 +147,10 @@ class DatabaseBase(Database):
logger.debug(command)
conn.execute(command, tuple(items.values()))
def __iter__(self):
def all(self):
"""Iterate through the rows of the table."""
with sqlite3.connect(self.path) as conn:
return conn.execute(f"SELECT * FROM {self.name}")
return list(conn.execute(f"SELECT * FROM {self.name}"))
def reset(self):
"""Delete the database file."""
@ -137,20 +160,6 @@ class DatabaseBase(Database):
pass
class Dummy(Database):
def create(self):
pass
def contains(self):
return False
def add(self):
pass
def remove(self):
pass
class Downloads(DatabaseBase):
"""A table that stores the downloaded IDs."""
@ -160,7 +169,7 @@ class Downloads(DatabaseBase):
}
class FailedDownloads(DatabaseBase):
class Failed(DatabaseBase):
"""A table that stores information about failed downloads."""
name = "failed_downloads"
@ -169,3 +178,21 @@ class FailedDownloads(DatabaseBase):
"media_type": ["text"],
"id": ["text", "unique"],
}
@dataclass(slots=True)
class Database:
downloads: DatabaseInterface
failed: DatabaseInterface
def downloaded(self, item_id: str) -> bool:
return self.downloads.contains(id=item_id)
def set_downloaded(self, item_id: str):
self.downloads.add((item_id,))
def get_failed_downloads(self) -> list[tuple[str, str, str]]:
return self.failed.all()
def set_failed(self, source: str, media_type: str, id: str):
self.failed.add((source, media_type, id))

View File

@ -1,10 +1,10 @@
import asyncio
from dataclasses import dataclass
from .album import PendingAlbum
from .album_list import AlbumList
from .client import Client
from .config import Config
from .db import Database
from .media import Pending
from .metadata import LabelMetadata
@ -18,12 +18,13 @@ class PendingLabel(Pending):
id: str
client: Client
config: Config
db: Database
async def resolve(self) -> Label:
resp = await self.client.get_metadata(self.id, "label")
meta = LabelMetadata.from_resp(resp, self.client.source)
albums = [
PendingAlbum(album_id, self.client, self.config)
PendingAlbum(album_id, self.client, self.config, self.db)
for album_id in meta.album_ids()
]
return Label(meta.name, albums, self.client, self.config)

View File

@ -1,6 +1,7 @@
import asyncio
import logging
from . import db
from .artwork import remove_artwork_tempdirs
from .client import Client
from .config import Config
@ -26,9 +27,8 @@ class Main:
"""
def __init__(self, config: Config):
# Pipeline:
# input URL -> (URL) -> (Pending) -> (Media) -> (Downloadable)
# -> downloaded audio file
# Data pipeline:
# input URL -> (URL) -> (Pending) -> (Media) -> (Downloadable) -> audio file
self.pending: list[Pending] = []
self.media: list[Media] = []
self.config = config
@ -37,20 +37,55 @@ class Main:
# "tidal": TidalClient(config),
# "deezer": DeezerClient(config),
"soundcloud": SoundcloudClient(config),
# "deezloader": DeezloaderClient(config),
}
self.database: db.Database
c = self.config.session.database
if c.downloads_enabled:
downloads_db = db.Downloads(c.downloads_path)
else:
downloads_db = db.Dummy()
if c.failed_downloads_enabled:
failed_downloads_db = db.Failed(c.failed_downloads_path)
else:
failed_downloads_db = db.Dummy()
self.database = db.Database(downloads_db, failed_downloads_db)
async def add(self, url: str):
"""Add url as a pending item. Do not `asyncio.gather` calls to this!"""
"""Add url as a pending item.
Do not `asyncio.gather` calls to this! Use `add_all` for concurrency.
"""
parsed = parse_url(url)
if parsed is None:
raise Exception(f"Unable to parse url {url}")
client = await self.get_logged_in_client(parsed.source)
self.pending.append(await parsed.into_pending(client, self.config))
self.pending.append(
await parsed.into_pending(client, self.config, self.database)
)
logger.debug("Added url=%s", url)
async def add_all(self, urls: list[str]):
parsed = [parse_url(url) for url in urls]
url_w_client = [
(p, await self.get_logged_in_client(p.source))
for p in parsed
if p is not None
]
pendings = await asyncio.gather(
*[
url.into_pending(client, self.config, self.database)
for url, client in url_w_client
]
)
self.pending.extend(pendings)
async def get_logged_in_client(self, source: str):
"""Return a functioning client instance for `source`."""
client = self.clients[source]
if not client.logged_in:
prompter = get_prompter(client, self.config)
@ -81,5 +116,9 @@ class Main:
if hasattr(client, "session"):
await client.session.close()
# close global progress bar manager
clear_progress()
# We remove artwork tempdirs here because multiple singles
# may be able to share downloaded artwork in the same `rip` session
# We don't know that a cover will not be used again until end of execution
remove_artwork_tempdirs()

View File

@ -7,6 +7,7 @@ from . import progress
from .artwork import download_artwork
from .client import Client
from .config import Config
from .db import Database
from .filepath_utils import clean_filename
from .media import Media, Pending
from .metadata import AlbumMetadata, Covers, PlaylistMetadata, TrackMetadata
@ -23,8 +24,12 @@ class PendingPlaylistTrack(Pending):
folder: str
playlist_name: str
position: int
db: Database
async def resolve(self) -> Track | None:
if self.db.downloaded(self.id):
logger.info(f"Track ({self.id}) already logged in database. Skipping.")
return None
resp = await self.client.get_metadata(self.id, "track")
album = AlbumMetadata.from_track_resp(resp, self.client.source)
@ -33,6 +38,7 @@ class PendingPlaylistTrack(Pending):
logger.error(
f"Track ({self.id}) not available for stream on {self.client.source}"
)
self.db.set_failed(self.client.source, "track", self.id)
return None
c = self.config.session.metadata
@ -46,7 +52,9 @@ class PendingPlaylistTrack(Pending):
self._download_cover(album.covers, self.folder),
self.client.get_downloadable(self.id, quality),
)
return Track(meta, downloadable, self.config, self.folder, embedded_cover_path)
return Track(
meta, downloadable, self.config, self.folder, embedded_cover_path, self.db
)
async def _download_cover(self, covers: Covers, folder: str) -> str | None:
embed_path, _ = await download_artwork(
@ -90,6 +98,7 @@ class PendingPlaylist(Pending):
id: str
client: Client
config: Config
db: Database
async def resolve(self) -> Playlist | None:
resp = await self.client.get_metadata(self.id, "playlist")
@ -99,7 +108,7 @@ class PendingPlaylist(Pending):
folder = os.path.join(parent, clean_filename(name))
tracks = [
PendingPlaylistTrack(
id, self.client, self.config, folder, name, position + 1
id, self.client, self.config, folder, name, position + 1, self.db
)
for position, id in enumerate(meta.ids())
]

View File

@ -14,9 +14,11 @@ class ProgressManager:
def __init__(self):
self.started = False
self.progress = Progress(console=console)
self.prefix = Text.assemble(("Downloading ", "bold cyan"), overflow="ellipsis")
self.live = Live(Group(self.prefix, self.progress), refresh_per_second=10)
self.task_titles = []
self.prefix = Text.assemble(("Downloading ", "bold cyan"), overflow="ellipsis")
self.live = Live(
Group(self.get_title_text(), self.progress), refresh_per_second=10
)
def get_callback(self, total: int, desc: str):
if not self.started:

View File

@ -7,8 +7,8 @@ from . import converter
from .artwork import download_artwork
from .client import Client
from .config import Config
from .db import Database
from .downloadable import Downloadable
from .exceptions import NonStreamable
from .filepath_utils import clean_filename
from .media import Media, Pending
from .metadata import AlbumMetadata, Covers, TrackMetadata
@ -27,6 +27,7 @@ class Track(Media):
folder: str
# Is None if a cover doesn't exist for the track
cover_path: str | None
db: Database
# change?
download_path: str = ""
@ -45,15 +46,11 @@ class Track(Media):
await self.downloadable.download(self.download_path, callback)
async def postprocess(self):
await self._tag()
await tag_file(self.download_path, self.meta, self.cover_path)
if self.config.session.conversion.enabled:
await self._convert()
# if self.cover_path is not None:
# os.remove(self.cover_path)
async def _tag(self):
await tag_file(self.download_path, self.meta, self.cover_path)
self.db.set_downloaded(self.meta.info.id)
async def _convert(self):
c = self.config.session.conversion
@ -88,22 +85,30 @@ class PendingTrack(Pending):
client: Client
config: Config
folder: str
db: Database
# cover_path is None <==> Artwork for this track doesn't exist in API
cover_path: str | None
async def resolve(self) -> Track | None:
resp = await self.client.get_metadata(self.id, "track")
meta = TrackMetadata.from_resp(self.album, self.client.source, resp)
if meta is None:
logger.error(
f"Track {self.id} not available for stream on {self.client.source}"
if self.db.downloaded(self.id):
logger.info(
f"Skipping track {self.id}. Marked as downloaded in the database."
)
return None
quality = getattr(self.config.session, self.client.source).quality
assert isinstance(quality, int)
resp = await self.client.get_metadata(self.id, "track")
source = self.client.source
meta = TrackMetadata.from_resp(self.album, source, resp)
if meta is None:
logger.error(f"Track {self.id} not available for stream on {source}")
self.db.set_failed(source, "track", self.id)
return None
quality = self.config.session.get_source(source).quality
downloadable = await self.client.get_downloadable(self.id, quality)
return Track(meta, downloadable, self.config, self.folder, self.cover_path)
return Track(
meta, downloadable, self.config, self.folder, self.cover_path, self.db
)
@dataclass(slots=True)
@ -117,6 +122,7 @@ class PendingSingle(Pending):
id: str
client: Client
config: Config
db: Database
async def resolve(self) -> Track | None:
resp = await self.client.get_metadata(self.id, "track")
@ -126,6 +132,7 @@ class PendingSingle(Pending):
meta = TrackMetadata.from_resp(album, self.client.source, resp)
if meta is None:
self.db.set_failed(self.client.source, "track", self.id)
logger.error(f"Cannot stream track ({self.id}) on {self.client.source}")
return None
@ -140,7 +147,9 @@ class PendingSingle(Pending):
self._download_cover(album.covers, folder),
self.client.get_downloadable(self.id, quality),
)
return Track(meta, downloadable, self.config, folder, embedded_cover_path)
return Track(
meta, downloadable, self.config, folder, embedded_cover_path, self.db
)
def _format_folder(self, meta: AlbumMetadata) -> str:
c = self.config.session

View File

@ -7,6 +7,7 @@ from .album import PendingAlbum
from .artist import PendingArtist
from .client import Client
from .config import Config
from .db import Database
from .label import PendingLabel
from .media import Pending
from .playlist import PendingPlaylist
@ -23,7 +24,6 @@ from .validation_regexps import (
class URL(ABC):
match: re.Match
source: str
def __init__(self, match: re.Match, source: str):
@ -35,7 +35,9 @@ class URL(ABC):
raise NotImplementedError
@abstractmethod
async def into_pending(self, client: Client, config: Config) -> Pending:
async def into_pending(
self, client: Client, config: Config, db: Database
) -> Pending:
raise NotImplementedError
@ -48,20 +50,22 @@ class GenericURL(URL):
source = generic_url.group(1)
return cls(generic_url, source)
async def into_pending(self, client: Client, config: Config) -> Pending:
async def into_pending(
self, client: Client, config: Config, db: Database
) -> Pending:
source, media_type, item_id = self.match.groups()
assert client.source == source
if media_type == "track":
return PendingSingle(item_id, client, config)
return PendingSingle(item_id, client, config, db)
elif media_type == "album":
return PendingAlbum(item_id, client, config)
return PendingAlbum(item_id, client, config, db)
elif media_type == "playlist":
return PendingPlaylist(item_id, client, config)
return PendingPlaylist(item_id, client, config, db)
elif media_type == "artist":
return PendingArtist(item_id, client, config)
return PendingArtist(item_id, client, config, db)
elif media_type == "label":
return PendingLabel(item_id, client, config)
return PendingLabel(item_id, client, config, db)
else:
raise NotImplementedError
@ -76,10 +80,12 @@ class QobuzInterpreterURL(URL):
return None
return cls(qobuz_interpreter_url, "qobuz")
async def into_pending(self, client: Client, config: Config) -> Pending:
async def into_pending(
self, client: Client, config: Config, db: Database
) -> Pending:
url = self.match.group(0)
artist_id = await self.extract_interpreter_url(url, client)
return PendingArtist(artist_id, client, config)
return PendingArtist(artist_id, client, config, db)
@staticmethod
async def extract_interpreter_url(url: str, client: Client) -> str:
@ -113,14 +119,16 @@ class SoundcloudURL(URL):
def __init__(self, url: str):
self.url = url
async def into_pending(self, client: SoundcloudClient, config: Config) -> Pending:
async def into_pending(
self, client: SoundcloudClient, config: Config, db: Database
) -> Pending:
resolved = await client._resolve_url(self.url)
media_type = resolved["kind"]
item_id = str(resolved["id"])
if media_type == "track":
return PendingSingle(item_id, client, config)
return PendingSingle(item_id, client, config, db)
elif media_type == "playlist":
return PendingPlaylist(item_id, client, config)
return PendingPlaylist(item_id, client, config, db)
else:
raise NotImplementedError(media_type)

View File

@ -8,10 +8,7 @@ APP_DIR = user_config_dir(APPNAME)
HOME = Path.home()
LOG_DIR = CACHE_DIR = CONFIG_DIR = APP_DIR
CONFIG_PATH = os.path.join(CONFIG_DIR, "config.toml")
DB_PATH = os.path.join(LOG_DIR, "downloads.db")
FAILED_DB_PATH = os.path.join(LOG_DIR, "failed_downloads.db")
DEFAULT_CONFIG_PATH = os.path.join(CONFIG_DIR, "config.toml")
DOWNLOADS_DIR = os.path.join(HOME, "StreamripDownloads")
# file shipped with script
@ -20,6 +17,4 @@ BLANK_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "config.toml")
DEFAULT_DOWNLOADS_FOLDER = os.path.join(HOME, "StreamripDownloads")
DEFAULT_DOWNLOADS_DB_PATH = os.path.join(LOG_DIR, "downloads.db")
DEFAULT_FAILED_DOWNLOADS_DB_PATH = os.path.join(LOG_DIR, "failed_downloads.db")
DEFAULT_YOUTUBE_VIDEO_DOWNLOADS_FOLDER = os.path.join(
HOME, "StreamripDownloads", "YouTubeVideos"
)
DEFAULT_YOUTUBE_VIDEO_DOWNLOADS_FOLDER = os.path.join(DOWNLOADS_DIR, "YouTubeVideos")