Implement post editing

This commit is contained in:
Andrew Godwin 2023-02-13 19:40:10 -07:00
parent 04ad97c69b
commit 1c5ef675f0
10 changed files with 136 additions and 63 deletions

View File

@ -503,6 +503,7 @@ class Post(StatorModel):
self, self,
content: str, content: str,
summary: str | None = None, summary: str | None = None,
sensitive: bool | None = None,
visibility: int = Visibilities.public, visibility: int = Visibilities.public,
attachments: list | None = None, attachments: list | None = None,
): ):
@ -512,7 +513,7 @@ class Post(StatorModel):
self.content = parser.html self.content = parser.html
self.hashtags = sorted(parser.hashtags) or None self.hashtags = sorted(parser.hashtags) or None
self.summary = summary or None self.summary = summary or None
self.sensitive = bool(summary) self.sensitive = bool(summary) if sensitive is None else sensitive
self.visibility = visibility self.visibility = visibility
self.edited = timezone.now() self.edited = timezone.now()
self.mentions.set(self.mentions_from_content(content, self.author)) self.mentions.set(self.mentions_from_content(content, self.author))

View File

@ -1,7 +1,9 @@
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
from activities import models as activities_models
from hatchway import Field, Schema from hatchway import Field, Schema
from activities import models as activities_models
from core.html import FediverseHtmlParser
from users import models as users_models from users import models as users_models
from users.services import IdentityService from users.services import IdentityService
@ -174,6 +176,20 @@ class Status(Schema):
] ]
class StatusSource(Schema):
id: str
text: str
spoiler_text: str
@classmethod
def from_post(cls, post: activities_models.Post):
return cls(
id=post.id,
text=FediverseHtmlParser(post.content).plain_text,
spoiler_text=post.summary or "",
)
class Conversation(Schema): class Conversation(Schema):
id: str id: str
unread: bool unread: bool

View File

@ -1,4 +1,5 @@
from django.urls import path from django.urls import path
from hatchway import methods
from api.views import ( from api.views import (
accounts, accounts,
@ -14,7 +15,6 @@ from api.views import (
timelines, timelines,
trends, trends,
) )
from hatchway import methods
urlpatterns = [ urlpatterns = [
# Accounts # Accounts
@ -55,9 +55,11 @@ urlpatterns = [
"v1/statuses/<id>", "v1/statuses/<id>",
methods( methods(
get=statuses.status, get=statuses.status,
put=statuses.edit_status,
delete=statuses.delete_status, delete=statuses.delete_status,
), ),
), ),
path("v1/statuses/<id>/source", statuses.status_source),
# Notifications # Notifications
path("v1/notifications", notifications.notifications), path("v1/notifications", notifications.notifications),
# Search # Search

View File

@ -1,10 +1,10 @@
from django.core.files import File from django.core.files import File
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from hatchway import ApiError, QueryOrBody, api_view
from activities.models import PostAttachment, PostAttachmentStates from activities.models import PostAttachment, PostAttachmentStates
from api import schemas from api import schemas
from core.files import blurhash_image, resize_image from core.files import blurhash_image, resize_image
from hatchway import QueryOrBody, api_view
from ..decorators import identity_required from ..decorators import identity_required
@ -54,6 +54,8 @@ def get_media(
id: str, id: str,
) -> schemas.MediaAttachment: ) -> schemas.MediaAttachment:
attachment = get_object_or_404(PostAttachment, pk=id) attachment = get_object_or_404(PostAttachment, pk=id)
if attachment.post.author != request.identity:
raise ApiError(401, "Not the author of this attachment")
return schemas.MediaAttachment.from_post_attachment(attachment) return schemas.MediaAttachment.from_post_attachment(attachment)
@ -66,6 +68,8 @@ def update_media(
focus: QueryOrBody[str] = "0,0", focus: QueryOrBody[str] = "0,0",
) -> schemas.MediaAttachment: ) -> schemas.MediaAttachment:
attachment = get_object_or_404(PostAttachment, pk=id) attachment = get_object_or_404(PostAttachment, pk=id)
if attachment.post.author != request.identity:
raise ApiError(401, "Not the author of this attachment")
attachment.name = description or None attachment.name = description or None
attachment.save() attachment.save()
return schemas.MediaAttachment.from_post_attachment(attachment) return schemas.MediaAttachment.from_post_attachment(attachment)

View File

@ -1,8 +1,8 @@
from typing import Literal from typing import Literal
from django.forms import ValidationError
from django.http import HttpRequest from django.http import HttpRequest
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from hatchway import ApiError, ApiResponse, Schema, api_view
from activities.models import ( from activities.models import (
Post, Post,
@ -16,7 +16,6 @@ from api import schemas
from api.decorators import identity_required from api.decorators import identity_required
from api.pagination import MastodonPaginator, PaginationResult from api.pagination import MastodonPaginator, PaginationResult
from core.models import Config from core.models import Config
from hatchway import ApiResponse, Schema, api_view
class PostStatusSchema(Schema): class PostStatusSchema(Schema):
@ -30,14 +29,36 @@ class PostStatusSchema(Schema):
media_ids: list[str] = [] media_ids: list[str] = []
class EditStatusSchema(Schema):
status: str
sensitive: bool = False
spoiler_text: str | None = None
language: str | None = None
media_ids: list[str] = []
def post_for_id(request: HttpRequest, id: str) -> Post:
"""
Common logic to get a Post object for an ID, taking visibility into
account.
"""
if request.identity:
queryset = Post.objects.not_hidden().visible_to(
request.identity, include_replies=True
)
else:
queryset = Post.objects.not_hidden().unlisted()
return get_object_or_404(queryset, pk=id)
@identity_required @identity_required
@api_view.post @api_view.post
def post_status(request, details: PostStatusSchema) -> schemas.Status: def post_status(request, details: PostStatusSchema) -> schemas.Status:
# Check text length # Check text length
if len(details.status) > Config.system.post_length: if len(details.status) > Config.system.post_length:
raise ValidationError("Status is too long") raise ApiError(400, "Status is too long")
if len(details.status) == 0 and not details.media_ids: if len(details.status) == 0 and not details.media_ids:
raise ValidationError("Status is empty") raise ApiError(400, "Status is empty")
# Grab attachments # Grab attachments
attachments = [get_object_or_404(PostAttachment, pk=id) for id in details.media_ids] attachments = [get_object_or_404(PostAttachment, pk=id) for id in details.media_ids]
# Create the Post # Create the Post
@ -70,23 +91,50 @@ def post_status(request, details: PostStatusSchema) -> schemas.Status:
@identity_required @identity_required
@api_view.get @api_view.get
def status(request, id: str) -> schemas.Status: def status(request, id: str) -> schemas.Status:
post = get_object_or_404(Post, pk=id) post = post_for_id(request, id)
interactions = PostInteraction.get_post_interactions([post], request.identity) interactions = PostInteraction.get_post_interactions([post], request.identity)
return schemas.Status.from_post(post, interactions=interactions) return schemas.Status.from_post(post, interactions=interactions)
@identity_required
@api_view.put
def edit_status(request, id: str, details: EditStatusSchema) -> schemas.Status:
post = post_for_id(request, id)
if post.author != request.identity:
raise ApiError(401, "Not the author of this status")
# Grab attachments
attachments = [get_object_or_404(PostAttachment, pk=id) for id in details.media_ids]
# Update all details, as the client must provide them all
post.edit_local(
content=details.status,
summary=details.spoiler_text,
sensitive=details.sensitive,
attachments=attachments,
)
return schemas.Status.from_post(post)
@identity_required @identity_required
@api_view.delete @api_view.delete
def delete_status(request, id: str) -> schemas.Status: def delete_status(request, id: str) -> schemas.Status:
post = get_object_or_404(Post, pk=id) post = post_for_id(request, id)
if post.author != request.identity:
raise ApiError(401, "Not the author of this status")
PostService(post).delete() PostService(post).delete()
return schemas.Status.from_post(post) return schemas.Status.from_post(post)
@identity_required
@api_view.get
def status_source(request, id: str) -> schemas.StatusSource:
post = post_for_id(request, id)
return schemas.StatusSource.from_post(post)
@identity_required @identity_required
@api_view.get @api_view.get
def status_context(request, id: str) -> schemas.Context: def status_context(request, id: str) -> schemas.Context:
post = get_object_or_404(Post, pk=id) post = post_for_id(request, id)
service = PostService(post) service = PostService(post)
ancestors, descendants = service.context(request.identity) ancestors, descendants = service.context(request.identity)
interactions = PostInteraction.get_post_interactions( interactions = PostInteraction.get_post_interactions(
@ -106,7 +154,7 @@ def status_context(request, id: str) -> schemas.Context:
@identity_required @identity_required
@api_view.post @api_view.post
def favourite_status(request, id: str) -> schemas.Status: def favourite_status(request, id: str) -> schemas.Status:
post = get_object_or_404(Post, pk=id) post = post_for_id(request, id)
service = PostService(post) service = PostService(post)
service.like_as(request.identity) service.like_as(request.identity)
interactions = PostInteraction.get_post_interactions([post], request.identity) interactions = PostInteraction.get_post_interactions([post], request.identity)
@ -116,7 +164,7 @@ def favourite_status(request, id: str) -> schemas.Status:
@identity_required @identity_required
@api_view.post @api_view.post
def unfavourite_status(request, id: str) -> schemas.Status: def unfavourite_status(request, id: str) -> schemas.Status:
post = get_object_or_404(Post, pk=id) post = post_for_id(request, id)
service = PostService(post) service = PostService(post)
service.unlike_as(request.identity) service.unlike_as(request.identity)
interactions = PostInteraction.get_post_interactions([post], request.identity) interactions = PostInteraction.get_post_interactions([post], request.identity)
@ -135,9 +183,7 @@ def favourited_by(
""" """
View who favourited a given status. View who favourited a given status.
""" """
# This method should filter out private statuses, but we don't really have post = post_for_id(request, id)
# a concept of "private status" yet.
post = get_object_or_404(Post, pk=id)
paginator = MastodonPaginator() paginator = MastodonPaginator()
pager: PaginationResult[PostInteraction] = paginator.paginate( pager: PaginationResult[PostInteraction] = paginator.paginate(
@ -169,7 +215,7 @@ def favourited_by(
@identity_required @identity_required
@api_view.post @api_view.post
def reblog_status(request, id: str) -> schemas.Status: def reblog_status(request, id: str) -> schemas.Status:
post = get_object_or_404(Post, pk=id) post = post_for_id(request, id)
service = PostService(post) service = PostService(post)
service.boost_as(request.identity) service.boost_as(request.identity)
interactions = PostInteraction.get_post_interactions([post], request.identity) interactions = PostInteraction.get_post_interactions([post], request.identity)
@ -179,7 +225,7 @@ def reblog_status(request, id: str) -> schemas.Status:
@identity_required @identity_required
@api_view.post @api_view.post
def unreblog_status(request, id: str) -> schemas.Status: def unreblog_status(request, id: str) -> schemas.Status:
post = get_object_or_404(Post, pk=id) post = post_for_id(request, id)
service = PostService(post) service = PostService(post)
service.unboost_as(request.identity) service.unboost_as(request.identity)
interactions = PostInteraction.get_post_interactions([post], request.identity) interactions = PostInteraction.get_post_interactions([post], request.identity)

View File

@ -2,22 +2,14 @@ import pytest
@pytest.mark.django_db @pytest.mark.django_db
def test_verify_credentials(api_token, identity, client): def test_verify_credentials(api_client, identity):
response = client.get( response = api_client.get("/api/v1/accounts/verify_credentials").json()
"/api/v1/accounts/verify_credentials",
HTTP_AUTHORIZATION=f"Bearer {api_token.token}",
HTTP_ACCEPT="application/json",
).json()
assert response["id"] == str(identity.pk) assert response["id"] == str(identity.pk)
assert response["username"] == identity.username assert response["username"] == identity.username
@pytest.mark.django_db @pytest.mark.django_db
def test_account_search(api_token, identity, client): def test_account_search(api_client, identity):
response = client.get( response = api_client.get("/api/v1/accounts/search?q=test").json()
"/api/v1/accounts/search?q=test",
HTTP_AUTHORIZATION=f"Bearer {api_token.token}",
HTTP_ACCEPT="application/json",
).json()
assert response[0]["id"] == str(identity.pk) assert response[0]["id"] == str(identity.pk)
assert response[0]["username"] == identity.username assert response[0]["username"] == identity.username

View File

@ -2,10 +2,6 @@ import pytest
@pytest.mark.django_db @pytest.mark.django_db
def test_instance(api_token, client): def test_instance(api_client):
response = client.get( response = api_client.get("/api/v1/instance").json()
"/api/v1/instance",
HTTP_AUTHORIZATION=f"Bearer {api_token.token}",
HTTP_ACCEPT="application/json",
).json()
assert response["uri"] == "example.com" assert response["uri"] == "example.com"

View File

@ -2,12 +2,10 @@ import pytest
@pytest.mark.django_db @pytest.mark.django_db
def test_likes_flow(api_token, client): def test_likes_flow(api_client):
# Add a post # Add a post
response = client.post( response = api_client.post(
"/api/v1/statuses", "/api/v1/statuses",
HTTP_AUTHORIZATION=f"Bearer {api_token.token}",
HTTP_ACCEPT="application/json",
content_type="application/json", content_type="application/json",
data={ data={
"status": "Like test.", "status": "Like test.",
@ -19,17 +17,9 @@ def test_likes_flow(api_token, client):
status_id = response["id"] status_id = response["id"]
# Like it # Like it
response = client.post( response = api_client.post(f"/api/v1/statuses/{status_id}/favourite").json()
f"/api/v1/statuses/{status_id}/favourite",
HTTP_AUTHORIZATION=f"Bearer {api_token.token}",
HTTP_ACCEPT="application/json",
).json()
assert response["favourited"] is True assert response["favourited"] is True
# Check if it's displaying at likes endpoint # Check if it's displaying at likes endpoint
response = client.get( response = api_client.get("/api/v1/favourites").json()
"/api/v1/favourites",
HTTP_AUTHORIZATION=f"Bearer {api_token.token}",
HTTP_ACCEPT="application/json",
).json()
assert response[0]["id"] == status_id assert response[0]["id"] == status_id

View File

@ -4,11 +4,13 @@ from activities.models import Post
@pytest.mark.django_db @pytest.mark.django_db
def test_post_status(api_token, identity, client): def test_post_status(api_client):
response = client.post( """
Tests posting, editing and deleting a status
"""
# Post new one
response = api_client.post(
"/api/v1/statuses", "/api/v1/statuses",
HTTP_AUTHORIZATION=f"Bearer {api_token.token}",
HTTP_ACCEPT="application/json",
content_type="application/json", content_type="application/json",
data={ data={
"status": "Hello, world!", "status": "Hello, world!",
@ -17,19 +19,37 @@ def test_post_status(api_token, identity, client):
).json() ).json()
assert response["content"] == "<p>Hello, world!</p>" assert response["content"] == "<p>Hello, world!</p>"
assert response["visibility"] == "unlisted" assert response["visibility"] == "unlisted"
status_id = response["id"]
# Retrieve "source" version an edit would use
response = api_client.get(f"/api/v1/statuses/{status_id}/source").json()
assert response["text"] == "Hello, world!"
# Post an edit
response = api_client.put(
f"/api/v1/statuses/{status_id}",
content_type="application/json",
data={
"status": "Hello, world! Again!",
},
).json()
# Check it stuck
response = api_client.get(f"/api/v1/statuses/{status_id}").json()
assert response["content"] == "<p>Hello, world! Again!</p>"
# Delete it
response = api_client.delete(f"/api/v1/statuses/{status_id}")
assert response.status_code == 200
# Check it's gone
response = api_client.get(f"/api/v1/statuses/{status_id}")
assert response.status_code == 404
@pytest.mark.django_db @pytest.mark.django_db
def test_mention_format(api_token, identity, remote_identity, client): def test_mention_format(api_client, identity, remote_identity):
""" """
Ensures mentions work, and only have one link around them. Ensures mentions work, and only have one link around them.
""" """
# Make a local post and check it # Make a local post and check it
response = client.post( response = api_client.post(
"/api/v1/statuses", "/api/v1/statuses",
HTTP_AUTHORIZATION=f"Bearer {api_token.token}",
HTTP_ACCEPT="application/json",
content_type="application/json",
data={ data={
"status": "Hello, @test!", "status": "Hello, @test!",
"visibility": "unlisted", "visibility": "unlisted",
@ -49,11 +69,8 @@ def test_mention_format(api_token, identity, remote_identity, client):
object_uri="https://remote.test/status/12345", object_uri="https://remote.test/status/12345",
) )
post.mentions.add(identity) post.mentions.add(identity)
response = client.get( response = api_client.get(
f"/api/v1/statuses/{post.id}", f"/api/v1/statuses/{post.id}",
HTTP_AUTHORIZATION=f"Bearer {api_token.token}",
HTTP_ACCEPT="application/json",
content_type="application/json",
).json() ).json()
assert ( assert (
response["text"] == '<p>Hey <a href="https://example.com/@test/">@test</a></p>' response["text"] == '<p>Hey <a href="https://example.com/@test/">@test</a></p>'

View File

@ -2,6 +2,7 @@ import time
import pytest import pytest
from django.conf import settings from django.conf import settings
from django.test import Client
from api.models import Application, Token from api.models import Application, Token
from core.models import Config from core.models import Config
@ -218,6 +219,14 @@ def api_token(identity) -> Token:
) )
@pytest.fixture
def api_client(api_token):
return Client(
HTTP_AUTHORIZATION=f"Bearer {api_token.token}",
HTTP_ACCEPT="application/json",
)
@pytest.fixture @pytest.fixture
def stator(config_system) -> StatorRunner: def stator(config_system) -> StatorRunner:
""" """