Most of the way through the stator refactor
This commit is contained in:
parent
61c324508e
commit
7746abbbb7
|
@ -1,8 +1,17 @@
|
||||||
from django.contrib import admin
|
from django.contrib import admin
|
||||||
|
|
||||||
from stator.models import StatorTask
|
from stator.models import StatorError
|
||||||
|
|
||||||
|
|
||||||
@admin.register(StatorTask)
|
@admin.register(StatorError)
|
||||||
class DomainAdmin(admin.ModelAdmin):
|
class DomainAdmin(admin.ModelAdmin):
|
||||||
list_display = ["id", "model_label", "instance_pk", "locked_until"]
|
list_display = [
|
||||||
|
"id",
|
||||||
|
"date",
|
||||||
|
"model_label",
|
||||||
|
"instance_pk",
|
||||||
|
"from_state",
|
||||||
|
"to_state",
|
||||||
|
"error",
|
||||||
|
]
|
||||||
|
ordering = ["-date"]
|
||||||
|
|
|
@ -1,9 +1,16 @@
|
||||||
import datetime
|
from typing import (
|
||||||
from functools import wraps
|
Any,
|
||||||
from typing import Callable, ClassVar, Dict, List, Optional, Set, Tuple, Union
|
Callable,
|
||||||
|
ClassVar,
|
||||||
from django.db import models
|
Dict,
|
||||||
from django.utils import timezone
|
List,
|
||||||
|
Optional,
|
||||||
|
Set,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class StateGraph:
|
class StateGraph:
|
||||||
|
@ -13,7 +20,7 @@ class StateGraph:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
states: ClassVar[Dict[str, "State"]]
|
states: ClassVar[Dict[str, "State"]]
|
||||||
choices: ClassVar[List[Tuple[str, str]]]
|
choices: ClassVar[List[Tuple[object, str]]]
|
||||||
initial_state: ClassVar["State"]
|
initial_state: ClassVar["State"]
|
||||||
terminal_states: ClassVar[Set["State"]]
|
terminal_states: ClassVar[Set["State"]]
|
||||||
|
|
||||||
|
@ -50,7 +57,7 @@ class StateGraph:
|
||||||
cls.initial_state = initial_state
|
cls.initial_state = initial_state
|
||||||
cls.terminal_states = terminal_states
|
cls.terminal_states = terminal_states
|
||||||
# Generate choices
|
# Generate choices
|
||||||
cls.choices = [(name, name) for name in cls.states.keys()]
|
cls.choices = [(state, name) for name, state in cls.states.items()]
|
||||||
|
|
||||||
|
|
||||||
class State:
|
class State:
|
||||||
|
@ -63,7 +70,7 @@ class State:
|
||||||
self.parents: Set["State"] = set()
|
self.parents: Set["State"] = set()
|
||||||
self.children: Dict["State", "Transition"] = {}
|
self.children: Dict["State", "Transition"] = {}
|
||||||
|
|
||||||
def _add_to_graph(self, graph: StateGraph, name: str):
|
def _add_to_graph(self, graph: Type[StateGraph], name: str):
|
||||||
self.graph = graph
|
self.graph = graph
|
||||||
self.name = name
|
self.name = name
|
||||||
self.graph.states[name] = self
|
self.graph.states[name] = self
|
||||||
|
@ -71,13 +78,19 @@ class State:
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<State {self.name}>"
|
return f"<State {self.name}>"
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.name)
|
||||||
|
|
||||||
def add_transition(
|
def add_transition(
|
||||||
self,
|
self,
|
||||||
other: "State",
|
other: "State",
|
||||||
handler: Optional[Union[str, Callable]] = None,
|
handler: Optional[Callable] = None,
|
||||||
priority: int = 0,
|
priority: int = 0,
|
||||||
) -> Callable:
|
) -> Callable:
|
||||||
def decorator(handler: Union[str, Callable]):
|
def decorator(handler: Callable[[Any], bool]):
|
||||||
self.children[other] = Transition(
|
self.children[other] = Transition(
|
||||||
self,
|
self,
|
||||||
other,
|
other,
|
||||||
|
@ -85,9 +98,7 @@ class State:
|
||||||
priority=priority,
|
priority=priority,
|
||||||
)
|
)
|
||||||
other.parents.add(self)
|
other.parents.add(self)
|
||||||
# All handlers should be class methods, so do that automatically.
|
return handler
|
||||||
if callable(handler):
|
|
||||||
return classmethod(handler)
|
|
||||||
|
|
||||||
# If we're not being called as a decorator, invoke it immediately
|
# If we're not being called as a decorator, invoke it immediately
|
||||||
if handler is not None:
|
if handler is not None:
|
||||||
|
@ -113,7 +124,7 @@ class State:
|
||||||
if automatic_only:
|
if automatic_only:
|
||||||
transitions = [t for t in self.children.values() if t.automatic]
|
transitions = [t for t in self.children.values() if t.automatic]
|
||||||
else:
|
else:
|
||||||
transitions = self.children.values()
|
transitions = list(self.children.values())
|
||||||
return sorted(transitions, key=lambda t: t.priority, reverse=True)
|
return sorted(transitions, key=lambda t: t.priority, reverse=True)
|
||||||
|
|
||||||
|
|
||||||
|
@ -141,7 +152,10 @@ class Transition:
|
||||||
"""
|
"""
|
||||||
if isinstance(self.handler, str):
|
if isinstance(self.handler, str):
|
||||||
self.handler = getattr(self.from_state.graph, self.handler)
|
self.handler = getattr(self.from_state.graph, self.handler)
|
||||||
return self.handler
|
return cast(Callable, self.handler)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<Transition {self.from_state} -> {self.to_state}>"
|
||||||
|
|
||||||
|
|
||||||
class ManualTransition(Transition):
|
class ManualTransition(Transition):
|
||||||
|
@ -157,6 +171,5 @@ class ManualTransition(Transition):
|
||||||
):
|
):
|
||||||
self.from_state = from_state
|
self.from_state = from_state
|
||||||
self.to_state = to_state
|
self.to_state = to_state
|
||||||
self.handler = None
|
|
||||||
self.priority = 0
|
self.priority = 0
|
||||||
self.automatic = False
|
self.automatic = False
|
||||||
|
|
|
@ -0,0 +1,28 @@
|
||||||
|
from typing import List, Type, cast
|
||||||
|
|
||||||
|
from asgiref.sync import async_to_sync
|
||||||
|
from django.apps import apps
|
||||||
|
from django.core.management.base import BaseCommand
|
||||||
|
|
||||||
|
from stator.models import StatorModel
|
||||||
|
from stator.runner import StatorRunner
|
||||||
|
|
||||||
|
|
||||||
|
class Command(BaseCommand):
|
||||||
|
help = "Runs a Stator runner for a short period"
|
||||||
|
|
||||||
|
def add_arguments(self, parser):
|
||||||
|
parser.add_argument("model_labels", nargs="*", type=str)
|
||||||
|
|
||||||
|
def handle(self, model_labels: List[str], *args, **options):
|
||||||
|
# Resolve the models list into names
|
||||||
|
models = cast(
|
||||||
|
List[Type[StatorModel]],
|
||||||
|
[apps.get_model(label) for label in model_labels],
|
||||||
|
)
|
||||||
|
if not models:
|
||||||
|
models = StatorModel.subclasses
|
||||||
|
print("Running for models: " + " ".join(m._meta.label_lower for m in models))
|
||||||
|
# Run a runner
|
||||||
|
runner = StatorRunner(models)
|
||||||
|
async_to_sync(runner.run)()
|
|
@ -1,4 +1,4 @@
|
||||||
# Generated by Django 4.1.3 on 2022-11-09 05:46
|
# Generated by Django 4.1.3 on 2022-11-10 03:24
|
||||||
|
|
||||||
from django.db import migrations, models
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ class Migration(migrations.Migration):
|
||||||
|
|
||||||
operations = [
|
operations = [
|
||||||
migrations.CreateModel(
|
migrations.CreateModel(
|
||||||
name="StatorTask",
|
name="StatorError",
|
||||||
fields=[
|
fields=[
|
||||||
(
|
(
|
||||||
"id",
|
"id",
|
||||||
|
@ -24,8 +24,11 @@ class Migration(migrations.Migration):
|
||||||
),
|
),
|
||||||
("model_label", models.CharField(max_length=200)),
|
("model_label", models.CharField(max_length=200)),
|
||||||
("instance_pk", models.CharField(max_length=200)),
|
("instance_pk", models.CharField(max_length=200)),
|
||||||
("locked_until", models.DateTimeField(blank=True, null=True)),
|
("from_state", models.CharField(max_length=200)),
|
||||||
("priority", models.IntegerField(default=0)),
|
("to_state", models.CharField(max_length=200)),
|
||||||
|
("date", models.DateTimeField(auto_now_add=True)),
|
||||||
|
("error", models.TextField()),
|
||||||
|
("error_details", models.TextField(blank=True, null=True)),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
201
stator/models.py
201
stator/models.py
|
@ -1,14 +1,13 @@
|
||||||
import datetime
|
import datetime
|
||||||
from functools import reduce
|
import traceback
|
||||||
from typing import Type, cast
|
from typing import ClassVar, List, Optional, Type, cast
|
||||||
|
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
from django.apps import apps
|
|
||||||
from django.db import models, transaction
|
from django.db import models, transaction
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
from django.utils.functional import classproperty
|
from django.utils.functional import classproperty
|
||||||
|
|
||||||
from stator.graph import State, StateGraph
|
from stator.graph import State, StateGraph, Transition
|
||||||
|
|
||||||
|
|
||||||
class StateField(models.CharField):
|
class StateField(models.CharField):
|
||||||
|
@ -55,6 +54,9 @@ class StatorModel(models.Model):
|
||||||
concrete model yourself.
|
concrete model yourself.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# If this row is up for transition attempts
|
||||||
|
state_ready = models.BooleanField(default=False)
|
||||||
|
|
||||||
# When the state last actually changed, or the date of instance creation
|
# When the state last actually changed, or the date of instance creation
|
||||||
state_changed = models.DateTimeField(auto_now_add=True)
|
state_changed = models.DateTimeField(auto_now_add=True)
|
||||||
|
|
||||||
|
@ -62,68 +64,128 @@ class StatorModel(models.Model):
|
||||||
# (and not successful, as this is cleared on transition)
|
# (and not successful, as this is cleared on transition)
|
||||||
state_attempted = models.DateTimeField(blank=True, null=True)
|
state_attempted = models.DateTimeField(blank=True, null=True)
|
||||||
|
|
||||||
|
# If a lock is out on this row, when it is locked until
|
||||||
|
# (we don't identify the lock owner, as there's no heartbeats)
|
||||||
|
state_locked_until = models.DateTimeField(null=True, blank=True)
|
||||||
|
|
||||||
|
# Collection of subclasses of us
|
||||||
|
subclasses: ClassVar[List[Type["StatorModel"]]] = []
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
abstract = True
|
abstract = True
|
||||||
|
|
||||||
@classmethod
|
def __init_subclass__(cls) -> None:
|
||||||
def schedule_overdue(cls, now=None) -> models.QuerySet:
|
if cls is not StatorModel:
|
||||||
"""
|
cls.subclasses.append(cls)
|
||||||
Finds instances of this model that need to run and schedule them.
|
|
||||||
"""
|
|
||||||
q = models.Q()
|
|
||||||
for transition in cls.state_graph.transitions(automatic_only=True):
|
|
||||||
q = q | transition.get_query(now=now)
|
|
||||||
return cls.objects.filter(q)
|
|
||||||
|
|
||||||
@classproperty
|
@classproperty
|
||||||
def state_graph(cls) -> Type[StateGraph]:
|
def state_graph(cls) -> Type[StateGraph]:
|
||||||
return cls._meta.get_field("state").graph
|
return cls._meta.get_field("state").graph
|
||||||
|
|
||||||
def schedule_transition(self, priority: int = 0):
|
@classmethod
|
||||||
|
async def atransition_schedule_due(cls, now=None) -> models.QuerySet:
|
||||||
|
"""
|
||||||
|
Finds instances of this model that need to run and schedule them.
|
||||||
|
"""
|
||||||
|
q = models.Q()
|
||||||
|
for state in cls.state_graph.states.values():
|
||||||
|
state = cast(State, state)
|
||||||
|
if not state.terminal:
|
||||||
|
q = q | models.Q(
|
||||||
|
(
|
||||||
|
models.Q(
|
||||||
|
state_attempted__lte=timezone.now()
|
||||||
|
- datetime.timedelta(seconds=state.try_interval)
|
||||||
|
)
|
||||||
|
| models.Q(state_attempted__isnull=True)
|
||||||
|
),
|
||||||
|
state=state.name,
|
||||||
|
)
|
||||||
|
await cls.objects.filter(q).aupdate(state_ready=True)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def transition_get_with_lock(
|
||||||
|
cls, number: int, lock_expiry: datetime.datetime
|
||||||
|
) -> List["StatorModel"]:
|
||||||
|
"""
|
||||||
|
Returns up to `number` tasks for execution, having locked them.
|
||||||
|
"""
|
||||||
|
with transaction.atomic():
|
||||||
|
selected = list(
|
||||||
|
cls.objects.filter(state_locked_until__isnull=True, state_ready=True)[
|
||||||
|
:number
|
||||||
|
].select_for_update()
|
||||||
|
)
|
||||||
|
cls.objects.filter(pk__in=[i.pk for i in selected]).update(
|
||||||
|
state_locked_until=timezone.now()
|
||||||
|
)
|
||||||
|
return selected
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def atransition_get_with_lock(
|
||||||
|
cls, number: int, lock_expiry: datetime.datetime
|
||||||
|
) -> List["StatorModel"]:
|
||||||
|
return await sync_to_async(cls.transition_get_with_lock)(number, lock_expiry)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def atransition_clean_locks(cls):
|
||||||
|
await cls.objects.filter(state_locked_until__lte=timezone.now()).aupdate(
|
||||||
|
state_locked_until=None
|
||||||
|
)
|
||||||
|
|
||||||
|
def transition_schedule(self):
|
||||||
"""
|
"""
|
||||||
Adds this instance to the queue to get its state transition attempted.
|
Adds this instance to the queue to get its state transition attempted.
|
||||||
|
|
||||||
The scheduler will call this, but you can also call it directly if you
|
The scheduler will call this, but you can also call it directly if you
|
||||||
know it'll be ready and want to lower latency.
|
know it'll be ready and want to lower latency.
|
||||||
"""
|
"""
|
||||||
StatorTask.schedule_for_execution(self, priority=priority)
|
self.state_ready = True
|
||||||
|
self.save()
|
||||||
|
|
||||||
async def attempt_transition(self):
|
async def atransition_attempt(self) -> bool:
|
||||||
"""
|
"""
|
||||||
Attempts to transition the current state by running its handler(s).
|
Attempts to transition the current state by running its handler(s).
|
||||||
"""
|
"""
|
||||||
# Try each transition in priority order
|
# Try each transition in priority order
|
||||||
for transition in self.state_graph.states[self.state].transitions(
|
for transition in self.state.transitions(automatic_only=True):
|
||||||
automatic_only=True
|
try:
|
||||||
):
|
|
||||||
success = await transition.get_handler()(self)
|
success = await transition.get_handler()(self)
|
||||||
|
except BaseException as e:
|
||||||
|
await StatorError.acreate_from_instance(self, transition, e)
|
||||||
|
traceback.print_exc()
|
||||||
|
continue
|
||||||
if success:
|
if success:
|
||||||
await self.perform_transition(transition.to_state.name)
|
await self.atransition_perform(transition.to_state.name)
|
||||||
return
|
return True
|
||||||
await self.__class__.objects.filter(pk=self.pk).aupdate(
|
await self.__class__.objects.filter(pk=self.pk).aupdate(
|
||||||
state_attempted=timezone.now()
|
state_attempted=timezone.now(),
|
||||||
|
state_locked_until=None,
|
||||||
|
state_ready=False,
|
||||||
)
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
async def perform_transition(self, state_name):
|
def transition_perform(self, state_name):
|
||||||
"""
|
"""
|
||||||
Transitions the instance to the given state name
|
Transitions the instance to the given state name, forcibly.
|
||||||
"""
|
"""
|
||||||
if state_name not in self.state_graph.states:
|
if state_name not in self.state_graph.states:
|
||||||
raise ValueError(f"Invalid state {state_name}")
|
raise ValueError(f"Invalid state {state_name}")
|
||||||
await self.__class__.objects.filter(pk=self.pk).aupdate(
|
self.__class__.objects.filter(pk=self.pk).update(
|
||||||
state=state_name,
|
state=state_name,
|
||||||
state_changed=timezone.now(),
|
state_changed=timezone.now(),
|
||||||
state_attempted=None,
|
state_attempted=None,
|
||||||
|
state_locked_until=None,
|
||||||
|
state_ready=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
atransition_perform = sync_to_async(transition_perform)
|
||||||
|
|
||||||
class StatorTask(models.Model):
|
|
||||||
|
class StatorError(models.Model):
|
||||||
"""
|
"""
|
||||||
The model that we use for an internal scheduling queue.
|
Tracks any errors running the transitions.
|
||||||
|
Meant to be cleaned out regularly. Should probably be a log.
|
||||||
Entries in this queue are up for checking and execution - it also performs
|
|
||||||
locking to ensure we get closer to exactly-once execution (but we err on
|
|
||||||
the side of at-least-once)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# appname.modelname (lowercased) label for the model this represents
|
# appname.modelname (lowercased) label for the model this represents
|
||||||
|
@ -132,60 +194,33 @@ class StatorTask(models.Model):
|
||||||
# The primary key of that model (probably int or str)
|
# The primary key of that model (probably int or str)
|
||||||
instance_pk = models.CharField(max_length=200)
|
instance_pk = models.CharField(max_length=200)
|
||||||
|
|
||||||
# Locking columns (no runner ID, as we have no heartbeats - all runners
|
# The state we moved from
|
||||||
# only live for a short amount of time anyway)
|
from_state = models.CharField(max_length=200)
|
||||||
locked_until = models.DateTimeField(null=True, blank=True)
|
|
||||||
|
|
||||||
# Basic total ordering priority - higher is more important
|
# The state we moved to (or tried to)
|
||||||
priority = models.IntegerField(default=0)
|
to_state = models.CharField(max_length=200)
|
||||||
|
|
||||||
def __str__(self):
|
# When it happened
|
||||||
return f"#{self.pk}: {self.model_label}.{self.instance_pk}"
|
date = models.DateTimeField(auto_now_add=True)
|
||||||
|
|
||||||
|
# Error name
|
||||||
|
error = models.TextField()
|
||||||
|
|
||||||
|
# Error details
|
||||||
|
error_details = models.TextField(blank=True, null=True)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def schedule_for_execution(cls, model_instance: StatorModel, priority: int = 0):
|
async def acreate_from_instance(
|
||||||
# We don't do a transaction here as it's fine to occasionally double up
|
cls,
|
||||||
model_label = model_instance._meta.label_lower
|
instance: StatorModel,
|
||||||
pk = model_instance.pk
|
transition: Transition,
|
||||||
# TODO: Increase priority of existing if present
|
exception: Optional[BaseException] = None,
|
||||||
if not cls.objects.filter(
|
):
|
||||||
model_label=model_label, instance_pk=pk, locked__isnull=True
|
return await cls.objects.acreate(
|
||||||
).exists():
|
model_label=instance._meta.label_lower,
|
||||||
StatorTask.objects.create(
|
instance_pk=str(instance.pk),
|
||||||
model_label=model_label,
|
from_state=transition.from_state,
|
||||||
instance_pk=pk,
|
to_state=transition.to_state,
|
||||||
priority=priority,
|
error=str(exception),
|
||||||
|
error_details=traceback.format_exc(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_for_execution(cls, number: int, lock_expiry: datetime.datetime):
|
|
||||||
"""
|
|
||||||
Returns up to `number` tasks for execution, having locked them.
|
|
||||||
"""
|
|
||||||
with transaction.atomic():
|
|
||||||
selected = list(
|
|
||||||
cls.objects.filter(locked_until__isnull=True)[
|
|
||||||
:number
|
|
||||||
].select_for_update()
|
|
||||||
)
|
|
||||||
cls.objects.filter(pk__in=[i.pk for i in selected]).update(
|
|
||||||
locked_until=timezone.now()
|
|
||||||
)
|
|
||||||
return selected
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def aget_for_execution(cls, number: int, lock_expiry: datetime.datetime):
|
|
||||||
return await sync_to_async(cls.get_for_execution)(number, lock_expiry)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def aclean_old_locks(cls):
|
|
||||||
await cls.objects.filter(locked_until__lte=timezone.now()).aupdate(
|
|
||||||
locked_until=None
|
|
||||||
)
|
|
||||||
|
|
||||||
async def aget_model_instance(self) -> StatorModel:
|
|
||||||
model = apps.get_model(self.model_label)
|
|
||||||
return cast(StatorModel, await model.objects.aget(pk=self.pk))
|
|
||||||
|
|
||||||
async def adelete(self):
|
|
||||||
self.__class__.objects.adelete(pk=self.pk)
|
|
||||||
|
|
|
@ -4,11 +4,9 @@ import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import List, Type
|
from typing import List, Type
|
||||||
|
|
||||||
from asgiref.sync import sync_to_async
|
|
||||||
from django.db import transaction
|
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
|
|
||||||
from stator.models import StatorModel, StatorTask
|
from stator.models import StatorModel
|
||||||
|
|
||||||
|
|
||||||
class StatorRunner:
|
class StatorRunner:
|
||||||
|
@ -22,6 +20,7 @@ class StatorRunner:
|
||||||
LOCK_TIMEOUT = 120
|
LOCK_TIMEOUT = 120
|
||||||
|
|
||||||
MAX_TASKS = 30
|
MAX_TASKS = 30
|
||||||
|
MAX_TASKS_PER_MODEL = 5
|
||||||
|
|
||||||
def __init__(self, models: List[Type[StatorModel]]):
|
def __init__(self, models: List[Type[StatorModel]]):
|
||||||
self.models = models
|
self.models = models
|
||||||
|
@ -32,38 +31,44 @@ class StatorRunner:
|
||||||
self.handled = 0
|
self.handled = 0
|
||||||
self.tasks = []
|
self.tasks = []
|
||||||
# Clean up old locks
|
# Clean up old locks
|
||||||
await StatorTask.aclean_old_locks()
|
print("Running initial cleaning and scheduling")
|
||||||
# Examine what needs scheduling
|
initial_tasks = []
|
||||||
|
for model in self.models:
|
||||||
|
initial_tasks.append(model.atransition_clean_locks())
|
||||||
|
initial_tasks.append(model.atransition_schedule_due())
|
||||||
|
await asyncio.gather(*initial_tasks)
|
||||||
# For the first time period, launch tasks
|
# For the first time period, launch tasks
|
||||||
|
print("Running main task loop")
|
||||||
while (time.monotonic() - start_time) < self.START_TIMEOUT:
|
while (time.monotonic() - start_time) < self.START_TIMEOUT:
|
||||||
self.remove_completed_tasks()
|
self.remove_completed_tasks()
|
||||||
space_remaining = self.MAX_TASKS - len(self.tasks)
|
space_remaining = self.MAX_TASKS - len(self.tasks)
|
||||||
# Fetch new tasks
|
# Fetch new tasks
|
||||||
|
for model in self.models:
|
||||||
if space_remaining > 0:
|
if space_remaining > 0:
|
||||||
for new_task in await StatorTask.aget_for_execution(
|
for instance in await model.atransition_get_with_lock(
|
||||||
space_remaining,
|
min(space_remaining, self.MAX_TASKS_PER_MODEL),
|
||||||
timezone.now() + datetime.timedelta(seconds=self.LOCK_TIMEOUT),
|
timezone.now() + datetime.timedelta(seconds=self.LOCK_TIMEOUT),
|
||||||
):
|
):
|
||||||
self.tasks.append(asyncio.create_task(self.run_task(new_task)))
|
print(
|
||||||
|
f"Attempting transition on {instance._meta.label_lower}#{instance.pk}"
|
||||||
|
)
|
||||||
|
self.tasks.append(
|
||||||
|
asyncio.create_task(instance.atransition_attempt())
|
||||||
|
)
|
||||||
self.handled += 1
|
self.handled += 1
|
||||||
|
space_remaining -= 1
|
||||||
# Prevent busylooping
|
# Prevent busylooping
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.1)
|
||||||
# Then wait for tasks to finish
|
# Then wait for tasks to finish
|
||||||
|
print("Waiting for tasks to complete")
|
||||||
while (time.monotonic() - start_time) < self.TOTAL_TIMEOUT:
|
while (time.monotonic() - start_time) < self.TOTAL_TIMEOUT:
|
||||||
self.remove_completed_tasks()
|
self.remove_completed_tasks()
|
||||||
if not self.tasks:
|
if not self.tasks:
|
||||||
break
|
break
|
||||||
# Prevent busylooping
|
# Prevent busylooping
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
print("Complete")
|
||||||
return self.handled
|
return self.handled
|
||||||
|
|
||||||
async def run_task(self, task: StatorTask):
|
|
||||||
# Resolve the model instance
|
|
||||||
model_instance = await task.aget_model_instance()
|
|
||||||
await model_instance.attempt_transition()
|
|
||||||
# Remove ourselves from the database as complete
|
|
||||||
await task.adelete()
|
|
||||||
|
|
||||||
def remove_completed_tasks(self):
|
def remove_completed_tasks(self):
|
||||||
self.tasks = [t for t in self.tasks if not t.done()]
|
self.tasks = [t for t in self.tasks if not t.done()]
|
||||||
|
|
|
@ -51,14 +51,14 @@ def test_bad_declarations():
|
||||||
# More than one initial state
|
# More than one initial state
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
|
|
||||||
class TestGraph(StateGraph):
|
class TestGraph2(StateGraph):
|
||||||
initial = State()
|
initial = State()
|
||||||
initial2 = State()
|
initial2 = State()
|
||||||
|
|
||||||
# No initial states
|
# No initial states
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
|
|
||||||
class TestGraph(StateGraph):
|
class TestGraph3(StateGraph):
|
||||||
loop = State()
|
loop = State()
|
||||||
loop2 = State()
|
loop2 = State()
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,23 @@
|
||||||
|
# Generated by Django 4.1.3 on 2022-11-10 03:24
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
("users", "0004_remove_follow_state_locked_and_more"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="follow",
|
||||||
|
name="state_locked_until",
|
||||||
|
field=models.DateTimeField(blank=True, null=True),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="follow",
|
||||||
|
name="state_ready",
|
||||||
|
field=models.BooleanField(default=False),
|
||||||
|
),
|
||||||
|
]
|
|
@ -1,6 +1,6 @@
|
||||||
from .block import Block # noqa
|
from .block import Block # noqa
|
||||||
from .domain import Domain # noqa
|
from .domain import Domain # noqa
|
||||||
from .follow import Follow # noqa
|
from .follow import Follow, FollowStates # noqa
|
||||||
from .identity import Identity # noqa
|
from .identity import Identity, IdentityStates # noqa
|
||||||
from .user import User # noqa
|
from .user import User # noqa
|
||||||
from .user_event import UserEvent # noqa
|
from .user_event import UserEvent # noqa
|
||||||
|
|
|
@ -55,7 +55,7 @@ class Domain(models.Model):
|
||||||
return cls.objects.create(domain=domain, local=False)
|
return cls.objects.create(domain=domain, local=False)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_local_domain(cls, domain: str) -> Optional["Domain"]:
|
def get_domain(cls, domain: str) -> Optional["Domain"]:
|
||||||
try:
|
try:
|
||||||
return cls.objects.get(
|
return cls.objects.get(
|
||||||
models.Q(domain=domain) | models.Q(service_domain=domain)
|
models.Q(domain=domain) | models.Q(service_domain=domain)
|
||||||
|
|
|
@ -6,13 +6,13 @@ from stator.models import State, StateField, StateGraph, StatorModel
|
||||||
|
|
||||||
|
|
||||||
class FollowStates(StateGraph):
|
class FollowStates(StateGraph):
|
||||||
pending = State(try_interval=3600)
|
pending = State(try_interval=30)
|
||||||
requested = State()
|
requested = State()
|
||||||
accepted = State()
|
accepted = State()
|
||||||
|
|
||||||
@pending.add_transition(requested)
|
@pending.add_transition(requested)
|
||||||
async def try_request(cls, instance):
|
async def try_request(instance: "Follow"): # type:ignore
|
||||||
print("Would have tried to follow")
|
print("Would have tried to follow on", instance)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
requested.add_manual_transition(accepted)
|
requested.add_manual_transition(accepted)
|
||||||
|
@ -73,11 +73,3 @@ class Follow(StatorModel):
|
||||||
follow.state = FollowStates.accepted
|
follow.state = FollowStates.accepted
|
||||||
follow.save()
|
follow.save()
|
||||||
return follow
|
return follow
|
||||||
|
|
||||||
def undo(self):
|
|
||||||
"""
|
|
||||||
Undoes this follow
|
|
||||||
"""
|
|
||||||
if not self.target.local:
|
|
||||||
Task.submit("follow_undo", str(self.pk))
|
|
||||||
self.delete()
|
|
||||||
|
|
|
@ -14,9 +14,21 @@ from django.utils import timezone
|
||||||
from OpenSSL import crypto
|
from OpenSSL import crypto
|
||||||
|
|
||||||
from core.ld import canonicalise
|
from core.ld import canonicalise
|
||||||
|
from stator.models import State, StateField, StateGraph, StatorModel
|
||||||
from users.models.domain import Domain
|
from users.models.domain import Domain
|
||||||
|
|
||||||
|
|
||||||
|
class IdentityStates(StateGraph):
|
||||||
|
outdated = State(try_interval=3600)
|
||||||
|
updated = State()
|
||||||
|
|
||||||
|
@outdated.add_transition(updated)
|
||||||
|
async def fetch_identity(identity: "Identity"): # type:ignore
|
||||||
|
if identity.local:
|
||||||
|
return True
|
||||||
|
return await identity.fetch_actor()
|
||||||
|
|
||||||
|
|
||||||
def upload_namer(prefix, instance, filename):
|
def upload_namer(prefix, instance, filename):
|
||||||
"""
|
"""
|
||||||
Names uploaded images etc.
|
Names uploaded images etc.
|
||||||
|
@ -26,7 +38,7 @@ def upload_namer(prefix, instance, filename):
|
||||||
return f"{prefix}/{now.year}/{now.month}/{now.day}/{filename}"
|
return f"{prefix}/{now.year}/{now.month}/{now.day}/{filename}"
|
||||||
|
|
||||||
|
|
||||||
class Identity(models.Model):
|
class Identity(StatorModel):
|
||||||
"""
|
"""
|
||||||
Represents both local and remote Fediverse identities (actors)
|
Represents both local and remote Fediverse identities (actors)
|
||||||
"""
|
"""
|
||||||
|
@ -35,6 +47,8 @@ class Identity(models.Model):
|
||||||
# one around as well for making nice URLs etc.
|
# one around as well for making nice URLs etc.
|
||||||
actor_uri = models.CharField(max_length=500, unique=True)
|
actor_uri = models.CharField(max_length=500, unique=True)
|
||||||
|
|
||||||
|
state = StateField(IdentityStates)
|
||||||
|
|
||||||
local = models.BooleanField()
|
local = models.BooleanField()
|
||||||
users = models.ManyToManyField("users.User", related_name="identities")
|
users = models.ManyToManyField("users.User", related_name="identities")
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ from django.http import Http404
|
||||||
from users.models import Domain, Identity
|
from users.models import Domain, Identity
|
||||||
|
|
||||||
|
|
||||||
def by_handle_or_404(request, handle, local=True, fetch=False):
|
def by_handle_or_404(request, handle, local=True, fetch=False) -> Identity:
|
||||||
"""
|
"""
|
||||||
Retrieves an Identity by its long or short handle.
|
Retrieves an Identity by its long or short handle.
|
||||||
Domain-sensitive, so it will understand short handles on alternate domains.
|
Domain-sensitive, so it will understand short handles on alternate domains.
|
||||||
|
@ -12,14 +12,17 @@ def by_handle_or_404(request, handle, local=True, fetch=False):
|
||||||
if "HTTP_HOST" not in request.META:
|
if "HTTP_HOST" not in request.META:
|
||||||
raise Http404("No hostname available")
|
raise Http404("No hostname available")
|
||||||
username = handle
|
username = handle
|
||||||
domain_instance = Domain.get_local_domain(request.META["HTTP_HOST"])
|
domain_instance = Domain.get_domain(request.META["HTTP_HOST"])
|
||||||
if domain_instance is None:
|
if domain_instance is None:
|
||||||
raise Http404("No matching domains found")
|
raise Http404("No matching domains found")
|
||||||
domain = domain_instance.domain
|
domain = domain_instance.domain
|
||||||
else:
|
else:
|
||||||
username, domain = handle.split("@", 1)
|
username, domain = handle.split("@", 1)
|
||||||
# Resolve the domain to the display domain
|
# Resolve the domain to the display domain
|
||||||
domain = Domain.get_local_domain(request.META["HTTP_HOST"]).domain
|
domain_instance = Domain.get_domain(domain)
|
||||||
|
if domain_instance is None:
|
||||||
|
raise Http404("No matching domains found")
|
||||||
|
domain = domain_instance.domain
|
||||||
identity = Identity.by_username_and_domain(
|
identity = Identity.by_username_and_domain(
|
||||||
username,
|
username,
|
||||||
domain,
|
domain,
|
||||||
|
|
|
@ -17,7 +17,7 @@ from core.forms import FormHelper
|
||||||
from core.ld import canonicalise
|
from core.ld import canonicalise
|
||||||
from core.signatures import HttpSignature
|
from core.signatures import HttpSignature
|
||||||
from users.decorators import identity_required
|
from users.decorators import identity_required
|
||||||
from users.models import Domain, Follow, Identity
|
from users.models import Domain, Follow, Identity, IdentityStates
|
||||||
from users.shortcuts import by_handle_or_404
|
from users.shortcuts import by_handle_or_404
|
||||||
|
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ class ViewIdentity(TemplateView):
|
||||||
)
|
)
|
||||||
statuses = identity.statuses.all()[:100]
|
statuses = identity.statuses.all()[:100]
|
||||||
if identity.data_age > settings.IDENTITY_MAX_AGE:
|
if identity.data_age > settings.IDENTITY_MAX_AGE:
|
||||||
Task.submit("identity_fetch", identity.handle)
|
identity.transition_perform(IdentityStates.outdated)
|
||||||
return {
|
return {
|
||||||
"identity": identity,
|
"identity": identity,
|
||||||
"statuses": statuses,
|
"statuses": statuses,
|
||||||
|
@ -129,7 +129,7 @@ class CreateIdentity(FormView):
|
||||||
def form_valid(self, form):
|
def form_valid(self, form):
|
||||||
username = form.cleaned_data["username"]
|
username = form.cleaned_data["username"]
|
||||||
domain = form.cleaned_data["domain"]
|
domain = form.cleaned_data["domain"]
|
||||||
domain_instance = Domain.get_local_domain(domain)
|
domain_instance = Domain.get_domain(domain)
|
||||||
new_identity = Identity.objects.create(
|
new_identity = Identity.objects.create(
|
||||||
actor_uri=f"https://{domain_instance.uri_domain}/@{username}@{domain}/actor/",
|
actor_uri=f"https://{domain_instance.uri_domain}/@{username}@{domain}/actor/",
|
||||||
username=username,
|
username=username,
|
||||||
|
|
Loading…
Reference in New Issue