Remove all remaining async code for now

This commit is contained in:
Andrew Godwin 2023-07-17 00:37:47 -06:00
parent 0915b17c4b
commit 188e5a2446
19 changed files with 114 additions and 185 deletions

View File

@ -4,7 +4,6 @@ from typing import ClassVar
import httpx import httpx
import urlman import urlman
from asgiref.sync import sync_to_async
from cachetools import TTLCache, cached from cachetools import TTLCache, cached
from django.conf import settings from django.conf import settings
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
@ -35,13 +34,13 @@ class EmojiStates(StateGraph):
outdated.transitions_to(updated) outdated.transitions_to(updated)
@classmethod @classmethod
async def handle_outdated(cls, instance: "Emoji"): def handle_outdated(cls, instance: "Emoji"):
""" """
Fetches remote emoji and uploads to file for local caching Fetches remote emoji and uploads to file for local caching
""" """
if instance.remote_url and not instance.file: if instance.remote_url and not instance.file:
try: try:
file, mimetype = await get_remote_file( file, mimetype = get_remote_file(
instance.remote_url, instance.remote_url,
timeout=settings.SETUP.REMOTE_TIMEOUT, timeout=settings.SETUP.REMOTE_TIMEOUT,
max_size=settings.SETUP.EMOJI_MAX_IMAGE_FILESIZE_KB * 1024, max_size=settings.SETUP.EMOJI_MAX_IMAGE_FILESIZE_KB * 1024,
@ -55,7 +54,7 @@ class EmojiStates(StateGraph):
instance.file = file instance.file = file
instance.mimetype = mimetype instance.mimetype = mimetype
await sync_to_async(instance.save)() instance.save()
return cls.updated return cls.updated

View File

@ -1,5 +1,4 @@
import httpx import httpx
from asgiref.sync import async_to_sync
from django.db import models from django.db import models
from activities.models.timeline_event import TimelineEvent from activities.models.timeline_event import TimelineEvent
@ -77,7 +76,7 @@ class FanOutStates(StateGraph):
post = instance.subject_post post = instance.subject_post
# Sign it and send it # Sign it and send it
try: try:
async_to_sync(post.author.signed_request)( post.author.signed_request(
method="post", method="post",
uri=( uri=(
instance.identity.shared_inbox_uri instance.identity.shared_inbox_uri
@ -93,7 +92,7 @@ class FanOutStates(StateGraph):
post = instance.subject_post post = instance.subject_post
# Sign it and send it # Sign it and send it
try: try:
async_to_sync(post.author.signed_request)( post.author.signed_request(
method="post", method="post",
uri=( uri=(
instance.identity.shared_inbox_uri instance.identity.shared_inbox_uri
@ -119,7 +118,7 @@ class FanOutStates(StateGraph):
post = instance.subject_post post = instance.subject_post
# Send it to the remote inbox # Send it to the remote inbox
try: try:
async_to_sync(post.author.signed_request)( post.author.signed_request(
method="post", method="post",
uri=( uri=(
instance.identity.shared_inbox_uri instance.identity.shared_inbox_uri
@ -172,7 +171,7 @@ class FanOutStates(StateGraph):
body = interaction.to_add_ap() body = interaction.to_add_ap()
else: else:
body = interaction.to_create_ap() body = interaction.to_create_ap()
async_to_sync(interaction.identity.signed_request)( interaction.identity.signed_request(
method="post", method="post",
uri=( uri=(
instance.identity.shared_inbox_uri instance.identity.shared_inbox_uri
@ -202,7 +201,7 @@ class FanOutStates(StateGraph):
body = interaction.to_remove_ap() body = interaction.to_remove_ap()
else: else:
body = interaction.to_undo_ap() body = interaction.to_undo_ap()
async_to_sync(interaction.identity.signed_request)( interaction.identity.signed_request(
method="post", method="post",
uri=( uri=(
instance.identity.shared_inbox_uri instance.identity.shared_inbox_uri
@ -217,7 +216,7 @@ class FanOutStates(StateGraph):
case (FanOut.Types.identity_edited, False): case (FanOut.Types.identity_edited, False):
identity = instance.subject_identity identity = instance.subject_identity
try: try:
async_to_sync(identity.signed_request)( identity.signed_request(
method="post", method="post",
uri=( uri=(
instance.identity.shared_inbox_uri instance.identity.shared_inbox_uri
@ -232,7 +231,7 @@ class FanOutStates(StateGraph):
case (FanOut.Types.identity_deleted, False): case (FanOut.Types.identity_deleted, False):
identity = instance.subject_identity identity = instance.subject_identity
try: try:
async_to_sync(identity.signed_request)( identity.signed_request(
method="post", method="post",
uri=( uri=(
instance.identity.shared_inbox_uri instance.identity.shared_inbox_uri

View File

@ -8,7 +8,6 @@ from urllib.parse import urlparse
import httpx import httpx
import urlman import urlman
from asgiref.sync import async_to_sync
from django.contrib.postgres.indexes import GinIndex from django.contrib.postgres.indexes import GinIndex
from django.contrib.postgres.search import SearchVector from django.contrib.postgres.search import SearchVector
from django.db import models, transaction from django.db import models, transaction
@ -831,7 +830,7 @@ class Post(StatorModel):
# If the author is not fetched yet, try again later # If the author is not fetched yet, try again later
if author.domain is None: if author.domain is None:
if fetch_author: if fetch_author:
async_to_sync(author.fetch_actor)() author.fetch_actor()
# perhaps the entire "try again" logic below # perhaps the entire "try again" logic below
# could be replaced with TryAgainLater for # could be replaced with TryAgainLater for
# _all_ fetches, to let it handle pinned posts? # _all_ fetches, to let it handle pinned posts?
@ -981,7 +980,7 @@ class Post(StatorModel):
except cls.DoesNotExist: except cls.DoesNotExist:
if fetch: if fetch:
try: try:
response = async_to_sync(SystemActor().signed_request)( response = SystemActor().signed_request(
method="get", uri=object_uri method="get", uri=object_uri
) )
except (httpx.HTTPError, ssl.SSLCertVerificationError): except (httpx.HTTPError, ssl.SSLCertVerificationError):
@ -1008,7 +1007,7 @@ class Post(StatorModel):
) from err ) from err
# We may need to fetch the author too # We may need to fetch the author too
if post.author.state == IdentityStates.outdated: if post.author.state == IdentityStates.outdated:
async_to_sync(post.author.fetch_actor)() post.author.fetch_actor()
return post return post
else: else:
raise cls.DoesNotExist(f"Cannot find Post with URI {object_uri}") raise cls.DoesNotExist(f"Cannot find Post with URI {object_uri}")

View File

@ -1,5 +1,4 @@
import httpx import httpx
from asgiref.sync import async_to_sync
from activities.models import Hashtag, Post from activities.models import Hashtag, Post
from core.ld import canonicalise from core.ld import canonicalise
@ -49,7 +48,7 @@ class SearchService:
username, domain_instance or domain, fetch=True username, domain_instance or domain, fetch=True
) )
if identity and identity.state == IdentityStates.outdated: if identity and identity.state == IdentityStates.outdated:
async_to_sync(identity.fetch_actor)() identity.fetch_actor()
except ValueError: except ValueError:
pass pass
@ -74,7 +73,7 @@ class SearchService:
# Fetch the provided URL as the system actor to retrieve the AP JSON # Fetch the provided URL as the system actor to retrieve the AP JSON
try: try:
response = async_to_sync(SystemActor().signed_request)( response = SystemActor().signed_request(
method="get", method="get",
uri=self.query, uri=self.query,
) )
@ -90,7 +89,7 @@ class SearchService:
# Try and retrieve the profile by actor URI # Try and retrieve the profile by actor URI
identity = Identity.by_actor_uri(document["id"], create=True) identity = Identity.by_actor_uri(document["id"], create=True)
if identity and identity.state == IdentityStates.outdated: if identity and identity.state == IdentityStates.outdated:
async_to_sync(identity.fetch_actor)() identity.fetch_actor()
return identity return identity
# Is it a post? # Is it a post?

View File

@ -1,7 +1,6 @@
import json import json
import httpx import httpx
from asgiref.sync import async_to_sync
from django import forms from django import forms
from django.utils.decorators import method_decorator from django.utils.decorators import method_decorator
from django.views.generic import FormView, TemplateView from django.views.generic import FormView, TemplateView
@ -13,7 +12,6 @@ from users.models import SystemActor
@method_decorator(admin_required, name="dispatch") @method_decorator(admin_required, name="dispatch")
class JsonViewer(FormView): class JsonViewer(FormView):
template_name = "activities/debug_json.html" template_name = "activities/debug_json.html"
class form_class(forms.Form): class form_class(forms.Form):
@ -31,7 +29,7 @@ class JsonViewer(FormView):
context = self.get_context_data(form=form) context = self.get_context_data(form=form)
try: try:
response = async_to_sync(SystemActor().signed_request)( response = SystemActor().signed_request(
method="get", method="get",
uri=uri, uri=uri,
) )
@ -64,18 +62,15 @@ class JsonViewer(FormView):
class NotFound(TemplateView): class NotFound(TemplateView):
template_name = "404.html" template_name = "404.html"
class ServerError(TemplateView): class ServerError(TemplateView):
template_name = "500.html" template_name = "500.html"
@method_decorator(admin_required, name="dispatch") @method_decorator(admin_required, name="dispatch")
class OauthAuthorize(TemplateView): class OauthAuthorize(TemplateView):
template_name = "api/oauth_authorize.html" template_name = "api/oauth_authorize.html"
def get_context_data(self): def get_context_data(self):

View File

@ -57,7 +57,7 @@ def blurhash_image(file) -> str:
return blurhash.encode(file, 4, 4) return blurhash.encode(file, 4, 4)
async def get_remote_file( def get_remote_file(
url: str, url: str,
*, *,
timeout: float = settings.SETUP.REMOTE_TIMEOUT, timeout: float = settings.SETUP.REMOTE_TIMEOUT,
@ -70,8 +70,8 @@ async def get_remote_file(
"User-Agent": settings.TAKAHE_USER_AGENT, "User-Agent": settings.TAKAHE_USER_AGENT,
} }
async with httpx.AsyncClient(headers=headers) as client: with httpx.Client(headers=headers) as client:
async with client.stream( with client.stream(
"GET", url, timeout=timeout, follow_redirects=True "GET", url, timeout=timeout, follow_redirects=True
) as stream: ) as stream:
allow_download = max_size is None allow_download = max_size is None
@ -82,7 +82,7 @@ async def get_remote_file(
except (KeyError, TypeError): except (KeyError, TypeError):
pass pass
if allow_download: if allow_download:
file = ContentFile(await stream.aread(), name=url) file = ContentFile(stream.read(), name=url)
return file, stream.headers.get( return file, stream.headers.get(
"content-type", "application/octet-stream" "content-type", "application/octet-stream"
) )

View File

@ -177,7 +177,7 @@ class HttpSignature:
) )
@classmethod @classmethod
async def signed_request( def signed_request(
cls, cls,
uri: str, uri: str,
body: dict | None, body: dict | None,
@ -241,9 +241,9 @@ class HttpSignature:
# Send the request with all those headers except the pseudo one # Send the request with all those headers except the pseudo one
del headers["(request-target)"] del headers["(request-target)"]
async with httpx.AsyncClient(timeout=timeout) as client: with httpx.Client(timeout=timeout) as client:
try: try:
response = await client.request( response = client.request(
method, method,
uri, uri,
headers=headers, headers=headers,

View File

@ -238,10 +238,6 @@ def stator(config_system) -> StatorRunner:
""" """
Return an initialized StatorRunner for tests that need state transitioning Return an initialized StatorRunner for tests that need state transitioning
to happen. to happen.
Example:
# Do some tasks with state side effects
async_to_sync(stator_runner.fetch_and_process_tasks)()
""" """
runner = StatorRunner( runner = StatorRunner(
StatorModel.subclasses, StatorModel.subclasses,

View File

@ -1,5 +1,4 @@
import pytest import pytest
from asgiref.sync import async_to_sync
from django.test.client import RequestFactory from django.test.client import RequestFactory
from pytest_httpx import HTTPXMock from pytest_httpx import HTTPXMock
@ -75,7 +74,7 @@ def test_sign_http(httpx_mock: HTTPXMock, keypair):
} }
# Send the signed request to the mock library # Send the signed request to the mock library
httpx_mock.add_response() httpx_mock.add_response()
async_to_sync(HttpSignature.signed_request)( HttpSignature.signed_request(
uri="https://example.com/test-actor", uri="https://example.com/test-actor",
body=document, body=document,
private_key=keypair["private_key"], private_key=keypair["private_key"],

View File

@ -1,5 +1,4 @@
import pytest import pytest
from asgiref.sync import async_to_sync
from pytest_httpx import HTTPXMock from pytest_httpx import HTTPXMock
from core.models import Config from core.models import Config
@ -169,7 +168,7 @@ def test_fetch_actor(httpx_mock, config_system):
"url": "https://example.com/test-actor/view/", "url": "https://example.com/test-actor/view/",
}, },
) )
async_to_sync(identity.fetch_actor)() identity.fetch_actor()
# Verify the data arrived # Verify the data arrived
identity = Identity.objects.get(pk=identity.pk) identity = Identity.objects.get(pk=identity.pk)
@ -189,15 +188,14 @@ def test_fetch_actor(httpx_mock, config_system):
@pytest.mark.django_db @pytest.mark.django_db
@pytest.mark.asyncio def test_fetch_webfinger_url(httpx_mock: HTTPXMock, config_system):
async def test_fetch_webfinger_url(httpx_mock: HTTPXMock, config_system):
""" """
Ensures that we can deal with various kinds of webfinger URLs Ensures that we can deal with various kinds of webfinger URLs
""" """
# With no host-meta, it should be the default # With no host-meta, it should be the default
assert ( assert (
await Identity.fetch_webfinger_url("example.com") Identity.fetch_webfinger_url("example.com")
== "https://example.com/.well-known/webfinger?resource={uri}" == "https://example.com/.well-known/webfinger?resource={uri}"
) )
@ -210,7 +208,7 @@ async def test_fetch_webfinger_url(httpx_mock: HTTPXMock, config_system):
</XRD>""", </XRD>""",
) )
assert ( assert (
await Identity.fetch_webfinger_url("example.com") Identity.fetch_webfinger_url("example.com")
== "https://fedi.example.com/.well-known/webfinger?resource={uri}" == "https://fedi.example.com/.well-known/webfinger?resource={uri}"
) )
@ -223,7 +221,7 @@ async def test_fetch_webfinger_url(httpx_mock: HTTPXMock, config_system):
</XRD>""", </XRD>""",
) )
assert ( assert (
await Identity.fetch_webfinger_url("example.com") Identity.fetch_webfinger_url("example.com")
== "https://example.com/amazing-webfinger?query={uri}" == "https://example.com/amazing-webfinger?query={uri}"
) )
@ -237,7 +235,7 @@ async def test_fetch_webfinger_url(httpx_mock: HTTPXMock, config_system):
</XRD>""", </XRD>""",
) )
assert ( assert (
await Identity.fetch_webfinger_url("example.com") Identity.fetch_webfinger_url("example.com")
== "https://example.com/.well-known/webfinger?resource={uri}" == "https://example.com/.well-known/webfinger?resource={uri}"
) )

View File

@ -1,5 +1,4 @@
import pytest import pytest
from asgiref.sync import async_to_sync
from django.test.client import RequestFactory from django.test.client import RequestFactory
from pytest_httpx import HTTPXMock from pytest_httpx import HTTPXMock
@ -16,7 +15,7 @@ def test_system_actor_signed(config_system, httpx_mock: HTTPXMock):
system_actor.generate_keys() system_actor.generate_keys()
# Send a fake outbound request # Send a fake outbound request
httpx_mock.add_response() httpx_mock.add_response()
async_to_sync(system_actor.signed_request)( system_actor.signed_request(
method="get", method="get",
uri="http://example.com/test-actor", uri="http://example.com/test-actor",
) )

View File

@ -1,4 +1,3 @@
from asgiref.sync import async_to_sync
from django.contrib import admin from django.contrib import admin
from django.db import models from django.db import models
from django.utils import formats from django.utils import formats
@ -60,7 +59,7 @@ class DomainAdmin(admin.ModelAdmin):
@admin.action(description="Fetch nodeinfo") @admin.action(description="Fetch nodeinfo")
def fetch_nodeinfo(self, request, queryset): def fetch_nodeinfo(self, request, queryset):
for instance in queryset: for instance in queryset:
info = async_to_sync(instance.fetch_nodeinfo)() info = instance.fetch_nodeinfo()
if info: if info:
instance.nodeinfo = info.dict() instance.nodeinfo = info.dict()
instance.save() instance.save()

View File

@ -30,7 +30,7 @@ class BlockStates(StateGraph):
return [cls.new, cls.sent, cls.awaiting_expiry] return [cls.new, cls.sent, cls.awaiting_expiry]
@classmethod @classmethod
async def handle_new(cls, instance: "Block"): def handle_new(cls, instance: "Block"):
""" """
Block that are new need us to deliver the Block object Block that are new need us to deliver the Block object
to the target server. to the target server.
@ -38,20 +38,18 @@ class BlockStates(StateGraph):
# Mutes don't send but might need expiry # Mutes don't send but might need expiry
if instance.mute: if instance.mute:
return cls.awaiting_expiry return cls.awaiting_expiry
# Fetch more info
block = await instance.afetch_full()
# Remote blocks should not be here, local blocks just work # Remote blocks should not be here, local blocks just work
if not block.source.local or block.target.local: if not instance.source.local or instance.target.local:
return cls.sent return cls.sent
# Don't try if the other identity didn't fetch yet # Don't try if the other identity didn't fetch yet
if not block.target.inbox_uri: if not instance.target.inbox_uri:
return return
# Sign it and send it # Sign it and send it
try: try:
await block.source.signed_request( instance.source.signed_request(
method="post", method="post",
uri=block.target.inbox_uri, uri=instance.target.inbox_uri,
body=canonicalise(block.to_ap()), body=canonicalise(instance.to_ap()),
) )
except httpx.RequestError: except httpx.RequestError:
return return
@ -66,19 +64,18 @@ class BlockStates(StateGraph):
return cls.undone return cls.undone
@classmethod @classmethod
async def handle_undone(cls, instance: "Block"): def handle_undone(cls, instance: "Block"):
""" """
Delivers the Undo object to the target server Delivers the Undo object to the target server
""" """
block = await instance.afetch_full()
# Remote blocks should not be here, mutes don't send, local blocks just work # Remote blocks should not be here, mutes don't send, local blocks just work
if not block.source.local or block.target.local or instance.mute: if not instance.source.local or instance.target.local or instance.mute:
return cls.undone_sent return cls.undone_sent
try: try:
await block.source.signed_request( instance.source.signed_request(
method="post", method="post",
uri=block.target.inbox_uri, uri=instance.target.inbox_uri,
body=canonicalise(block.to_undo_ap()), body=canonicalise(instance.to_undo_ap()),
) )
except httpx.RequestError: except httpx.RequestError:
return return
@ -227,16 +224,6 @@ class Block(StatorModel):
def active(self): def active(self):
return self.state in BlockStates.group_active() return self.state in BlockStates.group_active()
### Async helpers ###
async def afetch_full(self):
"""
Returns a version of the object with all relations pre-loaded
"""
return await Block.objects.select_related(
"source", "source__domain", "target"
).aget(pk=self.pk)
### ActivityPub (outbound) ### ### ActivityPub (outbound) ###
def to_ap(self): def to_ap(self):

View File

@ -6,7 +6,6 @@ from typing import Optional
import httpx import httpx
import pydantic import pydantic
import urlman import urlman
from asgiref.sync import sync_to_async
from django.conf import settings from django.conf import settings
from django.db import models from django.db import models
@ -33,15 +32,15 @@ class DomainStates(StateGraph):
outdated.times_out_to(connection_issue, 60 * 60 * 24) outdated.times_out_to(connection_issue, 60 * 60 * 24)
@classmethod @classmethod
async def handle_outdated(cls, instance: "Domain"): def handle_outdated(cls, instance: "Domain"):
info = await instance.fetch_nodeinfo() info = instance.fetch_nodeinfo()
if info: if info:
instance.nodeinfo = info.dict() instance.nodeinfo = info.dict()
await sync_to_async(instance.save)() instance.save()
return cls.updated return cls.updated
@classmethod @classmethod
async def handle_updated(cls, instance: "Domain"): def handle_updated(cls, instance: "Domain"):
return cls.outdated return cls.outdated
@ -157,18 +156,18 @@ class Domain(StatorModel):
) )
super().save(*args, **kwargs) super().save(*args, **kwargs)
async def fetch_nodeinfo(self) -> NodeInfo | None: def fetch_nodeinfo(self) -> NodeInfo | None:
""" """
Fetch the /NodeInfo/2.0 for the domain Fetch the /NodeInfo/2.0 for the domain
""" """
nodeinfo20_url = f"https://{self.domain}/nodeinfo/2.0" nodeinfo20_url = f"https://{self.domain}/nodeinfo/2.0"
async with httpx.AsyncClient( with httpx.Client(
timeout=settings.SETUP.REMOTE_TIMEOUT, timeout=settings.SETUP.REMOTE_TIMEOUT,
headers={"User-Agent": settings.TAKAHE_USER_AGENT}, headers={"User-Agent": settings.TAKAHE_USER_AGENT},
) as client: ) as client:
try: try:
response = await client.get( response = client.get(
f"https://{self.domain}/.well-known/nodeinfo", f"https://{self.domain}/.well-known/nodeinfo",
follow_redirects=True, follow_redirects=True,
headers={"Accept": "application/json"}, headers={"Accept": "application/json"},
@ -190,7 +189,7 @@ class Domain(StatorModel):
pass pass
try: try:
response = await client.get( response = client.get(
nodeinfo20_url, nodeinfo20_url,
follow_redirects=True, follow_redirects=True,
headers={"Accept": "application/json"}, headers={"Accept": "application/json"},

View File

@ -34,26 +34,25 @@ class FollowStates(StateGraph):
return [cls.unrequested, cls.local_requested, cls.accepted] return [cls.unrequested, cls.local_requested, cls.accepted]
@classmethod @classmethod
async def handle_unrequested(cls, instance: "Follow"): def handle_unrequested(cls, instance: "Follow"):
""" """
Follows that are unrequested need us to deliver the Follow object Follows that are unrequested need us to deliver the Follow object
to the target server. to the target server.
""" """
follow = await instance.afetch_full()
# Remote follows should not be here # Remote follows should not be here
if not follow.source.local: if not instance.source.local:
return cls.remote_requested return cls.remote_requested
if follow.target.local: if instance.target.local:
return cls.accepted return cls.accepted
# Don't try if the other identity didn't fetch yet # Don't try if the other identity didn't fetch yet
if not follow.target.inbox_uri: if not instance.target.inbox_uri:
return return
# Sign it and send it # Sign it and send it
try: try:
await follow.source.signed_request( instance.source.signed_request(
method="post", method="post",
uri=follow.target.inbox_uri, uri=instance.target.inbox_uri,
body=canonicalise(follow.to_ap()), body=canonicalise(instance.to_ap()),
) )
except httpx.RequestError: except httpx.RequestError:
return return
@ -65,33 +64,31 @@ class FollowStates(StateGraph):
pass pass
@classmethod @classmethod
async def handle_remote_requested(cls, instance: "Follow"): def handle_remote_requested(cls, instance: "Follow"):
""" """
Items in remote_requested need us to send an Accept object to the Items in remote_requested need us to send an Accept object to the
source server. source server.
""" """
follow = await instance.afetch_full()
try: try:
await follow.target.signed_request( instance.target.signed_request(
method="post", method="post",
uri=follow.source.inbox_uri, uri=instance.source.inbox_uri,
body=canonicalise(follow.to_accept_ap()), body=canonicalise(instance.to_accept_ap()),
) )
except httpx.RequestError: except httpx.RequestError:
return return
return cls.accepted return cls.accepted
@classmethod @classmethod
async def handle_undone(cls, instance: "Follow"): def handle_undone(cls, instance: "Follow"):
""" """
Delivers the Undo object to the target server Delivers the Undo object to the target server
""" """
follow = await instance.afetch_full()
try: try:
await follow.source.signed_request( instance.source.signed_request(
method="post", method="post",
uri=follow.target.inbox_uri, uri=instance.target.inbox_uri,
body=canonicalise(follow.to_undo_ap()), body=canonicalise(instance.to_undo_ap()),
) )
except httpx.RequestError: except httpx.RequestError:
return return
@ -204,16 +201,6 @@ class Follow(StatorModel):
follow.save() follow.save()
return follow return follow
### Async helpers ###
async def afetch_full(self):
"""
Returns a version of the object with all relations pre-loaded
"""
return await Follow.objects.select_related(
"source", "source__domain", "target"
).aget(pk=self.pk)
### Properties ### ### Properties ###
@property @property

View File

@ -5,7 +5,6 @@ from urllib.parse import urlparse
import httpx import httpx
import urlman import urlman
from asgiref.sync import async_to_sync, sync_to_async
from django.conf import settings from django.conf import settings
from django.db import IntegrityError, models from django.db import IntegrityError, models
from django.utils import timezone from django.utils import timezone
@ -66,13 +65,13 @@ class IdentityStates(StateGraph):
return [cls.deleted, cls.deleted_fanned_out] return [cls.deleted, cls.deleted_fanned_out]
@classmethod @classmethod
async def targets_fan_out(cls, identity: "Identity", type_: str) -> None: def targets_fan_out(cls, identity: "Identity", type_: str) -> None:
from activities.models import FanOut from activities.models import FanOut
from users.models import Follow from users.models import Follow
# Fan out to each target # Fan out to each target
shared_inboxes = set() shared_inboxes = set()
async for follower in Follow.objects.select_related("source", "target").filter( for follower in Follow.objects.select_related("source", "target").filter(
target=identity target=identity
): ):
# Dedupe shared_inbox_uri # Dedupe shared_inbox_uri
@ -80,7 +79,7 @@ class IdentityStates(StateGraph):
if shared_uri and shared_uri in shared_inboxes: if shared_uri and shared_uri in shared_inboxes:
continue continue
await FanOut.objects.acreate( FanOut.objects.create(
identity=follower.source, identity=follower.source,
type=type_, type=type_,
subject_identity=identity, subject_identity=identity,
@ -88,34 +87,32 @@ class IdentityStates(StateGraph):
shared_inboxes.add(shared_uri) shared_inboxes.add(shared_uri)
@classmethod @classmethod
async def handle_edited(cls, instance: "Identity"): def handle_edited(cls, instance: "Identity"):
from activities.models import FanOut from activities.models import FanOut
if not instance.local: if not instance.local:
return cls.updated return cls.updated
identity = await instance.afetch_full() cls.targets_fan_out(instance, FanOut.Types.identity_edited)
await cls.targets_fan_out(identity, FanOut.Types.identity_edited)
return cls.updated return cls.updated
@classmethod @classmethod
async def handle_deleted(cls, instance: "Identity"): def handle_deleted(cls, instance: "Identity"):
from activities.models import FanOut from activities.models import FanOut
if not instance.local: if not instance.local:
return cls.updated return cls.updated
identity = await instance.afetch_full() cls.targets_fan_out(instance, FanOut.Types.identity_deleted)
await cls.targets_fan_out(identity, FanOut.Types.identity_deleted)
return cls.deleted_fanned_out return cls.deleted_fanned_out
@classmethod @classmethod
async def handle_outdated(cls, identity: "Identity"): def handle_outdated(cls, identity: "Identity"):
# Local identities never need fetching # Local identities never need fetching
if identity.local: if identity.local:
return cls.updated return cls.updated
# Run the actor fetch and progress to updated if it succeeds # Run the actor fetch and progress to updated if it succeeds
if await identity.fetch_actor(): if identity.fetch_actor():
return cls.updated return cls.updated
@classmethod @classmethod
@ -365,9 +362,7 @@ class Identity(StatorModel):
) )
except cls.DoesNotExist: except cls.DoesNotExist:
if fetch and not local: if fetch and not local:
actor_uri, handle = async_to_sync(cls.fetch_webfinger)( actor_uri, handle = cls.fetch_webfinger(f"{username}@{domain}")
f"{username}@{domain}"
)
if handle is None: if handle is None:
return None return None
# See if this actually does match an existing actor # See if this actually does match an existing actor
@ -449,14 +444,6 @@ class Identity(StatorModel):
def limited(self) -> bool: def limited(self) -> bool:
return self.restriction == self.Restriction.limited return self.restriction == self.Restriction.limited
### Async helpers ###
async def afetch_full(self):
"""
Returns a version of the object with all relations pre-loaded
"""
return await Identity.objects.select_related("domain").aget(pk=self.pk)
### ActivityPub (outbound) ### ### ActivityPub (outbound) ###
def to_webfinger(self): def to_webfinger(self):
@ -637,17 +624,17 @@ class Identity(StatorModel):
### Actor/Webfinger fetching ### ### Actor/Webfinger fetching ###
@classmethod @classmethod
async def fetch_webfinger_url(cls, domain: str): def fetch_webfinger_url(cls, domain: str):
""" """
Given a domain (hostname), returns the correct webfinger URL to use Given a domain (hostname), returns the correct webfinger URL to use
based on probing host-meta. based on probing host-meta.
""" """
async with httpx.AsyncClient( with httpx.Client(
timeout=settings.SETUP.REMOTE_TIMEOUT, timeout=settings.SETUP.REMOTE_TIMEOUT,
headers={"User-Agent": settings.TAKAHE_USER_AGENT}, headers={"User-Agent": settings.TAKAHE_USER_AGENT},
) as client: ) as client:
try: try:
response = await client.get( response = client.get(
f"https://{domain}/.well-known/host-meta", f"https://{domain}/.well-known/host-meta",
follow_redirects=True, follow_redirects=True,
headers={"Accept": "application/xml"}, headers={"Accept": "application/xml"},
@ -669,24 +656,24 @@ class Identity(StatorModel):
return f"https://{domain}/.well-known/webfinger?resource={{uri}}" return f"https://{domain}/.well-known/webfinger?resource={{uri}}"
@classmethod @classmethod
async def fetch_webfinger(cls, handle: str) -> tuple[str | None, str | None]: def fetch_webfinger(cls, handle: str) -> tuple[str | None, str | None]:
""" """
Given a username@domain handle, returns a tuple of Given a username@domain handle, returns a tuple of
(actor uri, canonical handle) or None, None if it does not resolve. (actor uri, canonical handle) or None, None if it does not resolve.
""" """
domain = handle.split("@")[1].lower() domain = handle.split("@")[1].lower()
try: try:
webfinger_url = await cls.fetch_webfinger_url(domain) webfinger_url = cls.fetch_webfinger_url(domain)
except ssl.SSLCertVerificationError: except ssl.SSLCertVerificationError:
return None, None return None, None
# Go make a Webfinger request # Go make a Webfinger request
async with httpx.AsyncClient( with httpx.Client(
timeout=settings.SETUP.REMOTE_TIMEOUT, timeout=settings.SETUP.REMOTE_TIMEOUT,
headers={"User-Agent": settings.TAKAHE_USER_AGENT}, headers={"User-Agent": settings.TAKAHE_USER_AGENT},
) as client: ) as client:
try: try:
response = await client.get( response = client.get(
webfinger_url.format(uri=f"acct:{handle}"), webfinger_url.format(uri=f"acct:{handle}"),
follow_redirects=True, follow_redirects=True,
headers={"Accept": "application/json"}, headers={"Accept": "application/json"},
@ -730,16 +717,16 @@ class Identity(StatorModel):
return None, None return None, None
@classmethod @classmethod
async def fetch_pinned_post_uris(cls, uri: str) -> list[str]: def fetch_pinned_post_uris(cls, uri: str) -> list[str]:
""" """
Fetch an identity's featured collection. Fetch an identity's featured collection.
""" """
async with httpx.AsyncClient( with httpx.Client(
timeout=settings.SETUP.REMOTE_TIMEOUT, timeout=settings.SETUP.REMOTE_TIMEOUT,
headers={"User-Agent": settings.TAKAHE_USER_AGENT}, headers={"User-Agent": settings.TAKAHE_USER_AGENT},
) as client: ) as client:
try: try:
response = await client.get( response = client.get(
uri, uri,
follow_redirects=True, follow_redirects=True,
headers={"Accept": "application/activity+json"}, headers={"Accept": "application/activity+json"},
@ -785,7 +772,7 @@ class Identity(StatorModel):
response.content, response.content,
) )
async def fetch_actor(self) -> bool: def fetch_actor(self) -> bool:
""" """
Fetches the user's actor information, as well as their domain from Fetches the user's actor information, as well as their domain from
webfinger if it's available. webfinger if it's available.
@ -796,7 +783,7 @@ class Identity(StatorModel):
if self.local: if self.local:
raise ValueError("Cannot fetch local identities") raise ValueError("Cannot fetch local identities")
try: try:
response = await SystemActor().signed_request( response = SystemActor().signed_request(
method="get", method="get",
uri=self.actor_uri, uri=self.actor_uri,
) )
@ -810,7 +797,7 @@ class Identity(StatorModel):
if status_code >= 400: if status_code >= 400:
if status_code == 410 and self.pk: if status_code == 410 and self.pk:
# Their account got deleted, so let's do the same. # Their account got deleted, so let's do the same.
await Identity.objects.filter(pk=self.pk).adelete() Identity.objects.filter(pk=self.pk).delete()
if status_code < 500 and status_code not in [401, 403, 404, 406, 410]: if status_code < 500 and status_code not in [401, 403, 404, 406, 410]:
capture_message( capture_message(
@ -866,44 +853,43 @@ class Identity(StatorModel):
) )
# Now go do webfinger with that info to see if we can get a canonical domain # Now go do webfinger with that info to see if we can get a canonical domain
actor_url_parts = urlparse(self.actor_uri) actor_url_parts = urlparse(self.actor_uri)
get_domain = sync_to_async(Domain.get_remote_domain)
if self.username: if self.username:
webfinger_actor, webfinger_handle = await self.fetch_webfinger( webfinger_actor, webfinger_handle = self.fetch_webfinger(
f"{self.username}@{actor_url_parts.hostname}" f"{self.username}@{actor_url_parts.hostname}"
) )
if webfinger_handle: if webfinger_handle:
webfinger_username, webfinger_domain = webfinger_handle.split("@") webfinger_username, webfinger_domain = webfinger_handle.split("@")
self.username = webfinger_username self.username = webfinger_username
self.domain = await get_domain(webfinger_domain) self.domain = Domain.get_remote_domain(webfinger_domain)
else: else:
self.domain = await get_domain(actor_url_parts.hostname) self.domain = Domain.get_remote_domain(actor_url_parts.hostname)
else: else:
self.domain = await get_domain(actor_url_parts.hostname) self.domain = Domain.get_remote_domain(actor_url_parts.hostname)
# Emojis (we need the domain so we do them here) # Emojis (we need the domain so we do them here)
for tag in get_list(document, "tag"): for tag in get_list(document, "tag"):
if tag["type"].lower() in ["toot:emoji", "emoji"]: if tag["type"].lower() in ["toot:emoji", "emoji"]:
await sync_to_async(Emoji.by_ap_tag)(self.domain, tag, create=True) Emoji.by_ap_tag(self.domain, tag, create=True)
# Mark as fetched # Mark as fetched
self.fetched = timezone.now() self.fetched = timezone.now()
try: try:
await sync_to_async(self.save)() self.save()
except IntegrityError as e: except IntegrityError as e:
# See if we can fetch a PK and save there # See if we can fetch a PK and save there
if self.pk is None: if self.pk is None:
try: try:
other_row = await Identity.objects.aget(actor_uri=self.actor_uri) other_row = Identity.objects.get(actor_uri=self.actor_uri)
except Identity.DoesNotExist: except Identity.DoesNotExist:
raise ValueError( raise ValueError(
f"Could not save Identity at end of actor fetch: {e}" f"Could not save Identity at end of actor fetch: {e}"
) )
self.pk: int | None = other_row.pk self.pk: int | None = other_row.pk
await sync_to_async(self.save)() self.save()
# Fetch pinned posts after identity has been fetched and saved # Fetch pinned posts after identity has been fetched and saved
if self.featured_collection_uri: if self.featured_collection_uri:
featured = await self.fetch_pinned_post_uris(self.featured_collection_uri) featured = self.fetch_pinned_post_uris(self.featured_collection_uri)
service = IdentityService(self) service = IdentityService(self)
await sync_to_async(service.sync_pins)(featured) service.sync_pins(featured)
return True return True
@ -1016,7 +1002,7 @@ class Identity(StatorModel):
### Cryptography ### ### Cryptography ###
async def signed_request( def signed_request(
self, self,
method: Literal["get", "post"], method: Literal["get", "post"],
uri: str, uri: str,
@ -1025,7 +1011,7 @@ class Identity(StatorModel):
""" """
Performs a signed request on behalf of the System Actor. Performs a signed request on behalf of the System Actor.
""" """
return await HttpSignature.signed_request( return HttpSignature.signed_request(
method=method, method=method,
uri=uri, uri=uri,
body=body, body=body,

View File

@ -2,7 +2,6 @@ from urllib.parse import urlparse
import httpx import httpx
import urlman import urlman
from asgiref.sync import sync_to_async
from django.conf import settings from django.conf import settings
from django.core.mail import EmailMultiAlternatives from django.core.mail import EmailMultiAlternatives
from django.db import models from django.db import models
@ -22,26 +21,25 @@ class ReportStates(StateGraph):
new.transitions_to(sent) new.transitions_to(sent)
@classmethod @classmethod
async def handle_new(cls, instance: "Report"): def handle_new(cls, instance: "Report"):
""" """
Sends the report to the remote server if we need to Sends the report to the remote server if we need to
""" """
from users.models import SystemActor, User from users.models import SystemActor, User
recipients = [] recipients = []
report = await instance.afetch_full() for mod in User.objects.filter(
async for mod in User.objects.filter(
models.Q(moderator=True) | models.Q(admin=True) models.Q(moderator=True) | models.Q(admin=True)
).values_list("email", flat=True): ).values_list("email", flat=True):
recipients.append(mod) recipients.append(mod)
if report.forward and not report.subject_identity.domain.local: if instance.forward and not instance.subject_identity.domain.local:
system_actor = SystemActor() system_actor = SystemActor()
try: try:
await system_actor.signed_request( system_actor.signed_request(
method="post", method="post",
uri=report.subject_identity.inbox_uri, uri=instance.subject_identity.inbox_uri,
body=canonicalise(report.to_ap()), body=canonicalise(instance.to_ap()),
) )
except httpx.RequestError: except httpx.RequestError:
pass pass
@ -50,7 +48,7 @@ class ReportStates(StateGraph):
body=render_to_string( body=render_to_string(
"emails/report_new.txt", "emails/report_new.txt",
{ {
"report": report, "report": instance,
"config": Config.system, "config": Config.system,
"settings": settings, "settings": settings,
}, },
@ -62,14 +60,14 @@ class ReportStates(StateGraph):
content=render_to_string( content=render_to_string(
"emails/report_new.html", "emails/report_new.html",
{ {
"report": report, "report": instance,
"config": Config.system, "config": Config.system,
"settings": settings, "settings": settings,
}, },
), ),
mimetype="text/html", mimetype="text/html",
) )
await sync_to_async(email.send)() email.send()
return cls.sent return cls.sent
@ -145,15 +143,6 @@ class Report(StatorModel):
### ActivityPub ### ### ActivityPub ###
async def afetch_full(self) -> "Report":
return await Report.objects.select_related(
"source_identity",
"source_domain",
"subject_identity__domain",
"subject_identity",
"subject_post",
).aget(pk=self.pk)
@classmethod @classmethod
def handle_ap(cls, data): def handle_ap(cls, data):
""" """

View File

@ -79,7 +79,7 @@ class SystemActor:
], ],
} }
async def signed_request( def signed_request(
self, self,
method: Literal["get", "post"], method: Literal["get", "post"],
uri: str, uri: str,
@ -88,7 +88,7 @@ class SystemActor:
""" """
Performs a signed request on behalf of the System Actor. Performs a signed request on behalf of the System Actor.
""" """
return await HttpSignature.signed_request( return HttpSignature.signed_request(
method=method, method=method,
uri=uri, uri=uri,
body=body, body=body,

View File

@ -1,6 +1,5 @@
import json import json
from asgiref.sync import async_to_sync
from django.conf import settings from django.conf import settings
from django.http import Http404, HttpResponse, HttpResponseBadRequest, JsonResponse from django.http import Http404, HttpResponse, HttpResponseBadRequest, JsonResponse
from django.utils.decorators import method_decorator from django.utils.decorators import method_decorator
@ -140,7 +139,7 @@ class Inbox(View):
if not identity.public_key: if not identity.public_key:
# See if we can fetch it right now # See if we can fetch it right now
async_to_sync(identity.fetch_actor)() identity.fetch_actor()
if not identity.public_key: if not identity.public_key:
exceptions.capture_message( exceptions.capture_message(