Inbound and outbound follows basic working

This commit is contained in:
Andrew Godwin 2022-11-10 23:42:43 -07:00
parent 2c3a129970
commit fbfad9fbf5
13 changed files with 206 additions and 103 deletions

View File

@ -1,6 +1,6 @@
import base64 import base64
import json import json
from typing import Dict, List, Literal, TypedDict from typing import TYPE_CHECKING, Dict, List, Literal, TypedDict
from urllib.parse import urlparse from urllib.parse import urlparse
import httpx import httpx
@ -8,7 +8,9 @@ from cryptography.hazmat.primitives import hashes
from django.http import HttpRequest from django.http import HttpRequest
from django.utils.http import http_date from django.utils.http import http_date
from users.models import Identity # Prevent a circular import
if TYPE_CHECKING:
from users.models import Identity
class HttpSignature: class HttpSignature:
@ -73,7 +75,7 @@ class HttpSignature:
self, self,
uri: str, uri: str,
body: Dict, body: Dict,
identity: Identity, identity: "Identity",
content_type: str = "application/json", content_type: str = "application/json",
method: Literal["post"] = "post", method: Literal["post"] = "post",
): ):
@ -105,13 +107,17 @@ class HttpSignature:
del headers["(request-target)"] del headers["(request-target)"]
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
print(f"Calling {method} {uri}") print(f"Calling {method} {uri}")
print(body) response = await client.request(
return await client.request(
method, method,
uri, uri,
headers=headers, headers=headers,
content=body_bytes, content=body_bytes,
) )
if response.status_code >= 400:
raise ValueError(
f"Request error: {response.status_code} {response.content}"
)
return response
class SignatureDetails(TypedDict): class SignatureDetails(TypedDict):

View File

@ -41,6 +41,7 @@ class StateGraph:
initial_state = state initial_state = state
# Collect terminal states # Collect terminal states
if state.terminal: if state.terminal:
state.externally_progressed = True
terminal_states.add(state) terminal_states.add(state)
# Ensure they do NOT have a handler # Ensure they do NOT have a handler
try: try:
@ -52,10 +53,11 @@ class StateGraph:
f"Terminal state '{state}' should not have a handler method ({state.handler_name})" f"Terminal state '{state}' should not have a handler method ({state.handler_name})"
) )
else: else:
# Ensure non-terminal states have a try interval and a handler # Ensure non-terminal/manual states have a try interval and a handler
if not state.externally_progressed:
if not state.try_interval: if not state.try_interval:
raise ValueError( raise ValueError(
f"State '{state}' has no try_interval and is not terminal" f"State '{state}' has no try_interval and is not terminal or manual"
) )
try: try:
state.handler state.handler
@ -80,9 +82,11 @@ class State:
self, self,
try_interval: Optional[float] = None, try_interval: Optional[float] = None,
handler_name: Optional[str] = None, handler_name: Optional[str] = None,
externally_progressed: bool = False,
): ):
self.try_interval = try_interval self.try_interval = try_interval
self.handler_name = handler_name self.handler_name = handler_name
self.externally_progressed = externally_progressed
self.parents: Set["State"] = set() self.parents: Set["State"] = set()
self.children: Set["State"] = set() self.children: Set["State"] = set()
@ -118,6 +122,7 @@ class State:
@property @property
def handler(self) -> Callable[[Any], Optional[str]]: def handler(self) -> Callable[[Any], Optional[str]]:
# Retrieve it by name off the graph
if self.handler_name is None: if self.handler_name is None:
raise AttributeError("No handler defined") raise AttributeError("No handler defined")
return getattr(self.graph, self.handler_name) return getattr(self.graph, self.handler_name)

View File

@ -80,7 +80,7 @@ class StatorModel(models.Model):
q = models.Q() q = models.Q()
for state in cls.state_graph.states.values(): for state in cls.state_graph.states.values():
state = cast(State, state) state = cast(State, state)
if not state.terminal: if not state.externally_progressed:
q = q | models.Q( q = q | models.Q(
( (
models.Q( models.Q(
@ -135,17 +135,31 @@ class StatorModel(models.Model):
self.state_ready = True self.state_ready = True
self.save() self.save()
async def atransition_attempt(self) -> Optional[str]: async def atransition_attempt(self) -> Optional[State]:
""" """
Attempts to transition the current state by running its handler(s). Attempts to transition the current state by running its handler(s).
""" """
current_state = self.state_graph.states[self.state]
# If it's a manual progression state don't even try
# We shouldn't really be here in this case, but it could be a race condition
if current_state.externally_progressed:
print("Externally progressed state!")
return None
try: try:
next_state = await self.state_graph.states[self.state].handler(self) next_state = await current_state.handler(self)
except BaseException as e: except BaseException as e:
await StatorError.acreate_from_instance(self, e) await StatorError.acreate_from_instance(self, e)
traceback.print_exc() traceback.print_exc()
else: else:
if next_state: if next_state:
# Ensure it's a State object
if isinstance(next_state, str):
next_state = self.state_graph.states[next_state]
# Ensure it's a child
if next_state not in current_state.children:
raise ValueError(
f"Cannot transition from {current_state} to {next_state} - not a declared transition"
)
await self.atransition_perform(next_state) await self.atransition_perform(next_state)
return next_state return next_state
await self.__class__.objects.filter(pk=self.pk).aupdate( await self.__class__.objects.filter(pk=self.pk).aupdate(

View File

@ -50,9 +50,6 @@ class StatorRunner:
min(space_remaining, self.MAX_TASKS_PER_MODEL), min(space_remaining, self.MAX_TASKS_PER_MODEL),
timezone.now() + datetime.timedelta(seconds=self.LOCK_TIMEOUT), timezone.now() + datetime.timedelta(seconds=self.LOCK_TIMEOUT),
): ):
print(
f"Attempting transition on {instance._meta.label_lower}#{instance.pk}"
)
self.tasks.append( self.tasks.append(
asyncio.create_task(self.run_transition(instance)) asyncio.create_task(self.run_transition(instance))
) )
@ -76,6 +73,9 @@ class StatorRunner:
Wrapper for atransition_attempt with fallback error handling Wrapper for atransition_attempt with fallback error handling
""" """
try: try:
print(
f"Attempting transition on {instance._meta.label_lower}#{instance.pk} from state {instance.state}"
)
await instance.atransition_attempt() await instance.atransition_attempt()
except BaseException: except BaseException:
traceback.print_exc() traceback.print_exc()

View File

@ -27,6 +27,9 @@ class Status(models.Model):
updated = models.DateTimeField(auto_now=True) updated = models.DateTimeField(auto_now=True)
deleted = models.DateTimeField(null=True, blank=True) deleted = models.DateTimeField(null=True, blank=True)
class Meta:
verbose_name_plural = "statuses"
@classmethod @classmethod
def create_local(cls, identity, text: str): def create_local(cls, identity, text: str):
return cls.objects.create( return cls.objects.create(

View File

@ -1,6 +1,6 @@
from django.contrib import admin from django.contrib import admin
from users.models import Domain, Follow, Identity, User, UserEvent from users.models import Domain, Follow, Identity, InboxMessage, User, UserEvent
@admin.register(Domain) @admin.register(Domain)
@ -26,3 +26,14 @@ class IdentityAdmin(admin.ModelAdmin):
@admin.register(Follow) @admin.register(Follow)
class FollowAdmin(admin.ModelAdmin): class FollowAdmin(admin.ModelAdmin):
list_display = ["id", "source", "target", "state"] list_display = ["id", "source", "target", "state"]
@admin.register(InboxMessage)
class InboxMessageAdmin(admin.ModelAdmin):
list_display = ["id", "state", "message_type"]
actions = ["reset_state"]
@admin.action(description="Reset State")
def reset_state(self, request, queryset):
for instance in queryset:
instance.transition_perform("received")

View File

@ -81,3 +81,15 @@ class Domain(models.Model):
def __str__(self): def __str__(self):
return self.domain return self.domain
def save(self, *args, **kwargs):
# Ensure that we are not conflicting with other domains
if Domain.objects.filter(service_domain=self.domain).exists():
raise ValueError(
f"Domain {self.domain} is already a service domain elsewhere!"
)
if self.service_domain:
if Domain.objects.filter(domain=self.service_domain).exists():
raise ValueError(
f"Service domain {self.service_domain} is already a domain elsewhere!"
)

View File

@ -2,24 +2,110 @@ from typing import Optional
from django.db import models from django.db import models
from core.ld import canonicalise
from core.signatures import HttpSignature
from stator.models import State, StateField, StateGraph, StatorModel from stator.models import State, StateField, StateGraph, StatorModel
class FollowStates(StateGraph): class FollowStates(StateGraph):
unrequested = State(try_interval=30) unrequested = State(try_interval=30)
requested = State(try_interval=24 * 60 * 60) local_requested = State(try_interval=24 * 60 * 60)
accepted = State() remote_requested = State(try_interval=24 * 60 * 60)
accepted = State(externally_progressed=True)
undone_locally = State(try_interval=60 * 60)
undone_remotely = State()
unrequested.transitions_to(requested) unrequested.transitions_to(local_requested)
requested.transitions_to(accepted) unrequested.transitions_to(remote_requested)
local_requested.transitions_to(accepted)
remote_requested.transitions_to(accepted)
accepted.transitions_to(undone_locally)
undone_locally.transitions_to(undone_remotely)
@classmethod @classmethod
async def handle_unrequested(cls, instance: "Follow"): async def handle_unrequested(cls, instance: "Follow"):
print("Would have tried to follow on", instance) # Re-retrieve the follow with more things linked
follow = await Follow.objects.select_related(
"source", "source__domain", "target"
).aget(pk=instance.pk)
# Remote follows should not be here
if not follow.source.local:
return cls.remote_requested
# Construct the request
request = canonicalise(
{
"@context": "https://www.w3.org/ns/activitystreams",
"id": follow.uri,
"type": "Follow",
"actor": follow.source.actor_uri,
"object": follow.target.actor_uri,
}
)
# Sign it and send it
await HttpSignature.signed_request(
follow.target.inbox_uri, request, follow.source
)
return cls.local_requested
@classmethod @classmethod
async def handle_requested(cls, instance: "Follow"): async def handle_local_requested(cls, instance: "Follow"):
print("Would have tried to requested on", instance) # TODO: Resend follow requests occasionally
pass
@classmethod
async def handle_remote_requested(cls, instance: "Follow"):
# Re-retrieve the follow with more things linked
follow = await Follow.objects.select_related(
"source", "source__domain", "target"
).aget(pk=instance.pk)
# Send an accept
request = canonicalise(
{
"@context": "https://www.w3.org/ns/activitystreams",
"id": follow.target.actor_uri + f"follow/{follow.pk}/#accept",
"type": "Follow",
"actor": follow.source.actor_uri,
"object": {
"id": follow.uri,
"type": "Follow",
"actor": follow.source.actor_uri,
"object": follow.target.actor_uri,
},
}
)
# Sign it and send it
await HttpSignature.signed_request(
follow.source.inbox_uri,
request,
identity=follow.target,
)
return cls.accepted
@classmethod
async def handle_undone_locally(cls, instance: "Follow"):
follow = Follow.objects.select_related(
"source", "source__domain", "target"
).get(pk=instance.pk)
# Construct the request
request = canonicalise(
{
"@context": "https://www.w3.org/ns/activitystreams",
"id": follow.uri + "#undo",
"type": "Undo",
"actor": follow.source.actor_uri,
"object": {
"id": follow.uri,
"type": "Follow",
"actor": follow.source.actor_uri,
"object": follow.target.actor_uri,
},
}
)
# Sign it and send it
await HttpSignature.signed_request(
follow.target.inbox_uri, request, follow.source
)
return cls.undone_remotely
class Follow(StatorModel): class Follow(StatorModel):
@ -83,11 +169,17 @@ class Follow(StatorModel):
follow = cls.maybe_get(source=source, target=target) follow = cls.maybe_get(source=source, target=target)
if follow is None: if follow is None:
follow = Follow.objects.create(source=source, target=target, uri=uri) follow = Follow.objects.create(source=source, target=target, uri=uri)
if follow.state == FollowStates.fresh: if follow.state == FollowStates.unrequested:
follow.transition_perform(FollowStates.requested) follow.transition_perform(FollowStates.remote_requested)
@classmethod @classmethod
def remote_accepted(cls, source, target): def remote_accepted(cls, source, target):
print(f"accepted follow source {source} target {target}")
follow = cls.maybe_get(source=source, target=target) follow = cls.maybe_get(source=source, target=target)
if follow and follow.state == FollowStates.requested: print(f"accepting follow {follow}")
if follow and follow.state in [
FollowStates.unrequested,
FollowStates.local_requested,
]:
follow.transition_perform(FollowStates.accepted) follow.transition_perform(FollowStates.accepted)
print("accepted")

View File

@ -13,7 +13,7 @@ class InboxMessageStates(StateGraph):
@classmethod @classmethod
async def handle_received(cls, instance: "InboxMessage"): async def handle_received(cls, instance: "InboxMessage"):
type = instance.message["type"].lower() type = instance.message_type
if type == "follow": if type == "follow":
await instance.follow_request() await instance.follow_request()
elif type == "accept": elif type == "accept":
@ -30,6 +30,7 @@ class InboxMessageStates(StateGraph):
raise ValueError(f"Cannot handle activity of type undo.{inner_type}") raise ValueError(f"Cannot handle activity of type undo.{inner_type}")
else: else:
raise ValueError(f"Cannot handle activity of type {type}") raise ValueError(f"Cannot handle activity of type {type}")
return cls.processed
class InboxMessage(StatorModel): class InboxMessage(StatorModel):
@ -60,10 +61,17 @@ class InboxMessage(StatorModel):
""" """
Handles an incoming acceptance of one of our follow requests Handles an incoming acceptance of one of our follow requests
""" """
Follow.remote_accepted( target = Identity.by_actor_uri_with_create(self.message["actor"])
source=Identity.by_actor_uri_with_create(self.message["actor"]), source = Identity.by_actor_uri(self.message["object"]["actor"])
target=Identity.by_actor_uri(self.message["object"]), if source is None:
raise ValueError(
f"Follow-Accept has invalid source {self.message['object']['actor']}"
) )
Follow.remote_accepted(source=source, target=target)
@property
def message_type(self):
return self.message["type"].lower()
async def follow_undo(self): async def follow_undo(self):
""" """

View File

@ -19,7 +19,10 @@ def by_handle_or_404(request, handle, local=True, fetch=False) -> Identity:
else: else:
username, domain = handle.split("@", 1) username, domain = handle.split("@", 1)
# Resolve the domain to the display domain # Resolve the domain to the display domain
domain = Domain.get_remote_domain(domain).domain domain_instance = Domain.get_domain(domain)
if domain_instance is None:
domain_instance = Domain.get_remote_domain(domain)
domain = domain_instance.domain
identity = Identity.by_username_and_domain( identity = Identity.by_username_and_domain(
username, username,
domain, domain,

View File

@ -1,62 +0,0 @@
from core.ld import canonicalise
from core.signatures import HttpSignature
from users.models import Follow
async def handle_follow_request(task_handler):
"""
Request a follow from a remote server
"""
follow = await Follow.objects.select_related(
"source", "source__domain", "target"
).aget(pk=task_handler.subject)
# Construct the request
request = canonicalise(
{
"@context": "https://www.w3.org/ns/activitystreams",
"id": follow.uri,
"type": "Follow",
"actor": follow.source.actor_uri,
"object": follow.target.actor_uri,
}
)
# Sign it and send it
response = await HttpSignature.signed_request(
follow.target.inbox_uri, request, follow.source
)
if response.status_code >= 400:
raise ValueError(f"Request error: {response.status_code} {response.content}")
await Follow.objects.filter(pk=follow.pk).aupdate(requested=True)
def send_follow_undo(id):
"""
Request a follow from a remote server
"""
follow = Follow.objects.select_related("source", "source__domain", "target").get(
pk=id
)
# Construct the request
request = canonicalise(
{
"@context": "https://www.w3.org/ns/activitystreams",
"id": follow.uri + "#undo",
"type": "Undo",
"actor": follow.source.actor_uri,
"object": {
"id": follow.uri,
"type": "Follow",
"actor": follow.source.actor_uri,
"object": follow.target.actor_uri,
},
}
)
# Sign it and send it
from asgiref.sync import async_to_sync
response = async_to_sync(HttpSignature.signed_request)(
follow.target.inbox_uri, request, follow.source
)
if response.status_code >= 400:
raise ValueError(f"Request error: {response.status_code} {response.content}")
print(response)

View File

@ -21,6 +21,10 @@ from users.models import Domain, Follow, Identity, IdentityStates, InboxMessage
from users.shortcuts import by_handle_or_404 from users.shortcuts import by_handle_or_404
class HttpResponseUnauthorized(HttpResponse):
status_code = 401
class ViewIdentity(TemplateView): class ViewIdentity(TemplateView):
template_name = "identity/view.html" template_name = "identity/view.html"
@ -188,20 +192,26 @@ class Inbox(View):
if "HTTP_DIGEST" in request.META: if "HTTP_DIGEST" in request.META:
expected_digest = HttpSignature.calculate_digest(request.body) expected_digest = HttpSignature.calculate_digest(request.body)
if request.META["HTTP_DIGEST"] != expected_digest: if request.META["HTTP_DIGEST"] != expected_digest:
print("Wrong digest")
return HttpResponseBadRequest("Digest is incorrect") return HttpResponseBadRequest("Digest is incorrect")
# Verify date header # Verify date header
if "HTTP_DATE" in request.META: if "HTTP_DATE" in request.META:
header_date = parse_http_date(request.META["HTTP_DATE"]) header_date = parse_http_date(request.META["HTTP_DATE"])
if abs(timezone.now().timestamp() - header_date) > 60: if abs(timezone.now().timestamp() - header_date) > 60:
print(
f"Date mismatch - they sent {header_date}, now is {timezone.now().timestamp()}"
)
return HttpResponseBadRequest("Date is too far away") return HttpResponseBadRequest("Date is too far away")
# Get the signature details # Get the signature details
if "HTTP_SIGNATURE" not in request.META: if "HTTP_SIGNATURE" not in request.META:
print("No signature")
return HttpResponseBadRequest("No signature present") return HttpResponseBadRequest("No signature present")
signature_details = HttpSignature.parse_signature( signature_details = HttpSignature.parse_signature(
request.META["HTTP_SIGNATURE"] request.META["HTTP_SIGNATURE"]
) )
# Reject unknown algorithms # Reject unknown algorithms
if signature_details["algorithm"] != "rsa-sha256": if signature_details["algorithm"] != "rsa-sha256":
print("Unknown sig algo")
return HttpResponseBadRequest("Unknown signature algorithm") return HttpResponseBadRequest("Unknown signature algorithm")
# Create the signature payload # Create the signature payload
headers_string = HttpSignature.headers_from_request( headers_string = HttpSignature.headers_from_request(
@ -217,13 +227,14 @@ class Inbox(View):
# See if we can fetch it right now # See if we can fetch it right now
async_to_sync(identity.fetch_actor)() async_to_sync(identity.fetch_actor)()
if not identity.public_key: if not identity.public_key:
print("Cannot get actor")
return HttpResponseBadRequest("Cannot retrieve actor") return HttpResponseBadRequest("Cannot retrieve actor")
if not identity.verify_signature( if not identity.verify_signature(
signature_details["signature"], headers_string signature_details["signature"], headers_string
): ):
return HttpResponseBadRequest("Bad signature") return HttpResponseUnauthorized("Bad signature")
# Hand off the item to be processed by the queue # Hand off the item to be processed by the queue
InboxMessage.objects.create(message=document) InboxMessage.objects.create(message=document, state_ready=True)
return HttpResponse(status=202) return HttpResponse(status=202)