Profile fetching now working on state machine

This commit is contained in:
Andrew Godwin 2022-11-09 23:48:31 -07:00
parent 7746abbbb7
commit 2c3a129970
21 changed files with 304 additions and 383 deletions

View File

@ -10,8 +10,7 @@ class DomainAdmin(admin.ModelAdmin):
"date", "date",
"model_label", "model_label",
"instance_pk", "instance_pk",
"from_state", "state",
"to_state",
"error", "error",
] ]
ordering = ["-date"] ordering = ["-date"]

View File

@ -1,16 +1,4 @@
from typing import ( from typing import Any, Callable, ClassVar, Dict, List, Optional, Set, Tuple, Type
Any,
Callable,
ClassVar,
Dict,
List,
Optional,
Set,
Tuple,
Type,
Union,
cast,
)
class StateGraph: class StateGraph:
@ -44,20 +32,43 @@ class StateGraph:
terminal_states = set() terminal_states = set()
initial_state = None initial_state = None
for state in cls.states.values(): for state in cls.states.values():
# Check for multiple initial states
if state.initial: if state.initial:
if initial_state: if initial_state:
raise ValueError( raise ValueError(
f"The graph has more than one initial state: {initial_state} and {state}" f"The graph has more than one initial state: {initial_state} and {state}"
) )
initial_state = state initial_state = state
# Collect terminal states
if state.terminal: if state.terminal:
terminal_states.add(state) terminal_states.add(state)
# Ensure they do NOT have a handler
try:
state.handler
except AttributeError:
pass
else:
raise ValueError(
f"Terminal state '{state}' should not have a handler method ({state.handler_name})"
)
else:
# Ensure non-terminal states have a try interval and a handler
if not state.try_interval:
raise ValueError(
f"State '{state}' has no try_interval and is not terminal"
)
try:
state.handler
except AttributeError:
raise ValueError(
f"State '{state}' does not have a handler method ({state.handler_name})"
)
if initial_state is None: if initial_state is None:
raise ValueError("The graph has no initial state") raise ValueError("The graph has no initial state")
cls.initial_state = initial_state cls.initial_state = initial_state
cls.terminal_states = terminal_states cls.terminal_states = terminal_states
# Generate choices # Generate choices
cls.choices = [(state, name) for name, state in cls.states.items()] cls.choices = [(name, name) for name in cls.states.keys()]
class State: class State:
@ -65,49 +76,37 @@ class State:
Represents an individual state Represents an individual state
""" """
def __init__(self, try_interval: float = 300): def __init__(
self,
try_interval: Optional[float] = None,
handler_name: Optional[str] = None,
):
self.try_interval = try_interval self.try_interval = try_interval
self.handler_name = handler_name
self.parents: Set["State"] = set() self.parents: Set["State"] = set()
self.children: Dict["State", "Transition"] = {} self.children: Set["State"] = set()
def _add_to_graph(self, graph: Type[StateGraph], name: str): def _add_to_graph(self, graph: Type[StateGraph], name: str):
self.graph = graph self.graph = graph
self.name = name self.name = name
self.graph.states[name] = self self.graph.states[name] = self
if self.handler_name is None:
self.handler_name = f"handle_{self.name}"
def __repr__(self): def __repr__(self):
return f"<State {self.name}>" return f"<State {self.name}>"
def __str__(self): def __eq__(self, other):
return self.name if isinstance(other, State):
return self is other
return self.name == other
def __len__(self): def __hash__(self):
return len(self.name) return hash(id(self))
def add_transition( def transitions_to(self, other: "State"):
self, self.children.add(other)
other: "State", other.parents.add(other)
handler: Optional[Callable] = None,
priority: int = 0,
) -> Callable:
def decorator(handler: Callable[[Any], bool]):
self.children[other] = Transition(
self,
other,
handler,
priority=priority,
)
other.parents.add(self)
return handler
# If we're not being called as a decorator, invoke it immediately
if handler is not None:
decorator(handler)
return decorator
def add_manual_transition(self, other: "State"):
self.children[other] = ManualTransition(self, other)
other.parents.add(self)
@property @property
def initial(self): def initial(self):
@ -117,59 +116,8 @@ class State:
def terminal(self): def terminal(self):
return not self.children return not self.children
def transitions(self, automatic_only=False) -> List["Transition"]: @property
""" def handler(self) -> Callable[[Any], Optional[str]]:
Returns all transitions from this State in priority order if self.handler_name is None:
""" raise AttributeError("No handler defined")
if automatic_only: return getattr(self.graph, self.handler_name)
transitions = [t for t in self.children.values() if t.automatic]
else:
transitions = list(self.children.values())
return sorted(transitions, key=lambda t: t.priority, reverse=True)
class Transition:
"""
A possible transition from one state to another
"""
def __init__(
self,
from_state: State,
to_state: State,
handler: Union[str, Callable],
priority: int = 0,
):
self.from_state = from_state
self.to_state = to_state
self.handler = handler
self.priority = priority
self.automatic = True
def get_handler(self) -> Callable:
"""
Returns the handler (it might need resolving from a string)
"""
if isinstance(self.handler, str):
self.handler = getattr(self.from_state.graph, self.handler)
return cast(Callable, self.handler)
def __repr__(self):
return f"<Transition {self.from_state} -> {self.to_state}>"
class ManualTransition(Transition):
"""
A possible transition from one state to another that cannot be done by
the stator task runner, and must come from an external source.
"""
def __init__(
self,
from_state: State,
to_state: State,
):
self.from_state = from_state
self.to_state = to_state
self.priority = 0
self.automatic = False

View File

@ -1,4 +1,4 @@
# Generated by Django 4.1.3 on 2022-11-10 03:24 # Generated by Django 4.1.3 on 2022-11-10 05:56
from django.db import migrations, models from django.db import migrations, models
@ -24,8 +24,7 @@ class Migration(migrations.Migration):
), ),
("model_label", models.CharField(max_length=200)), ("model_label", models.CharField(max_length=200)),
("instance_pk", models.CharField(max_length=200)), ("instance_pk", models.CharField(max_length=200)),
("from_state", models.CharField(max_length=200)), ("state", models.CharField(max_length=200)),
("to_state", models.CharField(max_length=200)),
("date", models.DateTimeField(auto_now_add=True)), ("date", models.DateTimeField(auto_now_add=True)),
("error", models.TextField()), ("error", models.TextField()),
("error_details", models.TextField(blank=True, null=True)), ("error_details", models.TextField(blank=True, null=True)),

View File

@ -1,13 +1,13 @@
import datetime import datetime
import traceback import traceback
from typing import ClassVar, List, Optional, Type, cast from typing import ClassVar, List, Optional, Type, Union, cast
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from django.db import models, transaction from django.db import models, transaction
from django.utils import timezone from django.utils import timezone
from django.utils.functional import classproperty from django.utils.functional import classproperty
from stator.graph import State, StateGraph, Transition from stator.graph import State, StateGraph
class StateField(models.CharField): class StateField(models.CharField):
@ -29,16 +29,6 @@ class StateField(models.CharField):
kwargs["graph"] = self.graph kwargs["graph"] = self.graph
return name, path, args, kwargs return name, path, args, kwargs
def from_db_value(self, value, expression, connection):
if value is None:
return value
return self.graph.states[value]
def to_python(self, value):
if isinstance(value, State) or value is None:
return value
return self.graph.states[value]
def get_prep_value(self, value): def get_prep_value(self, value):
if isinstance(value, State): if isinstance(value, State):
return value.name return value.name
@ -95,7 +85,9 @@ class StatorModel(models.Model):
( (
models.Q( models.Q(
state_attempted__lte=timezone.now() state_attempted__lte=timezone.now()
- datetime.timedelta(seconds=state.try_interval) - datetime.timedelta(
seconds=cast(float, state.try_interval)
)
) )
| models.Q(state_attempted__isnull=True) | models.Q(state_attempted__isnull=True)
), ),
@ -117,7 +109,7 @@ class StatorModel(models.Model):
].select_for_update() ].select_for_update()
) )
cls.objects.filter(pk__in=[i.pk for i in selected]).update( cls.objects.filter(pk__in=[i.pk for i in selected]).update(
state_locked_until=timezone.now() state_locked_until=lock_expiry
) )
return selected return selected
@ -143,36 +135,36 @@ class StatorModel(models.Model):
self.state_ready = True self.state_ready = True
self.save() self.save()
async def atransition_attempt(self) -> bool: async def atransition_attempt(self) -> Optional[str]:
""" """
Attempts to transition the current state by running its handler(s). Attempts to transition the current state by running its handler(s).
""" """
# Try each transition in priority order try:
for transition in self.state.transitions(automatic_only=True): next_state = await self.state_graph.states[self.state].handler(self)
try: except BaseException as e:
success = await transition.get_handler()(self) await StatorError.acreate_from_instance(self, e)
except BaseException as e: traceback.print_exc()
await StatorError.acreate_from_instance(self, transition, e) else:
traceback.print_exc() if next_state:
continue await self.atransition_perform(next_state)
if success: return next_state
await self.atransition_perform(transition.to_state.name)
return True
await self.__class__.objects.filter(pk=self.pk).aupdate( await self.__class__.objects.filter(pk=self.pk).aupdate(
state_attempted=timezone.now(), state_attempted=timezone.now(),
state_locked_until=None, state_locked_until=None,
state_ready=False, state_ready=False,
) )
return False return None
def transition_perform(self, state_name): def transition_perform(self, state: Union[State, str]):
""" """
Transitions the instance to the given state name, forcibly. Transitions the instance to the given state name, forcibly.
""" """
if state_name not in self.state_graph.states: if isinstance(state, State):
raise ValueError(f"Invalid state {state_name}") state = state.name
if state not in self.state_graph.states:
raise ValueError(f"Invalid state {state}")
self.__class__.objects.filter(pk=self.pk).update( self.__class__.objects.filter(pk=self.pk).update(
state=state_name, state=state,
state_changed=timezone.now(), state_changed=timezone.now(),
state_attempted=None, state_attempted=None,
state_locked_until=None, state_locked_until=None,
@ -194,11 +186,8 @@ class StatorError(models.Model):
# The primary key of that model (probably int or str) # The primary key of that model (probably int or str)
instance_pk = models.CharField(max_length=200) instance_pk = models.CharField(max_length=200)
# The state we moved from # The state we were on
from_state = models.CharField(max_length=200) state = models.CharField(max_length=200)
# The state we moved to (or tried to)
to_state = models.CharField(max_length=200)
# When it happened # When it happened
date = models.DateTimeField(auto_now_add=True) date = models.DateTimeField(auto_now_add=True)
@ -213,14 +202,12 @@ class StatorError(models.Model):
async def acreate_from_instance( async def acreate_from_instance(
cls, cls,
instance: StatorModel, instance: StatorModel,
transition: Transition,
exception: Optional[BaseException] = None, exception: Optional[BaseException] = None,
): ):
return await cls.objects.acreate( return await cls.objects.acreate(
model_label=instance._meta.label_lower, model_label=instance._meta.label_lower,
instance_pk=str(instance.pk), instance_pk=str(instance.pk),
from_state=transition.from_state, state=instance.state,
to_state=transition.to_state,
error=str(exception), error=str(exception),
error_details=traceback.format_exc(), error_details=traceback.format_exc(),
) )

View File

@ -1,6 +1,7 @@
import asyncio import asyncio
import datetime import datetime
import time import time
import traceback
import uuid import uuid
from typing import List, Type from typing import List, Type
@ -53,7 +54,7 @@ class StatorRunner:
f"Attempting transition on {instance._meta.label_lower}#{instance.pk}" f"Attempting transition on {instance._meta.label_lower}#{instance.pk}"
) )
self.tasks.append( self.tasks.append(
asyncio.create_task(instance.atransition_attempt()) asyncio.create_task(self.run_transition(instance))
) )
self.handled += 1 self.handled += 1
space_remaining -= 1 space_remaining -= 1
@ -70,5 +71,17 @@ class StatorRunner:
print("Complete") print("Complete")
return self.handled return self.handled
async def run_transition(self, instance: StatorModel):
"""
Wrapper for atransition_attempt with fallback error handling
"""
try:
await instance.atransition_attempt()
except BaseException:
traceback.print_exc()
def remove_completed_tasks(self): def remove_completed_tasks(self):
"""
Removes all completed asyncio.Tasks from our local in-progress list
"""
self.tasks = [t for t in self.tasks if not t.done()] self.tasks = [t for t in self.tasks if not t.done()]

View File

@ -9,39 +9,29 @@ def test_declare():
lookups. lookups.
""" """
fake_handler = lambda: True
class TestGraph(StateGraph): class TestGraph(StateGraph):
initial = State() initial = State(try_interval=3600)
second = State() second = State(try_interval=1)
third = State() third = State()
fourth = State()
final = State()
initial.add_transition(second, 60, handler=fake_handler) initial.transitions_to(second)
second.add_transition(third, 60, handler="check_third") second.transitions_to(third)
def check_third(cls): @classmethod
return True def handle_initial(cls):
pass
@third.add_transition(fourth, 60) @classmethod
def check_fourth(cls): def handle_second(cls):
return True pass
fourth.add_manual_transition(final)
assert TestGraph.initial_state == TestGraph.initial assert TestGraph.initial_state == TestGraph.initial
assert TestGraph.terminal_states == {TestGraph.final} assert TestGraph.terminal_states == {TestGraph.third}
assert TestGraph.initial.children[TestGraph.second].get_handler() == fake_handler assert TestGraph.initial.handler == TestGraph.handle_initial
assert ( assert TestGraph.initial.try_interval == 3600
TestGraph.second.children[TestGraph.third].get_handler() assert TestGraph.second.handler == TestGraph.handle_second
== TestGraph.check_third assert TestGraph.second.try_interval == 1
)
assert (
TestGraph.third.children[TestGraph.fourth].get_handler().__name__
== "check_fourth"
)
def test_bad_declarations(): def test_bad_declarations():
@ -62,5 +52,18 @@ def test_bad_declarations():
loop = State() loop = State()
loop2 = State() loop2 = State()
loop.add_transition(loop2, 1, handler="fake") loop.transitions_to(loop2)
loop2.add_transition(loop, 1, handler="fake") loop2.transitions_to(loop)
def test_state():
"""
Tests basic values of the State class
"""
class TestGraph(StateGraph):
initial = State()
assert "initial" == TestGraph.initial
assert TestGraph.initial == "initial"
assert TestGraph.initial == TestGraph.initial

View File

@ -1,4 +1,4 @@
# Generated by Django 4.1.3 on 2022-11-07 04:19 # Generated by Django 4.1.3 on 2022-11-10 05:58
import django.db.models.deletion import django.db.models.deletion
from django.db import migrations, models from django.db import migrations, models

View File

@ -20,7 +20,7 @@ class UserEventAdmin(admin.ModelAdmin):
@admin.register(Identity) @admin.register(Identity)
class IdentityAdmin(admin.ModelAdmin): class IdentityAdmin(admin.ModelAdmin):
list_display = ["id", "handle", "actor_uri", "name", "local"] list_display = ["id", "handle", "actor_uri", "state", "local"]
@admin.register(Follow) @admin.register(Follow)

View File

@ -1,4 +1,4 @@
# Generated by Django 4.1.3 on 2022-11-07 04:19 # Generated by Django 4.1.3 on 2022-11-10 05:58
import functools import functools
@ -6,7 +6,10 @@ import django.db.models.deletion
from django.conf import settings from django.conf import settings
from django.db import migrations, models from django.db import migrations, models
import stator.models
import users.models.follow
import users.models.identity import users.models.identity
import users.models.inbox_message
class Migration(migrations.Migration): class Migration(migrations.Migration):
@ -77,6 +80,37 @@ class Migration(migrations.Migration):
), ),
], ],
), ),
migrations.CreateModel(
name="InboxMessage",
fields=[
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("state_ready", models.BooleanField(default=False)),
("state_changed", models.DateTimeField(auto_now_add=True)),
("state_attempted", models.DateTimeField(blank=True, null=True)),
("state_locked_until", models.DateTimeField(blank=True, null=True)),
("message", models.JSONField()),
(
"state",
stator.models.StateField(
choices=[("received", "received"), ("processed", "processed")],
default="received",
graph=users.models.inbox_message.InboxMessageStates,
max_length=100,
),
),
],
options={
"abstract": False,
},
),
migrations.CreateModel( migrations.CreateModel(
name="UserEvent", name="UserEvent",
fields=[ fields=[
@ -124,7 +158,20 @@ class Migration(migrations.Migration):
verbose_name="ID", verbose_name="ID",
), ),
), ),
("state_ready", models.BooleanField(default=False)),
("state_changed", models.DateTimeField(auto_now_add=True)),
("state_attempted", models.DateTimeField(blank=True, null=True)),
("state_locked_until", models.DateTimeField(blank=True, null=True)),
("actor_uri", models.CharField(max_length=500, unique=True)), ("actor_uri", models.CharField(max_length=500, unique=True)),
(
"state",
stator.models.StateField(
choices=[("outdated", "outdated"), ("updated", "updated")],
default="outdated",
graph=users.models.identity.IdentityStates,
max_length=100,
),
),
("local", models.BooleanField()), ("local", models.BooleanField()),
("username", models.CharField(blank=True, max_length=500, null=True)), ("username", models.CharField(blank=True, max_length=500, null=True)),
("name", models.CharField(blank=True, max_length=500, null=True)), ("name", models.CharField(blank=True, max_length=500, null=True)),
@ -239,10 +286,25 @@ class Migration(migrations.Migration):
verbose_name="ID", verbose_name="ID",
), ),
), ),
("state_ready", models.BooleanField(default=False)),
("state_changed", models.DateTimeField(auto_now_add=True)),
("state_attempted", models.DateTimeField(blank=True, null=True)),
("state_locked_until", models.DateTimeField(blank=True, null=True)),
("uri", models.CharField(blank=True, max_length=500, null=True)), ("uri", models.CharField(blank=True, max_length=500, null=True)),
("note", models.TextField(blank=True, null=True)), ("note", models.TextField(blank=True, null=True)),
("requested", models.BooleanField(default=False)), (
("accepted", models.BooleanField(default=False)), "state",
stator.models.StateField(
choices=[
("pending", "pending"),
("requested", "requested"),
("accepted", "accepted"),
],
default="pending",
graph=users.models.follow.FollowStates,
max_length=100,
),
),
("created", models.DateTimeField(auto_now_add=True)), ("created", models.DateTimeField(auto_now_add=True)),
("updated", models.DateTimeField(auto_now=True)), ("updated", models.DateTimeField(auto_now=True)),
( (

View File

@ -1,44 +0,0 @@
# Generated by Django 4.1.3 on 2022-11-07 19:22
import django.utils.timezone
from django.db import migrations, models
import stator.models
import users.models.follow
class Migration(migrations.Migration):
dependencies = [
("users", "0001_initial"),
]
operations = [
migrations.AddField(
model_name="follow",
name="state",
field=stator.models.StateField(
choices=[
("pending", "pending"),
("requested", "requested"),
("accepted", "accepted"),
],
default="pending",
graph=users.models.follow.FollowStates,
max_length=100,
),
),
migrations.AddField(
model_name="follow",
name="state_attempted",
field=models.DateTimeField(blank=True, null=True),
),
migrations.AddField(
model_name="follow",
name="state_changed",
field=models.DateTimeField(
auto_now_add=True, default=django.utils.timezone.now
),
preserve_default=False,
),
]

View File

@ -1,31 +0,0 @@
# Generated by Django 4.1.3 on 2022-11-08 03:58
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("users", "0002_follow_state_follow_state_attempted_and_more"),
]
operations = [
migrations.RemoveField(
model_name="follow",
name="accepted",
),
migrations.RemoveField(
model_name="follow",
name="requested",
),
migrations.AddField(
model_name="follow",
name="state_locked",
field=models.DateTimeField(blank=True, null=True),
),
migrations.AddField(
model_name="follow",
name="state_runner",
field=models.CharField(blank=True, max_length=100, null=True),
),
]

View File

@ -1,21 +0,0 @@
# Generated by Django 4.1.3 on 2022-11-09 05:15
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("users", "0003_remove_follow_accepted_remove_follow_requested_and_more"),
]
operations = [
migrations.RemoveField(
model_name="follow",
name="state_locked",
),
migrations.RemoveField(
model_name="follow",
name="state_runner",
),
]

View File

@ -1,23 +0,0 @@
# Generated by Django 4.1.3 on 2022-11-10 03:24
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("users", "0004_remove_follow_state_locked_and_more"),
]
operations = [
migrations.AddField(
model_name="follow",
name="state_locked_until",
field=models.DateTimeField(blank=True, null=True),
),
migrations.AddField(
model_name="follow",
name="state_ready",
field=models.BooleanField(default=False),
),
]

View File

@ -2,5 +2,6 @@ from .block import Block # noqa
from .domain import Domain # noqa from .domain import Domain # noqa
from .follow import Follow, FollowStates # noqa from .follow import Follow, FollowStates # noqa
from .identity import Identity, IdentityStates # noqa from .identity import Identity, IdentityStates # noqa
from .inbox_message import InboxMessage, InboxMessageStates # noqa
from .user import User # noqa from .user import User # noqa
from .user_event import UserEvent # noqa from .user_event import UserEvent # noqa

View File

@ -6,16 +6,20 @@ from stator.models import State, StateField, StateGraph, StatorModel
class FollowStates(StateGraph): class FollowStates(StateGraph):
pending = State(try_interval=30) unrequested = State(try_interval=30)
requested = State() requested = State(try_interval=24 * 60 * 60)
accepted = State() accepted = State()
@pending.add_transition(requested) unrequested.transitions_to(requested)
async def try_request(instance: "Follow"): # type:ignore requested.transitions_to(accepted)
print("Would have tried to follow on", instance)
return False
requested.add_manual_transition(accepted) @classmethod
async def handle_unrequested(cls, instance: "Follow"):
print("Would have tried to follow on", instance)
@classmethod
async def handle_requested(cls, instance: "Follow"):
print("Would have tried to requested on", instance)
class Follow(StatorModel): class Follow(StatorModel):
@ -73,3 +77,17 @@ class Follow(StatorModel):
follow.state = FollowStates.accepted follow.state = FollowStates.accepted
follow.save() follow.save()
return follow return follow
@classmethod
def remote_created(cls, source, target, uri):
follow = cls.maybe_get(source=source, target=target)
if follow is None:
follow = Follow.objects.create(source=source, target=target, uri=uri)
if follow.state == FollowStates.fresh:
follow.transition_perform(FollowStates.requested)
@classmethod
def remote_accepted(cls, source, target):
follow = cls.maybe_get(source=source, target=target)
if follow and follow.state == FollowStates.requested:
follow.transition_perform(FollowStates.accepted)

View File

@ -22,11 +22,16 @@ class IdentityStates(StateGraph):
outdated = State(try_interval=3600) outdated = State(try_interval=3600)
updated = State() updated = State()
@outdated.add_transition(updated) outdated.transitions_to(updated)
async def fetch_identity(identity: "Identity"): # type:ignore
@classmethod
async def handle_outdated(cls, identity: "Identity"):
# Local identities never need fetching
if identity.local: if identity.local:
return True return "updated"
return await identity.fetch_actor() # Run the actor fetch and progress to updated if it succeeds
if await identity.fetch_actor():
return "updated"
def upload_namer(prefix, instance, filename): def upload_namer(prefix, instance, filename):

View File

@ -0,0 +1,71 @@
from asgiref.sync import sync_to_async
from django.db import models
from stator.models import State, StateField, StateGraph, StatorModel
from users.models import Follow, Identity
class InboxMessageStates(StateGraph):
received = State(try_interval=300)
processed = State()
received.transitions_to(processed)
@classmethod
async def handle_received(cls, instance: "InboxMessage"):
type = instance.message["type"].lower()
if type == "follow":
await instance.follow_request()
elif type == "accept":
inner_type = instance.message["object"]["type"].lower()
if inner_type == "follow":
await instance.follow_accepted()
else:
raise ValueError(f"Cannot handle activity of type accept.{inner_type}")
elif type == "undo":
inner_type = instance.message["object"]["type"].lower()
if inner_type == "follow":
await instance.follow_undo()
else:
raise ValueError(f"Cannot handle activity of type undo.{inner_type}")
else:
raise ValueError(f"Cannot handle activity of type {type}")
class InboxMessage(StatorModel):
"""
an incoming inbox message that needs processing.
Yes, this is kind of its own message queue built on the state graph system.
It's fine. It'll scale up to a decent point.
"""
message = models.JSONField()
state = StateField(InboxMessageStates)
@sync_to_async
def follow_request(self):
"""
Handles an incoming follow request
"""
Follow.remote_created(
source=Identity.by_actor_uri_with_create(self.message["actor"]),
target=Identity.by_actor_uri(self.message["object"]),
uri=self.message["id"],
)
@sync_to_async
def follow_accepted(self):
"""
Handles an incoming acceptance of one of our follow requests
"""
Follow.remote_accepted(
source=Identity.by_actor_uri_with_create(self.message["actor"]),
target=Identity.by_actor_uri(self.message["object"]),
)
async def follow_undo(self):
"""
Handles an incoming follow undo
"""

View File

@ -19,10 +19,7 @@ def by_handle_or_404(request, handle, local=True, fetch=False) -> Identity:
else: else:
username, domain = handle.split("@", 1) username, domain = handle.split("@", 1)
# Resolve the domain to the display domain # Resolve the domain to the display domain
domain_instance = Domain.get_domain(domain) domain = Domain.get_remote_domain(domain).domain
if domain_instance is None:
raise Http404("No matching domains found")
domain = domain_instance.domain
identity = Identity.by_username_and_domain( identity = Identity.by_username_and_domain(
username, username,
domain, domain,

View File

@ -1,11 +0,0 @@
from asgiref.sync import sync_to_async
from users.models import Identity
async def handle_identity_fetch(task_handler):
# Get the actor URI via webfinger
actor_uri, handle = await Identity.fetch_webfinger(task_handler.subject)
# Get or create the identity, then fetch
identity = await sync_to_async(Identity.by_actor_uri_with_create)(actor_uri)
await identity.fetch_actor()

View File

@ -1,56 +0,0 @@
from asgiref.sync import sync_to_async
from users.models import Follow, Identity
async def handle_inbox_item(task_handler):
type = task_handler.payload["type"].lower()
if type == "follow":
await inbox_follow(task_handler.payload)
elif type == "accept":
inner_type = task_handler.payload["object"]["type"].lower()
if inner_type == "follow":
await sync_to_async(accept_follow)(task_handler.payload["object"])
else:
raise ValueError(f"Cannot handle activity of type accept.{inner_type}")
elif type == "undo":
inner_type = task_handler.payload["object"]["type"].lower()
if inner_type == "follow":
await inbox_unfollow(task_handler.payload["object"])
else:
raise ValueError(f"Cannot handle activity of type undo.{inner_type}")
else:
raise ValueError(f"Cannot handle activity of type {inner_type}")
async def inbox_follow(payload):
"""
Handles an incoming follow request
"""
# TODO: Manually approved follows
source = Identity.by_actor_uri_with_create(payload["actor"])
target = Identity.by_actor_uri(payload["object"])
# See if this follow already exists
try:
follow = Follow.objects.get(source=source, target=target)
except Follow.DoesNotExist:
follow = Follow.objects.create(source=source, target=target, uri=payload["id"])
# See if we need to acknowledge it
if not follow.acknowledged:
pass
async def inbox_unfollow(payload):
pass
def accept_follow(payload):
"""
Another server has acknowledged our follow request
"""
source = Identity.by_actor_uri_with_create(payload["actor"])
target = Identity.by_actor_uri(payload["object"])
follow = Follow.maybe_get(source, target)
if follow:
follow.accepted = True
follow.save()

View File

@ -17,7 +17,7 @@ from core.forms import FormHelper
from core.ld import canonicalise from core.ld import canonicalise
from core.signatures import HttpSignature from core.signatures import HttpSignature
from users.decorators import identity_required from users.decorators import identity_required
from users.models import Domain, Follow, Identity, IdentityStates from users.models import Domain, Follow, Identity, IdentityStates, InboxMessage
from users.shortcuts import by_handle_or_404 from users.shortcuts import by_handle_or_404
@ -117,9 +117,13 @@ class CreateIdentity(FormView):
def clean(self): def clean(self):
# Check for existing users # Check for existing users
username = self.cleaned_data["username"] username = self.cleaned_data.get("username")
domain = self.cleaned_data["domain"] domain = self.cleaned_data.get("domain")
if Identity.objects.filter(username=username, domain=domain).exists(): if (
username
and domain
and Identity.objects.filter(username=username, domain=domain).exists()
):
raise forms.ValidationError(f"{username}@{domain} is already taken") raise forms.ValidationError(f"{username}@{domain} is already taken")
def get_form(self): def get_form(self):
@ -219,7 +223,7 @@ class Inbox(View):
): ):
return HttpResponseBadRequest("Bad signature") return HttpResponseBadRequest("Bad signature")
# Hand off the item to be processed by the queue # Hand off the item to be processed by the queue
Task.submit("inbox_item", subject=identity.actor_uri, payload=document) InboxMessage.objects.create(message=document)
return HttpResponse(status=202) return HttpResponse(status=202)