handle rate limits

This commit is contained in:
Kay Faraday 2022-07-02 23:35:16 +00:00
parent b1a9b37c5e
commit abc05ba85e
3 changed files with 37 additions and 18 deletions

View File

@ -14,7 +14,7 @@ import contextlib
import qtoml as toml import qtoml as toml
from pleroma import Pleroma from pleroma import Pleroma
from functools import partial from functools import partial
from utils import suppress, loading_spinner, http_session_factory from utils import loading_spinner, http_session_factory, HandleRateLimits
JSON_CONTENT_TYPE = 'application/json' JSON_CONTENT_TYPE = 'application/json'
ACTIVITYPUB_CONTENT_TYPE = 'application/activity+json' ACTIVITYPUB_CONTENT_TYPE = 'application/activity+json'
@ -37,6 +37,7 @@ class PostMirror:
raise_for_status=True, raise_for_status=True,
), ),
) )
self._rl_handler = HandleRateLimits(self._http)
self._ctx_stack = stack self._ctx_stack = stack
return self return self
@ -59,7 +60,7 @@ class PostMirror:
cursor.hide() cursor.hide()
done = False done = False
while not done: while not done:
async with self._http.get(page_url) as resp: page = await resp.json() async with self._rl_handler.request('GET', page_url) as resp: page = await resp.json()
try: try:
page_url = page['next'] page_url = page['next']
except KeyError: except KeyError:

View File

@ -7,7 +7,7 @@ import hashlib
import aiohttp import aiohttp
from http import HTTPStatus from http import HTTPStatus
from multidict import MultiDict from multidict import MultiDict
from utils import http_session_factory from utils import http_session_factory, HandleRateLimits
class BadRequest(Exception): class BadRequest(Exception):
pass pass
@ -20,6 +20,7 @@ class Pleroma:
self.api_base_url = api_base_url.rstrip('/') self.api_base_url = api_base_url.rstrip('/')
self.access_token = access_token.strip() self.access_token = access_token.strip()
self._session = http_session_factory({'Authorization': 'Bearer ' + self.access_token}) self._session = http_session_factory({'Authorization': 'Bearer ' + self.access_token})
self._rl_handler = HandleRateLimits(self._session)
self._logged_in_id = None self._logged_in_id = None
async def __aenter__(self): async def __aenter__(self):
@ -41,7 +42,7 @@ class Pleroma:
}: }:
raise RuntimeError('stop being a chud') raise RuntimeError('stop being a chud')
async with self._session.request(method, self.api_base_url + path, **kwargs) as resp: async with self._rl_handler.request(method, self.api_base_url + path, **kwargs) as resp:
if resp.status == HTTPStatus.BAD_REQUEST: if resp.status == HTTPStatus.BAD_REQUEST:
raise BadRequest((await resp.json())['error']) raise BadRequest((await resp.json())['error'])
#resp.raise_for_status() #resp.raise_for_status()

View File

@ -5,7 +5,7 @@ import aiohttp
import platform import platform
import itertools import itertools
import contextlib import contextlib
from functools import wraps from datetime import datetime, timezone
def http_session_factory(headers={}, **kwargs): def http_session_factory(headers={}, **kwargs):
user_agent = ( user_agent = (
@ -18,21 +18,38 @@ def http_session_factory(headers={}, **kwargs):
**kwargs, **kwargs,
) )
def as_corofunc(f): async def sleep_until(dt):
@wraps(f) await anyio.sleep((dt - datetime.now(timezone.utc)).total_seconds())
async def wrapped(*args, **kwargs):
# can't decide if i want an `anyio.sleep(0)` here.
return f(*args, **kwargs)
return wrapped
def as_async_cm(cls): class HandleRateLimits:
@wraps(cls, updated=()) # cls.__dict__ doesn't support .update() def __init__(self, http):
class wrapped(cls, contextlib.AbstractAsyncContextManager): self.http = http
__aenter__ = as_corofunc(cls.__enter__)
__aexit__ = as_corofunc(cls.__exit__)
return wrapped
suppress = as_async_cm(contextlib.suppress) def request(self, *args, **kwargs):
return _RateLimitContextManager(self.http, args, kwargs)
class _RateLimitContextManager(contextlib.AbstractAsyncContextManager):
def __init__(self, http, args, kwargs):
self.http = http
self.args = args
self.kwargs = kwargs
async def __aenter__(self):
self._request_cm = self.http.request(*self.args, **self.kwargs)
return await self._do_enter()
async def _do_enter(self):
resp = await self._request_cm.__aenter__()
if resp.headers.get('X-RateLimit-Remaining') not in {'0', '1'}:
return resp
print('Hit rate limit for', self.args)
await sleep_until(datetime.fromisoformat(resp.headers['X-RateLimit-Reset']))
await self._request_cm.__aexit__(*(None,)*3)
return await self.__aenter__()
async def __aexit__(self, *excinfo):
return await self._request_cm.__aexit__(*excinfo)
def loading_spinner(): def loading_spinner():
return itertools.cycle('\b' + x for x in [ return itertools.cycle('\b' + x for x in [