pleroma-ebooks/utils.py

82 lines
2.2 KiB
Python
Raw Normal View History

# SPDX-License-Identifier: AGPL-3.0-only
import anyio
2022-06-18 19:02:32 -07:00
import aiohttp
import platform
2021-09-16 23:34:44 -07:00
import contextlib
2021-07-25 23:29:20 -07:00
from functools import wraps
2021-09-16 23:34:44 -07:00
from datetime import datetime, timezone
2022-06-18 19:02:32 -07:00
def http_session_factory(headers={}, **kwargs):
user_agent = (
'pleroma-ebooks (https://lab.freak.university/KayFaraday/pleroma-ebooks); '
f'aiohttp/{aiohttp.__version__}; '
f'{platform.python_implementation()}/{platform.python_version()}'
)
return aiohttp.ClientSession(
headers={'User-Agent': user_agent, **headers},
**kwargs,
)
2021-09-16 23:34:44 -07:00
def as_corofunc(f):
@wraps(f)
async def wrapped(*args, **kwargs):
# can't decide if i want an `anyio.sleep(0)` here.
return f(*args, **kwargs)
return wrapped
def as_async_cm(cls):
@wraps(cls, updated=()) # cls.__dict__ doesn't support .update()
class wrapped(cls, contextlib.AbstractAsyncContextManager):
__aenter__ = as_corofunc(cls.__enter__)
__aexit__ = as_corofunc(cls.__exit__)
return wrapped
suppress = as_async_cm(contextlib.suppress)
def shield(f):
2021-07-25 23:29:20 -07:00
@wraps(f)
async def shielded(*args, **kwargs):
2021-07-25 23:29:20 -07:00
with anyio.CancelScope(shield=True):
return await f(*args, **kwargs)
return shielded
def removeprefix(s, prefix):
try:
return s.removeprefix(prefix)
except AttributeError:
# compatibility for pre-3.9
return s[len(prefix):] if s.startswith(prefix) else s
2021-09-16 23:35:54 -07:00
async def sleep_until(dt):
await anyio.sleep((dt - datetime.now(timezone.utc)).total_seconds())
class HandleRateLimits:
def __init__(self, http):
self.http = http
def request(self, *args, **kwargs):
return _RateLimitContextManager(self.http, args, kwargs)
class _RateLimitContextManager(contextlib.AbstractAsyncContextManager):
def __init__(self, http, args, kwargs):
self.http = http
self.args = args
self.kwargs = kwargs
async def __aenter__(self):
self._request_cm = self.http.request(*self.args, **self.kwargs)
return await self._do_enter()
async def _do_enter(self):
resp = await self._request_cm.__aenter__()
if resp.headers.get('X-RateLimit-Remaining') not in {'0', '1'}:
return resp
await sleep_until(datetime.fromisoformat(resp.headers['X-RateLimit-Reset']))
await self._request_cm.__aexit__(*(None,)*3)
return await self.__aenter__()
async def __aexit__(self, *excinfo):
return await self._request_cm.__aexit__(*excinfo)