diff --git a/core/signatures.py b/core/signatures.py index 805ae91..27e7f7d 100644 --- a/core/signatures.py +++ b/core/signatures.py @@ -1,6 +1,6 @@ import base64 import json -from typing import Dict, List, Literal, TypedDict +from typing import TYPE_CHECKING, Dict, List, Literal, TypedDict from urllib.parse import urlparse import httpx @@ -8,7 +8,9 @@ from cryptography.hazmat.primitives import hashes from django.http import HttpRequest from django.utils.http import http_date -from users.models import Identity +# Prevent a circular import +if TYPE_CHECKING: + from users.models import Identity class HttpSignature: @@ -73,7 +75,7 @@ class HttpSignature: self, uri: str, body: Dict, - identity: Identity, + identity: "Identity", content_type: str = "application/json", method: Literal["post"] = "post", ): @@ -105,13 +107,17 @@ class HttpSignature: del headers["(request-target)"] async with httpx.AsyncClient() as client: print(f"Calling {method} {uri}") - print(body) - return await client.request( + response = await client.request( method, uri, headers=headers, content=body_bytes, ) + if response.status_code >= 400: + raise ValueError( + f"Request error: {response.status_code} {response.content}" + ) + return response class SignatureDetails(TypedDict): diff --git a/stator/graph.py b/stator/graph.py index 7a8455c..00ef1c4 100644 --- a/stator/graph.py +++ b/stator/graph.py @@ -41,6 +41,7 @@ class StateGraph: initial_state = state # Collect terminal states if state.terminal: + state.externally_progressed = True terminal_states.add(state) # Ensure they do NOT have a handler try: @@ -52,17 +53,18 @@ class StateGraph: f"Terminal state '{state}' should not have a handler method ({state.handler_name})" ) else: - # Ensure non-terminal states have a try interval and a handler - if not state.try_interval: - raise ValueError( - f"State '{state}' has no try_interval and is not terminal" - ) - try: - state.handler - except AttributeError: - raise ValueError( - f"State '{state}' does not have a handler method ({state.handler_name})" - ) + # Ensure non-terminal/manual states have a try interval and a handler + if not state.externally_progressed: + if not state.try_interval: + raise ValueError( + f"State '{state}' has no try_interval and is not terminal or manual" + ) + try: + state.handler + except AttributeError: + raise ValueError( + f"State '{state}' does not have a handler method ({state.handler_name})" + ) if initial_state is None: raise ValueError("The graph has no initial state") cls.initial_state = initial_state @@ -80,9 +82,11 @@ class State: self, try_interval: Optional[float] = None, handler_name: Optional[str] = None, + externally_progressed: bool = False, ): self.try_interval = try_interval self.handler_name = handler_name + self.externally_progressed = externally_progressed self.parents: Set["State"] = set() self.children: Set["State"] = set() @@ -118,6 +122,7 @@ class State: @property def handler(self) -> Callable[[Any], Optional[str]]: + # Retrieve it by name off the graph if self.handler_name is None: raise AttributeError("No handler defined") return getattr(self.graph, self.handler_name) diff --git a/stator/models.py b/stator/models.py index 50ee622..072a3ed 100644 --- a/stator/models.py +++ b/stator/models.py @@ -80,7 +80,7 @@ class StatorModel(models.Model): q = models.Q() for state in cls.state_graph.states.values(): state = cast(State, state) - if not state.terminal: + if not state.externally_progressed: q = q | models.Q( ( models.Q( @@ -135,17 +135,31 @@ class StatorModel(models.Model): self.state_ready = True self.save() - async def atransition_attempt(self) -> Optional[str]: + async def atransition_attempt(self) -> Optional[State]: """ Attempts to transition the current state by running its handler(s). """ + current_state = self.state_graph.states[self.state] + # If it's a manual progression state don't even try + # We shouldn't really be here in this case, but it could be a race condition + if current_state.externally_progressed: + print("Externally progressed state!") + return None try: - next_state = await self.state_graph.states[self.state].handler(self) + next_state = await current_state.handler(self) except BaseException as e: await StatorError.acreate_from_instance(self, e) traceback.print_exc() else: if next_state: + # Ensure it's a State object + if isinstance(next_state, str): + next_state = self.state_graph.states[next_state] + # Ensure it's a child + if next_state not in current_state.children: + raise ValueError( + f"Cannot transition from {current_state} to {next_state} - not a declared transition" + ) await self.atransition_perform(next_state) return next_state await self.__class__.objects.filter(pk=self.pk).aupdate( diff --git a/stator/runner.py b/stator/runner.py index 1392e4d..0b42b27 100644 --- a/stator/runner.py +++ b/stator/runner.py @@ -50,9 +50,6 @@ class StatorRunner: min(space_remaining, self.MAX_TASKS_PER_MODEL), timezone.now() + datetime.timedelta(seconds=self.LOCK_TIMEOUT), ): - print( - f"Attempting transition on {instance._meta.label_lower}#{instance.pk}" - ) self.tasks.append( asyncio.create_task(self.run_transition(instance)) ) @@ -76,6 +73,9 @@ class StatorRunner: Wrapper for atransition_attempt with fallback error handling """ try: + print( + f"Attempting transition on {instance._meta.label_lower}#{instance.pk} from state {instance.state}" + ) await instance.atransition_attempt() except BaseException: traceback.print_exc() diff --git a/statuses/models/status.py b/statuses/models/status.py index bfc8eb9..b12a595 100644 --- a/statuses/models/status.py +++ b/statuses/models/status.py @@ -27,6 +27,9 @@ class Status(models.Model): updated = models.DateTimeField(auto_now=True) deleted = models.DateTimeField(null=True, blank=True) + class Meta: + verbose_name_plural = "statuses" + @classmethod def create_local(cls, identity, text: str): return cls.objects.create( diff --git a/users/admin.py b/users/admin.py index f2b807c..d8f2931 100644 --- a/users/admin.py +++ b/users/admin.py @@ -1,6 +1,6 @@ from django.contrib import admin -from users.models import Domain, Follow, Identity, User, UserEvent +from users.models import Domain, Follow, Identity, InboxMessage, User, UserEvent @admin.register(Domain) @@ -26,3 +26,14 @@ class IdentityAdmin(admin.ModelAdmin): @admin.register(Follow) class FollowAdmin(admin.ModelAdmin): list_display = ["id", "source", "target", "state"] + + +@admin.register(InboxMessage) +class InboxMessageAdmin(admin.ModelAdmin): + list_display = ["id", "state", "message_type"] + actions = ["reset_state"] + + @admin.action(description="Reset State") + def reset_state(self, request, queryset): + for instance in queryset: + instance.transition_perform("received") diff --git a/users/models/domain.py b/users/models/domain.py index 4ac6ee9..a3815ee 100644 --- a/users/models/domain.py +++ b/users/models/domain.py @@ -81,3 +81,15 @@ class Domain(models.Model): def __str__(self): return self.domain + + def save(self, *args, **kwargs): + # Ensure that we are not conflicting with other domains + if Domain.objects.filter(service_domain=self.domain).exists(): + raise ValueError( + f"Domain {self.domain} is already a service domain elsewhere!" + ) + if self.service_domain: + if Domain.objects.filter(domain=self.service_domain).exists(): + raise ValueError( + f"Service domain {self.service_domain} is already a domain elsewhere!" + ) diff --git a/users/models/follow.py b/users/models/follow.py index 6f62481..94ad40f 100644 --- a/users/models/follow.py +++ b/users/models/follow.py @@ -2,24 +2,110 @@ from typing import Optional from django.db import models +from core.ld import canonicalise +from core.signatures import HttpSignature from stator.models import State, StateField, StateGraph, StatorModel class FollowStates(StateGraph): unrequested = State(try_interval=30) - requested = State(try_interval=24 * 60 * 60) - accepted = State() + local_requested = State(try_interval=24 * 60 * 60) + remote_requested = State(try_interval=24 * 60 * 60) + accepted = State(externally_progressed=True) + undone_locally = State(try_interval=60 * 60) + undone_remotely = State() - unrequested.transitions_to(requested) - requested.transitions_to(accepted) + unrequested.transitions_to(local_requested) + unrequested.transitions_to(remote_requested) + local_requested.transitions_to(accepted) + remote_requested.transitions_to(accepted) + accepted.transitions_to(undone_locally) + undone_locally.transitions_to(undone_remotely) @classmethod async def handle_unrequested(cls, instance: "Follow"): - print("Would have tried to follow on", instance) + # Re-retrieve the follow with more things linked + follow = await Follow.objects.select_related( + "source", "source__domain", "target" + ).aget(pk=instance.pk) + # Remote follows should not be here + if not follow.source.local: + return cls.remote_requested + # Construct the request + request = canonicalise( + { + "@context": "https://www.w3.org/ns/activitystreams", + "id": follow.uri, + "type": "Follow", + "actor": follow.source.actor_uri, + "object": follow.target.actor_uri, + } + ) + # Sign it and send it + await HttpSignature.signed_request( + follow.target.inbox_uri, request, follow.source + ) + return cls.local_requested @classmethod - async def handle_requested(cls, instance: "Follow"): - print("Would have tried to requested on", instance) + async def handle_local_requested(cls, instance: "Follow"): + # TODO: Resend follow requests occasionally + pass + + @classmethod + async def handle_remote_requested(cls, instance: "Follow"): + # Re-retrieve the follow with more things linked + follow = await Follow.objects.select_related( + "source", "source__domain", "target" + ).aget(pk=instance.pk) + # Send an accept + request = canonicalise( + { + "@context": "https://www.w3.org/ns/activitystreams", + "id": follow.target.actor_uri + f"follow/{follow.pk}/#accept", + "type": "Follow", + "actor": follow.source.actor_uri, + "object": { + "id": follow.uri, + "type": "Follow", + "actor": follow.source.actor_uri, + "object": follow.target.actor_uri, + }, + } + ) + # Sign it and send it + await HttpSignature.signed_request( + follow.source.inbox_uri, + request, + identity=follow.target, + ) + return cls.accepted + + @classmethod + async def handle_undone_locally(cls, instance: "Follow"): + follow = Follow.objects.select_related( + "source", "source__domain", "target" + ).get(pk=instance.pk) + # Construct the request + request = canonicalise( + { + "@context": "https://www.w3.org/ns/activitystreams", + "id": follow.uri + "#undo", + "type": "Undo", + "actor": follow.source.actor_uri, + "object": { + "id": follow.uri, + "type": "Follow", + "actor": follow.source.actor_uri, + "object": follow.target.actor_uri, + }, + } + ) + # Sign it and send it + await HttpSignature.signed_request( + follow.target.inbox_uri, request, follow.source + ) + return cls.undone_remotely class Follow(StatorModel): @@ -83,11 +169,17 @@ class Follow(StatorModel): follow = cls.maybe_get(source=source, target=target) if follow is None: follow = Follow.objects.create(source=source, target=target, uri=uri) - if follow.state == FollowStates.fresh: - follow.transition_perform(FollowStates.requested) + if follow.state == FollowStates.unrequested: + follow.transition_perform(FollowStates.remote_requested) @classmethod def remote_accepted(cls, source, target): + print(f"accepted follow source {source} target {target}") follow = cls.maybe_get(source=source, target=target) - if follow and follow.state == FollowStates.requested: + print(f"accepting follow {follow}") + if follow and follow.state in [ + FollowStates.unrequested, + FollowStates.local_requested, + ]: follow.transition_perform(FollowStates.accepted) + print("accepted") diff --git a/users/models/inbox_message.py b/users/models/inbox_message.py index 0dbdc3a..54b05e9 100644 --- a/users/models/inbox_message.py +++ b/users/models/inbox_message.py @@ -13,7 +13,7 @@ class InboxMessageStates(StateGraph): @classmethod async def handle_received(cls, instance: "InboxMessage"): - type = instance.message["type"].lower() + type = instance.message_type if type == "follow": await instance.follow_request() elif type == "accept": @@ -30,6 +30,7 @@ class InboxMessageStates(StateGraph): raise ValueError(f"Cannot handle activity of type undo.{inner_type}") else: raise ValueError(f"Cannot handle activity of type {type}") + return cls.processed class InboxMessage(StatorModel): @@ -60,10 +61,17 @@ class InboxMessage(StatorModel): """ Handles an incoming acceptance of one of our follow requests """ - Follow.remote_accepted( - source=Identity.by_actor_uri_with_create(self.message["actor"]), - target=Identity.by_actor_uri(self.message["object"]), - ) + target = Identity.by_actor_uri_with_create(self.message["actor"]) + source = Identity.by_actor_uri(self.message["object"]["actor"]) + if source is None: + raise ValueError( + f"Follow-Accept has invalid source {self.message['object']['actor']}" + ) + Follow.remote_accepted(source=source, target=target) + + @property + def message_type(self): + return self.message["type"].lower() async def follow_undo(self): """ diff --git a/users/shortcuts.py b/users/shortcuts.py index 8e20a09..0726218 100644 --- a/users/shortcuts.py +++ b/users/shortcuts.py @@ -19,7 +19,10 @@ def by_handle_or_404(request, handle, local=True, fetch=False) -> Identity: else: username, domain = handle.split("@", 1) # Resolve the domain to the display domain - domain = Domain.get_remote_domain(domain).domain + domain_instance = Domain.get_domain(domain) + if domain_instance is None: + domain_instance = Domain.get_remote_domain(domain) + domain = domain_instance.domain identity = Identity.by_username_and_domain( username, domain, diff --git a/users/tasks/__init__.py b/users/tasks/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/users/tasks/follow.py b/users/tasks/follow.py deleted file mode 100644 index 0f802cf..0000000 --- a/users/tasks/follow.py +++ /dev/null @@ -1,62 +0,0 @@ -from core.ld import canonicalise -from core.signatures import HttpSignature -from users.models import Follow - - -async def handle_follow_request(task_handler): - """ - Request a follow from a remote server - """ - follow = await Follow.objects.select_related( - "source", "source__domain", "target" - ).aget(pk=task_handler.subject) - # Construct the request - request = canonicalise( - { - "@context": "https://www.w3.org/ns/activitystreams", - "id": follow.uri, - "type": "Follow", - "actor": follow.source.actor_uri, - "object": follow.target.actor_uri, - } - ) - # Sign it and send it - response = await HttpSignature.signed_request( - follow.target.inbox_uri, request, follow.source - ) - if response.status_code >= 400: - raise ValueError(f"Request error: {response.status_code} {response.content}") - await Follow.objects.filter(pk=follow.pk).aupdate(requested=True) - - -def send_follow_undo(id): - """ - Request a follow from a remote server - """ - follow = Follow.objects.select_related("source", "source__domain", "target").get( - pk=id - ) - # Construct the request - request = canonicalise( - { - "@context": "https://www.w3.org/ns/activitystreams", - "id": follow.uri + "#undo", - "type": "Undo", - "actor": follow.source.actor_uri, - "object": { - "id": follow.uri, - "type": "Follow", - "actor": follow.source.actor_uri, - "object": follow.target.actor_uri, - }, - } - ) - # Sign it and send it - from asgiref.sync import async_to_sync - - response = async_to_sync(HttpSignature.signed_request)( - follow.target.inbox_uri, request, follow.source - ) - if response.status_code >= 400: - raise ValueError(f"Request error: {response.status_code} {response.content}") - print(response) diff --git a/users/views/identity.py b/users/views/identity.py index 3e69dae..0aed7fa 100644 --- a/users/views/identity.py +++ b/users/views/identity.py @@ -21,6 +21,10 @@ from users.models import Domain, Follow, Identity, IdentityStates, InboxMessage from users.shortcuts import by_handle_or_404 +class HttpResponseUnauthorized(HttpResponse): + status_code = 401 + + class ViewIdentity(TemplateView): template_name = "identity/view.html" @@ -188,20 +192,26 @@ class Inbox(View): if "HTTP_DIGEST" in request.META: expected_digest = HttpSignature.calculate_digest(request.body) if request.META["HTTP_DIGEST"] != expected_digest: + print("Wrong digest") return HttpResponseBadRequest("Digest is incorrect") # Verify date header if "HTTP_DATE" in request.META: header_date = parse_http_date(request.META["HTTP_DATE"]) if abs(timezone.now().timestamp() - header_date) > 60: + print( + f"Date mismatch - they sent {header_date}, now is {timezone.now().timestamp()}" + ) return HttpResponseBadRequest("Date is too far away") # Get the signature details if "HTTP_SIGNATURE" not in request.META: + print("No signature") return HttpResponseBadRequest("No signature present") signature_details = HttpSignature.parse_signature( request.META["HTTP_SIGNATURE"] ) # Reject unknown algorithms if signature_details["algorithm"] != "rsa-sha256": + print("Unknown sig algo") return HttpResponseBadRequest("Unknown signature algorithm") # Create the signature payload headers_string = HttpSignature.headers_from_request( @@ -217,13 +227,14 @@ class Inbox(View): # See if we can fetch it right now async_to_sync(identity.fetch_actor)() if not identity.public_key: + print("Cannot get actor") return HttpResponseBadRequest("Cannot retrieve actor") if not identity.verify_signature( signature_details["signature"], headers_string ): - return HttpResponseBadRequest("Bad signature") + return HttpResponseUnauthorized("Bad signature") # Hand off the item to be processed by the queue - InboxMessage.objects.create(message=document) + InboxMessage.objects.create(message=document, state_ready=True) return HttpResponse(status=202)