Compare commits

..

3 Commits

Author SHA1 Message Date
io 290bb33538 use external pleroma.py 2023-01-11 06:30:47 +00:00
lucdev d932256125 Rate limiter: fix datetime parser (fixes #5) 2023-01-11 06:28:57 +00:00
Joel Beckmeyer e2a18f8888 fix some pleroma errors with async and 500 errors (#4)
* fix some pleroma errors with async and 500 errors

* add better recovery/handling of HTTP 500

* remove unnecessary else
2023-01-11 06:28:54 +00:00
5 changed files with 21 additions and 202 deletions

View File

@ -9,11 +9,11 @@ import operator
import aiosqlite import aiosqlite
import contextlib import contextlib
from yarl import URL from yarl import URL
from pleroma import Pleroma from pleroma import Pleroma, HandleRateLimits
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from functools import partial from functools import partial
from typing import Iterable, NewType from typing import Iterable, NewType
from utils import shield, HandleRateLimits, suppress, http_session_factory from utils import shield, suppress, http_session_factory
from third_party.utils import extract_post_content from third_party.utils import extract_post_content
UTC = pendulum.timezone('UTC') UTC = pendulum.timezone('UTC')
@ -34,9 +34,7 @@ class PostFetcher:
) )
self._http = await stack.enter_async_context( self._http = await stack.enter_async_context(
http_session_factory( http_session_factory(
headers={ headers={'Accept': ', '.join([JSON_CONTENT_TYPE, ACTIVITYPUB_CONTENT_TYPE])},
'Accept': ', '.join([JSON_CONTENT_TYPE, ACTIVITYPUB_CONTENT_TYPE]),
},
trust_env=True, trust_env=True,
raise_for_status=True, raise_for_status=True,
), ),

View File

@ -1,160 +0,0 @@
# SPDX-License-Identifier: AGPL-3.0-only
import sys
import yarl
import json
import hashlib
import aiohttp
from http import HTTPStatus
from utils import http_session_factory
class BadRequest(Exception):
pass
class LoginFailed(Exception):
pass
class Pleroma:
def __init__(self, *, api_base_url, access_token):
self.api_base_url = api_base_url.rstrip('/')
self.access_token = access_token.strip()
self._session = http_session_factory({'Authorization': 'Bearer ' + self.access_token})
self._logged_in_id = None
async def __aenter__(self):
self._session = await self._session.__aenter__()
return self
async def __aexit__(self, *excinfo):
return await self._session.__aexit__(*excinfo)
async def request(self, method, path, **kwargs):
# blocklist of some horrible instances
if hashlib.sha256(
yarl.URL(self.api_base_url).host.encode()
+ bytes.fromhex('d590e3c48d599db6776e89dfc8ebaf53c8cd84866a76305049d8d8c5d4126ce1')
).hexdigest() in {
'56704d4d95b882e81c8e7765e9079be0afc4e353925ba9add8fd65976f52db83',
'1932431fa41a0baaccce7815115b01e40e0237035bb155713712075b887f5a19',
'a42191105a9f3514a1d5131969c07a95e06d0fdf0058f18e478823bf299881c9',
}:
raise RuntimeError('stop being a chud')
async with self._session.request(method, self.api_base_url + path, **kwargs) as resp:
if resp.status == HTTPStatus.BAD_REQUEST:
raise BadRequest((await resp.json())['error'])
#resp.raise_for_status()
return await resp.json()
async def verify_credentials(self):
return await self.request('GET', '/api/v1/accounts/verify_credentials')
me = verify_credentials
async def _get_logged_in_id(self):
if self._logged_in_id is not None:
return self._logged_in_id
me = await self.me()
try:
self._logged_in_id = me['id']
except KeyError:
raise LoginFailed(me)
return self._logged_in_id
async def following(self, account_id=None):
account_id = account_id or await self._get_logged_in_id()
return await self.request('GET', f'/api/v1/accounts/{account_id}/following')
@staticmethod
def _unpack_id(obj):
if isinstance(obj, dict) and 'id' in obj:
return obj['id']
return obj
async def status_context(self, id):
id = self._unpack_id(id)
return await self.request('GET', f'/api/v1/statuses/{id}/context')
async def post(self, content, *, in_reply_to_id=None, cw=None, visibility=None):
if visibility not in {None, 'private', 'public', 'unlisted', 'direct'}:
raise ValueError('invalid visibility', visibility)
data = dict(status=content)
if in_reply_to_id := self._unpack_id(in_reply_to_id):
data['in_reply_to_id'] = in_reply_to_id
if visibility is not None:
data['visibility'] = visibility
# normally, this would be a check against None.
# however, apparently Pleroma serializes posts without CWs as posts with an empty string
# as a CW, so per the robustness principle we'll accept that too.
if cw:
data['spoiler_text'] = cw
return await self.request('POST', '/api/v1/statuses', data=data)
async def reply(self, to_status, content, *, cw=None):
user_id = await self._get_logged_in_id()
mentioned_accounts = {}
mentioned_accounts[to_status['account']['id']] = to_status['account']['acct']
for account in to_status['mentions']:
if account['id'] != user_id and account['id'] not in mentioned_accounts:
mentioned_accounts[account['id']] = account['acct']
content = ''.join('@' + x + ' ' for x in mentioned_accounts.values()) + content
visibility = 'unlisted' if to_status['visibility'] == 'public' else to_status['visibility']
if not cw and 'spoiler_text' in to_status and to_status['spoiler_text']:
cw = 're: ' + to_status['spoiler_text']
return await self.post(content, in_reply_to_id=to_status['id'], cw=cw, visibility=visibility)
async def favorite(self, id):
id = self._unpack_id(id)
return await self.request('POST', f'/api/v1/statuses/{id}/favourite')
async def unfavorite(self, id):
id = self._unpack_id(id)
return await self.request('POST', f'/api/v1/statuses/{id}/unfavourite')
async def react(self, id, reaction):
id = self._unpack_id(id)
return await self.request('PUT', f'/api/v1/pleroma/statuses/{id}/reactions/{reaction}')
async def remove_reaction(self, id, reaction):
id = self._unpack_id(id)
return await self.request('DELETE', f'/api/v1/pleroma/statuses/{id}/reactions/{reaction}')
async def pin(self, id):
id = self._unpack_id(id)
return await self.request('POST', f'/api/v1/statuses/{id}/pin')
async def unpin(self, id):
id = self._unpack_id(id)
return await self.request('POST', f'/api/v1/statuses/{id}/unpin')
async def stream(self, stream_name, *, target_event_type=None):
async with self._session.ws_connect(
self.api_base_url + f'/api/v1/streaming?stream={stream_name}&access_token={self.access_token}'
) as ws:
async for msg in ws:
if msg.type == aiohttp.WSMsgType.TEXT:
event = msg.json()
# the only event type that doesn't define `payload` is `filters_changed`
if event['event'] == 'filters_changed':
yield event
elif target_event_type is None or event['event'] == target_event_type:
# don't ask me why the payload is also JSON encoded smh
yield json.loads(event['payload'])
async def stream_notifications(self):
async for notif in self.stream('user:notification', target_event_type='notification'):
yield notif
async def stream_mentions(self):
async for notif in self.stream_notifications():
if notif['type'] == 'mention':
yield notif

View File

@ -22,10 +22,22 @@ class ReplyBot:
async for notification in self.pleroma.stream_mentions(): async for notification in self.pleroma.stream_mentions():
await self.process_notification(notification) await self.process_notification(notification)
async def process_notification(self, notification): async def process_notification(self, notification, retry_count=0):
acct = "@" + notification['account']['acct'] # get the account's @ acct = "@" + notification['account']['acct'] # get the account's @
post_id = notification['status']['id'] post_id = notification['status']['id']
# catch HTTP 500 and backoff on requests
retry_count = retry_count + 1
try:
context = await self.pleroma.status_context(post_id) context = await self.pleroma.status_context(post_id)
except pleroma.BadResponse as exc:
if retry_count < 3:
await anyio.sleep(2**retry_count)
await self.process_notification(notification, retry_count)
else:
# failed too many times in a row, logging
print(f"Received HTTP 500 {retry_count} times in a row, aborting reply attempt.")
return
# check if we've already been participating in this thread # check if we've already been participating in this thread
if self.check_thread_length(context): if self.check_thread_length(context):
@ -69,12 +81,12 @@ class ReplyBot:
await self.pleroma.react(post_id, '') await self.pleroma.react(post_id, '')
async def reply(self, notification): async def reply(self, notification):
toot = utils.make_toot(self.cfg) # generate a toot toot = await utils.make_post(self.cfg) # generate a toot
await self.pleroma.reply(notification['status'], toot, cw=self.cfg['cw']) await self.pleroma.reply(notification['status'], toot, cw=self.cfg['cw'])
@staticmethod @staticmethod
def extract_toot(toot): def extract_toot(toot):
text = utils.extract_toot(toot) text = utils.extract_post_content(toot)
text = re.sub(r"^@\S+\s", r"", text) # remove the initial mention text = re.sub(r"^@\S+\s", r"", text) # remove the initial mention
text = text.lower() # treat text as lowercase for easier keyword matching (if this bot uses it) text = text.lower() # treat text as lowercase for easier keyword matching (if this bot uses it)
return text return text

View File

@ -1,4 +1,5 @@
beautifulsoup4 ~= 4.9 beautifulsoup4 ~= 4.9
pleroma.py ~= 0.0.1
aiohttp ~= 3.0 aiohttp ~= 3.0
json5 ~= 0.9.5 json5 ~= 0.9.5
anyio ~= 3.0 anyio ~= 3.0

View File

@ -1,11 +1,11 @@
# SPDX-License-Identifier: AGPL-3.0-only # SPDX-License-Identifier: AGPL-3.0-only
import sys
import anyio import anyio
import aiohttp import aiohttp
import platform import platform
import contextlib import contextlib
from functools import wraps from functools import wraps
from datetime import datetime, timezone
def http_session_factory(headers={}, **kwargs): def http_session_factory(headers={}, **kwargs):
user_agent = ( user_agent = (
@ -47,35 +47,3 @@ def removeprefix(s, prefix):
except AttributeError: except AttributeError:
# compatibility for pre-3.9 # compatibility for pre-3.9
return s[len(prefix):] if s.startswith(prefix) else s return s[len(prefix):] if s.startswith(prefix) else s
async def sleep_until(dt):
await anyio.sleep((dt - datetime.now(timezone.utc)).total_seconds())
class HandleRateLimits:
def __init__(self, http):
self.http = http
def request(self, *args, **kwargs):
return _RateLimitContextManager(self.http, args, kwargs)
class _RateLimitContextManager(contextlib.AbstractAsyncContextManager):
def __init__(self, http, args, kwargs):
self.http = http
self.args = args
self.kwargs = kwargs
async def __aenter__(self):
self._request_cm = self.http.request(*self.args, **self.kwargs)
return await self._do_enter()
async def _do_enter(self):
resp = await self._request_cm.__aenter__()
if resp.headers.get('X-RateLimit-Remaining') not in {'0', '1'}:
return resp
await sleep_until(datetime.fromisoformat(resp.headers['X-RateLimit-Reset']))
await self._request_cm.__aexit__(*(None,)*3)
return await self.__aenter__()
async def __aexit__(self, *excinfo):
return await self._request_cm.__aexit__(*excinfo)