From e382994e737b8331beac75dc68672a0eddebb447 Mon Sep 17 00:00:00 2001 From: Humberto Rocha Date: Wed, 15 Feb 2023 12:42:56 -0500 Subject: [PATCH] Polls (#496) Implement poll AP and API backend --- ...action_value_alter_postinteraction_type.py | 26 ++ activities/models/fan_out.py | 8 +- activities/models/post.py | 54 ++- activities/models/post_interaction.py | 123 ++++++- activities/models/post_types.py | 48 +++ activities/models/timeline_event.py | 8 +- api/pagination.py | 8 +- api/schemas.py | 47 ++- api/urls.py | 4 + api/views/polls.py | 26 ++ api/views/search.py | 4 +- api/views/statuses.py | 55 ++- core/ld.py | 4 +- tests/activities/models/test_post.py | 41 +++ .../models/test_post_interaction.py | 314 ++++++++++++++++++ tests/api/test_polls.py | 62 ++++ tests/api/test_statuses.py | 65 ++++ users/models/inbox_message.py | 14 +- 18 files changed, 868 insertions(+), 43 deletions(-) create mode 100644 activities/migrations/0011_postinteraction_value_alter_postinteraction_type.py create mode 100644 api/views/polls.py create mode 100644 tests/activities/models/test_post_interaction.py create mode 100644 tests/api/test_polls.py diff --git a/activities/migrations/0011_postinteraction_value_alter_postinteraction_type.py b/activities/migrations/0011_postinteraction_value_alter_postinteraction_type.py new file mode 100644 index 0000000..2de059e --- /dev/null +++ b/activities/migrations/0011_postinteraction_value_alter_postinteraction_type.py @@ -0,0 +1,26 @@ +# Generated by Django 4.1.7 on 2023-02-14 22:30 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("activities", "0010_stator_indexes"), + ] + + operations = [ + migrations.AddField( + model_name="postinteraction", + name="value", + field=models.CharField(blank=True, max_length=50, null=True), + ), + migrations.AlterField( + model_name="postinteraction", + name="type", + field=models.CharField( + choices=[("like", "Like"), ("boost", "Boost"), ("vote", "Vote")], + max_length=100, + ), + ), + ] diff --git a/activities/models/fan_out.py b/activities/models/fan_out.py index e63509c..ff6fc64 100644 --- a/activities/models/fan_out.py +++ b/activities/models/fan_out.py @@ -163,7 +163,7 @@ class FanOutStates(StateGraph): interaction=interaction, ) - # Handle sending remote boosts/likes + # Handle sending remote boosts/likes/votes case (FanOut.Types.interaction, False): interaction = await fan_out.subject_post_interaction.afetch_full() # Send it to the remote inbox @@ -174,7 +174,11 @@ class FanOutStates(StateGraph): fan_out.identity.shared_inbox_uri or fan_out.identity.inbox_uri ), - body=canonicalise(interaction.to_ap()), + body=canonicalise( + interaction.to_create_ap() + if interaction.type == interaction.Types.vote + else interaction.to_ap() + ), ) except httpx.RequestError: return diff --git a/activities/models/post.py b/activities/models/post.py index 62c5983..bbcfecc 100644 --- a/activities/models/post.py +++ b/activities/models/post.py @@ -23,6 +23,7 @@ from activities.models.post_types import ( PostTypeData, PostTypeDataDecoder, PostTypeDataEncoder, + QuestionData, ) from core.exceptions import capture_message from core.html import ContentRenderer, FediverseHtmlParser @@ -377,12 +378,16 @@ class Post(StatorModel): def _safe_content_note(self, *, local: bool = True): return ContentRenderer(local=local).render_post(self.content, self) - # def _safe_content_question(self, *, local: bool = True): - # context = { - # "post": self, - # "typed_data": PostTypeData(self.type_data), - # } - # return loader.render_to_string("activities/_type_question.html", context) + def _safe_content_question(self, *, local: bool = True): + if local: + context = { + "post": self, + "sanitized_content": self._safe_content_note(local=local), + "local_display": local, + } + return loader.render_to_string("activities/_type_question.html", context) + else: + return ContentRenderer(local=local).render_post(self.content, self) def _safe_content_typed(self, *, local: bool = True): context = { @@ -461,6 +466,7 @@ class Post(StatorModel): visibility: int = Visibilities.public, reply_to: Optional["Post"] = None, attachments: list | None = None, + question: dict | None = None, ) -> "Post": with transaction.atomic(): # Find mentions in this post @@ -493,6 +499,9 @@ class Post(StatorModel): post.emojis.set(emojis) if attachments: post.attachments.set(attachments) + if question: + post.type = question["type"] + post.type_data = PostTypeData(__root__=question).__root__ post.save() # Recalculate parent stats for replies if reply_to: @@ -574,6 +583,30 @@ class Post(StatorModel): if save: self.save() + def calculate_type_data(self, save=True): + """ + Recalculate type_data (used mostly for poll votes) + """ + from activities.models import PostInteraction + + if self.local and isinstance(self.type_data, QuestionData): + self.type_data.voter_count = ( + self.interactions.filter( + type=PostInteraction.Types.vote, + ) + .values("identity") + .distinct() + .count() + ) + + for option in self.type_data.options: + option.votes = self.interactions.filter( + type=PostInteraction.Types.vote, + value=option.name, + ).count() + if save: + self.save() + ### ActivityPub (outbound) ### def to_ap(self) -> dict: @@ -802,8 +835,7 @@ class Post(StatorModel): if update or created: post.type = data["type"] if post.type in (cls.Types.article, cls.Types.question): - type_data = PostTypeData(__root__=data).__root__ - post.type_data = type_data.dict() + post.type_data = PostTypeData(__root__=data).__root__ post.content = get_value_or_map(data, "content", "contentMap") post.summary = data.get("summary") post.sensitive = data.get("sensitive", False) @@ -1027,7 +1059,7 @@ class Post(StatorModel): ### Mastodon API ### - def to_mastodon_json(self, interactions=None): + def to_mastodon_json(self, interactions=None, identity=None): reply_parent = None if self.in_reply_to: # Load the PK and author.id explicitly to prevent a SELECT on the entire author Identity @@ -1085,7 +1117,9 @@ class Post(StatorModel): reply_parent.author_id if reply_parent else None ), "reblog": None, - "poll": None, + "poll": self.type_data.to_mastodon_json(self, identity) + if isinstance(self.type_data, QuestionData) + else None, "card": None, "language": None, "text": self.safe_content_remote(), diff --git a/activities/models/post_interaction.py b/activities/models/post_interaction.py index 5bc92cf..0b66065 100644 --- a/activities/models/post_interaction.py +++ b/activities/models/post_interaction.py @@ -3,6 +3,7 @@ from django.utils import timezone from activities.models.fan_out import FanOut from activities.models.post import Post +from activities.models.post_types import QuestionData from core.ld import format_ld_date, get_str_or_id, parse_ld_date from core.snowflake import Snowflake from stator.models import State, StateField, StateGraph, StatorModel @@ -60,6 +61,17 @@ class PostInteractionStates(StateGraph): subject_post=interaction.post, subject_post_interaction=interaction, ) + # Vote: send a copy of the vote to the original + # post author only if it's a local interaction + # to a non local post + elif interaction.type == interaction.Types.vote: + if interaction.identity.local and not interaction.post.local: + await FanOut.objects.acreate( + type=FanOut.Types.interaction, + identity_id=interaction.post.author_id, + subject_post=interaction.post, + subject_post_interaction=interaction, + ) else: raise ValueError("Cannot fan out unknown type") # And one for themselves if they're local and it's a boost @@ -125,6 +137,7 @@ class PostInteraction(StatorModel): class Types(models.TextChoices): like = "like" boost = "boost" + vote = "vote" id = models.BigIntegerField( primary_key=True, @@ -154,6 +167,10 @@ class PostInteraction(StatorModel): related_name="interactions", ) + # Used to store any interaction extra text value like the vote + # in the question/poll case + value = models.CharField(max_length=50, blank=True, null=True) + # When the activity was originally created (as opposed to when we received it) # Mastodon only seems to send this for boosts, not likes published = models.DateTimeField(default=timezone.now) @@ -203,9 +220,44 @@ class PostInteraction(StatorModel): """ Returns a version of the object with all relations pre-loaded """ - return await PostInteraction.objects.select_related("identity", "post").aget( - pk=self.pk - ) + return await PostInteraction.objects.select_related( + "identity", "post", "post__author" + ).aget(pk=self.pk) + + ### Create helpers ### + + @classmethod + def create_votes(cls, post, identity, choices) -> list["PostInteraction"]: + question = post.type_data + + if question.end_time and timezone.now() > question.end_time: + raise ValueError("Validation failed: The poll has already ended") + + if post.interactions.filter(identity=identity, type=cls.Types.vote).exists(): + raise ValueError("Validation failed: You have already voted on this poll") + + votes = [] + with transaction.atomic(): + for choice in set(choices): + vote = cls.objects.create( + identity=identity, + post=post, + type=PostInteraction.Types.vote, + value=question.options[choice].name, + ) + vote.object_uri = f"{identity.actor_uri}#votes/{vote.id}" + vote.save() + votes.append(vote) + + if not post.local: + question.options[choice].votes += 1 + + if not post.local: + question.voter_count += 1 + + post.calculate_type_data() + + return votes ### ActivityPub (outbound) ### @@ -233,10 +285,33 @@ class PostInteraction(StatorModel): "actor": self.identity.actor_uri, "object": self.post.object_uri, } + elif self.type == self.Types.vote: + value = { + "type": "Note", + "id": self.object_uri, + "to": self.post.author.actor_uri, + "name": self.value, + "inReplyTo": self.post.object_uri, + "attributedTo": self.identity.actor_uri, + } else: raise ValueError("Cannot turn into AP") return value + def to_create_ap(self): + """ + Returns the AP JSON to create this object + """ + object = self.to_ap() + return { + "to": object.get("to", []), + "cc": object.get("cc", []), + "type": "Create", + "id": self.object_uri, + "actor": self.identity.actor_uri, + "object": object, + } + def to_undo_ap(self) -> dict: """ Returns the AP JSON to undo this object @@ -267,12 +342,40 @@ class PostInteraction(StatorModel): # Resolve the author identity = Identity.by_actor_uri(data["actor"], create=True) # Resolve the post - post = Post.by_object_uri(get_str_or_id(data["object"]), fetch=True) + object = data["object"] + target = get_str_or_id(object, "inReplyTo") or get_str_or_id(object) + post = Post.by_object_uri(target, fetch=True) + value = None # Get the right type if data["type"].lower() == "like": type = cls.Types.like elif data["type"].lower() == "announce": type = cls.Types.boost + elif ( + data["type"].lower() == "create" + and object["type"].lower() == "note" + and isinstance(post.type_data, QuestionData) + ): + type = cls.Types.vote + question = post.type_data + value = object["name"] + if question.end_time and timezone.now() > question.end_time: + # TODO: Maybe create an expecific expired exception? + raise cls.DoesNotExist( + f"Cannot create a vote to the expired question {post.id}" + ) + + already_voted = ( + post.type_data.mode == "oneOf" + and post.interactions.filter( + type=cls.Types.vote, identity=identity + ).exists() + ) + if already_voted: + raise cls.DoesNotExist( + f"The identity {identity.handle} already voted in question {post.id}" + ) + else: raise ValueError(f"Cannot handle AP type {data['type']}") # Make the actual interaction @@ -283,6 +386,7 @@ class PostInteraction(StatorModel): published=parse_ld_date(data.get("published", None)) or timezone.now(), type=type, + value=value, ) else: raise cls.DoesNotExist(f"No interaction with ID {data['id']}", data) @@ -301,7 +405,9 @@ class PostInteraction(StatorModel): # That post is gone, boss # TODO: Limited retry state? return + interaction.post.calculate_stats() + interaction.post.calculate_type_data() @classmethod def handle_undo_ap(cls, data): @@ -324,10 +430,11 @@ class PostInteraction(StatorModel): interaction.transition_perform(PostInteractionStates.undone_fanned_out) # Recalculate post stats interaction.post.calculate_stats() + interaction.post.calculate_type_data() ### Mastodon API ### - def to_mastodon_status_json(self, interactions=None): + def to_mastodon_status_json(self, interactions=None, identity=None): """ This wraps Posts in a fake Status for boost interactions. """ @@ -336,7 +443,9 @@ class PostInteraction(StatorModel): f"Cannot make status JSON for interaction of type {self.type}" ) # Make a fake post for this boost (because mastodon treats boosts as posts) - post_json = self.post.to_mastodon_json(interactions=interactions) + post_json = self.post.to_mastodon_json( + interactions=interactions, identity=identity + ) return { "id": f"{self.pk}", "uri": post_json["uri"], @@ -356,7 +465,7 @@ class PostInteraction(StatorModel): "url": post_json["url"], "in_reply_to_id": None, "in_reply_to_account_id": None, - "poll": None, + "poll": post_json["poll"], "card": None, "language": None, "text": "", diff --git a/activities/models/post_types.py b/activities/models/post_types.py index a1536a2..b4e58f6 100644 --- a/activities/models/post_types.py +++ b/activities/models/post_types.py @@ -2,8 +2,11 @@ import json from datetime import datetime from typing import Literal +from django.utils import timezone from pydantic import BaseModel, Field +from core.ld import format_ld_date + class BasePostDataType(BaseModel): pass @@ -45,6 +48,51 @@ class QuestionData(BasePostDataType): data["options"] = options super().__init__(**data) + def to_mastodon_json(self, post, identity=None): + from activities.models import PostInteraction + + multiple = self.mode == "anyOf" + value = { + "id": post.id, + "expires_at": None, + "expired": False, + "multiple": multiple, + "votes_count": 0, + "voters_count": self.voter_count if multiple else None, + "voted": False, + "own_votes": [], + "options": [], + "emojis": [], + } + + if self.end_time: + value["expires_at"] = format_ld_date(self.end_time) + value["expired"] = timezone.now() >= self.end_time + + options = self.options or [] + option_map = {} + for index, option in enumerate(options): + value["options"].append( + { + "title": option.name, + "votes_count": option.votes, + } + ) + value["votes_count"] += option.votes + option_map[option.name] = index + + if identity: + votes = post.interactions.filter( + identity=identity, + type=PostInteraction.Types.vote, + ) + value["voted"] = post.author == identity or votes.exists() + value["own_votes"] = [ + option_map[vote.value] for vote in votes if vote.value in option_map + ] + + return value + class ArticleData(BasePostDataType): type: Literal["Article"] diff --git a/activities/models/timeline_event.py b/activities/models/timeline_event.py index eda9472..e94dd8e 100644 --- a/activities/models/timeline_event.py +++ b/activities/models/timeline_event.py @@ -221,12 +221,14 @@ class TimelineEvent(models.Model): raise ValueError(f"Cannot convert {self.type} to notification JSON") return result - def to_mastodon_status_json(self, interactions=None): + def to_mastodon_status_json(self, interactions=None, identity=None): if self.type == self.Types.post: - return self.subject_post.to_mastodon_json(interactions=interactions) + return self.subject_post.to_mastodon_json( + interactions=interactions, identity=identity + ) elif self.type == self.Types.boost: return self.subject_post_interaction.to_mastodon_status_json( - interactions=interactions + interactions=interactions, identity=identity ) else: raise ValueError(f"Cannot make status JSON for type {self.type}") diff --git a/api/pagination.py b/api/pagination.py index 2139e09..5e8fe73 100644 --- a/api/pagination.py +++ b/api/pagination.py @@ -153,7 +153,9 @@ class PaginationResult(Generic[T]): """ interactions = PostInteraction.get_post_interactions(self.results, identity) self.jsonify_results( - lambda post: post.to_mastodon_json(interactions=interactions) + lambda post: post.to_mastodon_json( + interactions=interactions, identity=identity + ) ) def jsonify_status_events(self, identity): @@ -162,7 +164,9 @@ class PaginationResult(Generic[T]): """ interactions = PostInteraction.get_event_interactions(self.results, identity) self.jsonify_results( - lambda event: event.to_mastodon_status_json(interactions=interactions) + lambda event: event.to_mastodon_status_json( + interactions=interactions, identity=identity + ) ) def jsonify_notification_events(self, identity): diff --git a/api/schemas.py b/api/schemas.py index 5f07dd9..c40e00f 100644 --- a/api/schemas.py +++ b/api/schemas.py @@ -91,6 +91,32 @@ class MediaAttachment(Schema): return cls(**attachment.to_mastodon_json()) +class PollOptions(Schema): + title: str + votes_count: int | None + + +class Poll(Schema): + id: str + expires_at: str | None + expired: bool + multiple: bool + votes_count: int + voters_count: int | None + voted: bool + own_votes: list[int] + options: list[PollOptions] + emojis: list[CustomEmoji] + + @classmethod + def from_post( + cls, + post: activities_models.Post, + identity: users_models.Identity | None = None, + ) -> "Poll": + return cls(**post.type_data.to_mastodon_json(post, identity=identity)) + + class StatusMention(Schema): id: str username: str @@ -123,7 +149,7 @@ class Status(Schema): in_reply_to_id: str | None = Field(...) in_reply_to_account_id: str | None = Field(...) reblog: Optional["Status"] = Field(...) - poll: None = Field(...) + poll: Poll | None = Field(...) card: None = Field(...) language: None = Field(...) text: str | None = Field(...) @@ -139,8 +165,11 @@ class Status(Schema): cls, post: activities_models.Post, interactions: dict[str, set[str]] | None = None, + identity: users_models.Identity | None = None, ) -> "Status": - return cls(**post.to_mastodon_json(interactions=interactions)) + return cls( + **post.to_mastodon_json(interactions=interactions, identity=identity) + ) @classmethod def map_from_post( @@ -151,15 +180,23 @@ class Status(Schema): interactions = activities_models.PostInteraction.get_post_interactions( posts, identity ) - return [cls.from_post(post, interactions=interactions) for post in posts] + return [ + cls.from_post(post, interactions=interactions, identity=identity) + for post in posts + ] @classmethod def from_timeline_event( cls, timeline_event: activities_models.TimelineEvent, interactions: dict[str, set[str]] | None = None, + identity: users_models.Identity | None = None, ) -> "Status": - return cls(**timeline_event.to_mastodon_status_json(interactions=interactions)) + return cls( + **timeline_event.to_mastodon_status_json( + interactions=interactions, identity=identity + ) + ) @classmethod def map_from_timeline_event( @@ -171,7 +208,7 @@ class Status(Schema): events, identity ) return [ - cls.from_timeline_event(event, interactions=interactions) + cls.from_timeline_event(event, interactions=interactions, identity=identity) for event in events ] diff --git a/api/urls.py b/api/urls.py index a2ea0a6..47f2420 100644 --- a/api/urls.py +++ b/api/urls.py @@ -10,6 +10,7 @@ from api.views import ( instance, media, notifications, + polls, search, statuses, timelines, @@ -62,6 +63,9 @@ urlpatterns = [ path("v1/statuses//source", statuses.status_source), # Notifications path("v1/notifications", notifications.notifications), + # Polls + path("v1/polls/", polls.get_poll), + path("v1/polls//votes", polls.vote_poll), # Search path("v2/search", search.search), # Statuses diff --git a/api/views/polls.py b/api/views/polls.py new file mode 100644 index 0000000..2658061 --- /dev/null +++ b/api/views/polls.py @@ -0,0 +1,26 @@ +from django.shortcuts import get_object_or_404 +from hatchway import Schema, api_view + +from activities.models import Post, PostInteraction +from api import schemas +from api.decorators import identity_required + + +class PostVoteSchema(Schema): + choices: list[int] + + +@identity_required +@api_view.get +def get_poll(request, id: str) -> schemas.Poll: + post = get_object_or_404(Post, pk=id, type=Post.Types.question) + return schemas.Poll.from_post(post, identity=request.identity) + + +@identity_required +@api_view.post +def vote_poll(request, id: str, details: PostVoteSchema) -> schemas.Poll: + post = get_object_or_404(Post, pk=id, type=Post.Types.question) + PostInteraction.create_votes(post, request.identity, details.choices) + post.refresh_from_db() + return schemas.Poll.from_post(post, identity=request.identity) diff --git a/api/views/search.py b/api/views/search.py index a3a675a..853d990 100644 --- a/api/views/search.py +++ b/api/views/search.py @@ -47,7 +47,9 @@ def search( search_result["posts"], request.identity ) result["statuses"] = [ - schemas.Status.from_post(p, interactions=interactions) + schemas.Status.from_post( + p, interactions=interactions, identity=request.identity + ) for p in search_result["posts"] ] return schemas.Search(**result) diff --git a/api/views/statuses.py b/api/views/statuses.py index 0492c80..c473594 100644 --- a/api/views/statuses.py +++ b/api/views/statuses.py @@ -1,7 +1,9 @@ +from datetime import timedelta from typing import Literal from django.http import HttpRequest from django.shortcuts import get_object_or_404 +from django.utils import timezone from hatchway import ApiError, ApiResponse, Schema, api_view from activities.models import ( @@ -18,6 +20,24 @@ from api.pagination import MastodonPaginator, PaginationResult from core.models import Config +class PostPollSchema(Schema): + options: list[str] + expires_in: int + multiple: bool = False + hide_totals: bool = False + + def dict(self): + return { + "type": "Question", + "mode": "anyOf" if self.multiple else "oneOf", + "options": [ + {"name": name, "type": "Note", "votes": 0} for name in self.options + ], + "voter_count": 0, + "end_time": timezone.now() + timedelta(seconds=self.expires_in), + } + + class PostStatusSchema(Schema): status: str in_reply_to_id: str | None = None @@ -27,6 +47,7 @@ class PostStatusSchema(Schema): language: str | None = None scheduled_at: str | None = None media_ids: list[str] = [] + poll: PostPollSchema | None = None class EditStatusSchema(Schema): @@ -82,10 +103,11 @@ def post_status(request, details: PostStatusSchema) -> schemas.Status: visibility=visibility_map[details.visibility], reply_to=reply_post, attachments=attachments, + question=details.poll.dict() if details.poll else None, ) # Add their own timeline event for immediate visibility TimelineEvent.add_post(request.identity, post) - return schemas.Status.from_post(post) + return schemas.Status.from_post(post, identity=request.identity) @identity_required @@ -93,7 +115,9 @@ def post_status(request, details: PostStatusSchema) -> schemas.Status: def status(request, id: str) -> schemas.Status: post = post_for_id(request, id) interactions = PostInteraction.get_post_interactions([post], request.identity) - return schemas.Status.from_post(post, interactions=interactions) + return schemas.Status.from_post( + post, interactions=interactions, identity=request.identity + ) @identity_required @@ -121,7 +145,7 @@ def delete_status(request, id: str) -> schemas.Status: if post.author != request.identity: raise ApiError(401, "Not the author of this status") PostService(post).delete() - return schemas.Status.from_post(post) + return schemas.Status.from_post(post, identity=request.identity) @identity_required @@ -142,11 +166,16 @@ def status_context(request, id: str) -> schemas.Context: ) return schemas.Context( ancestors=[ - schemas.Status.from_post(p, interactions=interactions) + schemas.Status.from_post( + p, interactions=interactions, identity=request.identity + ) for p in reversed(ancestors) ], descendants=[ - schemas.Status.from_post(p, interactions=interactions) for p in descendants + schemas.Status.from_post( + p, interactions=interactions, identity=request.identity + ) + for p in descendants ], ) @@ -158,7 +187,9 @@ def favourite_status(request, id: str) -> schemas.Status: service = PostService(post) service.like_as(request.identity) interactions = PostInteraction.get_post_interactions([post], request.identity) - return schemas.Status.from_post(post, interactions=interactions) + return schemas.Status.from_post( + post, interactions=interactions, identity=request.identity + ) @identity_required @@ -168,7 +199,9 @@ def unfavourite_status(request, id: str) -> schemas.Status: service = PostService(post) service.unlike_as(request.identity) interactions = PostInteraction.get_post_interactions([post], request.identity) - return schemas.Status.from_post(post, interactions=interactions) + return schemas.Status.from_post( + post, interactions=interactions, identity=request.identity + ) @api_view.get @@ -219,7 +252,9 @@ def reblog_status(request, id: str) -> schemas.Status: service = PostService(post) service.boost_as(request.identity) interactions = PostInteraction.get_post_interactions([post], request.identity) - return schemas.Status.from_post(post, interactions=interactions) + return schemas.Status.from_post( + post, interactions=interactions, identity=request.identity + ) @identity_required @@ -229,4 +264,6 @@ def unreblog_status(request, id: str) -> schemas.Status: service = PostService(post) service.unboost_as(request.identity) interactions = PostInteraction.get_post_interactions([post], request.identity) - return schemas.Status.from_post(post, interactions=interactions) + return schemas.Status.from_post( + post, interactions=interactions, identity=request.identity + ) diff --git a/core/ld.py b/core/ld.py index b15320f..bdb9fa2 100644 --- a/core/ld.py +++ b/core/ld.py @@ -626,14 +626,14 @@ def get_list(container, key) -> list: return value -def get_str_or_id(value: str | dict | None) -> str | None: +def get_str_or_id(value: str | dict | None, key: str = "id") -> str | None: """ Given a value that could be a str or {"id": str}, return the str """ if isinstance(value, str): return value elif isinstance(value, dict): - return value.get("id") + return value.get(key) return None diff --git a/tests/activities/models/test_post.py b/tests/activities/models/test_post.py index 0eb1fc5..6ebc766 100644 --- a/tests/activities/models/test_post.py +++ b/tests/activities/models/test_post.py @@ -2,6 +2,7 @@ import pytest from pytest_httpx import HTTPXMock from activities.models import Post, PostStates +from activities.models.post_types import QuestionData from users.models import Identity, InboxMessage @@ -254,6 +255,46 @@ def test_content_map(remote_identity): assert post3.content == "Hello World" +@pytest.mark.django_db +def test_content_map_question(remote_identity: Identity): + """ + Tests post contentmap for questions + """ + post = Post.by_ap( + data={ + "id": "https://remote.test/posts/1/", + "type": "Question", + "votersCount": 10, + "closed": "2023-01-01T26:04:45Z", + "content": "Test Question", + "attributedTo": "https://remote.test/test-actor/", + "published": "2022-12-23T10:50:54Z", + "endTime": "2023-01-01T20:04:45Z", + "oneOf": [ + { + "type": "Note", + "name": "Option 1", + "replies": { + "type": "Collection", + "totalItems": 6, + }, + }, + { + "type": "Note", + "name": "Option 2", + "replies": { + "type": "Collection", + "totalItems": 4, + }, + }, + ], + }, + create=True, + ) + assert post.content == "Test Question" + assert isinstance(post.type_data, QuestionData) + + @pytest.mark.django_db @pytest.mark.parametrize("delete_type", ["note", "tombstone", "ref"]) def test_inbound_posts( diff --git a/tests/activities/models/test_post_interaction.py b/tests/activities/models/test_post_interaction.py new file mode 100644 index 0000000..ced7b4f --- /dev/null +++ b/tests/activities/models/test_post_interaction.py @@ -0,0 +1,314 @@ +from datetime import timedelta + +import pytest +from django.utils import timezone + +from activities.models import Post, PostInteraction +from activities.models.post_types import QuestionData +from core.ld import format_ld_date +from users.models import Identity + + +@pytest.mark.django_db +def test_vote_in_question(identity: Identity, remote_identity: Identity, config_system): + post = Post.create_local( + author=identity, + content="

Test Question

", + question={ + "type": "Question", + "mode": "oneOf", + "options": [ + {"name": "Option 1", "type": "Note", "votes": 0}, + {"name": "Option 2", "type": "Note", "votes": 0}, + ], + "voter_count": 0, + "end_time": format_ld_date(timezone.now() + timedelta(1)), + }, + ) + + PostInteraction.handle_ap( + data={ + "id": "https://remote.test/test-actor#votes/11/activity", + "to": "https://example.com/@test@example.com/", + "type": "Create", + "actor": "https://remote.test/test-actor/", + "object": { + "id": "https://remote.test/users/test-actor#votes/11", + "to": "https://example.com/@test@example.com/", + "name": "Option 1", + "type": "Note", + "inReplyTo": post.object_uri, + "attributedTo": "https://remote.test/test-actor/", + }, + "@context": [ + "https://www.w3.org/ns/activitystreams", + { + "toot": "http://joinmastodon.org/ns#", + "Emoji": "toot:Emoji", + "Public": "as:Public", + "Hashtag": "as:Hashtag", + "votersCount": "toot:votersCount", + }, + "https://w3id.org/security/v1", + ], + }, + ) + + post.refresh_from_db() + + assert isinstance(post.type_data, QuestionData) + assert post.type_data.voter_count == 1 + assert post.type_data.options + assert len(post.type_data.options) == 2 + assert post.type_data.options[0].votes == 1 + assert post.type_data.options[1].votes == 0 + + +@pytest.mark.django_db +def test_vote_in_multiple_choice_question( + identity: Identity, remote_identity: Identity, config_system +): + post = Post.create_local( + author=identity, + content="

Test Question

", + question={ + "type": "Question", + "mode": "anyOf", + "options": [ + {"name": "Option 1", "type": "Note", "votes": 0}, + {"name": "Option 2", "type": "Note", "votes": 0}, + {"name": "Option 3", "type": "Note", "votes": 0}, + ], + "voter_count": 0, + "end_time": format_ld_date(timezone.now() + timedelta(1)), + }, + ) + + PostInteraction.handle_ap( + data={ + "id": "https://remote.test/test-actor#votes/12/activity", + "to": "https://example.com/@test@example.com/", + "type": "Create", + "actor": "https://remote.test/test-actor/", + "object": { + "id": "https://remote.test/users/test-actor#votes/12", + "to": "https://example.com/@test@example.com/", + "name": "Option 1", + "type": "Note", + "inReplyTo": post.object_uri, + "attributedTo": "https://remote.test/test-actor/", + }, + "@context": [ + "https://www.w3.org/ns/activitystreams", + { + "toot": "http://joinmastodon.org/ns#", + "Emoji": "toot:Emoji", + "Public": "as:Public", + "Hashtag": "as:Hashtag", + "votersCount": "toot:votersCount", + }, + "https://w3id.org/security/v1", + ], + }, + ) + + PostInteraction.handle_ap( + data={ + "id": "https://remote.test/test-actor#votes/13/activity", + "to": "https://example.com/@test@example.com/", + "type": "Create", + "actor": "https://remote.test/test-actor/", + "object": { + "id": "https://remote.test/users/test-actor#votes/13", + "to": "https://example.com/@test@example.com/", + "name": "Option 2", + "type": "Note", + "inReplyTo": post.object_uri, + "attributedTo": "https://remote.test/test-actor/", + }, + "@context": [ + "https://www.w3.org/ns/activitystreams", + { + "toot": "http://joinmastodon.org/ns#", + "Emoji": "toot:Emoji", + "Public": "as:Public", + "Hashtag": "as:Hashtag", + "votersCount": "toot:votersCount", + }, + "https://w3id.org/security/v1", + ], + }, + ) + + post.refresh_from_db() + + assert isinstance(post.type_data, QuestionData) + assert post.type_data.voter_count == 1 + assert post.type_data.options + assert len(post.type_data.options) == 3 + assert post.type_data.options[0].votes == 1 + assert post.type_data.options[1].votes == 1 + assert post.type_data.options[2].votes == 0 + + +@pytest.mark.django_db +def test_multiple_votes_to_single_vote_question( + identity: Identity, remote_identity: Identity, config_system +): + post = Post.create_local( + author=identity, + content="

Test Question

", + question={ + "type": "Question", + "mode": "oneOf", + "options": [ + {"name": "Option 1", "type": "Note", "votes": 0}, + {"name": "Option 2", "type": "Note", "votes": 0}, + ], + "voter_count": 0, + "end_time": format_ld_date(timezone.now() + timedelta(1)), + }, + ) + + PostInteraction.by_ap( + data={ + "id": "https://remote.test/test-actor#votes/14/activity", + "to": "https://example.com/@test@example.com/", + "type": "Create", + "actor": "https://remote.test/test-actor/", + "object": { + "id": "https://remote.test/users/test-actor#votes/14", + "to": "https://example.com/@test@example.com/", + "name": "Option 1", + "type": "Note", + "inReplyTo": post.object_uri, + "attributedTo": "https://remote.test/test-actor/", + }, + "@context": [ + "https://www.w3.org/ns/activitystreams", + { + "toot": "http://joinmastodon.org/ns#", + "Emoji": "toot:Emoji", + "Public": "as:Public", + "Hashtag": "as:Hashtag", + "votersCount": "toot:votersCount", + }, + "https://w3id.org/security/v1", + ], + }, + create=True, + ) + + with pytest.raises(PostInteraction.DoesNotExist) as ex: + PostInteraction.by_ap( + data={ + "id": "https://remote.test/test-actor#votes/15/activity", + "to": "https://example.com/@test@example.com/", + "type": "Create", + "actor": "https://remote.test/test-actor/", + "object": { + "id": "https://remote.test/users/test-actor#votes/15", + "to": "https://example.com/@test@example.com/", + "name": "Option 2", + "type": "Note", + "inReplyTo": post.object_uri, + "attributedTo": "https://remote.test/test-actor/", + }, + "@context": [ + "https://www.w3.org/ns/activitystreams", + { + "toot": "http://joinmastodon.org/ns#", + "Emoji": "toot:Emoji", + "Public": "as:Public", + "Hashtag": "as:Hashtag", + "votersCount": "toot:votersCount", + }, + "https://w3id.org/security/v1", + ], + }, + create=True, + ) + assert "already voted" in str(ex) + + +@pytest.mark.django_db +def test_vote_in_expired_question( + identity: Identity, remote_identity: Identity, config_system +): + post = Post.create_local( + author=identity, + content="

Test Question

", + question={ + "type": "Question", + "mode": "oneOf", + "options": [ + {"name": "Option 1", "type": "Note", "votes": 0}, + {"name": "Option 2", "type": "Note", "votes": 0}, + ], + "voter_count": 0, + "end_time": "2022-01-01T23:04:45+00:00", + }, + ) + + with pytest.raises(PostInteraction.DoesNotExist) as ex: + PostInteraction.by_ap( + data={ + "id": "https://remote.test/test-actor#votes/16/activity", + "to": "https://example.com/@test@example.com/", + "type": "Create", + "actor": "https://remote.test/test-actor/", + "object": { + "id": "https://remote.test/users/test-actor#votes/16", + "to": "https://example.com/@test@example.com/", + "name": "Option 1", + "type": "Note", + "inReplyTo": post.object_uri, + "attributedTo": "https://remote.test/test-actor/", + }, + "@context": [ + "https://www.w3.org/ns/activitystreams", + { + "toot": "http://joinmastodon.org/ns#", + "Emoji": "toot:Emoji", + "Public": "as:Public", + "Hashtag": "as:Hashtag", + "votersCount": "toot:votersCount", + }, + "https://w3id.org/security/v1", + ], + }, + create=True, + ) + assert "Cannot create a vote to the expired question" in str(ex) + + +@pytest.mark.django_db +def test_vote_to_ap(identity: Identity, remote_identity: Identity, config_system): + post = Post.objects.create( + author=remote_identity, + local=False, + content="

Test Question

", + type_data={ + "type": "Question", + "mode": "oneOf", + "options": [ + {"name": "Option 1", "type": "Note", "votes": 6}, + {"name": "Option 2", "type": "Note", "votes": 4}, + ], + "voter_count": 10, + "end_time": format_ld_date(timezone.now() + timedelta(1)), + }, + ) + post.refresh_from_db() + + interaction = PostInteraction.create_votes( + post=post, + identity=identity, + choices=[0], + )[0] + + data = interaction.to_create_ap() + assert data["object"]["to"] == remote_identity.actor_uri + assert data["object"]["attributedTo"] == identity.actor_uri + assert data["object"]["name"] == "Option 1" + assert data["object"]["inReplyTo"] == post.object_uri diff --git a/tests/api/test_polls.py b/tests/api/test_polls.py new file mode 100644 index 0000000..4043078 --- /dev/null +++ b/tests/api/test_polls.py @@ -0,0 +1,62 @@ +from datetime import timedelta + +import pytest +from django.utils import timezone + +from activities.models import Post +from core.ld import format_ld_date + + +@pytest.mark.django_db +def test_get_poll(api_client): + response = api_client.post( + "/api/v1/statuses", + content_type="application/json", + data={ + "status": "Hello, world!", + "poll": { + "options": ["Option 1", "Option 2"], + "expires_in": 300, + }, + }, + ).json() + + id = response["id"] + + response = api_client.get( + f"/api/v1/polls/{id}", + ).json() + + assert response["id"] == id + assert response["voted"] + + +@pytest.mark.django_db +def test_vote_poll(api_client, identity2): + post = Post.create_local( + author=identity2, + content="

Test Question

", + question={ + "type": "Question", + "mode": "oneOf", + "options": [ + {"name": "Option 1", "type": "Note", "votes": 0}, + {"name": "Option 2", "type": "Note", "votes": 0}, + ], + "voter_count": 0, + "end_time": format_ld_date(timezone.now() + timedelta(1)), + }, + ) + + response = api_client.post( + f"/api/v1/polls/{post.id}/votes", + content_type="application/json", + data={ + "choices": [0], + }, + ).json() + + assert response["id"] == str(post.id) + assert response["voted"] + assert response["votes_count"] == 1 + assert response["own_votes"] == [0] diff --git a/tests/api/test_statuses.py b/tests/api/test_statuses.py index 05ede30..4add2af 100644 --- a/tests/api/test_statuses.py +++ b/tests/api/test_statuses.py @@ -75,3 +75,68 @@ def test_mention_format(api_client, identity, remote_identity): assert ( response["text"] == '

Hey @test

' ) + + +@pytest.mark.django_db +def test_post_question_status(api_client): + response = api_client.post( + "/api/v1/statuses", + content_type="application/json", + data={ + "status": "Hello, world!", + "poll": { + "options": ["Option 1", "Option 2"], + "expires_in": 300, + }, + }, + ).json() + + assert response["poll"]["id"] == response["id"] + assert response["poll"]["options"] == [ + {"title": "Option 1", "votes_count": 0}, + {"title": "Option 2", "votes_count": 0}, + ] + assert not response["poll"]["expired"] + assert not response["poll"]["multiple"] + + +@pytest.mark.django_db +def test_question_format(api_client, remote_identity): + """ + Ensures incoming questions are property parsed. + """ + # Make a remote question post and check it + post = Post.objects.create( + local=False, + author=remote_identity, + content="

Test Question

", + object_uri="https://remote.test/status/123456", + type=Post.Types.question, + type_data={ + "type": "Question", + "mode": "oneOf", + "options": [ + {"name": "Option 1", "type": "Note", "votes": 10}, + {"name": "Option 2", "type": "Note", "votes": 20}, + ], + "voter_count": 30, + "end_time": "2022-01-01T23:04:45+00:00", + }, + ) + response = api_client.get(f"/api/v1/statuses/{post.id}").json() + assert response["text"] == "

Test Question

" + assert response["poll"] == { + "id": str(post.id), + "expires_at": "2022-01-01T23:04:45.000Z", + "expired": True, + "multiple": False, + "votes_count": 30, + "voters_count": None, + "voted": False, + "own_votes": [], + "options": [ + {"title": "Option 1", "votes_count": 10}, + {"title": "Option 2", "votes_count": 20}, + ], + "emojis": [], + } diff --git a/users/models/inbox_message.py b/users/models/inbox_message.py index 82a166b..4119aab 100644 --- a/users/models/inbox_message.py +++ b/users/models/inbox_message.py @@ -30,9 +30,15 @@ class InboxMessageStates(StateGraph): case "create": match instance.message_object_type: case "note": - await sync_to_async(Post.handle_create_ap)(instance.message) + if instance.message_object_has_content: + await sync_to_async(Post.handle_create_ap)(instance.message) + else: + # Notes without content are Interaction candidates + await sync_to_async(PostInteraction.handle_ap)( + instance.message + ) case "question": - pass # Drop for now + await sync_to_async(Post.handle_create_ap)(instance.message) case unknown: if unknown in Post.Types.names: await sync_to_async(Post.handle_create_ap)(instance.message) @@ -213,3 +219,7 @@ class InboxMessage(StatorModel): @property def message_actor(self): return self.message.get("actor") + + @property + def message_object_has_content(self): + return "content" in self.message.get("object", {})