Create rate limited requests session

This commit is contained in:
Nathan Thomas 2023-06-20 20:35:31 -07:00
parent fd353d57cc
commit ef34756046
2 changed files with 66 additions and 14 deletions

View File

@ -40,7 +40,7 @@ from .exceptions import (
NonStreamable,
)
from .spoofbuz import Spoofer
from .utils import gen_threadsafe_session, get_quality, safe_get
from .utils import SRSession, gen_threadsafe_session, get_quality, safe_get
logger = logging.getLogger("streamrip")
@ -134,7 +134,7 @@ class QobuzClient(Client):
str(kwargs["app_id"]),
kwargs["secrets"],
)
self.session = gen_threadsafe_session(
self.session = SRSession(
headers={"User-Agent": AGENT, "X-App-Id": self.app_id}
)
self._validate_secrets()
@ -223,7 +223,7 @@ class QobuzClient(Client):
if not hasattr(self, "sec"):
if not hasattr(self, "session"):
self.session = gen_threadsafe_session(
self.session = SRSession(
headers={"User-Agent": AGENT, "X-App-Id": self.app_id}
)
self._validate_secrets()
@ -343,7 +343,9 @@ class QobuzClient(Client):
return self._gen_pages(epoint, params)
def _api_login(self, use_auth_token: bool, email_or_userid: str, password_or_token: str):
def _api_login(
self, use_auth_token: bool, email_or_userid: str, password_or_token: str
):
"""Log into the api to get the user authentication token.
:param use_auth_token:
@ -380,7 +382,7 @@ class QobuzClient(Client):
raise IneligibleError("Free accounts are not eligible to download tracks.")
self.uat = resp["user_auth_token"]
self.session.headers.update({"X-User-Auth-Token": self.uat})
self.session.update_headers({"X-User-Auth-Token": self.uat})
self.label = resp["user"]["credential"]["parameters"]["short_label"]
def _api_get_file_url(
@ -472,7 +474,6 @@ class DeezerClient(Client):
def __init__(self):
"""Create a DeezerClient."""
self.client = deezer.Deezer()
# self.session = gen_threadsafe_session()
# no login required
self.logged_in = False
@ -645,7 +646,7 @@ class DeezloaderClient(Client):
def __init__(self):
"""Create a DeezloaderClient."""
self.session = gen_threadsafe_session()
self.session = SRSession()
# no login required
self.logged_in = True
@ -735,7 +736,7 @@ class TidalClient(Client):
self.refresh_token = None
self.expiry = None
self.session = gen_threadsafe_session()
self.session = SRSession()
def login(
self,
@ -994,7 +995,7 @@ class TidalClient(Client):
def _update_authorization(self):
"""Update the requests session headers with the auth token."""
self.session.headers.update(self.authorization)
self.session.update_headers(self.authorization)
@property
def authorization(self):
@ -1094,8 +1095,7 @@ class TidalClient(Client):
:param data:
:param auth:
"""
r = self.session.post(url, data=data, auth=auth, verify=False).json()
return r
return self.session.post(url, data=data, auth=auth, verify=False).json()
class SoundCloudClient(Client):
@ -1110,7 +1110,7 @@ class SoundCloudClient(Client):
def __init__(self):
"""Create a SoundCloudClient."""
self.session = gen_threadsafe_session(
self.session = SRSession(
headers={
"User-Agent": AGENT,
}

View File

@ -9,6 +9,8 @@ import os
import shutil
import subprocess
import tempfile
import time
from multiprocessing import Lock
from string import Formatter
from typing import Dict, Hashable, Iterator, List, Optional, Tuple, Union
@ -307,6 +309,56 @@ def ext(quality: int, source: str):
return ".flac"
class SRSession:
# requests per minute
PERIOD = 60.0
def __init__(
self,
headers: Optional[dict] = None,
pool_connections: int = 100,
pool_maxsize: int = 100,
requests_per_min: Optional[int] = None,
):
if headers is None:
headers = {}
self.session = requests.Session()
adapter = requests.adapters.HTTPAdapter(pool_connections, pool_maxsize)
self.session.mount("https://", adapter)
self.session.headers.update(headers)
self.has_rate_limit = requests_per_min is not None
self.rpm = requests_per_min
self.last_minute: float = time.time()
self.call_no: int = 0
self.rate_limit_lock = Lock() if self.has_rate_limit else None
def get(self, *args, **kwargs):
if self.has_rate_limit: # only use locks if there is a rate limit
assert self.rate_limit_lock is not None
assert self.rpm is not None
with self.rate_limit_lock:
now = time.time()
if self.call_no >= self.rpm:
if now - self.last_minute < SRSession.PERIOD:
time.sleep(SRSession.PERIOD - (now - self.last_minute))
self.last_minute = time.time()
self.call_no = 0
self.call_no += 1
return self.session.get(*args, **kwargs)
def update_headers(self, headers: dict):
self.session.headers.update(headers)
# No rate limit on post
def post(self, *args, **kwargs) -> requests.Response:
self.session.post(*args, **kwargs)
def gen_threadsafe_session(
headers: dict = None, pool_connections: int = 100, pool_maxsize: int = 100
) -> requests.Session:
@ -324,7 +376,7 @@ def gen_threadsafe_session(
headers = {}
session = requests.Session()
adapter = requests.adapters.HTTPAdapter(pool_connections=100, pool_maxsize=100)
adapter = requests.adapters.HTTPAdapter(pool_connections, pool_maxsize)
session.mount("https://", adapter)
session.headers.update(headers)
return session
@ -373,7 +425,7 @@ def get_cover_urls(resp: dict, source: str) -> Optional[dict]:
if source == "qobuz":
cover_urls = resp["image"]
cover_urls["original"] = "org".join(cover_urls["large"].rsplit('600', 1))
cover_urls["original"] = "org".join(cover_urls["large"].rsplit("600", 1))
return cover_urls
if source == "tidal":