diff --git a/activities/services/post.py b/activities/services/post.py index 5745c35..59f73fe 100644 --- a/activities/services/post.py +++ b/activities/services/post.py @@ -1,6 +1,12 @@ from django.db import models -from activities.models import Post, PostInteraction, PostInteractionStates, PostStates +from activities.models import ( + Post, + PostInteraction, + PostInteractionStates, + PostStates, + TimelineEvent, +) from users.models import Identity @@ -9,6 +15,40 @@ class PostService: High-level operations on Posts """ + @classmethod + def queryset(cls): + """ + Returns the base queryset to use for fetching posts efficiently. + """ + return ( + Post.objects.not_hidden() + .prefetch_related( + "attachments", + "mentions", + "emojis", + ) + .select_related( + "author", + "author__domain", + ) + .annotate( + like_count=models.Count( + "interactions", + filter=models.Q( + interactions__type=PostInteraction.Types.like, + interactions__state__in=PostInteractionStates.group_active(), + ), + ), + boost_count=models.Count( + "interactions", + filter=models.Q( + interactions__type=PostInteraction.Types.boost, + interactions__state__in=PostInteractionStates.group_active(), + ), + ), + ) + ) + def __init__(self, post: Post): self.post = post @@ -47,40 +87,6 @@ class PostService: def unboost_as(self, identity: Identity): self.uninteract_as(identity, PostInteraction.Types.boost) - @classmethod - def queryset(cls): - """ - Returns the base queryset to use for fetching posts efficiently. - """ - return ( - Post.objects.not_hidden() - .prefetch_related( - "attachments", - "mentions", - "emojis", - ) - .select_related( - "author", - "author__domain", - ) - .annotate( - like_count=models.Count( - "interactions", - filter=models.Q( - interactions__type=PostInteraction.Types.like, - interactions__state__in=PostInteractionStates.group_active(), - ), - ), - boost_count=models.Count( - "interactions", - filter=models.Q( - interactions__type=PostInteraction.Types.boost, - interactions__state__in=PostInteractionStates.group_active(), - ), - ), - ) - ) - def context(self, identity: Identity | None) -> tuple[list[Post], list[Post]]: """ Returns ancestor/descendant information. @@ -123,3 +129,17 @@ class PostService: descendants.append(child) queue.append(child) return ancestors, descendants + + def delete(self): + """ + Marks a post as deleted and immediately cleans up its timeline events etc. + """ + self.post.transition_perform(PostStates.deleted) + TimelineEvent.objects.filter(subject_post=self.post).delete() + PostInteraction.transition_perform_queryset( + PostInteraction.objects.filter( + post=self.post, + state__in=PostInteractionStates.group_active(), + ), + PostInteractionStates.undone, + ) diff --git a/activities/views/posts.py b/activities/views/posts.py index 0925e62..49d0b5d 100644 --- a/activities/views/posts.py +++ b/activities/views/posts.py @@ -164,5 +164,5 @@ class Delete(TemplateView): return {"post": self.post_obj} def post(self, request): - self.post_obj.transition_perform(PostStates.deleted) + PostService(self.post_obj).delete() return redirect("/") diff --git a/api/views/statuses.py b/api/views/statuses.py index 88f967c..eb3b464 100644 --- a/api/views/statuses.py +++ b/api/views/statuses.py @@ -4,13 +4,7 @@ from django.forms import ValidationError from django.shortcuts import get_object_or_404 from ninja import Schema -from activities.models import ( - Post, - PostAttachment, - PostInteraction, - PostStates, - TimelineEvent, -) +from activities.models import Post, PostAttachment, PostInteraction, TimelineEvent from activities.services import PostService from api import schemas from api.views.base import api_router @@ -79,8 +73,7 @@ def status(request, id: str): @identity_required def delete_status(request, id: str): post = get_object_or_404(Post, pk=id) - post.transition_perform(PostStates.deleted) - TimelineEvent.objects.filter(subject_post=post, identity=request.identity).delete() + PostService(post).delete() return post.to_mastodon_json() diff --git a/stator/models.py b/stator/models.py index 350421d..c69ed0f 100644 --- a/stator/models.py +++ b/stator/models.py @@ -206,13 +206,29 @@ class StatorModel(models.Model): """ Transitions the instance to the given state name, forcibly. """ + self.transition_perform_queryset( + self.__class__.objects.filter(pk=self.pk), + state, + ) + + atransition_perform = sync_to_async(transition_perform) + + @classmethod + def transition_perform_queryset( + cls, + queryset: models.QuerySet, + state: State | str, + ): + """ + Transitions every instance in the queryset to the given state name, forcibly. + """ if isinstance(state, State): state = state.name - if state not in self.state_graph.states: + if state not in cls.state_graph.states: raise ValueError(f"Invalid state {state}") # See if it's ready immediately (if not, delay until first try_interval) - if self.state_graph.states[state].attempt_immediately: - self.__class__.objects.filter(pk=self.pk).update( + if cls.state_graph.states[state].attempt_immediately: + queryset.update( state=state, state_changed=timezone.now(), state_attempted=None, @@ -220,7 +236,7 @@ class StatorModel(models.Model): state_ready=True, ) else: - self.__class__.objects.filter(pk=self.pk).update( + queryset.update( state=state, state_changed=timezone.now(), state_attempted=timezone.now(), @@ -228,8 +244,6 @@ class StatorModel(models.Model): state_ready=False, ) - atransition_perform = sync_to_async(transition_perform) - class Stats(models.Model): """