Hatchway API Rewrite (#499)
Removes django-ninja and replaces it with a new API framework, "hatchway". I plan to move hatchway into its own project very soon.
This commit is contained in:
parent
99a0a16fb7
commit
5d2ed9edfe
|
@ -188,7 +188,7 @@ class PostInteraction(StatorModel):
|
|||
return result
|
||||
|
||||
@classmethod
|
||||
def get_event_interactions(cls, events, identity):
|
||||
def get_event_interactions(cls, events, identity) -> dict[str, set[str]]:
|
||||
"""
|
||||
Returns a dict of {interaction_type: set(post_ids)} for all the posts
|
||||
within the events and the given identity, for use in templates.
|
||||
|
|
|
@ -16,4 +16,7 @@ def identity_required(function):
|
|||
return JsonResponse({"error": "identity_token_required"}, status=400)
|
||||
return function(request, *args, **kwargs)
|
||||
|
||||
# This is for the API only
|
||||
inner.csrf_exempt = True
|
||||
|
||||
return inner
|
||||
|
|
|
@ -1,22 +1,98 @@
|
|||
import dataclasses
|
||||
import urllib.parse
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import Any, Generic, Protocol, TypeVar
|
||||
|
||||
from django.db import models
|
||||
from django.http import HttpRequest
|
||||
|
||||
from activities.models import PostInteraction
|
||||
from hatchway.http import ApiResponse
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class SchemaWithId(Protocol):
|
||||
"""
|
||||
Little protocol type to represent schemas that have an ID attribute
|
||||
"""
|
||||
|
||||
id: str
|
||||
|
||||
|
||||
TI = TypeVar("TI", bound=SchemaWithId)
|
||||
TM = TypeVar("TM", bound=models.Model)
|
||||
|
||||
|
||||
class PaginatingApiResponse(ApiResponse[list[TI]]):
|
||||
"""
|
||||
An ApiResponse subclass that also handles pagination link headers
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: list[TI],
|
||||
request: HttpRequest,
|
||||
include_params: list[str],
|
||||
**kwargs,
|
||||
):
|
||||
# Call superclass
|
||||
super().__init__(data, **kwargs)
|
||||
# Figure out if we need link headers
|
||||
self._request = request
|
||||
self.extra_params = self.filter_params(self._request, include_params)
|
||||
link_header = self.build_link_header()
|
||||
if link_header:
|
||||
self.headers["link"] = link_header
|
||||
|
||||
@staticmethod
|
||||
def filter_params(request: HttpRequest, allowed_params: list[str]):
|
||||
params = {}
|
||||
for key in allowed_params:
|
||||
value = request.GET.get(key, None)
|
||||
if value:
|
||||
params[key] = value
|
||||
return params
|
||||
|
||||
def get_part(self, data_index: int, param_name: str, rel: str) -> str | None:
|
||||
"""
|
||||
Used to get next/prev URLs
|
||||
"""
|
||||
if not self.data:
|
||||
return None
|
||||
# Use the ID of the last object for the next page start
|
||||
params = dict(self.extra_params)
|
||||
params[param_name] = self.data[data_index].id
|
||||
return (
|
||||
"<"
|
||||
+ self._request.build_absolute_uri(self._request.path)
|
||||
+ "?"
|
||||
+ urllib.parse.urlencode(params)
|
||||
+ f'>; rel="{rel}"'
|
||||
)
|
||||
|
||||
def build_link_header(self):
|
||||
parts = [
|
||||
entry
|
||||
for entry in [
|
||||
self.get_part(0, "min_id", "prev"),
|
||||
self.get_part(-1, "max_id", "next"),
|
||||
]
|
||||
if entry
|
||||
]
|
||||
if not parts:
|
||||
return None
|
||||
return ", ".join(parts)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PaginationResult:
|
||||
class PaginationResult(Generic[T]):
|
||||
"""
|
||||
Represents a pagination result for Mastodon (it does Link header stuff)
|
||||
"""
|
||||
|
||||
#: A list of objects that matched the pagination query.
|
||||
results: list[models.Model]
|
||||
results: list[T]
|
||||
|
||||
#: The actual applied limit, which may be different from what was requested.
|
||||
limit: int
|
||||
|
@ -130,12 +206,12 @@ class MastodonPaginator:
|
|||
|
||||
def paginate(
|
||||
self,
|
||||
queryset,
|
||||
queryset: models.QuerySet[TM],
|
||||
min_id: str | None,
|
||||
max_id: str | None,
|
||||
since_id: str | None,
|
||||
limit: int | None,
|
||||
) -> PaginationResult:
|
||||
) -> PaginationResult[TM]:
|
||||
# These "does not start with interaction" checks can be removed after a
|
||||
# couple months, when clients have flushed them out.
|
||||
if max_id and not max_id.startswith("interaction"):
|
||||
|
|
|
@ -1,22 +0,0 @@
|
|||
import json
|
||||
|
||||
from ninja.parser import Parser
|
||||
|
||||
|
||||
class FormOrJsonParser(Parser):
|
||||
"""
|
||||
If there's form data in a request, makes it into a JSON dict.
|
||||
This is needed as the Mastodon API allows form data OR json body as input.
|
||||
"""
|
||||
|
||||
def parse_body(self, request):
|
||||
# Did they submit JSON?
|
||||
if request.content_type == "application/json" and request.body.strip():
|
||||
return json.loads(request.body)
|
||||
# Fall back to form data
|
||||
value = {}
|
||||
for key, item in request.POST.items():
|
||||
value[key] = item
|
||||
for key, item in request.GET.items():
|
||||
value[key] = item
|
||||
return value
|
|
@ -1,6 +1,9 @@
|
|||
from typing import Literal, Optional, Union
|
||||
|
||||
from ninja import Field, Schema
|
||||
from activities import models as activities_models
|
||||
from hatchway import Field, Schema
|
||||
from users import models as users_models
|
||||
from users.services import IdentityService
|
||||
|
||||
|
||||
class Application(Schema):
|
||||
|
@ -19,6 +22,10 @@ class CustomEmoji(Schema):
|
|||
visible_in_picker: bool
|
||||
category: str
|
||||
|
||||
@classmethod
|
||||
def from_emoji(cls, emoji: activities_models.Emoji) -> "CustomEmoji":
|
||||
return cls(**emoji.to_mastodon_json())
|
||||
|
||||
|
||||
class AccountField(Schema):
|
||||
name: str
|
||||
|
@ -53,6 +60,17 @@ class Account(Schema):
|
|||
following_count: int
|
||||
source: dict | None
|
||||
|
||||
@classmethod
|
||||
def from_identity(
|
||||
cls,
|
||||
identity: users_models.Identity,
|
||||
include_counts: bool = True,
|
||||
source=False,
|
||||
) -> "Account":
|
||||
return cls(
|
||||
**identity.to_mastodon_json(include_counts=include_counts, source=source)
|
||||
)
|
||||
|
||||
|
||||
class MediaAttachment(Schema):
|
||||
id: str
|
||||
|
@ -64,6 +82,12 @@ class MediaAttachment(Schema):
|
|||
description: str | None
|
||||
blurhash: str | None
|
||||
|
||||
@classmethod
|
||||
def from_post_attachment(
|
||||
cls, attachment: activities_models.PostAttachment
|
||||
) -> "MediaAttachment":
|
||||
return cls(**attachment.to_mastodon_json())
|
||||
|
||||
|
||||
class StatusMention(Schema):
|
||||
id: str
|
||||
|
@ -108,6 +132,47 @@ class Status(Schema):
|
|||
bookmarked: bool | None
|
||||
pinned: bool | None
|
||||
|
||||
@classmethod
|
||||
def from_post(
|
||||
cls,
|
||||
post: activities_models.Post,
|
||||
interactions: dict[str, set[str]] | None = None,
|
||||
) -> "Status":
|
||||
return cls(**post.to_mastodon_json(interactions=interactions))
|
||||
|
||||
@classmethod
|
||||
def map_from_post(
|
||||
cls,
|
||||
posts: list[activities_models.Post],
|
||||
identity: users_models.Identity,
|
||||
) -> list["Status"]:
|
||||
interactions = activities_models.PostInteraction.get_post_interactions(
|
||||
posts, identity
|
||||
)
|
||||
return [cls.from_post(post, interactions=interactions) for post in posts]
|
||||
|
||||
@classmethod
|
||||
def from_timeline_event(
|
||||
cls,
|
||||
timeline_event: activities_models.TimelineEvent,
|
||||
interactions: dict[str, set[str]] | None = None,
|
||||
) -> "Status":
|
||||
return cls(**timeline_event.to_mastodon_status_json(interactions=interactions))
|
||||
|
||||
@classmethod
|
||||
def map_from_timeline_event(
|
||||
cls,
|
||||
events: list[activities_models.TimelineEvent],
|
||||
identity: users_models.Identity,
|
||||
) -> list["Status"]:
|
||||
interactions = activities_models.PostInteraction.get_event_interactions(
|
||||
events, identity
|
||||
)
|
||||
return [
|
||||
cls.from_timeline_event(event, interactions=interactions)
|
||||
for event in events
|
||||
]
|
||||
|
||||
|
||||
class Conversation(Schema):
|
||||
id: str
|
||||
|
@ -134,12 +199,26 @@ class Notification(Schema):
|
|||
account: Account
|
||||
status: Status | None
|
||||
|
||||
@classmethod
|
||||
def from_timeline_event(
|
||||
cls,
|
||||
event: activities_models.TimelineEvent,
|
||||
) -> "Notification":
|
||||
return cls(**event.to_mastodon_notification_json())
|
||||
|
||||
|
||||
class Tag(Schema):
|
||||
name: str
|
||||
url: str
|
||||
history: dict
|
||||
|
||||
@classmethod
|
||||
def from_hashtag(
|
||||
cls,
|
||||
hashtag: activities_models.Hashtag,
|
||||
) -> "Tag":
|
||||
return cls(**hashtag.to_mastodon_json())
|
||||
|
||||
|
||||
class Search(Schema):
|
||||
accounts: list[Account]
|
||||
|
@ -162,6 +241,16 @@ class Relationship(Schema):
|
|||
endorsed: bool
|
||||
note: str
|
||||
|
||||
@classmethod
|
||||
def from_identity_pair(
|
||||
cls,
|
||||
identity: users_models.Identity,
|
||||
from_identity: users_models.Identity,
|
||||
) -> "Relationship":
|
||||
return cls(
|
||||
**IdentityService(identity).mastodon_json_relationship(from_identity)
|
||||
)
|
||||
|
||||
|
||||
class Context(Schema):
|
||||
ancestors: list[Status]
|
||||
|
@ -187,3 +276,11 @@ class Announcement(Schema):
|
|||
tags: list[Tag]
|
||||
emojis: list[CustomEmoji]
|
||||
reactions: list
|
||||
|
||||
@classmethod
|
||||
def from_announcement(
|
||||
cls,
|
||||
announcement: users_models.Announcement,
|
||||
user: users_models.User,
|
||||
) -> "Announcement":
|
||||
return cls(**announcement.to_mastodon_json(user=user))
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
from django.urls import path
|
||||
|
||||
from api.views import (
|
||||
accounts,
|
||||
announcements,
|
||||
apps,
|
||||
emoji,
|
||||
filters,
|
||||
instance,
|
||||
media,
|
||||
notifications,
|
||||
search,
|
||||
statuses,
|
||||
timelines,
|
||||
)
|
||||
from hatchway import methods
|
||||
|
||||
urlpatterns = [
|
||||
# Accounts
|
||||
path("v1/accounts/verify_credentials", accounts.verify_credentials),
|
||||
path("v1/accounts/update_credentials", accounts.update_credentials),
|
||||
path("v1/accounts/relationships", accounts.account_relationships),
|
||||
path("v1/accounts/familiar_followers", accounts.familiar_followers),
|
||||
path("v1/accounts/search", accounts.accounts_search),
|
||||
path("v1/accounts/lookup", accounts.lookup),
|
||||
path("v1/accounts/<id>", accounts.account),
|
||||
path("v1/accounts/<id>/statuses", accounts.account_statuses),
|
||||
path("v1/accounts/<id>/follow", accounts.account_follow),
|
||||
path("v1/accounts/<id>/unfollow", accounts.account_unfollow),
|
||||
path("v1/accounts/<id>/block", accounts.account_block),
|
||||
path("v1/accounts/<id>/unblock", accounts.account_unblock),
|
||||
path("v1/accounts/<id>/mute", accounts.account_mute),
|
||||
path("v1/accounts/<id>/unmute", accounts.account_unmute),
|
||||
path("v1/accounts/<id>/following", accounts.account_following),
|
||||
path("v1/accounts/<id>/followers", accounts.account_followers),
|
||||
# Announcements
|
||||
path("v1/announcements", announcements.announcement_list),
|
||||
path("v1/announcements/<pk>/dismiss", announcements.announcement_dismiss),
|
||||
# Apps
|
||||
path("v1/apps", apps.add_app),
|
||||
# Filters
|
||||
path("v2/filters", filters.list_filters),
|
||||
path("v1/filters", filters.list_filters),
|
||||
# Instance
|
||||
path("v1/instance", instance.instance_info_v1),
|
||||
path("v2/instance", instance.instance_info_v2),
|
||||
path("v1/timelines/home", timelines.home),
|
||||
path("v1/timelines/public", timelines.public),
|
||||
path("v1/timelines/tag/<hashtag>", timelines.hashtag),
|
||||
path("v1/timelines/conversations", timelines.conversations),
|
||||
path("v1/favourites", timelines.favourites),
|
||||
path("v1/notifications", notifications.notifications),
|
||||
path("v1/statuses", statuses.post_status),
|
||||
path("v1/media", media.upload_media),
|
||||
path("v2/media", media.upload_media),
|
||||
path("v1/media/<id>", methods(get=media.get_media, put=media.update_media)),
|
||||
path(
|
||||
"v1/statuses/<id>",
|
||||
methods(
|
||||
get=statuses.status,
|
||||
delete=statuses.delete_status,
|
||||
),
|
||||
),
|
||||
path("v1/statuses/<id>/context", statuses.status_context),
|
||||
path("v1/statuses/<id>/favourite", statuses.favourite_status),
|
||||
path("v1/statuses/<id>/unfavourite", statuses.unfavourite_status),
|
||||
path("v1/statuses/<id>/favourited_by", statuses.favourited_by),
|
||||
path("v1/statuses/<id>/reblog", statuses.reblog_status),
|
||||
path("v1/statuses/<id>/unreblog", statuses.unreblog_status),
|
||||
path("v1/custom_emojis", emoji.emojis),
|
||||
path("v2/search", search.search),
|
||||
]
|
|
@ -1,12 +0,0 @@
|
|||
from .accounts import * # noqa
|
||||
from .announcements import * # noqa
|
||||
from .apps import * # noqa
|
||||
from .emoji import * # noqa
|
||||
from .filters import * # noqa
|
||||
from .instance import * # noqa
|
||||
from .media import * # noqa
|
||||
from .notifications import * # noqa
|
||||
from .oauth import * # noqa
|
||||
from .search import * # noqa
|
||||
from .statuses import * # noqa
|
||||
from .timelines import * # noqa
|
|
@ -1,131 +1,129 @@
|
|||
from django.http import HttpRequest, HttpResponse, QueryDict
|
||||
from django.http.multipartparser import MultiPartParser
|
||||
from typing import Any
|
||||
|
||||
from django.core.files import File
|
||||
from django.http import HttpRequest
|
||||
from django.shortcuts import get_object_or_404
|
||||
from ninja import Field, Schema
|
||||
|
||||
from activities.models import Post
|
||||
from activities.services import SearchService
|
||||
from api import schemas
|
||||
from api.decorators import identity_required
|
||||
from api.pagination import MastodonPaginator
|
||||
from api.views.base import api_router
|
||||
from api.pagination import MastodonPaginator, PaginatingApiResponse, PaginationResult
|
||||
from core.models import Config
|
||||
from hatchway import ApiResponse, QueryOrBody, api_view
|
||||
from users.models import Identity
|
||||
from users.services import IdentityService
|
||||
from users.shortcuts import by_handle_or_404
|
||||
|
||||
|
||||
@api_router.get("/v1/accounts/verify_credentials", response=schemas.Account)
|
||||
@identity_required
|
||||
def verify_credentials(request):
|
||||
return request.identity.to_mastodon_json(source=True)
|
||||
@api_view.get
|
||||
def verify_credentials(request) -> schemas.Account:
|
||||
return schemas.Account.from_identity(request.identity, source=True)
|
||||
|
||||
|
||||
@api_router.patch("/v1/accounts/update_credentials", response=schemas.Account)
|
||||
@identity_required
|
||||
@api_view.patch
|
||||
def update_credentials(
|
||||
request,
|
||||
):
|
||||
# Django won't load POST and FILES for patch methods, so we do it.
|
||||
if request.content_type == "multipart/form-data":
|
||||
POST, FILES = MultiPartParser(
|
||||
request.META, request, request.upload_handlers, request.encoding
|
||||
).parse()
|
||||
elif request.content_type == "application/x-www-form-urlencoded":
|
||||
POST = QueryDict(request.body, encoding=request._encoding)
|
||||
FILES = {}
|
||||
else:
|
||||
return HttpResponse(status=400)
|
||||
display_name: QueryOrBody[str | None] = None,
|
||||
note: QueryOrBody[str | None] = None,
|
||||
discoverable: QueryOrBody[bool | None] = None,
|
||||
source: QueryOrBody[dict[str, Any] | None] = None,
|
||||
fields_attributes: QueryOrBody[dict[str, dict[str, str]] | None] = None,
|
||||
avatar: File | None = None,
|
||||
header: File | None = None,
|
||||
) -> schemas.Account:
|
||||
identity = request.identity
|
||||
service = IdentityService(identity)
|
||||
if "display_name" in POST:
|
||||
identity.name = POST["display_name"]
|
||||
if "note" in POST:
|
||||
service.set_summary(POST["note"])
|
||||
if "discoverable" in POST:
|
||||
identity.discoverable = POST["discoverable"] == "checked"
|
||||
if "source[privacy]" in POST:
|
||||
privacy_map = {
|
||||
"public": Post.Visibilities.public,
|
||||
"unlisted": Post.Visibilities.unlisted,
|
||||
"private": Post.Visibilities.followers,
|
||||
"direct": Post.Visibilities.mentioned,
|
||||
}
|
||||
Config.set_identity(
|
||||
identity,
|
||||
"default_post_visibility",
|
||||
privacy_map[POST["source[privacy]"]],
|
||||
)
|
||||
if "fields_attributes[0][name]" in POST:
|
||||
if display_name is not None:
|
||||
identity.name = display_name
|
||||
if note is not None:
|
||||
service.set_summary(note)
|
||||
if discoverable is not None:
|
||||
identity.discoverable = discoverable
|
||||
if source:
|
||||
if "privacy" in source:
|
||||
privacy_map = {
|
||||
"public": Post.Visibilities.public,
|
||||
"unlisted": Post.Visibilities.unlisted,
|
||||
"private": Post.Visibilities.followers,
|
||||
"direct": Post.Visibilities.mentioned,
|
||||
}
|
||||
Config.set_identity(
|
||||
identity,
|
||||
"default_post_visibility",
|
||||
privacy_map[source["privacy"]],
|
||||
)
|
||||
if fields_attributes:
|
||||
identity.metadata = []
|
||||
for i in range(4):
|
||||
name_name = f"fields_attributes[{i}][name]"
|
||||
value_name = f"fields_attributes[{i}][value]"
|
||||
if name_name and value_name in POST:
|
||||
for attribute in fields_attributes.values():
|
||||
attr_name = attribute.get("name", None)
|
||||
attr_value = attribute.get("value", None)
|
||||
if attr_name:
|
||||
# Empty value means delete this item
|
||||
if not POST[value_name]:
|
||||
if not attr_value:
|
||||
break
|
||||
identity.metadata.append(
|
||||
{"name": POST[name_name], "value": POST[value_name]}
|
||||
)
|
||||
if "avatar" in FILES:
|
||||
service.set_icon(FILES["avatar"])
|
||||
if "header" in FILES:
|
||||
service.set_image(FILES["header"])
|
||||
identity.metadata.append({"name": attr_name, "value": attr_value})
|
||||
if avatar:
|
||||
service.set_icon(avatar)
|
||||
if header:
|
||||
service.set_image(header)
|
||||
identity.save()
|
||||
return identity.to_mastodon_json(source=True)
|
||||
return schemas.Account.from_identity(identity, source=True)
|
||||
|
||||
|
||||
@api_router.get("/v1/accounts/relationships", response=list[schemas.Relationship])
|
||||
@identity_required
|
||||
def account_relationships(request):
|
||||
ids = request.GET.getlist("id[]")
|
||||
@api_view.get
|
||||
def account_relationships(request, id: list[str] | None) -> list[schemas.Relationship]:
|
||||
result = []
|
||||
for id in ids:
|
||||
identity = get_object_or_404(Identity, pk=id)
|
||||
# ID is actually a list. Thanks Mastodon!
|
||||
ids = id or []
|
||||
for actual_id in ids:
|
||||
identity = get_object_or_404(Identity, pk=actual_id)
|
||||
result.append(
|
||||
IdentityService(identity).mastodon_json_relationship(request.identity)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@api_router.get(
|
||||
"/v1/accounts/familiar_followers", response=list[schemas.FamiliarFollowers]
|
||||
)
|
||||
@identity_required
|
||||
def familiar_followers(request):
|
||||
@api_view.get
|
||||
def familiar_followers(
|
||||
request, id: list[str] | None
|
||||
) -> list[schemas.FamiliarFollowers]:
|
||||
"""
|
||||
Returns people you follow that also follow given account IDs
|
||||
"""
|
||||
ids = request.GET.getlist("id[]")
|
||||
ids = id or []
|
||||
result = []
|
||||
for id in ids:
|
||||
target_identity = get_object_or_404(Identity, pk=id)
|
||||
for actual_id in ids:
|
||||
target_identity = get_object_or_404(Identity, pk=actual_id)
|
||||
result.append(
|
||||
{
|
||||
"id": id,
|
||||
"accounts": [
|
||||
identity.to_mastodon_json()
|
||||
schemas.FamiliarFollowers(
|
||||
id=actual_id,
|
||||
accounts=[
|
||||
schemas.Account.from_identity(identity)
|
||||
for identity in Identity.objects.filter(
|
||||
inbound_follows__source=request.identity,
|
||||
outbound_follows__target=target_identity,
|
||||
)[:20]
|
||||
],
|
||||
}
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@api_router.get("/v1/accounts/search", response=list[schemas.Account])
|
||||
@identity_required
|
||||
def search(
|
||||
@api_view.get
|
||||
def accounts_search(
|
||||
request,
|
||||
q: str,
|
||||
fetch_identities: bool = Field(False, alias="resolve"),
|
||||
resolve: bool = False,
|
||||
following: bool = False,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
):
|
||||
) -> list[schemas.Account]:
|
||||
"""
|
||||
Handles searching for accounts by username or handle
|
||||
"""
|
||||
|
@ -135,33 +133,33 @@ def search(
|
|||
return []
|
||||
searcher = SearchService(q, request.identity)
|
||||
search_result = searcher.search_identities_handle()
|
||||
return [i.to_mastodon_json() for i in search_result]
|
||||
return [schemas.Account.from_identity(i) for i in search_result]
|
||||
|
||||
|
||||
@api_router.get("/v1/accounts/lookup", response=schemas.Account)
|
||||
def lookup(request: HttpRequest, acct: str):
|
||||
@api_view.get
|
||||
def lookup(request: HttpRequest, acct: str) -> schemas.Account:
|
||||
"""
|
||||
Quickly lookup a username to see if it is available, skipping WebFinger
|
||||
resolution.
|
||||
"""
|
||||
identity = by_handle_or_404(request, handle=acct, local=False)
|
||||
return identity.to_mastodon_json()
|
||||
return schemas.Account.from_identity(identity)
|
||||
|
||||
|
||||
@api_router.get("/v1/accounts/{id}", response=schemas.Account)
|
||||
@api_view.get
|
||||
@identity_required
|
||||
def account(request, id: str):
|
||||
def account(request, id: str) -> schemas.Account:
|
||||
identity = get_object_or_404(
|
||||
Identity.objects.exclude(restriction=Identity.Restriction.blocked), pk=id
|
||||
Identity.objects.exclude(restriction=Identity.Restriction.blocked),
|
||||
pk=id,
|
||||
)
|
||||
return identity.to_mastodon_json()
|
||||
return schemas.Account.from_identity(identity)
|
||||
|
||||
|
||||
@api_router.get("/v1/accounts/{id}/statuses", response=list[schemas.Status])
|
||||
@api_view.get
|
||||
@identity_required
|
||||
def account_statuses(
|
||||
request: HttpRequest,
|
||||
response: HttpResponse,
|
||||
id: str,
|
||||
exclude_reblogs: bool = False,
|
||||
exclude_replies: bool = False,
|
||||
|
@ -172,7 +170,7 @@ def account_statuses(
|
|||
since_id: str | None = None,
|
||||
min_id: str | None = None,
|
||||
limit: int = 20,
|
||||
):
|
||||
) -> ApiResponse[list[schemas.Status]]:
|
||||
identity = get_object_or_404(
|
||||
Identity.objects.exclude(restriction=Identity.Restriction.blocked), pk=id
|
||||
)
|
||||
|
@ -191,177 +189,163 @@ def account_statuses(
|
|||
.order_by("-created")
|
||||
)
|
||||
if pinned:
|
||||
return []
|
||||
return ApiResponse([])
|
||||
if only_media:
|
||||
queryset = queryset.filter(attachments__pk__isnull=False)
|
||||
if tagged:
|
||||
queryset = queryset.tagged_with(tagged)
|
||||
# Get user posts with pagination
|
||||
paginator = MastodonPaginator()
|
||||
pager = paginator.paginate(
|
||||
pager: PaginationResult[Post] = paginator.paginate(
|
||||
queryset,
|
||||
min_id=min_id,
|
||||
max_id=max_id,
|
||||
since_id=since_id,
|
||||
limit=limit,
|
||||
)
|
||||
# Convert those to the JSON form
|
||||
pager.jsonify_posts(identity=request.identity)
|
||||
# Add a link header if we need to
|
||||
if pager.results:
|
||||
response.headers["Link"] = pager.link_header(
|
||||
request,
|
||||
[
|
||||
"limit",
|
||||
"id",
|
||||
"exclude_reblogs",
|
||||
"exclude_replies",
|
||||
"only_media",
|
||||
"pinned",
|
||||
"tagged",
|
||||
],
|
||||
)
|
||||
return pager.json_results
|
||||
return PaginatingApiResponse(
|
||||
schemas.Status.map_from_post(pager.results, request.identity),
|
||||
request=request,
|
||||
include_params=[
|
||||
"limit",
|
||||
"id",
|
||||
"exclude_reblogs",
|
||||
"exclude_replies",
|
||||
"only_media",
|
||||
"pinned",
|
||||
"tagged",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@api_router.post("/v1/accounts/{id}/follow", response=schemas.Relationship)
|
||||
@api_view.post
|
||||
@identity_required
|
||||
def account_follow(request, id: str, reblogs: bool = True):
|
||||
def account_follow(request, id: str, reblogs: bool = True) -> schemas.Relationship:
|
||||
identity = get_object_or_404(
|
||||
Identity.objects.exclude(restriction=Identity.Restriction.blocked), pk=id
|
||||
)
|
||||
service = IdentityService(identity)
|
||||
service.follow_from(request.identity, boosts=reblogs)
|
||||
return service.mastodon_json_relationship(request.identity)
|
||||
return schemas.Relationship.from_identity_pair(identity, request.identity)
|
||||
|
||||
|
||||
@api_router.post("/v1/accounts/{id}/unfollow", response=schemas.Relationship)
|
||||
@api_view.post
|
||||
@identity_required
|
||||
def account_unfollow(request, id: str):
|
||||
def account_unfollow(request, id: str) -> schemas.Relationship:
|
||||
identity = get_object_or_404(
|
||||
Identity.objects.exclude(restriction=Identity.Restriction.blocked), pk=id
|
||||
)
|
||||
service = IdentityService(identity)
|
||||
service.unfollow_from(request.identity)
|
||||
return service.mastodon_json_relationship(request.identity)
|
||||
return schemas.Relationship.from_identity_pair(identity, request.identity)
|
||||
|
||||
|
||||
@api_router.post("/v1/accounts/{id}/block", response=schemas.Relationship)
|
||||
@api_view.post
|
||||
@identity_required
|
||||
def account_block(request, id: str):
|
||||
def account_block(request, id: str) -> schemas.Relationship:
|
||||
identity = get_object_or_404(Identity, pk=id)
|
||||
service = IdentityService(identity)
|
||||
service.block_from(request.identity)
|
||||
return service.mastodon_json_relationship(request.identity)
|
||||
return schemas.Relationship.from_identity_pair(identity, request.identity)
|
||||
|
||||
|
||||
@api_router.post("/v1/accounts/{id}/unblock", response=schemas.Relationship)
|
||||
@api_view.post
|
||||
@identity_required
|
||||
def account_unblock(request, id: str):
|
||||
def account_unblock(request, id: str) -> schemas.Relationship:
|
||||
identity = get_object_or_404(Identity, pk=id)
|
||||
service = IdentityService(identity)
|
||||
service.unblock_from(request.identity)
|
||||
return service.mastodon_json_relationship(request.identity)
|
||||
return schemas.Relationship.from_identity_pair(identity, request.identity)
|
||||
|
||||
|
||||
class MuteDetailsSchema(Schema):
|
||||
notifications: bool = True
|
||||
duration: int = 0
|
||||
|
||||
|
||||
@api_router.post("/v1/accounts/{id}/mute", response=schemas.Relationship)
|
||||
@identity_required
|
||||
def account_mute(request, id: str, details: MuteDetailsSchema):
|
||||
@api_view.post
|
||||
def account_mute(
|
||||
request,
|
||||
id: str,
|
||||
notifications: QueryOrBody[bool] = True,
|
||||
duration: QueryOrBody[int] = 0,
|
||||
) -> schemas.Relationship:
|
||||
identity = get_object_or_404(Identity, pk=id)
|
||||
service = IdentityService(identity)
|
||||
service.mute_from(
|
||||
request.identity,
|
||||
duration=details.duration,
|
||||
include_notifications=details.notifications,
|
||||
duration=duration,
|
||||
include_notifications=notifications,
|
||||
)
|
||||
return service.mastodon_json_relationship(request.identity)
|
||||
return schemas.Relationship.from_identity_pair(identity, request.identity)
|
||||
|
||||
|
||||
@api_router.post("/v1/accounts/{id}/unmute", response=schemas.Relationship)
|
||||
@identity_required
|
||||
def account_unmute(request, id: str):
|
||||
@api_view.post
|
||||
def account_unmute(request, id: str) -> schemas.Relationship:
|
||||
identity = get_object_or_404(Identity, pk=id)
|
||||
service = IdentityService(identity)
|
||||
service.unmute_from(request.identity)
|
||||
return service.mastodon_json_relationship(request.identity)
|
||||
return schemas.Relationship.from_identity_pair(identity, request.identity)
|
||||
|
||||
|
||||
@api_router.get("/v1/accounts/{id}/following", response=list[schemas.Account])
|
||||
@api_view.get
|
||||
def account_following(
|
||||
request: HttpRequest,
|
||||
response: HttpResponse,
|
||||
id: str,
|
||||
max_id: str | None = None,
|
||||
since_id: str | None = None,
|
||||
min_id: str | None = None,
|
||||
limit: int = 40,
|
||||
):
|
||||
) -> ApiResponse[list[schemas.Account]]:
|
||||
identity = get_object_or_404(
|
||||
Identity.objects.exclude(restriction=Identity.Restriction.blocked), pk=id
|
||||
)
|
||||
|
||||
if not identity.config_identity.visible_follows and request.identity != identity:
|
||||
return []
|
||||
return ApiResponse([])
|
||||
|
||||
service = IdentityService(identity)
|
||||
|
||||
paginator = MastodonPaginator(max_limit=80)
|
||||
pager = paginator.paginate(
|
||||
pager: PaginationResult[Identity] = paginator.paginate(
|
||||
service.following(),
|
||||
min_id=min_id,
|
||||
max_id=max_id,
|
||||
since_id=since_id,
|
||||
limit=limit,
|
||||
)
|
||||
pager.jsonify_identities()
|
||||
|
||||
if pager.results:
|
||||
response.headers["Link"] = pager.link_header(
|
||||
request,
|
||||
["limit"],
|
||||
)
|
||||
|
||||
return pager.json_results
|
||||
return PaginatingApiResponse(
|
||||
[schemas.Account.from_identity(i) for i in pager.results],
|
||||
request=request,
|
||||
include_params=["limit"],
|
||||
)
|
||||
|
||||
|
||||
@api_router.get("/v1/accounts/{id}/followers", response=list[schemas.Account])
|
||||
@api_view.get
|
||||
def account_followers(
|
||||
request: HttpRequest,
|
||||
response: HttpResponse,
|
||||
id: str,
|
||||
max_id: str | None = None,
|
||||
since_id: str | None = None,
|
||||
min_id: str | None = None,
|
||||
limit: int = 40,
|
||||
):
|
||||
) -> ApiResponse[list[schemas.Account]]:
|
||||
identity = get_object_or_404(
|
||||
Identity.objects.exclude(restriction=Identity.Restriction.blocked), pk=id
|
||||
)
|
||||
|
||||
if not identity.config_identity.visible_follows and request.identity != identity:
|
||||
return []
|
||||
return ApiResponse([])
|
||||
|
||||
service = IdentityService(identity)
|
||||
|
||||
paginator = MastodonPaginator(max_limit=80)
|
||||
pager = paginator.paginate(
|
||||
pager: PaginationResult[Identity] = paginator.paginate(
|
||||
service.followers(),
|
||||
min_id=min_id,
|
||||
max_id=max_id,
|
||||
since_id=since_id,
|
||||
limit=limit,
|
||||
)
|
||||
pager.jsonify_identities()
|
||||
|
||||
if pager.results:
|
||||
response.headers["Link"] = pager.link_header(
|
||||
request,
|
||||
["limit"],
|
||||
)
|
||||
|
||||
return pager.json_results
|
||||
return PaginatingApiResponse(
|
||||
[schemas.Account.from_identity(i) for i in pager.results],
|
||||
request=request,
|
||||
include_params=["limit"],
|
||||
)
|
||||
|
|
|
@ -2,22 +2,22 @@ from django.shortcuts import get_object_or_404
|
|||
|
||||
from api import schemas
|
||||
from api.decorators import identity_required
|
||||
from api.views.base import api_router
|
||||
from hatchway import api_view
|
||||
from users.models import Announcement
|
||||
from users.services import AnnouncementService
|
||||
|
||||
|
||||
@api_router.get("/v1/announcements", response=list[schemas.Announcement])
|
||||
@identity_required
|
||||
def announcement_list(request):
|
||||
@api_view.get
|
||||
def announcement_list(request) -> list[schemas.Announcement]:
|
||||
return [
|
||||
a.to_mastodon_json(request.user)
|
||||
schemas.Announcement.from_announcement(a, request.user)
|
||||
for a in AnnouncementService(request.user).visible()
|
||||
]
|
||||
|
||||
|
||||
@api_router.post("/v1/announcements/{pk}/dismiss")
|
||||
@identity_required
|
||||
@api_view.post
|
||||
def announcement_dismiss(request, pk: str):
|
||||
announcement = get_object_or_404(Announcement, pk=pk)
|
||||
AnnouncementService(request.user).mark_seen(announcement)
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
import secrets
|
||||
|
||||
from ninja import Schema
|
||||
from hatchway import Schema, api_view
|
||||
|
||||
from .. import schemas
|
||||
from ..models import Application
|
||||
from .base import api_router
|
||||
|
||||
|
||||
class CreateApplicationSchema(Schema):
|
||||
|
@ -14,8 +13,8 @@ class CreateApplicationSchema(Schema):
|
|||
website: None | str = None
|
||||
|
||||
|
||||
@api_router.post("/v1/apps", response=schemas.Application)
|
||||
def add_app(request, details: CreateApplicationSchema):
|
||||
@api_view.post
|
||||
def add_app(request, details: CreateApplicationSchema) -> schemas.Application:
|
||||
client_id = "tk-" + secrets.token_urlsafe(16)
|
||||
client_secret = secrets.token_urlsafe(40)
|
||||
application = Application.objects.create(
|
||||
|
@ -26,4 +25,4 @@ def add_app(request, details: CreateApplicationSchema):
|
|||
redirect_uris=details.redirect_uris,
|
||||
scopes=details.scopes or "read",
|
||||
)
|
||||
return application
|
||||
return schemas.Application.from_orm(application)
|
||||
|
|
|
@ -1,5 +0,0 @@
|
|||
from ninja import NinjaAPI
|
||||
|
||||
from api.parser import FormOrJsonParser
|
||||
|
||||
api_router = NinjaAPI(parser=FormOrJsonParser())
|
|
@ -1,8 +1,10 @@
|
|||
from activities.models import Emoji
|
||||
from api.schemas import CustomEmoji
|
||||
from api.views.base import api_router
|
||||
from hatchway import api_view
|
||||
|
||||
|
||||
@api_router.get("/v1/custom_emojis", response=list[CustomEmoji])
|
||||
def emojis(request):
|
||||
return [e.to_mastodon_json() for e in Emoji.objects.usable().filter(local=True)]
|
||||
@api_view.get
|
||||
def emojis(request) -> list[CustomEmoji]:
|
||||
return [
|
||||
CustomEmoji.from_emoji(e) for e in Emoji.objects.usable().filter(local=True)
|
||||
]
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from api.decorators import identity_required
|
||||
from api.views.base import api_router
|
||||
from hatchway import api_view
|
||||
|
||||
|
||||
@api_router.get("/v1/filters")
|
||||
@identity_required
|
||||
def status(request):
|
||||
@api_view.get
|
||||
def list_filters(request):
|
||||
return []
|
||||
|
|
|
@ -1,15 +1,15 @@
|
|||
from django.conf import settings
|
||||
|
||||
from activities.models import Post
|
||||
from api import schemas
|
||||
from core.models import Config
|
||||
from hatchway import api_view
|
||||
from takahe import __version__
|
||||
from users.models import Domain, Identity
|
||||
|
||||
from .base import api_router
|
||||
|
||||
|
||||
@api_router.get("/v1/instance")
|
||||
def instance_info(request):
|
||||
@api_view.get
|
||||
def instance_info_v1(request):
|
||||
return {
|
||||
"uri": request.headers.get("host", settings.SETUP.MAIN_DOMAIN),
|
||||
"title": Config.system.site_name,
|
||||
|
@ -53,11 +53,17 @@ def instance_info(request):
|
|||
}
|
||||
|
||||
|
||||
@api_router.get("/v2/instance")
|
||||
def instance_info_v2(request):
|
||||
@api_view.get
|
||||
def instance_info_v2(request) -> dict:
|
||||
current_domain = Domain.get_domain(
|
||||
request.headers.get("host", settings.SETUP.MAIN_DOMAIN)
|
||||
)
|
||||
if current_domain is None or not current_domain.local:
|
||||
current_domain = Domain.get_domain(
|
||||
request.headers.get(settings.SETUP.MAIN_DOMAIN)
|
||||
)
|
||||
if current_domain is None:
|
||||
raise ValueError("No domain set up for MAIN_DOMAIN")
|
||||
admin_identity = (
|
||||
Identity.objects.filter(users__admin=True).order_by("created").first()
|
||||
)
|
||||
|
@ -116,7 +122,7 @@ def instance_info_v2(request):
|
|||
},
|
||||
"contact": {
|
||||
"email": "",
|
||||
"account": admin_identity.to_mastodon_json(),
|
||||
"account": schemas.Account.from_identity(admin_identity),
|
||||
},
|
||||
"rules": [],
|
||||
}
|
||||
|
|
|
@ -1,28 +1,22 @@
|
|||
from django.core.files import File
|
||||
from django.shortcuts import get_object_or_404
|
||||
from ninja import File, Schema
|
||||
from ninja.files import UploadedFile
|
||||
|
||||
from activities.models import PostAttachment, PostAttachmentStates
|
||||
from api import schemas
|
||||
from api.views.base import api_router
|
||||
from core.files import blurhash_image, resize_image
|
||||
from hatchway import QueryOrBody, api_view
|
||||
|
||||
from ..decorators import identity_required
|
||||
|
||||
|
||||
class UploadMediaSchema(Schema):
|
||||
description: str = ""
|
||||
focus: str = "0,0"
|
||||
|
||||
|
||||
@api_router.post("/v1/media", response=schemas.MediaAttachment)
|
||||
@api_router.post("/v2/media", response=schemas.MediaAttachment)
|
||||
@identity_required
|
||||
@api_view.post
|
||||
def upload_media(
|
||||
request,
|
||||
file: UploadedFile = File(...),
|
||||
details: UploadMediaSchema | None = None,
|
||||
):
|
||||
file: File,
|
||||
description: QueryOrBody[str] = "",
|
||||
focus: QueryOrBody[str] = "0,0",
|
||||
) -> schemas.MediaAttachment:
|
||||
main_file = resize_image(
|
||||
file,
|
||||
size=(2000, 2000),
|
||||
|
@ -38,7 +32,7 @@ def upload_media(
|
|||
mimetype="image/webp",
|
||||
width=main_file.image.width,
|
||||
height=main_file.image.height,
|
||||
name=details.description if details else None,
|
||||
name=description or None,
|
||||
state=PostAttachmentStates.fetched,
|
||||
)
|
||||
attachment.file.save(
|
||||
|
@ -50,27 +44,28 @@ def upload_media(
|
|||
thumbnail_file,
|
||||
)
|
||||
attachment.save()
|
||||
return attachment.to_mastodon_json()
|
||||
return schemas.MediaAttachment.from_post_attachment(attachment)
|
||||
|
||||
|
||||
@api_router.get("/v1/media/{id}", response=schemas.MediaAttachment)
|
||||
@identity_required
|
||||
@api_view.get
|
||||
def get_media(
|
||||
request,
|
||||
id: str,
|
||||
):
|
||||
) -> schemas.MediaAttachment:
|
||||
attachment = get_object_or_404(PostAttachment, pk=id)
|
||||
return attachment.to_mastodon_json()
|
||||
return schemas.MediaAttachment.from_post_attachment(attachment)
|
||||
|
||||
|
||||
@api_router.put("/v1/media/{id}", response=schemas.MediaAttachment)
|
||||
@identity_required
|
||||
@api_view.put
|
||||
def update_media(
|
||||
request,
|
||||
id: str,
|
||||
details: UploadMediaSchema | None = None,
|
||||
):
|
||||
description: QueryOrBody[str] = "",
|
||||
focus: QueryOrBody[str] = "0,0",
|
||||
) -> schemas.MediaAttachment:
|
||||
attachment = get_object_or_404(PostAttachment, pk=id)
|
||||
attachment.name = details.description if details else None
|
||||
attachment.name = description or None
|
||||
attachment.save()
|
||||
return attachment.to_mastodon_json()
|
||||
return schemas.MediaAttachment.from_post_attachment(attachment)
|
||||
|
|
|
@ -1,24 +1,23 @@
|
|||
from django.http import HttpRequest, HttpResponse
|
||||
from django.http import HttpRequest
|
||||
|
||||
from activities.models import TimelineEvent
|
||||
from activities.services import TimelineService
|
||||
from api import schemas
|
||||
from api.decorators import identity_required
|
||||
from api.pagination import MastodonPaginator
|
||||
from api.views.base import api_router
|
||||
from api.pagination import MastodonPaginator, PaginatingApiResponse, PaginationResult
|
||||
from hatchway import ApiResponse, api_view
|
||||
|
||||
|
||||
@api_router.get("/v1/notifications", response=list[schemas.Notification])
|
||||
@identity_required
|
||||
@api_view.get
|
||||
def notifications(
|
||||
request: HttpRequest,
|
||||
response: HttpResponse,
|
||||
max_id: str | None = None,
|
||||
since_id: str | None = None,
|
||||
min_id: str | None = None,
|
||||
limit: int = 20,
|
||||
account_id: str | None = None,
|
||||
):
|
||||
) -> ApiResponse[list[schemas.Notification]]:
|
||||
# Types/exclude_types use weird syntax so we have to handle them manually
|
||||
base_types = {
|
||||
"favourite": TimelineEvent.Types.liked,
|
||||
|
@ -37,16 +36,15 @@ def notifications(
|
|||
[base_types[r] for r in requested_types if r in base_types]
|
||||
)
|
||||
paginator = MastodonPaginator()
|
||||
pager = paginator.paginate(
|
||||
pager: PaginationResult[TimelineEvent] = paginator.paginate(
|
||||
queryset,
|
||||
min_id=min_id,
|
||||
max_id=max_id,
|
||||
since_id=since_id,
|
||||
limit=limit,
|
||||
)
|
||||
pager.jsonify_notification_events(identity=request.identity)
|
||||
|
||||
if pager.results:
|
||||
response.headers["Link"] = pager.link_header(request, ["limit", "account_id"])
|
||||
|
||||
return pager.json_results
|
||||
return PaginatingApiResponse(
|
||||
[schemas.Notification.from_timeline_event(event) for event in pager.results],
|
||||
request=request,
|
||||
include_params=["limit", "account_id"],
|
||||
)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import base64
|
||||
import json
|
||||
import secrets
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
|
@ -16,7 +17,6 @@ from django.views.decorators.csrf import csrf_exempt
|
|||
from django.views.generic import View
|
||||
|
||||
from api.models import Application, Authorization, Token
|
||||
from api.parser import FormOrJsonParser
|
||||
|
||||
|
||||
class OauthRedirect(HttpResponseRedirect):
|
||||
|
@ -40,6 +40,19 @@ class OauthRedirect(HttpResponseRedirect):
|
|||
super().__init__(urlunparse(url_parts))
|
||||
|
||||
|
||||
def get_json_and_formdata(request):
|
||||
# Did they submit JSON?
|
||||
if request.content_type == "application/json" and request.body.strip():
|
||||
return json.loads(request.body)
|
||||
# Fall back to form data
|
||||
value = {}
|
||||
for key, item in request.POST.items():
|
||||
value[key] = item
|
||||
for key, item in request.GET.items():
|
||||
value[key] = item
|
||||
return value
|
||||
|
||||
|
||||
class AuthorizationView(LoginRequiredMixin, View):
|
||||
"""
|
||||
Asks the user to authorize access.
|
||||
|
@ -87,7 +100,7 @@ class AuthorizationView(LoginRequiredMixin, View):
|
|||
return render(request, "api/oauth_authorize.html", context)
|
||||
|
||||
def post(self, request):
|
||||
post_data = FormOrJsonParser().parse_body(request)
|
||||
post_data = get_json_and_formdata(request)
|
||||
# Grab the application and other details again
|
||||
redirect_uri = post_data["redirect_uri"]
|
||||
scope = post_data["scope"]
|
||||
|
@ -141,7 +154,7 @@ class TokenView(View):
|
|||
)
|
||||
|
||||
def post(self, request):
|
||||
post_data = FormOrJsonParser().parse_body(request)
|
||||
post_data = get_json_and_formdata(request)
|
||||
auth_client_id, auth_client_secret = extract_client_info_from_basic_auth(
|
||||
request
|
||||
)
|
||||
|
@ -221,7 +234,7 @@ class TokenView(View):
|
|||
@method_decorator(csrf_exempt, name="dispatch")
|
||||
class RevokeTokenView(View):
|
||||
def post(self, request):
|
||||
post_data = FormOrJsonParser().parse_body(request)
|
||||
post_data = get_json_and_formdata(request)
|
||||
auth_client_id, auth_client_secret = extract_client_info_from_basic_auth(
|
||||
request
|
||||
)
|
||||
|
|
|
@ -1,16 +1,14 @@
|
|||
from typing import Literal
|
||||
|
||||
from ninja import Field
|
||||
|
||||
from activities.models import PostInteraction
|
||||
from activities.services.search import SearchService
|
||||
from api import schemas
|
||||
from api.decorators import identity_required
|
||||
from api.views.base import api_router
|
||||
from hatchway import Field, api_view
|
||||
|
||||
|
||||
@api_router.get("/v2/search", response=schemas.Search)
|
||||
@identity_required
|
||||
@api_view.get
|
||||
def search(
|
||||
request,
|
||||
q: str,
|
||||
|
@ -24,26 +22,31 @@ def search(
|
|||
min_id: str | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
):
|
||||
) -> schemas.Search:
|
||||
if limit > 40:
|
||||
limit = 40
|
||||
result: dict[str, list] = {"accounts": [], "statuses": [], "hashtags": []}
|
||||
# We don't support pagination for searches yet
|
||||
if max_id or since_id or min_id or offset:
|
||||
return result
|
||||
return schemas.Search(**result)
|
||||
# Run search
|
||||
searcher = SearchService(q, request.identity)
|
||||
search_result = searcher.search_all()
|
||||
if type is None or type == "accounts":
|
||||
result["accounts"] = [i.to_mastodon_json() for i in search_result["identities"]]
|
||||
result["accounts"] = [
|
||||
schemas.Account.from_identity(i, include_counts=False)
|
||||
for i in search_result["identities"]
|
||||
]
|
||||
if type is None or type == "hashtag":
|
||||
result["hashtag"] = [h.to_mastodon_json() for h in search_result["hashtags"]]
|
||||
result["hashtag"] = [
|
||||
schemas.Tag.from_hashtag(h) for h in search_result["hashtags"]
|
||||
]
|
||||
if type is None or type == "statuses":
|
||||
interactions = PostInteraction.get_post_interactions(
|
||||
search_result["posts"], request.identity
|
||||
)
|
||||
result["statuses"] = [
|
||||
p.to_mastodon_json(interactions=interactions)
|
||||
schemas.Status.from_post(p, interactions=interactions)
|
||||
for p in search_result["posts"]
|
||||
]
|
||||
return result
|
||||
return schemas.Search(**result)
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
from typing import Literal
|
||||
|
||||
from django.forms import ValidationError
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
from django.http import HttpRequest
|
||||
from django.shortcuts import get_object_or_404
|
||||
from ninja import Schema
|
||||
|
||||
from activities.models import (
|
||||
Post,
|
||||
|
@ -14,11 +13,10 @@ from activities.models import (
|
|||
)
|
||||
from activities.services import PostService
|
||||
from api import schemas
|
||||
from api.views.base import api_router
|
||||
from api.decorators import identity_required
|
||||
from api.pagination import MastodonPaginator, PaginationResult
|
||||
from core.models import Config
|
||||
|
||||
from ..decorators import identity_required
|
||||
from ..pagination import MastodonPaginator
|
||||
from hatchway import ApiResponse, Schema, api_view
|
||||
|
||||
|
||||
class PostStatusSchema(Schema):
|
||||
|
@ -32,9 +30,9 @@ class PostStatusSchema(Schema):
|
|||
media_ids: list[str] = []
|
||||
|
||||
|
||||
@api_router.post("/v1/statuses", response=schemas.Status)
|
||||
@identity_required
|
||||
def post_status(request, details: PostStatusSchema):
|
||||
@api_view.post
|
||||
def post_status(request, details: PostStatusSchema) -> schemas.Status:
|
||||
# Check text length
|
||||
if len(details.status) > Config.system.post_length:
|
||||
raise ValidationError("Status is too long")
|
||||
|
@ -66,74 +64,74 @@ def post_status(request, details: PostStatusSchema):
|
|||
)
|
||||
# Add their own timeline event for immediate visibility
|
||||
TimelineEvent.add_post(request.identity, post)
|
||||
return post.to_mastodon_json()
|
||||
return schemas.Status.from_post(post)
|
||||
|
||||
|
||||
@api_router.get("/v1/statuses/{id}", response=schemas.Status)
|
||||
@identity_required
|
||||
def status(request, id: str):
|
||||
@api_view.get
|
||||
def status(request, id: str) -> schemas.Status:
|
||||
post = get_object_or_404(Post, pk=id)
|
||||
interactions = PostInteraction.get_post_interactions([post], request.identity)
|
||||
return post.to_mastodon_json(interactions=interactions)
|
||||
return schemas.Status.from_post(post, interactions=interactions)
|
||||
|
||||
|
||||
@api_router.delete("/v1/statuses/{id}", response=schemas.Status)
|
||||
@identity_required
|
||||
def delete_status(request, id: str):
|
||||
@api_view.delete
|
||||
def delete_status(request, id: str) -> schemas.Status:
|
||||
post = get_object_or_404(Post, pk=id)
|
||||
PostService(post).delete()
|
||||
return post.to_mastodon_json()
|
||||
return schemas.Status.from_post(post)
|
||||
|
||||
|
||||
@api_router.get("/v1/statuses/{id}/context", response=schemas.Context)
|
||||
@identity_required
|
||||
def status_context(request, id: str):
|
||||
@api_view.get
|
||||
def status_context(request, id: str) -> schemas.Context:
|
||||
post = get_object_or_404(Post, pk=id)
|
||||
service = PostService(post)
|
||||
ancestors, descendants = service.context(request.identity)
|
||||
interactions = PostInteraction.get_post_interactions(
|
||||
ancestors + descendants, request.identity
|
||||
)
|
||||
return {
|
||||
"ancestors": [
|
||||
p.to_mastodon_json(interactions=interactions) for p in reversed(ancestors)
|
||||
return schemas.Context(
|
||||
ancestors=[
|
||||
schemas.Status.from_post(p, interactions=interactions)
|
||||
for p in reversed(ancestors)
|
||||
],
|
||||
"descendants": [
|
||||
p.to_mastodon_json(interactions=interactions) for p in descendants
|
||||
descendants=[
|
||||
schemas.Status.from_post(p, interactions=interactions) for p in descendants
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@api_router.post("/v1/statuses/{id}/favourite", response=schemas.Status)
|
||||
@identity_required
|
||||
def favourite_status(request, id: str):
|
||||
@api_view.post
|
||||
def favourite_status(request, id: str) -> schemas.Status:
|
||||
post = get_object_or_404(Post, pk=id)
|
||||
service = PostService(post)
|
||||
service.like_as(request.identity)
|
||||
interactions = PostInteraction.get_post_interactions([post], request.identity)
|
||||
return post.to_mastodon_json(interactions=interactions)
|
||||
return schemas.Status.from_post(post, interactions=interactions)
|
||||
|
||||
|
||||
@api_router.post("/v1/statuses/{id}/unfavourite", response=schemas.Status)
|
||||
@identity_required
|
||||
def unfavourite_status(request, id: str):
|
||||
@api_view.post
|
||||
def unfavourite_status(request, id: str) -> schemas.Status:
|
||||
post = get_object_or_404(Post, pk=id)
|
||||
service = PostService(post)
|
||||
service.unlike_as(request.identity)
|
||||
interactions = PostInteraction.get_post_interactions([post], request.identity)
|
||||
return post.to_mastodon_json(interactions=interactions)
|
||||
return schemas.Status.from_post(post, interactions=interactions)
|
||||
|
||||
|
||||
@api_router.get("/v1/statuses/{id}/favourited_by", response=list[schemas.Account])
|
||||
@api_view.get
|
||||
def favourited_by(
|
||||
request: HttpRequest,
|
||||
response: HttpResponse,
|
||||
id: str,
|
||||
max_id: str | None = None,
|
||||
since_id: str | None = None,
|
||||
min_id: str | None = None,
|
||||
limit: int = 20,
|
||||
):
|
||||
) -> ApiResponse[list[schemas.Account]]:
|
||||
"""
|
||||
View who favourited a given status.
|
||||
"""
|
||||
|
@ -142,7 +140,7 @@ def favourited_by(
|
|||
post = get_object_or_404(Post, pk=id)
|
||||
|
||||
paginator = MastodonPaginator()
|
||||
pager = paginator.paginate(
|
||||
pager: PaginationResult[PostInteraction] = paginator.paginate(
|
||||
post.interactions.filter(
|
||||
type=PostInteraction.Types.like,
|
||||
state__in=PostInteractionStates.group_active(),
|
||||
|
@ -152,32 +150,37 @@ def favourited_by(
|
|||
since_id=since_id,
|
||||
limit=limit,
|
||||
)
|
||||
pager.jsonify_results(lambda r: r.identity.to_mastodon_json(include_counts=False))
|
||||
|
||||
headers = {}
|
||||
if pager.results:
|
||||
response.headers["Link"] = pager.link_header(
|
||||
request,
|
||||
["limit"],
|
||||
)
|
||||
|
||||
return pager.json_results
|
||||
headers = {"link": pager.link_header(request, ["limit"])}
|
||||
return ApiResponse(
|
||||
[
|
||||
schemas.Account.from_identity(
|
||||
interaction.identity,
|
||||
include_counts=False,
|
||||
)
|
||||
for interaction in pager.results
|
||||
],
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
|
||||
@api_router.post("/v1/statuses/{id}/reblog", response=schemas.Status)
|
||||
@identity_required
|
||||
def reblog_status(request, id: str):
|
||||
@api_view.post
|
||||
def reblog_status(request, id: str) -> schemas.Status:
|
||||
post = get_object_or_404(Post, pk=id)
|
||||
service = PostService(post)
|
||||
service.boost_as(request.identity)
|
||||
interactions = PostInteraction.get_post_interactions([post], request.identity)
|
||||
return post.to_mastodon_json(interactions=interactions)
|
||||
return schemas.Status.from_post(post, interactions=interactions)
|
||||
|
||||
|
||||
@api_router.post("/v1/statuses/{id}/unreblog", response=schemas.Status)
|
||||
@identity_required
|
||||
def unreblog_status(request, id: str):
|
||||
@api_view.post
|
||||
def unreblog_status(request, id: str) -> schemas.Status:
|
||||
post = get_object_or_404(Post, pk=id)
|
||||
service = PostService(post)
|
||||
service.unboost_as(request.identity)
|
||||
interactions = PostInteraction.get_post_interactions([post], request.identity)
|
||||
return post.to_mastodon_json(interactions=interactions)
|
||||
return schemas.Status.from_post(post, interactions=interactions)
|
||||
|
|
|
@ -1,23 +1,23 @@
|
|||
from django.http import HttpRequest, HttpResponse, JsonResponse
|
||||
from django.http import HttpRequest
|
||||
|
||||
from activities.models import Post
|
||||
from activities.services import TimelineService
|
||||
from api import schemas
|
||||
from api.decorators import identity_required
|
||||
from api.pagination import MastodonPaginator
|
||||
from api.views.base import api_router
|
||||
from api.pagination import MastodonPaginator, PaginatingApiResponse, PaginationResult
|
||||
from core.models import Config
|
||||
from hatchway import ApiError, ApiResponse, api_view
|
||||
|
||||
|
||||
@api_router.get("/v1/timelines/home", response=list[schemas.Status])
|
||||
@identity_required
|
||||
@api_view.get
|
||||
def home(
|
||||
request: HttpRequest,
|
||||
response: HttpResponse,
|
||||
max_id: str | None = None,
|
||||
since_id: str | None = None,
|
||||
min_id: str | None = None,
|
||||
limit: int = 20,
|
||||
):
|
||||
) -> ApiResponse[list[schemas.Status]]:
|
||||
# Grab a paginated result set of instances
|
||||
paginator = MastodonPaginator()
|
||||
queryset = TimelineService(request.identity).home()
|
||||
|
@ -41,18 +41,16 @@ def home(
|
|||
since_id=since_id,
|
||||
limit=limit,
|
||||
)
|
||||
# Convert those to the JSON form
|
||||
pager.jsonify_status_events(identity=request.identity)
|
||||
# Add the link header if needed
|
||||
if pager.results:
|
||||
response.headers["Link"] = pager.link_header(request, ["limit"])
|
||||
return pager.json_results
|
||||
return PaginatingApiResponse(
|
||||
schemas.Status.map_from_timeline_event(pager.results, request.identity),
|
||||
request=request,
|
||||
include_params=["limit"],
|
||||
)
|
||||
|
||||
|
||||
@api_router.get("/v1/timelines/public", response=list[schemas.Status])
|
||||
@api_view.get
|
||||
def public(
|
||||
request: HttpRequest,
|
||||
response: HttpResponse,
|
||||
local: bool = False,
|
||||
remote: bool = False,
|
||||
only_media: bool = False,
|
||||
|
@ -60,9 +58,9 @@ def public(
|
|||
since_id: str | None = None,
|
||||
min_id: str | None = None,
|
||||
limit: int = 20,
|
||||
):
|
||||
) -> ApiResponse[list[schemas.Status]]:
|
||||
if not request.identity and not Config.system.public_timeline:
|
||||
return JsonResponse({"error": "public timeline is disabled"}, status=422)
|
||||
raise ApiError(error="public timeline is disabled", status=422)
|
||||
|
||||
if local:
|
||||
queryset = TimelineService(request.identity).local()
|
||||
|
@ -74,29 +72,24 @@ def public(
|
|||
queryset = queryset.filter(attachments__id__isnull=True)
|
||||
# Grab a paginated result set of instances
|
||||
paginator = MastodonPaginator()
|
||||
pager = paginator.paginate(
|
||||
pager: PaginationResult[Post] = paginator.paginate(
|
||||
queryset,
|
||||
min_id=min_id,
|
||||
max_id=max_id,
|
||||
since_id=since_id,
|
||||
limit=limit,
|
||||
)
|
||||
# Convert those to the JSON form
|
||||
pager.jsonify_posts(identity=request.identity)
|
||||
# Add the link header if needed
|
||||
if pager.results:
|
||||
response.headers["Link"] = pager.link_header(
|
||||
request,
|
||||
["limit", "local", "remote", "only_media"],
|
||||
)
|
||||
return pager.json_results
|
||||
return PaginatingApiResponse(
|
||||
schemas.Status.map_from_post(pager.results, request.identity),
|
||||
request=request,
|
||||
include_params=["limit", "local", "remote", "only_media"],
|
||||
)
|
||||
|
||||
|
||||
@api_router.get("/v1/timelines/tag/{hashtag}", response=list[schemas.Status])
|
||||
@identity_required
|
||||
@api_view.get
|
||||
def hashtag(
|
||||
request: HttpRequest,
|
||||
response: HttpResponse,
|
||||
hashtag: str,
|
||||
local: bool = False,
|
||||
only_media: bool = False,
|
||||
|
@ -104,7 +97,7 @@ def hashtag(
|
|||
since_id: str | None = None,
|
||||
min_id: str | None = None,
|
||||
limit: int = 20,
|
||||
):
|
||||
) -> ApiResponse[list[schemas.Status]]:
|
||||
if limit > 40:
|
||||
limit = 40
|
||||
queryset = TimelineService(request.identity).hashtag(hashtag)
|
||||
|
@ -114,63 +107,54 @@ def hashtag(
|
|||
queryset = queryset.filter(attachments__id__isnull=True)
|
||||
# Grab a paginated result set of instances
|
||||
paginator = MastodonPaginator()
|
||||
pager = paginator.paginate(
|
||||
pager: PaginationResult[Post] = paginator.paginate(
|
||||
queryset,
|
||||
min_id=min_id,
|
||||
max_id=max_id,
|
||||
since_id=since_id,
|
||||
limit=limit,
|
||||
)
|
||||
# Convert those to the JSON form
|
||||
pager.jsonify_posts(identity=request.identity)
|
||||
# Add a link header if we need to
|
||||
if pager.results:
|
||||
response.headers["Link"] = pager.link_header(
|
||||
request,
|
||||
["limit", "local", "remote", "only_media"],
|
||||
)
|
||||
return pager.json_results
|
||||
return PaginatingApiResponse(
|
||||
schemas.Status.map_from_post(pager.results, request.identity),
|
||||
request=request,
|
||||
include_params=["limit", "local", "remote", "only_media"],
|
||||
)
|
||||
|
||||
|
||||
@api_router.get("/v1/conversations", response=list[schemas.Status])
|
||||
@identity_required
|
||||
@api_view.get
|
||||
def conversations(
|
||||
request: HttpRequest,
|
||||
response: HttpResponse,
|
||||
max_id: str | None = None,
|
||||
since_id: str | None = None,
|
||||
min_id: str | None = None,
|
||||
limit: int = 20,
|
||||
):
|
||||
) -> list[schemas.Status]:
|
||||
# We don't implement this yet
|
||||
return []
|
||||
|
||||
|
||||
@api_router.get("/v1/favourites", response=list[schemas.Status])
|
||||
@identity_required
|
||||
@api_view.get
|
||||
def favourites(
|
||||
request: HttpRequest,
|
||||
response: HttpResponse,
|
||||
max_id: str | None = None,
|
||||
since_id: str | None = None,
|
||||
min_id: str | None = None,
|
||||
limit: int = 20,
|
||||
):
|
||||
) -> ApiResponse[list[schemas.Status]]:
|
||||
queryset = TimelineService(request.identity).likes()
|
||||
|
||||
paginator = MastodonPaginator()
|
||||
pager = paginator.paginate(
|
||||
pager: PaginationResult[Post] = paginator.paginate(
|
||||
queryset,
|
||||
min_id=min_id,
|
||||
max_id=max_id,
|
||||
since_id=since_id,
|
||||
limit=limit,
|
||||
)
|
||||
# Convert those to the JSON form
|
||||
pager.jsonify_posts(identity=request.identity)
|
||||
|
||||
# Add the link header if needed
|
||||
if pager.results:
|
||||
response.headers["Link"] = pager.link_header(request, ["limit"])
|
||||
|
||||
return pager.json_results
|
||||
return PaginatingApiResponse(
|
||||
schemas.Status.map_from_post(pager.results, request.identity),
|
||||
request=request,
|
||||
include_params=["limit"],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
from .http import ApiError, ApiResponse # noqa
|
||||
from .schema import Field, Schema # noqa
|
||||
from .types import Body, BodyDirect, Path, Query, QueryOrBody # noqa
|
||||
from .urls import methods # noqa
|
||||
from .view import api_view # noqa
|
|
@ -0,0 +1,10 @@
|
|||
import enum
|
||||
|
||||
|
||||
class InputSource(str, enum.Enum):
|
||||
path = "path"
|
||||
query = "query"
|
||||
body = "body"
|
||||
body_direct = "body_direct"
|
||||
query_and_body_direct = "query_and_body_direct"
|
||||
file = "file"
|
|
@ -0,0 +1,47 @@
|
|||
import json
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from django.core.serializers.json import DjangoJSONEncoder
|
||||
from django.http import HttpResponse
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ApiResponse(Generic[T], HttpResponse):
|
||||
"""
|
||||
A way to return extra information with a response if you want
|
||||
headers, etc.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: T,
|
||||
encoder=DjangoJSONEncoder,
|
||||
json_dumps_params: dict[str, object] | None = None,
|
||||
finalize: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
self.data = data
|
||||
self.encoder = encoder
|
||||
self.json_dumps_params = json_dumps_params or {}
|
||||
kwargs.setdefault("content_type", "application/json")
|
||||
super().__init__(content=b"(unfinalised)", **kwargs)
|
||||
if finalize:
|
||||
self.finalize()
|
||||
|
||||
def finalize(self):
|
||||
"""
|
||||
Converts whatever our current data is into HttpResponse content
|
||||
"""
|
||||
# TODO: Automatically call this when we're asked to write output?
|
||||
self.content = json.dumps(self.data, cls=self.encoder, **self.json_dumps_params)
|
||||
|
||||
|
||||
class ApiError(BaseException):
|
||||
"""
|
||||
A handy way to raise an error with JSONable contents
|
||||
"""
|
||||
|
||||
def __init__(self, status: int, error: str):
|
||||
self.status = status
|
||||
self.error = error
|
|
@ -0,0 +1,52 @@
|
|||
from typing import Any
|
||||
|
||||
from django.db.models import Manager, QuerySet
|
||||
from django.db.models.fields.files import FieldFile
|
||||
from django.template import Variable, VariableDoesNotExist
|
||||
from pydantic.fields import Field # noqa
|
||||
from pydantic.main import BaseModel
|
||||
from pydantic.utils import GetterDict
|
||||
|
||||
|
||||
class DjangoGetterDict(GetterDict):
|
||||
def __init__(self, obj: Any):
|
||||
self._obj = obj
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
try:
|
||||
item = getattr(self._obj, key)
|
||||
except AttributeError:
|
||||
try:
|
||||
item = Variable(key).resolve(self._obj)
|
||||
except VariableDoesNotExist as e:
|
||||
raise KeyError(key) from e
|
||||
return self._convert_result(item)
|
||||
|
||||
def get(self, key: Any, default: Any = None) -> Any:
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def _convert_result(self, result: Any) -> Any:
|
||||
if isinstance(result, Manager):
|
||||
return list(result.all())
|
||||
|
||||
elif isinstance(result, getattr(QuerySet, "__origin__", QuerySet)):
|
||||
return list(result)
|
||||
|
||||
if callable(result):
|
||||
return result()
|
||||
|
||||
elif isinstance(result, FieldFile):
|
||||
if not result:
|
||||
return None
|
||||
return result.url
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class Schema(BaseModel):
|
||||
class Config:
|
||||
orm_mode = True
|
||||
getter_dict = DjangoGetterDict
|
|
@ -0,0 +1,63 @@
|
|||
from typing import Literal, Optional, Union
|
||||
|
||||
from django.core.files import File
|
||||
|
||||
from hatchway.http import ApiResponse
|
||||
from hatchway.types import (
|
||||
Query,
|
||||
QueryType,
|
||||
acceptable_input,
|
||||
extract_output_type,
|
||||
extract_signifier,
|
||||
is_optional,
|
||||
)
|
||||
|
||||
|
||||
def test_is_optional():
|
||||
|
||||
assert is_optional(Optional[int]) == (True, int)
|
||||
assert is_optional(Union[int, None]) == (True, int)
|
||||
assert is_optional(Union[None, int]) == (True, int)
|
||||
assert is_optional(int | None) == (True, int)
|
||||
assert is_optional(None | int) == (True, int)
|
||||
assert is_optional(int) == (False, int)
|
||||
assert is_optional(Query[int]) == (False, Query[int])
|
||||
|
||||
|
||||
def test_extract_signifier():
|
||||
|
||||
assert extract_signifier(int) == (None, int)
|
||||
assert extract_signifier(Query[int]) == (QueryType, int)
|
||||
assert extract_signifier(Query[Optional[int]]) == ( # type:ignore
|
||||
QueryType,
|
||||
Optional[int],
|
||||
)
|
||||
assert extract_signifier(Query[int | None]) == ( # type:ignore
|
||||
QueryType,
|
||||
Optional[int],
|
||||
)
|
||||
assert extract_signifier(Optional[Query[int]]) == (QueryType, Optional[int])
|
||||
|
||||
|
||||
def test_extract_output_type():
|
||||
|
||||
assert extract_output_type(int) == int
|
||||
assert extract_output_type(ApiResponse[int]) == int
|
||||
assert extract_output_type(ApiResponse[int | str]) == int | str
|
||||
|
||||
|
||||
def test_acceptable_input():
|
||||
|
||||
assert acceptable_input(str) is True
|
||||
assert acceptable_input(int) is True
|
||||
assert acceptable_input(Query[int]) is True
|
||||
assert acceptable_input(Optional[int]) is True
|
||||
assert acceptable_input(int | None) is True
|
||||
assert acceptable_input(int | str | None) is True
|
||||
assert acceptable_input(Query[int | None]) is True # type: ignore
|
||||
assert acceptable_input(File) is True
|
||||
assert acceptable_input(list[str]) is True
|
||||
assert acceptable_input(dict[str, int]) is True
|
||||
assert acceptable_input(Literal["a", "b"]) is True
|
||||
assert acceptable_input(frozenset) is False
|
||||
assert acceptable_input(dict[str, frozenset]) is False
|
|
@ -0,0 +1,244 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
from django.core import files
|
||||
from django.core.files.uploadedfile import SimpleUploadedFile
|
||||
from django.http import QueryDict
|
||||
from django.test import RequestFactory
|
||||
from django.test.client import MULTIPART_CONTENT
|
||||
from pydantic import BaseModel
|
||||
|
||||
from hatchway import ApiError, Body, QueryOrBody, api_view
|
||||
from hatchway.view import ApiView
|
||||
|
||||
|
||||
def test_basic_view():
|
||||
"""
|
||||
Tests that a view with simple types works correctly
|
||||
"""
|
||||
|
||||
@api_view
|
||||
def test_view(
|
||||
request,
|
||||
a: int,
|
||||
b: QueryOrBody[int | None] = None,
|
||||
c: str = "x",
|
||||
) -> str:
|
||||
if b is None:
|
||||
return c * a
|
||||
else:
|
||||
return c * (a - b)
|
||||
|
||||
# Call it with a few different patterns to verify it's type coercing right
|
||||
factory = RequestFactory()
|
||||
|
||||
# Implicit query param
|
||||
response = test_view(factory.get("/test/?a=4"))
|
||||
assert json.loads(response.content) == "xxxx"
|
||||
|
||||
# QueryOrBody pulling from query
|
||||
response = test_view(factory.get("/test/?a=4&b=2"))
|
||||
assert json.loads(response.content) == "xx"
|
||||
|
||||
# QueryOrBody pulling from formdata body
|
||||
response = test_view(factory.post("/test/?a=4", {"b": "3"}))
|
||||
assert json.loads(response.content) == "x"
|
||||
|
||||
# QueryOrBody pulling from JSON body
|
||||
response = test_view(
|
||||
factory.post(
|
||||
"/test/?a=4", json.dumps({"b": 3}), content_type="application/json"
|
||||
)
|
||||
)
|
||||
assert json.loads(response.content) == "x"
|
||||
|
||||
# Implicit Query not pulling from body
|
||||
with pytest.raises(TypeError):
|
||||
test_view(factory.post("/test/", {"a": 4, "b": 3}))
|
||||
|
||||
|
||||
def test_body_direct():
|
||||
"""
|
||||
Tests that a Pydantic model with BodyDirect gets its fields from the top level
|
||||
"""
|
||||
|
||||
class TestModel(BaseModel):
|
||||
number: int
|
||||
name: str
|
||||
|
||||
@api_view
|
||||
def test_view(request, data: TestModel) -> int:
|
||||
return data.number
|
||||
|
||||
factory = RequestFactory()
|
||||
|
||||
# formdata version
|
||||
response = test_view(factory.post("/test/", {"number": "123", "name": "Andrew"}))
|
||||
assert json.loads(response.content) == 123
|
||||
|
||||
# JSON body version
|
||||
response = test_view(
|
||||
factory.post(
|
||||
"/test/",
|
||||
json.dumps({"number": "123", "name": "Andrew"}),
|
||||
content_type="application/json",
|
||||
)
|
||||
)
|
||||
assert json.loads(response.content) == 123
|
||||
|
||||
|
||||
def test_list_response():
|
||||
"""
|
||||
Tests that a view with a list response type works correctly with both
|
||||
dicts and pydantic model instances.
|
||||
"""
|
||||
|
||||
class TestModel(BaseModel):
|
||||
number: int
|
||||
name: str
|
||||
|
||||
@api_view
|
||||
def test_view_dict(request) -> list[TestModel]:
|
||||
return [
|
||||
{"name": "Andrew", "number": 1}, # type:ignore
|
||||
{"name": "Alice", "number": 0}, # type:ignore
|
||||
]
|
||||
|
||||
@api_view
|
||||
def test_view_model(request) -> list[TestModel]:
|
||||
return [TestModel(name="Andrew", number=1), TestModel(name="Alice", number=0)]
|
||||
|
||||
response = test_view_dict(RequestFactory().get("/test/"))
|
||||
assert json.loads(response.content) == [
|
||||
{"name": "Andrew", "number": 1},
|
||||
{"name": "Alice", "number": 0},
|
||||
]
|
||||
|
||||
response = test_view_model(RequestFactory().get("/test/"))
|
||||
assert json.loads(response.content) == [
|
||||
{"name": "Andrew", "number": 1},
|
||||
{"name": "Alice", "number": 0},
|
||||
]
|
||||
|
||||
|
||||
def test_patch_body():
|
||||
"""
|
||||
Tests that PATCH also gets its body parsed
|
||||
"""
|
||||
|
||||
@api_view.patch
|
||||
def test_view(request, a: Body[int]):
|
||||
return a
|
||||
|
||||
factory = RequestFactory()
|
||||
response = test_view(
|
||||
factory.patch(
|
||||
"/test/",
|
||||
content_type=MULTIPART_CONTENT,
|
||||
data=factory._encode_data({"a": "42"}, MULTIPART_CONTENT),
|
||||
)
|
||||
)
|
||||
assert json.loads(response.content) == 42
|
||||
|
||||
|
||||
def test_file_body():
|
||||
"""
|
||||
Tests that file uploads work right
|
||||
"""
|
||||
|
||||
@api_view.post
|
||||
def test_view(request, a: Body[int], b: files.File) -> str:
|
||||
return str(a) + b.read().decode("ascii")
|
||||
|
||||
factory = RequestFactory()
|
||||
uploaded_file = SimpleUploadedFile(
|
||||
"file.txt",
|
||||
b"MY FILE IS AMAZING",
|
||||
content_type="text/plain",
|
||||
)
|
||||
response = test_view(
|
||||
factory.post(
|
||||
"/test/",
|
||||
data={"a": 42, "b": uploaded_file},
|
||||
)
|
||||
)
|
||||
assert json.loads(response.content) == "42MY FILE IS AMAZING"
|
||||
|
||||
|
||||
def test_no_response():
|
||||
"""
|
||||
Tests that a view with no response type returns the contents verbatim
|
||||
"""
|
||||
|
||||
@api_view
|
||||
def test_view(request):
|
||||
return [1, "woooooo"]
|
||||
|
||||
response = test_view(RequestFactory().get("/test/"))
|
||||
assert json.loads(response.content) == [1, "woooooo"]
|
||||
|
||||
|
||||
def test_wrong_method():
|
||||
"""
|
||||
Tests that a view with a method limiter works
|
||||
"""
|
||||
|
||||
@api_view.get
|
||||
def test_view(request):
|
||||
return "yay"
|
||||
|
||||
response = test_view(RequestFactory().get("/test/"))
|
||||
assert json.loads(response.content) == "yay"
|
||||
|
||||
response = test_view(RequestFactory().post("/test/"))
|
||||
assert response.status_code == 405
|
||||
|
||||
|
||||
def test_api_error():
|
||||
"""
|
||||
Tests that ApiError propagates right
|
||||
"""
|
||||
|
||||
@api_view.get
|
||||
def test_view(request):
|
||||
raise ApiError(401, "you did a bad thing")
|
||||
|
||||
response = test_view(RequestFactory().get("/test/"))
|
||||
assert json.loads(response.content) == {"error": "you did a bad thing"}
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
def test_unusable_type():
|
||||
"""
|
||||
Tests that you get a nice error when you use a type on an input that
|
||||
Pydantic doesn't understand.
|
||||
"""
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
|
||||
@api_view.get
|
||||
def test_view(request, a: RequestFactory):
|
||||
pass
|
||||
|
||||
|
||||
def test_get_values():
|
||||
"""
|
||||
Tests that ApiView.get_values correctly handles lists
|
||||
"""
|
||||
|
||||
assert ApiView.get_values({"a": 2, "b": [3, 4]}) == {"a": 2, "b": [3, 4]}
|
||||
assert ApiView.get_values({"a": 2, "b[]": [3, 4]}) == {"a": 2, "b": [3, 4]}
|
||||
assert ApiView.get_values(QueryDict("a=2&b=3&b=4")) == {"a": "2", "b": ["3", "4"]}
|
||||
assert ApiView.get_values(QueryDict("a=2&b[]=3&b[]=4")) == {
|
||||
"a": "2",
|
||||
"b": ["3", "4"],
|
||||
}
|
||||
assert ApiView.get_values(QueryDict("a=2&b=3")) == {"a": "2", "b": "3"}
|
||||
assert ApiView.get_values(QueryDict("a=2&b[]=3")) == {"a": "2", "b": ["3"]}
|
||||
assert ApiView.get_values(QueryDict("a[b]=1")) == {"a": {"b": "1"}}
|
||||
assert ApiView.get_values(QueryDict("a[b]=1&a[c]=2")) == {"a": {"b": "1", "c": "2"}}
|
||||
assert ApiView.get_values(QueryDict("a[b][c]=1")) == {"a": {"b": {"c": "1"}}}
|
||||
assert ApiView.get_values(QueryDict("a[b][c][]=1")) == {"a": {"b": {"c": ["1"]}}}
|
||||
assert ApiView.get_values(QueryDict("a[b][]=1&a[b][]=2")) == {
|
||||
"a": {"b": ["1", "2"]}
|
||||
}
|
|
@ -0,0 +1,145 @@
|
|||
from types import NoneType, UnionType
|
||||
from typing import ( # type: ignore[attr-defined]
|
||||
Annotated,
|
||||
Any,
|
||||
Literal,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
_AnnotatedAlias,
|
||||
_GenericAlias,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
from django.core import files
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .http import ApiResponse
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class PathType:
|
||||
"""
|
||||
An input pulled from the path (url resolver kwargs)
|
||||
"""
|
||||
|
||||
|
||||
class QueryType:
|
||||
"""
|
||||
An input pulled from the query parameters (request.GET)
|
||||
"""
|
||||
|
||||
|
||||
class BodyType:
|
||||
"""
|
||||
An input pulled from the POST body (request.POST or a JSON body)
|
||||
"""
|
||||
|
||||
|
||||
class FileType:
|
||||
"""
|
||||
An input pulled from the POST body (request.POST or a JSON body)
|
||||
"""
|
||||
|
||||
|
||||
class BodyDirectType:
|
||||
"""
|
||||
A Pydantic model whose keys are all looked for in the top-level
|
||||
POST data, rather than in a dict under a key named after the input.
|
||||
"""
|
||||
|
||||
|
||||
class QueryOrBodyType:
|
||||
"""
|
||||
An input pulled from either query parameters or post data.
|
||||
"""
|
||||
|
||||
|
||||
Path = Annotated[T, PathType]
|
||||
Query = Annotated[T, QueryType]
|
||||
Body = Annotated[T, BodyType]
|
||||
File = Annotated[T, FileType]
|
||||
BodyDirect = Annotated[T, BodyDirectType]
|
||||
QueryOrBody = Annotated[T, QueryOrBodyType]
|
||||
|
||||
|
||||
def is_optional(annotation) -> tuple[bool, Any]:
|
||||
"""
|
||||
If an annotation is Optional or | None, returns (True, internal type).
|
||||
Returns (False, annotation) otherwise.
|
||||
"""
|
||||
if (isinstance(annotation, _GenericAlias) and annotation.__origin__ is Union) or (
|
||||
isinstance(annotation, UnionType)
|
||||
):
|
||||
args = get_args(annotation)
|
||||
if len(args) > 2:
|
||||
return False, annotation
|
||||
if args[0] is NoneType:
|
||||
return True, args[1]
|
||||
if args[1] is NoneType:
|
||||
return True, args[0]
|
||||
return False, annotation
|
||||
return False, annotation
|
||||
|
||||
|
||||
def extract_signifier(annotation) -> tuple[Any, Any]:
|
||||
"""
|
||||
Given a type annotation, looks to see if it can find a input source
|
||||
signifier (Path, Query, etc.)
|
||||
|
||||
If it can, returns (signifier, annotation_without_signifier)
|
||||
If not, returns (None, annotation)
|
||||
"""
|
||||
our_generics = {
|
||||
PathType,
|
||||
QueryType,
|
||||
BodyType,
|
||||
FileType,
|
||||
BodyDirectType,
|
||||
QueryOrBodyType,
|
||||
}
|
||||
# Remove any optional-style wrapper
|
||||
optional, internal_annotation = is_optional(annotation)
|
||||
# Is it an annotation?
|
||||
if isinstance(internal_annotation, _AnnotatedAlias):
|
||||
args = get_args(internal_annotation)
|
||||
for arg in args[1:]:
|
||||
if arg in our_generics:
|
||||
if optional:
|
||||
return (arg, Optional[args[0]])
|
||||
else:
|
||||
return (arg, args[0])
|
||||
return None, annotation
|
||||
|
||||
|
||||
def extract_output_type(annotation):
|
||||
"""
|
||||
Returns the right response type for a function
|
||||
"""
|
||||
# If the type is ApiResponse, we want to pull out its inside
|
||||
if isinstance(annotation, _GenericAlias):
|
||||
if get_origin(annotation) == ApiResponse:
|
||||
return get_args(annotation)[0]
|
||||
return annotation
|
||||
|
||||
|
||||
def acceptable_input(annotation) -> bool:
|
||||
"""
|
||||
Returns if this annotation is something we think we can accept as input
|
||||
"""
|
||||
_, inner_type = extract_signifier(annotation)
|
||||
try:
|
||||
if issubclass(inner_type, BaseModel):
|
||||
return True
|
||||
except TypeError:
|
||||
pass
|
||||
if inner_type in [str, int, list, tuple, bool, Any, files.File, type(None)]:
|
||||
return True
|
||||
origin = get_origin(inner_type)
|
||||
if origin == Literal:
|
||||
return True
|
||||
if origin in [Union, UnionType, dict, list, tuple]:
|
||||
return all(acceptable_input(a) for a in get_args(inner_type))
|
||||
return False
|
|
@ -0,0 +1,32 @@
|
|||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from django.http import HttpResponseNotAllowed
|
||||
|
||||
|
||||
class Methods:
|
||||
"""
|
||||
Allows easy multi-method dispatch to different functions
|
||||
"""
|
||||
|
||||
csrf_exempt = True
|
||||
|
||||
def __init__(self, **callables: Callable):
|
||||
self.callables = {
|
||||
method.lower(): callable for method, callable in callables.items()
|
||||
}
|
||||
unknown_methods = set(self.callables.keys()).difference(
|
||||
{"get", "post", "patch", "put", "delete"}
|
||||
)
|
||||
if unknown_methods:
|
||||
raise ValueError(f"Cannot route methods: {unknown_methods}")
|
||||
|
||||
def __call__(self, request, *args, **kwargs) -> Any:
|
||||
method = request.method.lower()
|
||||
if method in self.callables:
|
||||
return self.callables[method](request, *args, **kwargs)
|
||||
else:
|
||||
return HttpResponseNotAllowed(self.callables.keys())
|
||||
|
||||
|
||||
methods = Methods
|
|
@ -0,0 +1,297 @@
|
|||
import json
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional, get_type_hints
|
||||
|
||||
from django.core import files
|
||||
from django.http import HttpRequest, HttpResponseNotAllowed, QueryDict
|
||||
from django.http.multipartparser import MultiPartParser
|
||||
from pydantic import BaseModel, create_model
|
||||
|
||||
from .constants import InputSource
|
||||
from .http import ApiError, ApiResponse
|
||||
from .types import (
|
||||
BodyDirectType,
|
||||
BodyType,
|
||||
FileType,
|
||||
PathType,
|
||||
QueryOrBodyType,
|
||||
QueryType,
|
||||
acceptable_input,
|
||||
extract_output_type,
|
||||
extract_signifier,
|
||||
is_optional,
|
||||
)
|
||||
|
||||
|
||||
class ApiView:
|
||||
"""
|
||||
A view 'wrapper' object that replaces the API view for anything further
|
||||
up the stack.
|
||||
|
||||
Unlike Django's class-based views, we don't need an as_view pattern
|
||||
as we are careful never to write anything per-request to self.
|
||||
"""
|
||||
|
||||
csrf_exempt = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
view: Callable,
|
||||
input_types: dict[str, Any] | None = None,
|
||||
output_type: Any = None,
|
||||
implicit_lists: bool = True,
|
||||
method: str | None = None,
|
||||
):
|
||||
self.view = view
|
||||
self.implicit_lists = implicit_lists
|
||||
self.view_name = getattr(view, "__name__", "unknown_view")
|
||||
self.method = method
|
||||
# Extract input/output types from view annotations if we need to
|
||||
self.input_types = input_types
|
||||
if self.input_types is None:
|
||||
self.input_types = get_type_hints(view, include_extras=True)
|
||||
if "return" in self.input_types:
|
||||
del self.input_types["return"]
|
||||
self.output_type = output_type
|
||||
if self.output_type is None:
|
||||
try:
|
||||
self.output_type = extract_output_type(
|
||||
get_type_hints(view, include_extras=True)["return"]
|
||||
)
|
||||
except KeyError:
|
||||
self.output_type = None
|
||||
self.compile()
|
||||
|
||||
@classmethod
|
||||
def get(cls, view: Callable):
|
||||
return cls(view=view, method="get")
|
||||
|
||||
@classmethod
|
||||
def post(cls, view: Callable):
|
||||
return cls(view=view, method="post")
|
||||
|
||||
@classmethod
|
||||
def put(cls, view: Callable):
|
||||
return cls(view=view, method="put")
|
||||
|
||||
@classmethod
|
||||
def patch(cls, view: Callable):
|
||||
return cls(view=view, method="patch")
|
||||
|
||||
@classmethod
|
||||
def delete(cls, view: Callable):
|
||||
return cls(view=view, method="delete")
|
||||
|
||||
@classmethod
|
||||
def sources_for_input(cls, input_type) -> tuple[list[InputSource], Any]:
|
||||
"""
|
||||
Given a type that can appear as a request parameter type, returns
|
||||
what sources it can come from, and what its type is as understood
|
||||
by Pydantic.
|
||||
"""
|
||||
signifier, input_type = extract_signifier(input_type)
|
||||
if signifier is QueryType:
|
||||
return ([InputSource.query], input_type)
|
||||
elif signifier is BodyType:
|
||||
return ([InputSource.body], input_type)
|
||||
elif signifier is BodyDirectType:
|
||||
if not issubclass(input_type, BaseModel):
|
||||
raise ValueError(
|
||||
"You cannot use BodyDirect on something that is not a Pydantic model"
|
||||
)
|
||||
return ([InputSource.body_direct], input_type)
|
||||
elif signifier is PathType:
|
||||
return ([InputSource.path], input_type)
|
||||
elif (
|
||||
signifier is FileType
|
||||
or input_type is files.File
|
||||
or is_optional(input_type)[1] is files.File
|
||||
):
|
||||
return ([InputSource.file], input_type)
|
||||
elif signifier is QueryOrBodyType:
|
||||
return ([InputSource.query, InputSource.body], input_type)
|
||||
# Is it a Pydantic model, which means it's implicitly body?
|
||||
elif isinstance(input_type, type) and issubclass(input_type, BaseModel):
|
||||
return ([InputSource.body], input_type)
|
||||
# Otherwise, we look in the path first and then the query
|
||||
else:
|
||||
return ([InputSource.path, InputSource.query], input_type)
|
||||
|
||||
@classmethod
|
||||
def get_values(cls, data, use_square_brackets=True) -> dict[str, Any]:
|
||||
"""
|
||||
Given a QueryDict or normal dict, returns data taking into account
|
||||
lists made by repeated values or by suffixing names with [].
|
||||
"""
|
||||
result: dict[str, Any] = {}
|
||||
for key, value in data.items():
|
||||
# If it's a query dict with multiple values, make it a list
|
||||
if isinstance(data, QueryDict):
|
||||
values = data.getlist(key)
|
||||
if len(values) > 1:
|
||||
value = values
|
||||
# If it is in dict-ish/list-ish syntax, adhere to that
|
||||
# TODO: Make this better handle badly formed keys
|
||||
if "[" in key and use_square_brackets:
|
||||
parts = key.split("[")
|
||||
target = result
|
||||
last_key = parts[0]
|
||||
for part in parts[1:]:
|
||||
part = part.rstrip("]")
|
||||
if not part:
|
||||
target = target.setdefault(last_key, [])
|
||||
else:
|
||||
target = target.setdefault(last_key, {})
|
||||
last_key = part
|
||||
if isinstance(target, list):
|
||||
if isinstance(value, list):
|
||||
target.extend(value)
|
||||
else:
|
||||
target.append(value)
|
||||
else:
|
||||
target[last_key] = value
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
def compile(self):
|
||||
self.sources: dict[str, list[InputSource]] = {}
|
||||
amount_from_body = 0
|
||||
pydantic_model_dict = {}
|
||||
self.input_files = set()
|
||||
last_body_type = None
|
||||
# For each input item, work out where to pull it from
|
||||
for name, input_type in self.input_types.items():
|
||||
# Do some basic typechecking to stop things that aren't allowed
|
||||
if isinstance(input_type, type) and issubclass(input_type, HttpRequest):
|
||||
continue
|
||||
if not acceptable_input(input_type):
|
||||
# Strip away any singifiers for the error
|
||||
_, inner_type = extract_signifier(input_type)
|
||||
raise ValueError(
|
||||
f"Input argument {name} has an unsupported type {inner_type}"
|
||||
)
|
||||
sources, pydantic_type = self.sources_for_input(input_type)
|
||||
self.sources[name] = sources
|
||||
# Keep count of how many are pulling from the body
|
||||
if InputSource.body in sources:
|
||||
amount_from_body += 1
|
||||
last_body_type = pydantic_type
|
||||
if InputSource.file in sources:
|
||||
self.input_files.add(name)
|
||||
else:
|
||||
pydantic_model_dict[name] = (Optional[pydantic_type], ...)
|
||||
# If there is just one thing pulling from the body and it's a BaseModel,
|
||||
# signify that it's actually pulling from the body keys directly and
|
||||
# not a sub-dict
|
||||
if amount_from_body == 1:
|
||||
for name, sources in self.sources.items():
|
||||
if (
|
||||
InputSource.body in sources
|
||||
and isinstance(last_body_type, type)
|
||||
and issubclass(last_body_type, BaseModel)
|
||||
):
|
||||
self.sources[name] = [
|
||||
x for x in sources if x != InputSource.body
|
||||
] + [InputSource.body_direct]
|
||||
# Turn all the main arguments into Pydantic parsing models
|
||||
try:
|
||||
self.input_model = create_model(
|
||||
f"{self.view_name}_input", **pydantic_model_dict
|
||||
)
|
||||
except RuntimeError:
|
||||
raise ValueError(
|
||||
f"One or more inputs on view {self.view_name} have a bad configuration"
|
||||
)
|
||||
if self.output_type is not None:
|
||||
self.output_model = create_model(
|
||||
f"{self.view_name}_output", value=(self.output_type, ...)
|
||||
)
|
||||
|
||||
def __call__(self, request: HttpRequest, *args, **kwargs):
|
||||
"""
|
||||
Entrypoint when this is called as a view.
|
||||
"""
|
||||
# Do a method check if we have one set
|
||||
if self.method and self.method.upper() != request.method:
|
||||
return HttpResponseNotAllowed([self.method])
|
||||
# For each item we can source, go find it if we can
|
||||
query_values = self.get_values(request.GET)
|
||||
body_values = self.get_values(request.POST)
|
||||
files_values = self.get_values(request.FILES)
|
||||
# If it's a PUT or PATCH method, work around Django not handling FILES
|
||||
# or POST on those requests
|
||||
if request.method in ["PATCH", "PUT"]:
|
||||
if request.content_type == "multipart/form-data":
|
||||
POST, FILES = MultiPartParser(
|
||||
request.META, request, request.upload_handlers, request.encoding
|
||||
).parse()
|
||||
body_values = self.get_values(POST)
|
||||
files_values = self.get_values(FILES)
|
||||
elif request.content_type == "application/x-www-form-urlencoded":
|
||||
POST = QueryDict(request.body, encoding=request._encoding)
|
||||
body_values = self.get_values(POST)
|
||||
# If there was a JSON body, go load that
|
||||
if request.content_type == "application/json" and request.body.strip():
|
||||
body_values.update(self.get_values(json.loads(request.body)))
|
||||
values = {}
|
||||
for name, sources in self.sources.items():
|
||||
for source in sources:
|
||||
if source == InputSource.path:
|
||||
if name in kwargs:
|
||||
values[name] = kwargs[name]
|
||||
break
|
||||
elif source == InputSource.query:
|
||||
if name in query_values:
|
||||
values[name] = query_values[name]
|
||||
break
|
||||
elif source == InputSource.body:
|
||||
if name in body_values:
|
||||
values[name] = body_values[name]
|
||||
break
|
||||
elif source == InputSource.file:
|
||||
if name in files_values:
|
||||
values[name] = files_values[name]
|
||||
break
|
||||
elif source == InputSource.body_direct:
|
||||
values[name] = body_values
|
||||
break
|
||||
elif source == InputSource.query_and_body_direct:
|
||||
values[name] = dict(query_values)
|
||||
values[name].update(body_values)
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"Unknown source {source}")
|
||||
else:
|
||||
values[name] = None
|
||||
# Give that to the Pydantic model to make it handle stuff
|
||||
model_instance = self.input_model(**values)
|
||||
kwargs = {
|
||||
name: getattr(model_instance, name)
|
||||
for name in model_instance.__fields__
|
||||
if values[name] is not None # Trim out missing fields
|
||||
}
|
||||
# Add in any files
|
||||
# TODO: HTTP error if file is not optional
|
||||
for name in self.input_files:
|
||||
kwargs[name] = files_values.get(name, None)
|
||||
# Call the view with those as kwargs
|
||||
try:
|
||||
response = self.view(request, **kwargs)
|
||||
except ApiError as error:
|
||||
return ApiResponse(
|
||||
{"error": error.error}, status=error.status, finalize=True
|
||||
)
|
||||
# If it's not an ApiResponse, make it one
|
||||
if not isinstance(response, ApiResponse):
|
||||
response = ApiResponse(response)
|
||||
# Get pydantic to coerce the output response
|
||||
if self.output_type is not None:
|
||||
response.data = self.output_model(value=response.data).dict()["value"]
|
||||
elif isinstance(response.data, BaseModel):
|
||||
response.data = response.data.dict()
|
||||
response.finalize()
|
||||
return response
|
||||
|
||||
|
||||
api_view = ApiView
|
|
@ -6,7 +6,6 @@ django-cache-url~=3.4.2
|
|||
django-cors-headers~=3.13.0
|
||||
django-debug-toolbar~=3.8.1
|
||||
django-htmx~=1.13.0
|
||||
django-ninja~=0.19.1
|
||||
django-oauth-toolkit~=2.2.0
|
||||
django-storages[google,boto3]~=1.13.1
|
||||
django~=4.1
|
||||
|
|
|
@ -3,7 +3,7 @@ from django.contrib import admin as djadmin
|
|||
from django.urls import include, path, re_path
|
||||
|
||||
from activities.views import compose, debug, explore, follows, posts, search, timelines
|
||||
from api.views import api_router, oauth
|
||||
from api.views import oauth
|
||||
from core import views as core
|
||||
from mediaproxy import views as mediaproxy
|
||||
from stator import views as stator
|
||||
|
@ -296,7 +296,7 @@ urlpatterns = [
|
|||
path("actor/outbox/", activitypub.EmptyOutbox.as_view()),
|
||||
path("inbox/", activitypub.Inbox.as_view(), name="shared_inbox"),
|
||||
# API/Oauth
|
||||
path("api/", api_router.urls),
|
||||
path("api/", include("api.urls")),
|
||||
path("oauth/authorize", oauth.AuthorizationView.as_view()),
|
||||
path("oauth/token", oauth.TokenView.as_view()),
|
||||
path("oauth/revoke", oauth.RevokeTokenView.as_view()),
|
||||
|
|
|
@ -281,7 +281,9 @@ def test_inbound_posts(
|
|||
InboxMessage.objects.create(message=message)
|
||||
|
||||
# Run stator and ensure that made the post
|
||||
print("prestat")
|
||||
stator.run_single_cycle_sync()
|
||||
print("poststat")
|
||||
post = Post.objects.get(object_uri="https://remote.test/test-post")
|
||||
assert post.content == "post version one"
|
||||
assert post.published.day == 13
|
||||
|
|
Loading…
Reference in New Issue