Implement poll AP and API backend
This commit is contained in:
Humberto Rocha 2023-02-15 12:42:56 -05:00 committed by GitHub
parent cc0bbfb93b
commit e382994e73
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 868 additions and 43 deletions

View File

@ -0,0 +1,26 @@
# Generated by Django 4.1.7 on 2023-02-14 22:30
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("activities", "0010_stator_indexes"),
]
operations = [
migrations.AddField(
model_name="postinteraction",
name="value",
field=models.CharField(blank=True, max_length=50, null=True),
),
migrations.AlterField(
model_name="postinteraction",
name="type",
field=models.CharField(
choices=[("like", "Like"), ("boost", "Boost"), ("vote", "Vote")],
max_length=100,
),
),
]

View File

@ -163,7 +163,7 @@ class FanOutStates(StateGraph):
interaction=interaction, interaction=interaction,
) )
# Handle sending remote boosts/likes # Handle sending remote boosts/likes/votes
case (FanOut.Types.interaction, False): case (FanOut.Types.interaction, False):
interaction = await fan_out.subject_post_interaction.afetch_full() interaction = await fan_out.subject_post_interaction.afetch_full()
# Send it to the remote inbox # Send it to the remote inbox
@ -174,7 +174,11 @@ class FanOutStates(StateGraph):
fan_out.identity.shared_inbox_uri fan_out.identity.shared_inbox_uri
or fan_out.identity.inbox_uri or fan_out.identity.inbox_uri
), ),
body=canonicalise(interaction.to_ap()), body=canonicalise(
interaction.to_create_ap()
if interaction.type == interaction.Types.vote
else interaction.to_ap()
),
) )
except httpx.RequestError: except httpx.RequestError:
return return

View File

@ -23,6 +23,7 @@ from activities.models.post_types import (
PostTypeData, PostTypeData,
PostTypeDataDecoder, PostTypeDataDecoder,
PostTypeDataEncoder, PostTypeDataEncoder,
QuestionData,
) )
from core.exceptions import capture_message from core.exceptions import capture_message
from core.html import ContentRenderer, FediverseHtmlParser from core.html import ContentRenderer, FediverseHtmlParser
@ -377,12 +378,16 @@ class Post(StatorModel):
def _safe_content_note(self, *, local: bool = True): def _safe_content_note(self, *, local: bool = True):
return ContentRenderer(local=local).render_post(self.content, self) return ContentRenderer(local=local).render_post(self.content, self)
# def _safe_content_question(self, *, local: bool = True): def _safe_content_question(self, *, local: bool = True):
# context = { if local:
# "post": self, context = {
# "typed_data": PostTypeData(self.type_data), "post": self,
# } "sanitized_content": self._safe_content_note(local=local),
# return loader.render_to_string("activities/_type_question.html", context) "local_display": local,
}
return loader.render_to_string("activities/_type_question.html", context)
else:
return ContentRenderer(local=local).render_post(self.content, self)
def _safe_content_typed(self, *, local: bool = True): def _safe_content_typed(self, *, local: bool = True):
context = { context = {
@ -461,6 +466,7 @@ class Post(StatorModel):
visibility: int = Visibilities.public, visibility: int = Visibilities.public,
reply_to: Optional["Post"] = None, reply_to: Optional["Post"] = None,
attachments: list | None = None, attachments: list | None = None,
question: dict | None = None,
) -> "Post": ) -> "Post":
with transaction.atomic(): with transaction.atomic():
# Find mentions in this post # Find mentions in this post
@ -493,6 +499,9 @@ class Post(StatorModel):
post.emojis.set(emojis) post.emojis.set(emojis)
if attachments: if attachments:
post.attachments.set(attachments) post.attachments.set(attachments)
if question:
post.type = question["type"]
post.type_data = PostTypeData(__root__=question).__root__
post.save() post.save()
# Recalculate parent stats for replies # Recalculate parent stats for replies
if reply_to: if reply_to:
@ -574,6 +583,30 @@ class Post(StatorModel):
if save: if save:
self.save() self.save()
def calculate_type_data(self, save=True):
"""
Recalculate type_data (used mostly for poll votes)
"""
from activities.models import PostInteraction
if self.local and isinstance(self.type_data, QuestionData):
self.type_data.voter_count = (
self.interactions.filter(
type=PostInteraction.Types.vote,
)
.values("identity")
.distinct()
.count()
)
for option in self.type_data.options:
option.votes = self.interactions.filter(
type=PostInteraction.Types.vote,
value=option.name,
).count()
if save:
self.save()
### ActivityPub (outbound) ### ### ActivityPub (outbound) ###
def to_ap(self) -> dict: def to_ap(self) -> dict:
@ -802,8 +835,7 @@ class Post(StatorModel):
if update or created: if update or created:
post.type = data["type"] post.type = data["type"]
if post.type in (cls.Types.article, cls.Types.question): if post.type in (cls.Types.article, cls.Types.question):
type_data = PostTypeData(__root__=data).__root__ post.type_data = PostTypeData(__root__=data).__root__
post.type_data = type_data.dict()
post.content = get_value_or_map(data, "content", "contentMap") post.content = get_value_or_map(data, "content", "contentMap")
post.summary = data.get("summary") post.summary = data.get("summary")
post.sensitive = data.get("sensitive", False) post.sensitive = data.get("sensitive", False)
@ -1027,7 +1059,7 @@ class Post(StatorModel):
### Mastodon API ### ### Mastodon API ###
def to_mastodon_json(self, interactions=None): def to_mastodon_json(self, interactions=None, identity=None):
reply_parent = None reply_parent = None
if self.in_reply_to: if self.in_reply_to:
# Load the PK and author.id explicitly to prevent a SELECT on the entire author Identity # Load the PK and author.id explicitly to prevent a SELECT on the entire author Identity
@ -1085,7 +1117,9 @@ class Post(StatorModel):
reply_parent.author_id if reply_parent else None reply_parent.author_id if reply_parent else None
), ),
"reblog": None, "reblog": None,
"poll": None, "poll": self.type_data.to_mastodon_json(self, identity)
if isinstance(self.type_data, QuestionData)
else None,
"card": None, "card": None,
"language": None, "language": None,
"text": self.safe_content_remote(), "text": self.safe_content_remote(),

View File

@ -3,6 +3,7 @@ from django.utils import timezone
from activities.models.fan_out import FanOut from activities.models.fan_out import FanOut
from activities.models.post import Post from activities.models.post import Post
from activities.models.post_types import QuestionData
from core.ld import format_ld_date, get_str_or_id, parse_ld_date from core.ld import format_ld_date, get_str_or_id, parse_ld_date
from core.snowflake import Snowflake from core.snowflake import Snowflake
from stator.models import State, StateField, StateGraph, StatorModel from stator.models import State, StateField, StateGraph, StatorModel
@ -60,6 +61,17 @@ class PostInteractionStates(StateGraph):
subject_post=interaction.post, subject_post=interaction.post,
subject_post_interaction=interaction, subject_post_interaction=interaction,
) )
# Vote: send a copy of the vote to the original
# post author only if it's a local interaction
# to a non local post
elif interaction.type == interaction.Types.vote:
if interaction.identity.local and not interaction.post.local:
await FanOut.objects.acreate(
type=FanOut.Types.interaction,
identity_id=interaction.post.author_id,
subject_post=interaction.post,
subject_post_interaction=interaction,
)
else: else:
raise ValueError("Cannot fan out unknown type") raise ValueError("Cannot fan out unknown type")
# And one for themselves if they're local and it's a boost # And one for themselves if they're local and it's a boost
@ -125,6 +137,7 @@ class PostInteraction(StatorModel):
class Types(models.TextChoices): class Types(models.TextChoices):
like = "like" like = "like"
boost = "boost" boost = "boost"
vote = "vote"
id = models.BigIntegerField( id = models.BigIntegerField(
primary_key=True, primary_key=True,
@ -154,6 +167,10 @@ class PostInteraction(StatorModel):
related_name="interactions", related_name="interactions",
) )
# Used to store any interaction extra text value like the vote
# in the question/poll case
value = models.CharField(max_length=50, blank=True, null=True)
# When the activity was originally created (as opposed to when we received it) # When the activity was originally created (as opposed to when we received it)
# Mastodon only seems to send this for boosts, not likes # Mastodon only seems to send this for boosts, not likes
published = models.DateTimeField(default=timezone.now) published = models.DateTimeField(default=timezone.now)
@ -203,9 +220,44 @@ class PostInteraction(StatorModel):
""" """
Returns a version of the object with all relations pre-loaded Returns a version of the object with all relations pre-loaded
""" """
return await PostInteraction.objects.select_related("identity", "post").aget( return await PostInteraction.objects.select_related(
pk=self.pk "identity", "post", "post__author"
) ).aget(pk=self.pk)
### Create helpers ###
@classmethod
def create_votes(cls, post, identity, choices) -> list["PostInteraction"]:
question = post.type_data
if question.end_time and timezone.now() > question.end_time:
raise ValueError("Validation failed: The poll has already ended")
if post.interactions.filter(identity=identity, type=cls.Types.vote).exists():
raise ValueError("Validation failed: You have already voted on this poll")
votes = []
with transaction.atomic():
for choice in set(choices):
vote = cls.objects.create(
identity=identity,
post=post,
type=PostInteraction.Types.vote,
value=question.options[choice].name,
)
vote.object_uri = f"{identity.actor_uri}#votes/{vote.id}"
vote.save()
votes.append(vote)
if not post.local:
question.options[choice].votes += 1
if not post.local:
question.voter_count += 1
post.calculate_type_data()
return votes
### ActivityPub (outbound) ### ### ActivityPub (outbound) ###
@ -233,10 +285,33 @@ class PostInteraction(StatorModel):
"actor": self.identity.actor_uri, "actor": self.identity.actor_uri,
"object": self.post.object_uri, "object": self.post.object_uri,
} }
elif self.type == self.Types.vote:
value = {
"type": "Note",
"id": self.object_uri,
"to": self.post.author.actor_uri,
"name": self.value,
"inReplyTo": self.post.object_uri,
"attributedTo": self.identity.actor_uri,
}
else: else:
raise ValueError("Cannot turn into AP") raise ValueError("Cannot turn into AP")
return value return value
def to_create_ap(self):
"""
Returns the AP JSON to create this object
"""
object = self.to_ap()
return {
"to": object.get("to", []),
"cc": object.get("cc", []),
"type": "Create",
"id": self.object_uri,
"actor": self.identity.actor_uri,
"object": object,
}
def to_undo_ap(self) -> dict: def to_undo_ap(self) -> dict:
""" """
Returns the AP JSON to undo this object Returns the AP JSON to undo this object
@ -267,12 +342,40 @@ class PostInteraction(StatorModel):
# Resolve the author # Resolve the author
identity = Identity.by_actor_uri(data["actor"], create=True) identity = Identity.by_actor_uri(data["actor"], create=True)
# Resolve the post # Resolve the post
post = Post.by_object_uri(get_str_or_id(data["object"]), fetch=True) object = data["object"]
target = get_str_or_id(object, "inReplyTo") or get_str_or_id(object)
post = Post.by_object_uri(target, fetch=True)
value = None
# Get the right type # Get the right type
if data["type"].lower() == "like": if data["type"].lower() == "like":
type = cls.Types.like type = cls.Types.like
elif data["type"].lower() == "announce": elif data["type"].lower() == "announce":
type = cls.Types.boost type = cls.Types.boost
elif (
data["type"].lower() == "create"
and object["type"].lower() == "note"
and isinstance(post.type_data, QuestionData)
):
type = cls.Types.vote
question = post.type_data
value = object["name"]
if question.end_time and timezone.now() > question.end_time:
# TODO: Maybe create an expecific expired exception?
raise cls.DoesNotExist(
f"Cannot create a vote to the expired question {post.id}"
)
already_voted = (
post.type_data.mode == "oneOf"
and post.interactions.filter(
type=cls.Types.vote, identity=identity
).exists()
)
if already_voted:
raise cls.DoesNotExist(
f"The identity {identity.handle} already voted in question {post.id}"
)
else: else:
raise ValueError(f"Cannot handle AP type {data['type']}") raise ValueError(f"Cannot handle AP type {data['type']}")
# Make the actual interaction # Make the actual interaction
@ -283,6 +386,7 @@ class PostInteraction(StatorModel):
published=parse_ld_date(data.get("published", None)) published=parse_ld_date(data.get("published", None))
or timezone.now(), or timezone.now(),
type=type, type=type,
value=value,
) )
else: else:
raise cls.DoesNotExist(f"No interaction with ID {data['id']}", data) raise cls.DoesNotExist(f"No interaction with ID {data['id']}", data)
@ -301,7 +405,9 @@ class PostInteraction(StatorModel):
# That post is gone, boss # That post is gone, boss
# TODO: Limited retry state? # TODO: Limited retry state?
return return
interaction.post.calculate_stats() interaction.post.calculate_stats()
interaction.post.calculate_type_data()
@classmethod @classmethod
def handle_undo_ap(cls, data): def handle_undo_ap(cls, data):
@ -324,10 +430,11 @@ class PostInteraction(StatorModel):
interaction.transition_perform(PostInteractionStates.undone_fanned_out) interaction.transition_perform(PostInteractionStates.undone_fanned_out)
# Recalculate post stats # Recalculate post stats
interaction.post.calculate_stats() interaction.post.calculate_stats()
interaction.post.calculate_type_data()
### Mastodon API ### ### Mastodon API ###
def to_mastodon_status_json(self, interactions=None): def to_mastodon_status_json(self, interactions=None, identity=None):
""" """
This wraps Posts in a fake Status for boost interactions. This wraps Posts in a fake Status for boost interactions.
""" """
@ -336,7 +443,9 @@ class PostInteraction(StatorModel):
f"Cannot make status JSON for interaction of type {self.type}" f"Cannot make status JSON for interaction of type {self.type}"
) )
# Make a fake post for this boost (because mastodon treats boosts as posts) # Make a fake post for this boost (because mastodon treats boosts as posts)
post_json = self.post.to_mastodon_json(interactions=interactions) post_json = self.post.to_mastodon_json(
interactions=interactions, identity=identity
)
return { return {
"id": f"{self.pk}", "id": f"{self.pk}",
"uri": post_json["uri"], "uri": post_json["uri"],
@ -356,7 +465,7 @@ class PostInteraction(StatorModel):
"url": post_json["url"], "url": post_json["url"],
"in_reply_to_id": None, "in_reply_to_id": None,
"in_reply_to_account_id": None, "in_reply_to_account_id": None,
"poll": None, "poll": post_json["poll"],
"card": None, "card": None,
"language": None, "language": None,
"text": "", "text": "",

View File

@ -2,8 +2,11 @@ import json
from datetime import datetime from datetime import datetime
from typing import Literal from typing import Literal
from django.utils import timezone
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from core.ld import format_ld_date
class BasePostDataType(BaseModel): class BasePostDataType(BaseModel):
pass pass
@ -45,6 +48,51 @@ class QuestionData(BasePostDataType):
data["options"] = options data["options"] = options
super().__init__(**data) super().__init__(**data)
def to_mastodon_json(self, post, identity=None):
from activities.models import PostInteraction
multiple = self.mode == "anyOf"
value = {
"id": post.id,
"expires_at": None,
"expired": False,
"multiple": multiple,
"votes_count": 0,
"voters_count": self.voter_count if multiple else None,
"voted": False,
"own_votes": [],
"options": [],
"emojis": [],
}
if self.end_time:
value["expires_at"] = format_ld_date(self.end_time)
value["expired"] = timezone.now() >= self.end_time
options = self.options or []
option_map = {}
for index, option in enumerate(options):
value["options"].append(
{
"title": option.name,
"votes_count": option.votes,
}
)
value["votes_count"] += option.votes
option_map[option.name] = index
if identity:
votes = post.interactions.filter(
identity=identity,
type=PostInteraction.Types.vote,
)
value["voted"] = post.author == identity or votes.exists()
value["own_votes"] = [
option_map[vote.value] for vote in votes if vote.value in option_map
]
return value
class ArticleData(BasePostDataType): class ArticleData(BasePostDataType):
type: Literal["Article"] type: Literal["Article"]

View File

@ -221,12 +221,14 @@ class TimelineEvent(models.Model):
raise ValueError(f"Cannot convert {self.type} to notification JSON") raise ValueError(f"Cannot convert {self.type} to notification JSON")
return result return result
def to_mastodon_status_json(self, interactions=None): def to_mastodon_status_json(self, interactions=None, identity=None):
if self.type == self.Types.post: if self.type == self.Types.post:
return self.subject_post.to_mastodon_json(interactions=interactions) return self.subject_post.to_mastodon_json(
interactions=interactions, identity=identity
)
elif self.type == self.Types.boost: elif self.type == self.Types.boost:
return self.subject_post_interaction.to_mastodon_status_json( return self.subject_post_interaction.to_mastodon_status_json(
interactions=interactions interactions=interactions, identity=identity
) )
else: else:
raise ValueError(f"Cannot make status JSON for type {self.type}") raise ValueError(f"Cannot make status JSON for type {self.type}")

View File

@ -153,7 +153,9 @@ class PaginationResult(Generic[T]):
""" """
interactions = PostInteraction.get_post_interactions(self.results, identity) interactions = PostInteraction.get_post_interactions(self.results, identity)
self.jsonify_results( self.jsonify_results(
lambda post: post.to_mastodon_json(interactions=interactions) lambda post: post.to_mastodon_json(
interactions=interactions, identity=identity
)
) )
def jsonify_status_events(self, identity): def jsonify_status_events(self, identity):
@ -162,7 +164,9 @@ class PaginationResult(Generic[T]):
""" """
interactions = PostInteraction.get_event_interactions(self.results, identity) interactions = PostInteraction.get_event_interactions(self.results, identity)
self.jsonify_results( self.jsonify_results(
lambda event: event.to_mastodon_status_json(interactions=interactions) lambda event: event.to_mastodon_status_json(
interactions=interactions, identity=identity
)
) )
def jsonify_notification_events(self, identity): def jsonify_notification_events(self, identity):

View File

@ -91,6 +91,32 @@ class MediaAttachment(Schema):
return cls(**attachment.to_mastodon_json()) return cls(**attachment.to_mastodon_json())
class PollOptions(Schema):
title: str
votes_count: int | None
class Poll(Schema):
id: str
expires_at: str | None
expired: bool
multiple: bool
votes_count: int
voters_count: int | None
voted: bool
own_votes: list[int]
options: list[PollOptions]
emojis: list[CustomEmoji]
@classmethod
def from_post(
cls,
post: activities_models.Post,
identity: users_models.Identity | None = None,
) -> "Poll":
return cls(**post.type_data.to_mastodon_json(post, identity=identity))
class StatusMention(Schema): class StatusMention(Schema):
id: str id: str
username: str username: str
@ -123,7 +149,7 @@ class Status(Schema):
in_reply_to_id: str | None = Field(...) in_reply_to_id: str | None = Field(...)
in_reply_to_account_id: str | None = Field(...) in_reply_to_account_id: str | None = Field(...)
reblog: Optional["Status"] = Field(...) reblog: Optional["Status"] = Field(...)
poll: None = Field(...) poll: Poll | None = Field(...)
card: None = Field(...) card: None = Field(...)
language: None = Field(...) language: None = Field(...)
text: str | None = Field(...) text: str | None = Field(...)
@ -139,8 +165,11 @@ class Status(Schema):
cls, cls,
post: activities_models.Post, post: activities_models.Post,
interactions: dict[str, set[str]] | None = None, interactions: dict[str, set[str]] | None = None,
identity: users_models.Identity | None = None,
) -> "Status": ) -> "Status":
return cls(**post.to_mastodon_json(interactions=interactions)) return cls(
**post.to_mastodon_json(interactions=interactions, identity=identity)
)
@classmethod @classmethod
def map_from_post( def map_from_post(
@ -151,15 +180,23 @@ class Status(Schema):
interactions = activities_models.PostInteraction.get_post_interactions( interactions = activities_models.PostInteraction.get_post_interactions(
posts, identity posts, identity
) )
return [cls.from_post(post, interactions=interactions) for post in posts] return [
cls.from_post(post, interactions=interactions, identity=identity)
for post in posts
]
@classmethod @classmethod
def from_timeline_event( def from_timeline_event(
cls, cls,
timeline_event: activities_models.TimelineEvent, timeline_event: activities_models.TimelineEvent,
interactions: dict[str, set[str]] | None = None, interactions: dict[str, set[str]] | None = None,
identity: users_models.Identity | None = None,
) -> "Status": ) -> "Status":
return cls(**timeline_event.to_mastodon_status_json(interactions=interactions)) return cls(
**timeline_event.to_mastodon_status_json(
interactions=interactions, identity=identity
)
)
@classmethod @classmethod
def map_from_timeline_event( def map_from_timeline_event(
@ -171,7 +208,7 @@ class Status(Schema):
events, identity events, identity
) )
return [ return [
cls.from_timeline_event(event, interactions=interactions) cls.from_timeline_event(event, interactions=interactions, identity=identity)
for event in events for event in events
] ]

View File

@ -10,6 +10,7 @@ from api.views import (
instance, instance,
media, media,
notifications, notifications,
polls,
search, search,
statuses, statuses,
timelines, timelines,
@ -62,6 +63,9 @@ urlpatterns = [
path("v1/statuses/<id>/source", statuses.status_source), path("v1/statuses/<id>/source", statuses.status_source),
# Notifications # Notifications
path("v1/notifications", notifications.notifications), path("v1/notifications", notifications.notifications),
# Polls
path("v1/polls/<id>", polls.get_poll),
path("v1/polls/<id>/votes", polls.vote_poll),
# Search # Search
path("v2/search", search.search), path("v2/search", search.search),
# Statuses # Statuses

26
api/views/polls.py Normal file
View File

@ -0,0 +1,26 @@
from django.shortcuts import get_object_or_404
from hatchway import Schema, api_view
from activities.models import Post, PostInteraction
from api import schemas
from api.decorators import identity_required
class PostVoteSchema(Schema):
choices: list[int]
@identity_required
@api_view.get
def get_poll(request, id: str) -> schemas.Poll:
post = get_object_or_404(Post, pk=id, type=Post.Types.question)
return schemas.Poll.from_post(post, identity=request.identity)
@identity_required
@api_view.post
def vote_poll(request, id: str, details: PostVoteSchema) -> schemas.Poll:
post = get_object_or_404(Post, pk=id, type=Post.Types.question)
PostInteraction.create_votes(post, request.identity, details.choices)
post.refresh_from_db()
return schemas.Poll.from_post(post, identity=request.identity)

View File

@ -47,7 +47,9 @@ def search(
search_result["posts"], request.identity search_result["posts"], request.identity
) )
result["statuses"] = [ result["statuses"] = [
schemas.Status.from_post(p, interactions=interactions) schemas.Status.from_post(
p, interactions=interactions, identity=request.identity
)
for p in search_result["posts"] for p in search_result["posts"]
] ]
return schemas.Search(**result) return schemas.Search(**result)

View File

@ -1,7 +1,9 @@
from datetime import timedelta
from typing import Literal from typing import Literal
from django.http import HttpRequest from django.http import HttpRequest
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from django.utils import timezone
from hatchway import ApiError, ApiResponse, Schema, api_view from hatchway import ApiError, ApiResponse, Schema, api_view
from activities.models import ( from activities.models import (
@ -18,6 +20,24 @@ from api.pagination import MastodonPaginator, PaginationResult
from core.models import Config from core.models import Config
class PostPollSchema(Schema):
options: list[str]
expires_in: int
multiple: bool = False
hide_totals: bool = False
def dict(self):
return {
"type": "Question",
"mode": "anyOf" if self.multiple else "oneOf",
"options": [
{"name": name, "type": "Note", "votes": 0} for name in self.options
],
"voter_count": 0,
"end_time": timezone.now() + timedelta(seconds=self.expires_in),
}
class PostStatusSchema(Schema): class PostStatusSchema(Schema):
status: str status: str
in_reply_to_id: str | None = None in_reply_to_id: str | None = None
@ -27,6 +47,7 @@ class PostStatusSchema(Schema):
language: str | None = None language: str | None = None
scheduled_at: str | None = None scheduled_at: str | None = None
media_ids: list[str] = [] media_ids: list[str] = []
poll: PostPollSchema | None = None
class EditStatusSchema(Schema): class EditStatusSchema(Schema):
@ -82,10 +103,11 @@ def post_status(request, details: PostStatusSchema) -> schemas.Status:
visibility=visibility_map[details.visibility], visibility=visibility_map[details.visibility],
reply_to=reply_post, reply_to=reply_post,
attachments=attachments, attachments=attachments,
question=details.poll.dict() if details.poll else None,
) )
# Add their own timeline event for immediate visibility # Add their own timeline event for immediate visibility
TimelineEvent.add_post(request.identity, post) TimelineEvent.add_post(request.identity, post)
return schemas.Status.from_post(post) return schemas.Status.from_post(post, identity=request.identity)
@identity_required @identity_required
@ -93,7 +115,9 @@ def post_status(request, details: PostStatusSchema) -> schemas.Status:
def status(request, id: str) -> schemas.Status: def status(request, id: str) -> schemas.Status:
post = post_for_id(request, id) post = post_for_id(request, id)
interactions = PostInteraction.get_post_interactions([post], request.identity) interactions = PostInteraction.get_post_interactions([post], request.identity)
return schemas.Status.from_post(post, interactions=interactions) return schemas.Status.from_post(
post, interactions=interactions, identity=request.identity
)
@identity_required @identity_required
@ -121,7 +145,7 @@ def delete_status(request, id: str) -> schemas.Status:
if post.author != request.identity: if post.author != request.identity:
raise ApiError(401, "Not the author of this status") raise ApiError(401, "Not the author of this status")
PostService(post).delete() PostService(post).delete()
return schemas.Status.from_post(post) return schemas.Status.from_post(post, identity=request.identity)
@identity_required @identity_required
@ -142,11 +166,16 @@ def status_context(request, id: str) -> schemas.Context:
) )
return schemas.Context( return schemas.Context(
ancestors=[ ancestors=[
schemas.Status.from_post(p, interactions=interactions) schemas.Status.from_post(
p, interactions=interactions, identity=request.identity
)
for p in reversed(ancestors) for p in reversed(ancestors)
], ],
descendants=[ descendants=[
schemas.Status.from_post(p, interactions=interactions) for p in descendants schemas.Status.from_post(
p, interactions=interactions, identity=request.identity
)
for p in descendants
], ],
) )
@ -158,7 +187,9 @@ def favourite_status(request, id: str) -> schemas.Status:
service = PostService(post) service = PostService(post)
service.like_as(request.identity) service.like_as(request.identity)
interactions = PostInteraction.get_post_interactions([post], request.identity) interactions = PostInteraction.get_post_interactions([post], request.identity)
return schemas.Status.from_post(post, interactions=interactions) return schemas.Status.from_post(
post, interactions=interactions, identity=request.identity
)
@identity_required @identity_required
@ -168,7 +199,9 @@ def unfavourite_status(request, id: str) -> schemas.Status:
service = PostService(post) service = PostService(post)
service.unlike_as(request.identity) service.unlike_as(request.identity)
interactions = PostInteraction.get_post_interactions([post], request.identity) interactions = PostInteraction.get_post_interactions([post], request.identity)
return schemas.Status.from_post(post, interactions=interactions) return schemas.Status.from_post(
post, interactions=interactions, identity=request.identity
)
@api_view.get @api_view.get
@ -219,7 +252,9 @@ def reblog_status(request, id: str) -> schemas.Status:
service = PostService(post) service = PostService(post)
service.boost_as(request.identity) service.boost_as(request.identity)
interactions = PostInteraction.get_post_interactions([post], request.identity) interactions = PostInteraction.get_post_interactions([post], request.identity)
return schemas.Status.from_post(post, interactions=interactions) return schemas.Status.from_post(
post, interactions=interactions, identity=request.identity
)
@identity_required @identity_required
@ -229,4 +264,6 @@ def unreblog_status(request, id: str) -> schemas.Status:
service = PostService(post) service = PostService(post)
service.unboost_as(request.identity) service.unboost_as(request.identity)
interactions = PostInteraction.get_post_interactions([post], request.identity) interactions = PostInteraction.get_post_interactions([post], request.identity)
return schemas.Status.from_post(post, interactions=interactions) return schemas.Status.from_post(
post, interactions=interactions, identity=request.identity
)

View File

@ -626,14 +626,14 @@ def get_list(container, key) -> list:
return value return value
def get_str_or_id(value: str | dict | None) -> str | None: def get_str_or_id(value: str | dict | None, key: str = "id") -> str | None:
""" """
Given a value that could be a str or {"id": str}, return the str Given a value that could be a str or {"id": str}, return the str
""" """
if isinstance(value, str): if isinstance(value, str):
return value return value
elif isinstance(value, dict): elif isinstance(value, dict):
return value.get("id") return value.get(key)
return None return None

View File

@ -2,6 +2,7 @@ import pytest
from pytest_httpx import HTTPXMock from pytest_httpx import HTTPXMock
from activities.models import Post, PostStates from activities.models import Post, PostStates
from activities.models.post_types import QuestionData
from users.models import Identity, InboxMessage from users.models import Identity, InboxMessage
@ -254,6 +255,46 @@ def test_content_map(remote_identity):
assert post3.content == "Hello World" assert post3.content == "Hello World"
@pytest.mark.django_db
def test_content_map_question(remote_identity: Identity):
"""
Tests post contentmap for questions
"""
post = Post.by_ap(
data={
"id": "https://remote.test/posts/1/",
"type": "Question",
"votersCount": 10,
"closed": "2023-01-01T26:04:45Z",
"content": "Test Question",
"attributedTo": "https://remote.test/test-actor/",
"published": "2022-12-23T10:50:54Z",
"endTime": "2023-01-01T20:04:45Z",
"oneOf": [
{
"type": "Note",
"name": "Option 1",
"replies": {
"type": "Collection",
"totalItems": 6,
},
},
{
"type": "Note",
"name": "Option 2",
"replies": {
"type": "Collection",
"totalItems": 4,
},
},
],
},
create=True,
)
assert post.content == "Test Question"
assert isinstance(post.type_data, QuestionData)
@pytest.mark.django_db @pytest.mark.django_db
@pytest.mark.parametrize("delete_type", ["note", "tombstone", "ref"]) @pytest.mark.parametrize("delete_type", ["note", "tombstone", "ref"])
def test_inbound_posts( def test_inbound_posts(

View File

@ -0,0 +1,314 @@
from datetime import timedelta
import pytest
from django.utils import timezone
from activities.models import Post, PostInteraction
from activities.models.post_types import QuestionData
from core.ld import format_ld_date
from users.models import Identity
@pytest.mark.django_db
def test_vote_in_question(identity: Identity, remote_identity: Identity, config_system):
post = Post.create_local(
author=identity,
content="<p>Test Question</p>",
question={
"type": "Question",
"mode": "oneOf",
"options": [
{"name": "Option 1", "type": "Note", "votes": 0},
{"name": "Option 2", "type": "Note", "votes": 0},
],
"voter_count": 0,
"end_time": format_ld_date(timezone.now() + timedelta(1)),
},
)
PostInteraction.handle_ap(
data={
"id": "https://remote.test/test-actor#votes/11/activity",
"to": "https://example.com/@test@example.com/",
"type": "Create",
"actor": "https://remote.test/test-actor/",
"object": {
"id": "https://remote.test/users/test-actor#votes/11",
"to": "https://example.com/@test@example.com/",
"name": "Option 1",
"type": "Note",
"inReplyTo": post.object_uri,
"attributedTo": "https://remote.test/test-actor/",
},
"@context": [
"https://www.w3.org/ns/activitystreams",
{
"toot": "http://joinmastodon.org/ns#",
"Emoji": "toot:Emoji",
"Public": "as:Public",
"Hashtag": "as:Hashtag",
"votersCount": "toot:votersCount",
},
"https://w3id.org/security/v1",
],
},
)
post.refresh_from_db()
assert isinstance(post.type_data, QuestionData)
assert post.type_data.voter_count == 1
assert post.type_data.options
assert len(post.type_data.options) == 2
assert post.type_data.options[0].votes == 1
assert post.type_data.options[1].votes == 0
@pytest.mark.django_db
def test_vote_in_multiple_choice_question(
identity: Identity, remote_identity: Identity, config_system
):
post = Post.create_local(
author=identity,
content="<p>Test Question</p>",
question={
"type": "Question",
"mode": "anyOf",
"options": [
{"name": "Option 1", "type": "Note", "votes": 0},
{"name": "Option 2", "type": "Note", "votes": 0},
{"name": "Option 3", "type": "Note", "votes": 0},
],
"voter_count": 0,
"end_time": format_ld_date(timezone.now() + timedelta(1)),
},
)
PostInteraction.handle_ap(
data={
"id": "https://remote.test/test-actor#votes/12/activity",
"to": "https://example.com/@test@example.com/",
"type": "Create",
"actor": "https://remote.test/test-actor/",
"object": {
"id": "https://remote.test/users/test-actor#votes/12",
"to": "https://example.com/@test@example.com/",
"name": "Option 1",
"type": "Note",
"inReplyTo": post.object_uri,
"attributedTo": "https://remote.test/test-actor/",
},
"@context": [
"https://www.w3.org/ns/activitystreams",
{
"toot": "http://joinmastodon.org/ns#",
"Emoji": "toot:Emoji",
"Public": "as:Public",
"Hashtag": "as:Hashtag",
"votersCount": "toot:votersCount",
},
"https://w3id.org/security/v1",
],
},
)
PostInteraction.handle_ap(
data={
"id": "https://remote.test/test-actor#votes/13/activity",
"to": "https://example.com/@test@example.com/",
"type": "Create",
"actor": "https://remote.test/test-actor/",
"object": {
"id": "https://remote.test/users/test-actor#votes/13",
"to": "https://example.com/@test@example.com/",
"name": "Option 2",
"type": "Note",
"inReplyTo": post.object_uri,
"attributedTo": "https://remote.test/test-actor/",
},
"@context": [
"https://www.w3.org/ns/activitystreams",
{
"toot": "http://joinmastodon.org/ns#",
"Emoji": "toot:Emoji",
"Public": "as:Public",
"Hashtag": "as:Hashtag",
"votersCount": "toot:votersCount",
},
"https://w3id.org/security/v1",
],
},
)
post.refresh_from_db()
assert isinstance(post.type_data, QuestionData)
assert post.type_data.voter_count == 1
assert post.type_data.options
assert len(post.type_data.options) == 3
assert post.type_data.options[0].votes == 1
assert post.type_data.options[1].votes == 1
assert post.type_data.options[2].votes == 0
@pytest.mark.django_db
def test_multiple_votes_to_single_vote_question(
identity: Identity, remote_identity: Identity, config_system
):
post = Post.create_local(
author=identity,
content="<p>Test Question</p>",
question={
"type": "Question",
"mode": "oneOf",
"options": [
{"name": "Option 1", "type": "Note", "votes": 0},
{"name": "Option 2", "type": "Note", "votes": 0},
],
"voter_count": 0,
"end_time": format_ld_date(timezone.now() + timedelta(1)),
},
)
PostInteraction.by_ap(
data={
"id": "https://remote.test/test-actor#votes/14/activity",
"to": "https://example.com/@test@example.com/",
"type": "Create",
"actor": "https://remote.test/test-actor/",
"object": {
"id": "https://remote.test/users/test-actor#votes/14",
"to": "https://example.com/@test@example.com/",
"name": "Option 1",
"type": "Note",
"inReplyTo": post.object_uri,
"attributedTo": "https://remote.test/test-actor/",
},
"@context": [
"https://www.w3.org/ns/activitystreams",
{
"toot": "http://joinmastodon.org/ns#",
"Emoji": "toot:Emoji",
"Public": "as:Public",
"Hashtag": "as:Hashtag",
"votersCount": "toot:votersCount",
},
"https://w3id.org/security/v1",
],
},
create=True,
)
with pytest.raises(PostInteraction.DoesNotExist) as ex:
PostInteraction.by_ap(
data={
"id": "https://remote.test/test-actor#votes/15/activity",
"to": "https://example.com/@test@example.com/",
"type": "Create",
"actor": "https://remote.test/test-actor/",
"object": {
"id": "https://remote.test/users/test-actor#votes/15",
"to": "https://example.com/@test@example.com/",
"name": "Option 2",
"type": "Note",
"inReplyTo": post.object_uri,
"attributedTo": "https://remote.test/test-actor/",
},
"@context": [
"https://www.w3.org/ns/activitystreams",
{
"toot": "http://joinmastodon.org/ns#",
"Emoji": "toot:Emoji",
"Public": "as:Public",
"Hashtag": "as:Hashtag",
"votersCount": "toot:votersCount",
},
"https://w3id.org/security/v1",
],
},
create=True,
)
assert "already voted" in str(ex)
@pytest.mark.django_db
def test_vote_in_expired_question(
identity: Identity, remote_identity: Identity, config_system
):
post = Post.create_local(
author=identity,
content="<p>Test Question</p>",
question={
"type": "Question",
"mode": "oneOf",
"options": [
{"name": "Option 1", "type": "Note", "votes": 0},
{"name": "Option 2", "type": "Note", "votes": 0},
],
"voter_count": 0,
"end_time": "2022-01-01T23:04:45+00:00",
},
)
with pytest.raises(PostInteraction.DoesNotExist) as ex:
PostInteraction.by_ap(
data={
"id": "https://remote.test/test-actor#votes/16/activity",
"to": "https://example.com/@test@example.com/",
"type": "Create",
"actor": "https://remote.test/test-actor/",
"object": {
"id": "https://remote.test/users/test-actor#votes/16",
"to": "https://example.com/@test@example.com/",
"name": "Option 1",
"type": "Note",
"inReplyTo": post.object_uri,
"attributedTo": "https://remote.test/test-actor/",
},
"@context": [
"https://www.w3.org/ns/activitystreams",
{
"toot": "http://joinmastodon.org/ns#",
"Emoji": "toot:Emoji",
"Public": "as:Public",
"Hashtag": "as:Hashtag",
"votersCount": "toot:votersCount",
},
"https://w3id.org/security/v1",
],
},
create=True,
)
assert "Cannot create a vote to the expired question" in str(ex)
@pytest.mark.django_db
def test_vote_to_ap(identity: Identity, remote_identity: Identity, config_system):
post = Post.objects.create(
author=remote_identity,
local=False,
content="<p>Test Question</p>",
type_data={
"type": "Question",
"mode": "oneOf",
"options": [
{"name": "Option 1", "type": "Note", "votes": 6},
{"name": "Option 2", "type": "Note", "votes": 4},
],
"voter_count": 10,
"end_time": format_ld_date(timezone.now() + timedelta(1)),
},
)
post.refresh_from_db()
interaction = PostInteraction.create_votes(
post=post,
identity=identity,
choices=[0],
)[0]
data = interaction.to_create_ap()
assert data["object"]["to"] == remote_identity.actor_uri
assert data["object"]["attributedTo"] == identity.actor_uri
assert data["object"]["name"] == "Option 1"
assert data["object"]["inReplyTo"] == post.object_uri

62
tests/api/test_polls.py Normal file
View File

@ -0,0 +1,62 @@
from datetime import timedelta
import pytest
from django.utils import timezone
from activities.models import Post
from core.ld import format_ld_date
@pytest.mark.django_db
def test_get_poll(api_client):
response = api_client.post(
"/api/v1/statuses",
content_type="application/json",
data={
"status": "Hello, world!",
"poll": {
"options": ["Option 1", "Option 2"],
"expires_in": 300,
},
},
).json()
id = response["id"]
response = api_client.get(
f"/api/v1/polls/{id}",
).json()
assert response["id"] == id
assert response["voted"]
@pytest.mark.django_db
def test_vote_poll(api_client, identity2):
post = Post.create_local(
author=identity2,
content="<p>Test Question</p>",
question={
"type": "Question",
"mode": "oneOf",
"options": [
{"name": "Option 1", "type": "Note", "votes": 0},
{"name": "Option 2", "type": "Note", "votes": 0},
],
"voter_count": 0,
"end_time": format_ld_date(timezone.now() + timedelta(1)),
},
)
response = api_client.post(
f"/api/v1/polls/{post.id}/votes",
content_type="application/json",
data={
"choices": [0],
},
).json()
assert response["id"] == str(post.id)
assert response["voted"]
assert response["votes_count"] == 1
assert response["own_votes"] == [0]

View File

@ -75,3 +75,68 @@ def test_mention_format(api_client, identity, remote_identity):
assert ( assert (
response["text"] == '<p>Hey <a href="https://example.com/@test/">@test</a></p>' response["text"] == '<p>Hey <a href="https://example.com/@test/">@test</a></p>'
) )
@pytest.mark.django_db
def test_post_question_status(api_client):
response = api_client.post(
"/api/v1/statuses",
content_type="application/json",
data={
"status": "Hello, world!",
"poll": {
"options": ["Option 1", "Option 2"],
"expires_in": 300,
},
},
).json()
assert response["poll"]["id"] == response["id"]
assert response["poll"]["options"] == [
{"title": "Option 1", "votes_count": 0},
{"title": "Option 2", "votes_count": 0},
]
assert not response["poll"]["expired"]
assert not response["poll"]["multiple"]
@pytest.mark.django_db
def test_question_format(api_client, remote_identity):
"""
Ensures incoming questions are property parsed.
"""
# Make a remote question post and check it
post = Post.objects.create(
local=False,
author=remote_identity,
content="<p>Test Question</p>",
object_uri="https://remote.test/status/123456",
type=Post.Types.question,
type_data={
"type": "Question",
"mode": "oneOf",
"options": [
{"name": "Option 1", "type": "Note", "votes": 10},
{"name": "Option 2", "type": "Note", "votes": 20},
],
"voter_count": 30,
"end_time": "2022-01-01T23:04:45+00:00",
},
)
response = api_client.get(f"/api/v1/statuses/{post.id}").json()
assert response["text"] == "<p>Test Question</p>"
assert response["poll"] == {
"id": str(post.id),
"expires_at": "2022-01-01T23:04:45.000Z",
"expired": True,
"multiple": False,
"votes_count": 30,
"voters_count": None,
"voted": False,
"own_votes": [],
"options": [
{"title": "Option 1", "votes_count": 10},
{"title": "Option 2", "votes_count": 20},
],
"emojis": [],
}

View File

@ -30,9 +30,15 @@ class InboxMessageStates(StateGraph):
case "create": case "create":
match instance.message_object_type: match instance.message_object_type:
case "note": case "note":
await sync_to_async(Post.handle_create_ap)(instance.message) if instance.message_object_has_content:
await sync_to_async(Post.handle_create_ap)(instance.message)
else:
# Notes without content are Interaction candidates
await sync_to_async(PostInteraction.handle_ap)(
instance.message
)
case "question": case "question":
pass # Drop for now await sync_to_async(Post.handle_create_ap)(instance.message)
case unknown: case unknown:
if unknown in Post.Types.names: if unknown in Post.Types.names:
await sync_to_async(Post.handle_create_ap)(instance.message) await sync_to_async(Post.handle_create_ap)(instance.message)
@ -213,3 +219,7 @@ class InboxMessage(StatorModel):
@property @property
def message_actor(self): def message_actor(self):
return self.message.get("actor") return self.message.get("actor")
@property
def message_object_has_content(self):
return "content" in self.message.get("object", {})