Hatchway API Rewrite (#499)

Removes django-ninja and replaces it with a new API framework, "hatchway".

I plan to move hatchway into its own project very soon.
This commit is contained in:
Andrew Godwin 2023-02-07 12:07:15 -07:00 committed by GitHub
parent 99a0a16fb7
commit 5d2ed9edfe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 1471 additions and 379 deletions

View File

@ -188,7 +188,7 @@ class PostInteraction(StatorModel):
return result
@classmethod
def get_event_interactions(cls, events, identity):
def get_event_interactions(cls, events, identity) -> dict[str, set[str]]:
"""
Returns a dict of {interaction_type: set(post_ids)} for all the posts
within the events and the given identity, for use in templates.

View File

@ -16,4 +16,7 @@ def identity_required(function):
return JsonResponse({"error": "identity_token_required"}, status=400)
return function(request, *args, **kwargs)
# This is for the API only
inner.csrf_exempt = True
return inner

View File

@ -1,22 +1,98 @@
import dataclasses
import urllib.parse
from collections.abc import Callable
from typing import Any
from typing import Any, Generic, Protocol, TypeVar
from django.db import models
from django.http import HttpRequest
from activities.models import PostInteraction
from hatchway.http import ApiResponse
T = TypeVar("T")
class SchemaWithId(Protocol):
"""
Little protocol type to represent schemas that have an ID attribute
"""
id: str
TI = TypeVar("TI", bound=SchemaWithId)
TM = TypeVar("TM", bound=models.Model)
class PaginatingApiResponse(ApiResponse[list[TI]]):
"""
An ApiResponse subclass that also handles pagination link headers
"""
def __init__(
self,
data: list[TI],
request: HttpRequest,
include_params: list[str],
**kwargs,
):
# Call superclass
super().__init__(data, **kwargs)
# Figure out if we need link headers
self._request = request
self.extra_params = self.filter_params(self._request, include_params)
link_header = self.build_link_header()
if link_header:
self.headers["link"] = link_header
@staticmethod
def filter_params(request: HttpRequest, allowed_params: list[str]):
params = {}
for key in allowed_params:
value = request.GET.get(key, None)
if value:
params[key] = value
return params
def get_part(self, data_index: int, param_name: str, rel: str) -> str | None:
"""
Used to get next/prev URLs
"""
if not self.data:
return None
# Use the ID of the last object for the next page start
params = dict(self.extra_params)
params[param_name] = self.data[data_index].id
return (
"<"
+ self._request.build_absolute_uri(self._request.path)
+ "?"
+ urllib.parse.urlencode(params)
+ f'>; rel="{rel}"'
)
def build_link_header(self):
parts = [
entry
for entry in [
self.get_part(0, "min_id", "prev"),
self.get_part(-1, "max_id", "next"),
]
if entry
]
if not parts:
return None
return ", ".join(parts)
@dataclasses.dataclass
class PaginationResult:
class PaginationResult(Generic[T]):
"""
Represents a pagination result for Mastodon (it does Link header stuff)
"""
#: A list of objects that matched the pagination query.
results: list[models.Model]
results: list[T]
#: The actual applied limit, which may be different from what was requested.
limit: int
@ -130,12 +206,12 @@ class MastodonPaginator:
def paginate(
self,
queryset,
queryset: models.QuerySet[TM],
min_id: str | None,
max_id: str | None,
since_id: str | None,
limit: int | None,
) -> PaginationResult:
) -> PaginationResult[TM]:
# These "does not start with interaction" checks can be removed after a
# couple months, when clients have flushed them out.
if max_id and not max_id.startswith("interaction"):

View File

@ -1,22 +0,0 @@
import json
from ninja.parser import Parser
class FormOrJsonParser(Parser):
"""
If there's form data in a request, makes it into a JSON dict.
This is needed as the Mastodon API allows form data OR json body as input.
"""
def parse_body(self, request):
# Did they submit JSON?
if request.content_type == "application/json" and request.body.strip():
return json.loads(request.body)
# Fall back to form data
value = {}
for key, item in request.POST.items():
value[key] = item
for key, item in request.GET.items():
value[key] = item
return value

View File

@ -1,6 +1,9 @@
from typing import Literal, Optional, Union
from ninja import Field, Schema
from activities import models as activities_models
from hatchway import Field, Schema
from users import models as users_models
from users.services import IdentityService
class Application(Schema):
@ -19,6 +22,10 @@ class CustomEmoji(Schema):
visible_in_picker: bool
category: str
@classmethod
def from_emoji(cls, emoji: activities_models.Emoji) -> "CustomEmoji":
return cls(**emoji.to_mastodon_json())
class AccountField(Schema):
name: str
@ -53,6 +60,17 @@ class Account(Schema):
following_count: int
source: dict | None
@classmethod
def from_identity(
cls,
identity: users_models.Identity,
include_counts: bool = True,
source=False,
) -> "Account":
return cls(
**identity.to_mastodon_json(include_counts=include_counts, source=source)
)
class MediaAttachment(Schema):
id: str
@ -64,6 +82,12 @@ class MediaAttachment(Schema):
description: str | None
blurhash: str | None
@classmethod
def from_post_attachment(
cls, attachment: activities_models.PostAttachment
) -> "MediaAttachment":
return cls(**attachment.to_mastodon_json())
class StatusMention(Schema):
id: str
@ -108,6 +132,47 @@ class Status(Schema):
bookmarked: bool | None
pinned: bool | None
@classmethod
def from_post(
cls,
post: activities_models.Post,
interactions: dict[str, set[str]] | None = None,
) -> "Status":
return cls(**post.to_mastodon_json(interactions=interactions))
@classmethod
def map_from_post(
cls,
posts: list[activities_models.Post],
identity: users_models.Identity,
) -> list["Status"]:
interactions = activities_models.PostInteraction.get_post_interactions(
posts, identity
)
return [cls.from_post(post, interactions=interactions) for post in posts]
@classmethod
def from_timeline_event(
cls,
timeline_event: activities_models.TimelineEvent,
interactions: dict[str, set[str]] | None = None,
) -> "Status":
return cls(**timeline_event.to_mastodon_status_json(interactions=interactions))
@classmethod
def map_from_timeline_event(
cls,
events: list[activities_models.TimelineEvent],
identity: users_models.Identity,
) -> list["Status"]:
interactions = activities_models.PostInteraction.get_event_interactions(
events, identity
)
return [
cls.from_timeline_event(event, interactions=interactions)
for event in events
]
class Conversation(Schema):
id: str
@ -134,12 +199,26 @@ class Notification(Schema):
account: Account
status: Status | None
@classmethod
def from_timeline_event(
cls,
event: activities_models.TimelineEvent,
) -> "Notification":
return cls(**event.to_mastodon_notification_json())
class Tag(Schema):
name: str
url: str
history: dict
@classmethod
def from_hashtag(
cls,
hashtag: activities_models.Hashtag,
) -> "Tag":
return cls(**hashtag.to_mastodon_json())
class Search(Schema):
accounts: list[Account]
@ -162,6 +241,16 @@ class Relationship(Schema):
endorsed: bool
note: str
@classmethod
def from_identity_pair(
cls,
identity: users_models.Identity,
from_identity: users_models.Identity,
) -> "Relationship":
return cls(
**IdentityService(identity).mastodon_json_relationship(from_identity)
)
class Context(Schema):
ancestors: list[Status]
@ -187,3 +276,11 @@ class Announcement(Schema):
tags: list[Tag]
emojis: list[CustomEmoji]
reactions: list
@classmethod
def from_announcement(
cls,
announcement: users_models.Announcement,
user: users_models.User,
) -> "Announcement":
return cls(**announcement.to_mastodon_json(user=user))

72
api/urls.py Normal file
View File

@ -0,0 +1,72 @@
from django.urls import path
from api.views import (
accounts,
announcements,
apps,
emoji,
filters,
instance,
media,
notifications,
search,
statuses,
timelines,
)
from hatchway import methods
urlpatterns = [
# Accounts
path("v1/accounts/verify_credentials", accounts.verify_credentials),
path("v1/accounts/update_credentials", accounts.update_credentials),
path("v1/accounts/relationships", accounts.account_relationships),
path("v1/accounts/familiar_followers", accounts.familiar_followers),
path("v1/accounts/search", accounts.accounts_search),
path("v1/accounts/lookup", accounts.lookup),
path("v1/accounts/<id>", accounts.account),
path("v1/accounts/<id>/statuses", accounts.account_statuses),
path("v1/accounts/<id>/follow", accounts.account_follow),
path("v1/accounts/<id>/unfollow", accounts.account_unfollow),
path("v1/accounts/<id>/block", accounts.account_block),
path("v1/accounts/<id>/unblock", accounts.account_unblock),
path("v1/accounts/<id>/mute", accounts.account_mute),
path("v1/accounts/<id>/unmute", accounts.account_unmute),
path("v1/accounts/<id>/following", accounts.account_following),
path("v1/accounts/<id>/followers", accounts.account_followers),
# Announcements
path("v1/announcements", announcements.announcement_list),
path("v1/announcements/<pk>/dismiss", announcements.announcement_dismiss),
# Apps
path("v1/apps", apps.add_app),
# Filters
path("v2/filters", filters.list_filters),
path("v1/filters", filters.list_filters),
# Instance
path("v1/instance", instance.instance_info_v1),
path("v2/instance", instance.instance_info_v2),
path("v1/timelines/home", timelines.home),
path("v1/timelines/public", timelines.public),
path("v1/timelines/tag/<hashtag>", timelines.hashtag),
path("v1/timelines/conversations", timelines.conversations),
path("v1/favourites", timelines.favourites),
path("v1/notifications", notifications.notifications),
path("v1/statuses", statuses.post_status),
path("v1/media", media.upload_media),
path("v2/media", media.upload_media),
path("v1/media/<id>", methods(get=media.get_media, put=media.update_media)),
path(
"v1/statuses/<id>",
methods(
get=statuses.status,
delete=statuses.delete_status,
),
),
path("v1/statuses/<id>/context", statuses.status_context),
path("v1/statuses/<id>/favourite", statuses.favourite_status),
path("v1/statuses/<id>/unfavourite", statuses.unfavourite_status),
path("v1/statuses/<id>/favourited_by", statuses.favourited_by),
path("v1/statuses/<id>/reblog", statuses.reblog_status),
path("v1/statuses/<id>/unreblog", statuses.unreblog_status),
path("v1/custom_emojis", emoji.emojis),
path("v2/search", search.search),
]

View File

@ -1,12 +0,0 @@
from .accounts import * # noqa
from .announcements import * # noqa
from .apps import * # noqa
from .emoji import * # noqa
from .filters import * # noqa
from .instance import * # noqa
from .media import * # noqa
from .notifications import * # noqa
from .oauth import * # noqa
from .search import * # noqa
from .statuses import * # noqa
from .timelines import * # noqa

View File

@ -1,131 +1,129 @@
from django.http import HttpRequest, HttpResponse, QueryDict
from django.http.multipartparser import MultiPartParser
from typing import Any
from django.core.files import File
from django.http import HttpRequest
from django.shortcuts import get_object_or_404
from ninja import Field, Schema
from activities.models import Post
from activities.services import SearchService
from api import schemas
from api.decorators import identity_required
from api.pagination import MastodonPaginator
from api.views.base import api_router
from api.pagination import MastodonPaginator, PaginatingApiResponse, PaginationResult
from core.models import Config
from hatchway import ApiResponse, QueryOrBody, api_view
from users.models import Identity
from users.services import IdentityService
from users.shortcuts import by_handle_or_404
@api_router.get("/v1/accounts/verify_credentials", response=schemas.Account)
@identity_required
def verify_credentials(request):
return request.identity.to_mastodon_json(source=True)
@api_view.get
def verify_credentials(request) -> schemas.Account:
return schemas.Account.from_identity(request.identity, source=True)
@api_router.patch("/v1/accounts/update_credentials", response=schemas.Account)
@identity_required
@api_view.patch
def update_credentials(
request,
):
# Django won't load POST and FILES for patch methods, so we do it.
if request.content_type == "multipart/form-data":
POST, FILES = MultiPartParser(
request.META, request, request.upload_handlers, request.encoding
).parse()
elif request.content_type == "application/x-www-form-urlencoded":
POST = QueryDict(request.body, encoding=request._encoding)
FILES = {}
else:
return HttpResponse(status=400)
display_name: QueryOrBody[str | None] = None,
note: QueryOrBody[str | None] = None,
discoverable: QueryOrBody[bool | None] = None,
source: QueryOrBody[dict[str, Any] | None] = None,
fields_attributes: QueryOrBody[dict[str, dict[str, str]] | None] = None,
avatar: File | None = None,
header: File | None = None,
) -> schemas.Account:
identity = request.identity
service = IdentityService(identity)
if "display_name" in POST:
identity.name = POST["display_name"]
if "note" in POST:
service.set_summary(POST["note"])
if "discoverable" in POST:
identity.discoverable = POST["discoverable"] == "checked"
if "source[privacy]" in POST:
privacy_map = {
"public": Post.Visibilities.public,
"unlisted": Post.Visibilities.unlisted,
"private": Post.Visibilities.followers,
"direct": Post.Visibilities.mentioned,
}
Config.set_identity(
identity,
"default_post_visibility",
privacy_map[POST["source[privacy]"]],
)
if "fields_attributes[0][name]" in POST:
if display_name is not None:
identity.name = display_name
if note is not None:
service.set_summary(note)
if discoverable is not None:
identity.discoverable = discoverable
if source:
if "privacy" in source:
privacy_map = {
"public": Post.Visibilities.public,
"unlisted": Post.Visibilities.unlisted,
"private": Post.Visibilities.followers,
"direct": Post.Visibilities.mentioned,
}
Config.set_identity(
identity,
"default_post_visibility",
privacy_map[source["privacy"]],
)
if fields_attributes:
identity.metadata = []
for i in range(4):
name_name = f"fields_attributes[{i}][name]"
value_name = f"fields_attributes[{i}][value]"
if name_name and value_name in POST:
for attribute in fields_attributes.values():
attr_name = attribute.get("name", None)
attr_value = attribute.get("value", None)
if attr_name:
# Empty value means delete this item
if not POST[value_name]:
if not attr_value:
break
identity.metadata.append(
{"name": POST[name_name], "value": POST[value_name]}
)
if "avatar" in FILES:
service.set_icon(FILES["avatar"])
if "header" in FILES:
service.set_image(FILES["header"])
identity.metadata.append({"name": attr_name, "value": attr_value})
if avatar:
service.set_icon(avatar)
if header:
service.set_image(header)
identity.save()
return identity.to_mastodon_json(source=True)
return schemas.Account.from_identity(identity, source=True)
@api_router.get("/v1/accounts/relationships", response=list[schemas.Relationship])
@identity_required
def account_relationships(request):
ids = request.GET.getlist("id[]")
@api_view.get
def account_relationships(request, id: list[str] | None) -> list[schemas.Relationship]:
result = []
for id in ids:
identity = get_object_or_404(Identity, pk=id)
# ID is actually a list. Thanks Mastodon!
ids = id or []
for actual_id in ids:
identity = get_object_or_404(Identity, pk=actual_id)
result.append(
IdentityService(identity).mastodon_json_relationship(request.identity)
)
return result
@api_router.get(
"/v1/accounts/familiar_followers", response=list[schemas.FamiliarFollowers]
)
@identity_required
def familiar_followers(request):
@api_view.get
def familiar_followers(
request, id: list[str] | None
) -> list[schemas.FamiliarFollowers]:
"""
Returns people you follow that also follow given account IDs
"""
ids = request.GET.getlist("id[]")
ids = id or []
result = []
for id in ids:
target_identity = get_object_or_404(Identity, pk=id)
for actual_id in ids:
target_identity = get_object_or_404(Identity, pk=actual_id)
result.append(
{
"id": id,
"accounts": [
identity.to_mastodon_json()
schemas.FamiliarFollowers(
id=actual_id,
accounts=[
schemas.Account.from_identity(identity)
for identity in Identity.objects.filter(
inbound_follows__source=request.identity,
outbound_follows__target=target_identity,
)[:20]
],
}
)
)
return result
@api_router.get("/v1/accounts/search", response=list[schemas.Account])
@identity_required
def search(
@api_view.get
def accounts_search(
request,
q: str,
fetch_identities: bool = Field(False, alias="resolve"),
resolve: bool = False,
following: bool = False,
limit: int = 20,
offset: int = 0,
):
) -> list[schemas.Account]:
"""
Handles searching for accounts by username or handle
"""
@ -135,33 +133,33 @@ def search(
return []
searcher = SearchService(q, request.identity)
search_result = searcher.search_identities_handle()
return [i.to_mastodon_json() for i in search_result]
return [schemas.Account.from_identity(i) for i in search_result]
@api_router.get("/v1/accounts/lookup", response=schemas.Account)
def lookup(request: HttpRequest, acct: str):
@api_view.get
def lookup(request: HttpRequest, acct: str) -> schemas.Account:
"""
Quickly lookup a username to see if it is available, skipping WebFinger
resolution.
"""
identity = by_handle_or_404(request, handle=acct, local=False)
return identity.to_mastodon_json()
return schemas.Account.from_identity(identity)
@api_router.get("/v1/accounts/{id}", response=schemas.Account)
@api_view.get
@identity_required
def account(request, id: str):
def account(request, id: str) -> schemas.Account:
identity = get_object_or_404(
Identity.objects.exclude(restriction=Identity.Restriction.blocked), pk=id
Identity.objects.exclude(restriction=Identity.Restriction.blocked),
pk=id,
)
return identity.to_mastodon_json()
return schemas.Account.from_identity(identity)
@api_router.get("/v1/accounts/{id}/statuses", response=list[schemas.Status])
@api_view.get
@identity_required
def account_statuses(
request: HttpRequest,
response: HttpResponse,
id: str,
exclude_reblogs: bool = False,
exclude_replies: bool = False,
@ -172,7 +170,7 @@ def account_statuses(
since_id: str | None = None,
min_id: str | None = None,
limit: int = 20,
):
) -> ApiResponse[list[schemas.Status]]:
identity = get_object_or_404(
Identity.objects.exclude(restriction=Identity.Restriction.blocked), pk=id
)
@ -191,177 +189,163 @@ def account_statuses(
.order_by("-created")
)
if pinned:
return []
return ApiResponse([])
if only_media:
queryset = queryset.filter(attachments__pk__isnull=False)
if tagged:
queryset = queryset.tagged_with(tagged)
# Get user posts with pagination
paginator = MastodonPaginator()
pager = paginator.paginate(
pager: PaginationResult[Post] = paginator.paginate(
queryset,
min_id=min_id,
max_id=max_id,
since_id=since_id,
limit=limit,
)
# Convert those to the JSON form
pager.jsonify_posts(identity=request.identity)
# Add a link header if we need to
if pager.results:
response.headers["Link"] = pager.link_header(
request,
[
"limit",
"id",
"exclude_reblogs",
"exclude_replies",
"only_media",
"pinned",
"tagged",
],
)
return pager.json_results
return PaginatingApiResponse(
schemas.Status.map_from_post(pager.results, request.identity),
request=request,
include_params=[
"limit",
"id",
"exclude_reblogs",
"exclude_replies",
"only_media",
"pinned",
"tagged",
],
)
@api_router.post("/v1/accounts/{id}/follow", response=schemas.Relationship)
@api_view.post
@identity_required
def account_follow(request, id: str, reblogs: bool = True):
def account_follow(request, id: str, reblogs: bool = True) -> schemas.Relationship:
identity = get_object_or_404(
Identity.objects.exclude(restriction=Identity.Restriction.blocked), pk=id
)
service = IdentityService(identity)
service.follow_from(request.identity, boosts=reblogs)
return service.mastodon_json_relationship(request.identity)
return schemas.Relationship.from_identity_pair(identity, request.identity)
@api_router.post("/v1/accounts/{id}/unfollow", response=schemas.Relationship)
@api_view.post
@identity_required
def account_unfollow(request, id: str):
def account_unfollow(request, id: str) -> schemas.Relationship:
identity = get_object_or_404(
Identity.objects.exclude(restriction=Identity.Restriction.blocked), pk=id
)
service = IdentityService(identity)
service.unfollow_from(request.identity)
return service.mastodon_json_relationship(request.identity)
return schemas.Relationship.from_identity_pair(identity, request.identity)
@api_router.post("/v1/accounts/{id}/block", response=schemas.Relationship)
@api_view.post
@identity_required
def account_block(request, id: str):
def account_block(request, id: str) -> schemas.Relationship:
identity = get_object_or_404(Identity, pk=id)
service = IdentityService(identity)
service.block_from(request.identity)
return service.mastodon_json_relationship(request.identity)
return schemas.Relationship.from_identity_pair(identity, request.identity)
@api_router.post("/v1/accounts/{id}/unblock", response=schemas.Relationship)
@api_view.post
@identity_required
def account_unblock(request, id: str):
def account_unblock(request, id: str) -> schemas.Relationship:
identity = get_object_or_404(Identity, pk=id)
service = IdentityService(identity)
service.unblock_from(request.identity)
return service.mastodon_json_relationship(request.identity)
return schemas.Relationship.from_identity_pair(identity, request.identity)
class MuteDetailsSchema(Schema):
notifications: bool = True
duration: int = 0
@api_router.post("/v1/accounts/{id}/mute", response=schemas.Relationship)
@identity_required
def account_mute(request, id: str, details: MuteDetailsSchema):
@api_view.post
def account_mute(
request,
id: str,
notifications: QueryOrBody[bool] = True,
duration: QueryOrBody[int] = 0,
) -> schemas.Relationship:
identity = get_object_or_404(Identity, pk=id)
service = IdentityService(identity)
service.mute_from(
request.identity,
duration=details.duration,
include_notifications=details.notifications,
duration=duration,
include_notifications=notifications,
)
return service.mastodon_json_relationship(request.identity)
return schemas.Relationship.from_identity_pair(identity, request.identity)
@api_router.post("/v1/accounts/{id}/unmute", response=schemas.Relationship)
@identity_required
def account_unmute(request, id: str):
@api_view.post
def account_unmute(request, id: str) -> schemas.Relationship:
identity = get_object_or_404(Identity, pk=id)
service = IdentityService(identity)
service.unmute_from(request.identity)
return service.mastodon_json_relationship(request.identity)
return schemas.Relationship.from_identity_pair(identity, request.identity)
@api_router.get("/v1/accounts/{id}/following", response=list[schemas.Account])
@api_view.get
def account_following(
request: HttpRequest,
response: HttpResponse,
id: str,
max_id: str | None = None,
since_id: str | None = None,
min_id: str | None = None,
limit: int = 40,
):
) -> ApiResponse[list[schemas.Account]]:
identity = get_object_or_404(
Identity.objects.exclude(restriction=Identity.Restriction.blocked), pk=id
)
if not identity.config_identity.visible_follows and request.identity != identity:
return []
return ApiResponse([])
service = IdentityService(identity)
paginator = MastodonPaginator(max_limit=80)
pager = paginator.paginate(
pager: PaginationResult[Identity] = paginator.paginate(
service.following(),
min_id=min_id,
max_id=max_id,
since_id=since_id,
limit=limit,
)
pager.jsonify_identities()
if pager.results:
response.headers["Link"] = pager.link_header(
request,
["limit"],
)
return pager.json_results
return PaginatingApiResponse(
[schemas.Account.from_identity(i) for i in pager.results],
request=request,
include_params=["limit"],
)
@api_router.get("/v1/accounts/{id}/followers", response=list[schemas.Account])
@api_view.get
def account_followers(
request: HttpRequest,
response: HttpResponse,
id: str,
max_id: str | None = None,
since_id: str | None = None,
min_id: str | None = None,
limit: int = 40,
):
) -> ApiResponse[list[schemas.Account]]:
identity = get_object_or_404(
Identity.objects.exclude(restriction=Identity.Restriction.blocked), pk=id
)
if not identity.config_identity.visible_follows and request.identity != identity:
return []
return ApiResponse([])
service = IdentityService(identity)
paginator = MastodonPaginator(max_limit=80)
pager = paginator.paginate(
pager: PaginationResult[Identity] = paginator.paginate(
service.followers(),
min_id=min_id,
max_id=max_id,
since_id=since_id,
limit=limit,
)
pager.jsonify_identities()
if pager.results:
response.headers["Link"] = pager.link_header(
request,
["limit"],
)
return pager.json_results
return PaginatingApiResponse(
[schemas.Account.from_identity(i) for i in pager.results],
request=request,
include_params=["limit"],
)

View File

@ -2,22 +2,22 @@ from django.shortcuts import get_object_or_404
from api import schemas
from api.decorators import identity_required
from api.views.base import api_router
from hatchway import api_view
from users.models import Announcement
from users.services import AnnouncementService
@api_router.get("/v1/announcements", response=list[schemas.Announcement])
@identity_required
def announcement_list(request):
@api_view.get
def announcement_list(request) -> list[schemas.Announcement]:
return [
a.to_mastodon_json(request.user)
schemas.Announcement.from_announcement(a, request.user)
for a in AnnouncementService(request.user).visible()
]
@api_router.post("/v1/announcements/{pk}/dismiss")
@identity_required
@api_view.post
def announcement_dismiss(request, pk: str):
announcement = get_object_or_404(Announcement, pk=pk)
AnnouncementService(request.user).mark_seen(announcement)

View File

@ -1,10 +1,9 @@
import secrets
from ninja import Schema
from hatchway import Schema, api_view
from .. import schemas
from ..models import Application
from .base import api_router
class CreateApplicationSchema(Schema):
@ -14,8 +13,8 @@ class CreateApplicationSchema(Schema):
website: None | str = None
@api_router.post("/v1/apps", response=schemas.Application)
def add_app(request, details: CreateApplicationSchema):
@api_view.post
def add_app(request, details: CreateApplicationSchema) -> schemas.Application:
client_id = "tk-" + secrets.token_urlsafe(16)
client_secret = secrets.token_urlsafe(40)
application = Application.objects.create(
@ -26,4 +25,4 @@ def add_app(request, details: CreateApplicationSchema):
redirect_uris=details.redirect_uris,
scopes=details.scopes or "read",
)
return application
return schemas.Application.from_orm(application)

View File

@ -1,5 +0,0 @@
from ninja import NinjaAPI
from api.parser import FormOrJsonParser
api_router = NinjaAPI(parser=FormOrJsonParser())

View File

@ -1,8 +1,10 @@
from activities.models import Emoji
from api.schemas import CustomEmoji
from api.views.base import api_router
from hatchway import api_view
@api_router.get("/v1/custom_emojis", response=list[CustomEmoji])
def emojis(request):
return [e.to_mastodon_json() for e in Emoji.objects.usable().filter(local=True)]
@api_view.get
def emojis(request) -> list[CustomEmoji]:
return [
CustomEmoji.from_emoji(e) for e in Emoji.objects.usable().filter(local=True)
]

View File

@ -1,8 +1,8 @@
from api.decorators import identity_required
from api.views.base import api_router
from hatchway import api_view
@api_router.get("/v1/filters")
@identity_required
def status(request):
@api_view.get
def list_filters(request):
return []

View File

@ -1,15 +1,15 @@
from django.conf import settings
from activities.models import Post
from api import schemas
from core.models import Config
from hatchway import api_view
from takahe import __version__
from users.models import Domain, Identity
from .base import api_router
@api_router.get("/v1/instance")
def instance_info(request):
@api_view.get
def instance_info_v1(request):
return {
"uri": request.headers.get("host", settings.SETUP.MAIN_DOMAIN),
"title": Config.system.site_name,
@ -53,11 +53,17 @@ def instance_info(request):
}
@api_router.get("/v2/instance")
def instance_info_v2(request):
@api_view.get
def instance_info_v2(request) -> dict:
current_domain = Domain.get_domain(
request.headers.get("host", settings.SETUP.MAIN_DOMAIN)
)
if current_domain is None or not current_domain.local:
current_domain = Domain.get_domain(
request.headers.get(settings.SETUP.MAIN_DOMAIN)
)
if current_domain is None:
raise ValueError("No domain set up for MAIN_DOMAIN")
admin_identity = (
Identity.objects.filter(users__admin=True).order_by("created").first()
)
@ -116,7 +122,7 @@ def instance_info_v2(request):
},
"contact": {
"email": "",
"account": admin_identity.to_mastodon_json(),
"account": schemas.Account.from_identity(admin_identity),
},
"rules": [],
}

View File

@ -1,28 +1,22 @@
from django.core.files import File
from django.shortcuts import get_object_or_404
from ninja import File, Schema
from ninja.files import UploadedFile
from activities.models import PostAttachment, PostAttachmentStates
from api import schemas
from api.views.base import api_router
from core.files import blurhash_image, resize_image
from hatchway import QueryOrBody, api_view
from ..decorators import identity_required
class UploadMediaSchema(Schema):
description: str = ""
focus: str = "0,0"
@api_router.post("/v1/media", response=schemas.MediaAttachment)
@api_router.post("/v2/media", response=schemas.MediaAttachment)
@identity_required
@api_view.post
def upload_media(
request,
file: UploadedFile = File(...),
details: UploadMediaSchema | None = None,
):
file: File,
description: QueryOrBody[str] = "",
focus: QueryOrBody[str] = "0,0",
) -> schemas.MediaAttachment:
main_file = resize_image(
file,
size=(2000, 2000),
@ -38,7 +32,7 @@ def upload_media(
mimetype="image/webp",
width=main_file.image.width,
height=main_file.image.height,
name=details.description if details else None,
name=description or None,
state=PostAttachmentStates.fetched,
)
attachment.file.save(
@ -50,27 +44,28 @@ def upload_media(
thumbnail_file,
)
attachment.save()
return attachment.to_mastodon_json()
return schemas.MediaAttachment.from_post_attachment(attachment)
@api_router.get("/v1/media/{id}", response=schemas.MediaAttachment)
@identity_required
@api_view.get
def get_media(
request,
id: str,
):
) -> schemas.MediaAttachment:
attachment = get_object_or_404(PostAttachment, pk=id)
return attachment.to_mastodon_json()
return schemas.MediaAttachment.from_post_attachment(attachment)
@api_router.put("/v1/media/{id}", response=schemas.MediaAttachment)
@identity_required
@api_view.put
def update_media(
request,
id: str,
details: UploadMediaSchema | None = None,
):
description: QueryOrBody[str] = "",
focus: QueryOrBody[str] = "0,0",
) -> schemas.MediaAttachment:
attachment = get_object_or_404(PostAttachment, pk=id)
attachment.name = details.description if details else None
attachment.name = description or None
attachment.save()
return attachment.to_mastodon_json()
return schemas.MediaAttachment.from_post_attachment(attachment)

View File

@ -1,24 +1,23 @@
from django.http import HttpRequest, HttpResponse
from django.http import HttpRequest
from activities.models import TimelineEvent
from activities.services import TimelineService
from api import schemas
from api.decorators import identity_required
from api.pagination import MastodonPaginator
from api.views.base import api_router
from api.pagination import MastodonPaginator, PaginatingApiResponse, PaginationResult
from hatchway import ApiResponse, api_view
@api_router.get("/v1/notifications", response=list[schemas.Notification])
@identity_required
@api_view.get
def notifications(
request: HttpRequest,
response: HttpResponse,
max_id: str | None = None,
since_id: str | None = None,
min_id: str | None = None,
limit: int = 20,
account_id: str | None = None,
):
) -> ApiResponse[list[schemas.Notification]]:
# Types/exclude_types use weird syntax so we have to handle them manually
base_types = {
"favourite": TimelineEvent.Types.liked,
@ -37,16 +36,15 @@ def notifications(
[base_types[r] for r in requested_types if r in base_types]
)
paginator = MastodonPaginator()
pager = paginator.paginate(
pager: PaginationResult[TimelineEvent] = paginator.paginate(
queryset,
min_id=min_id,
max_id=max_id,
since_id=since_id,
limit=limit,
)
pager.jsonify_notification_events(identity=request.identity)
if pager.results:
response.headers["Link"] = pager.link_header(request, ["limit", "account_id"])
return pager.json_results
return PaginatingApiResponse(
[schemas.Notification.from_timeline_event(event) for event in pager.results],
request=request,
include_params=["limit", "account_id"],
)

View File

@ -1,4 +1,5 @@
import base64
import json
import secrets
from urllib.parse import urlparse, urlunparse
@ -16,7 +17,6 @@ from django.views.decorators.csrf import csrf_exempt
from django.views.generic import View
from api.models import Application, Authorization, Token
from api.parser import FormOrJsonParser
class OauthRedirect(HttpResponseRedirect):
@ -40,6 +40,19 @@ class OauthRedirect(HttpResponseRedirect):
super().__init__(urlunparse(url_parts))
def get_json_and_formdata(request):
# Did they submit JSON?
if request.content_type == "application/json" and request.body.strip():
return json.loads(request.body)
# Fall back to form data
value = {}
for key, item in request.POST.items():
value[key] = item
for key, item in request.GET.items():
value[key] = item
return value
class AuthorizationView(LoginRequiredMixin, View):
"""
Asks the user to authorize access.
@ -87,7 +100,7 @@ class AuthorizationView(LoginRequiredMixin, View):
return render(request, "api/oauth_authorize.html", context)
def post(self, request):
post_data = FormOrJsonParser().parse_body(request)
post_data = get_json_and_formdata(request)
# Grab the application and other details again
redirect_uri = post_data["redirect_uri"]
scope = post_data["scope"]
@ -141,7 +154,7 @@ class TokenView(View):
)
def post(self, request):
post_data = FormOrJsonParser().parse_body(request)
post_data = get_json_and_formdata(request)
auth_client_id, auth_client_secret = extract_client_info_from_basic_auth(
request
)
@ -221,7 +234,7 @@ class TokenView(View):
@method_decorator(csrf_exempt, name="dispatch")
class RevokeTokenView(View):
def post(self, request):
post_data = FormOrJsonParser().parse_body(request)
post_data = get_json_and_formdata(request)
auth_client_id, auth_client_secret = extract_client_info_from_basic_auth(
request
)

View File

@ -1,16 +1,14 @@
from typing import Literal
from ninja import Field
from activities.models import PostInteraction
from activities.services.search import SearchService
from api import schemas
from api.decorators import identity_required
from api.views.base import api_router
from hatchway import Field, api_view
@api_router.get("/v2/search", response=schemas.Search)
@identity_required
@api_view.get
def search(
request,
q: str,
@ -24,26 +22,31 @@ def search(
min_id: str | None = None,
limit: int = 20,
offset: int = 0,
):
) -> schemas.Search:
if limit > 40:
limit = 40
result: dict[str, list] = {"accounts": [], "statuses": [], "hashtags": []}
# We don't support pagination for searches yet
if max_id or since_id or min_id or offset:
return result
return schemas.Search(**result)
# Run search
searcher = SearchService(q, request.identity)
search_result = searcher.search_all()
if type is None or type == "accounts":
result["accounts"] = [i.to_mastodon_json() for i in search_result["identities"]]
result["accounts"] = [
schemas.Account.from_identity(i, include_counts=False)
for i in search_result["identities"]
]
if type is None or type == "hashtag":
result["hashtag"] = [h.to_mastodon_json() for h in search_result["hashtags"]]
result["hashtag"] = [
schemas.Tag.from_hashtag(h) for h in search_result["hashtags"]
]
if type is None or type == "statuses":
interactions = PostInteraction.get_post_interactions(
search_result["posts"], request.identity
)
result["statuses"] = [
p.to_mastodon_json(interactions=interactions)
schemas.Status.from_post(p, interactions=interactions)
for p in search_result["posts"]
]
return result
return schemas.Search(**result)

View File

@ -1,9 +1,8 @@
from typing import Literal
from django.forms import ValidationError
from django.http import HttpRequest, HttpResponse
from django.http import HttpRequest
from django.shortcuts import get_object_or_404
from ninja import Schema
from activities.models import (
Post,
@ -14,11 +13,10 @@ from activities.models import (
)
from activities.services import PostService
from api import schemas
from api.views.base import api_router
from api.decorators import identity_required
from api.pagination import MastodonPaginator, PaginationResult
from core.models import Config
from ..decorators import identity_required
from ..pagination import MastodonPaginator
from hatchway import ApiResponse, Schema, api_view
class PostStatusSchema(Schema):
@ -32,9 +30,9 @@ class PostStatusSchema(Schema):
media_ids: list[str] = []
@api_router.post("/v1/statuses", response=schemas.Status)
@identity_required
def post_status(request, details: PostStatusSchema):
@api_view.post
def post_status(request, details: PostStatusSchema) -> schemas.Status:
# Check text length
if len(details.status) > Config.system.post_length:
raise ValidationError("Status is too long")
@ -66,74 +64,74 @@ def post_status(request, details: PostStatusSchema):
)
# Add their own timeline event for immediate visibility
TimelineEvent.add_post(request.identity, post)
return post.to_mastodon_json()
return schemas.Status.from_post(post)
@api_router.get("/v1/statuses/{id}", response=schemas.Status)
@identity_required
def status(request, id: str):
@api_view.get
def status(request, id: str) -> schemas.Status:
post = get_object_or_404(Post, pk=id)
interactions = PostInteraction.get_post_interactions([post], request.identity)
return post.to_mastodon_json(interactions=interactions)
return schemas.Status.from_post(post, interactions=interactions)
@api_router.delete("/v1/statuses/{id}", response=schemas.Status)
@identity_required
def delete_status(request, id: str):
@api_view.delete
def delete_status(request, id: str) -> schemas.Status:
post = get_object_or_404(Post, pk=id)
PostService(post).delete()
return post.to_mastodon_json()
return schemas.Status.from_post(post)
@api_router.get("/v1/statuses/{id}/context", response=schemas.Context)
@identity_required
def status_context(request, id: str):
@api_view.get
def status_context(request, id: str) -> schemas.Context:
post = get_object_or_404(Post, pk=id)
service = PostService(post)
ancestors, descendants = service.context(request.identity)
interactions = PostInteraction.get_post_interactions(
ancestors + descendants, request.identity
)
return {
"ancestors": [
p.to_mastodon_json(interactions=interactions) for p in reversed(ancestors)
return schemas.Context(
ancestors=[
schemas.Status.from_post(p, interactions=interactions)
for p in reversed(ancestors)
],
"descendants": [
p.to_mastodon_json(interactions=interactions) for p in descendants
descendants=[
schemas.Status.from_post(p, interactions=interactions) for p in descendants
],
}
)
@api_router.post("/v1/statuses/{id}/favourite", response=schemas.Status)
@identity_required
def favourite_status(request, id: str):
@api_view.post
def favourite_status(request, id: str) -> schemas.Status:
post = get_object_or_404(Post, pk=id)
service = PostService(post)
service.like_as(request.identity)
interactions = PostInteraction.get_post_interactions([post], request.identity)
return post.to_mastodon_json(interactions=interactions)
return schemas.Status.from_post(post, interactions=interactions)
@api_router.post("/v1/statuses/{id}/unfavourite", response=schemas.Status)
@identity_required
def unfavourite_status(request, id: str):
@api_view.post
def unfavourite_status(request, id: str) -> schemas.Status:
post = get_object_or_404(Post, pk=id)
service = PostService(post)
service.unlike_as(request.identity)
interactions = PostInteraction.get_post_interactions([post], request.identity)
return post.to_mastodon_json(interactions=interactions)
return schemas.Status.from_post(post, interactions=interactions)
@api_router.get("/v1/statuses/{id}/favourited_by", response=list[schemas.Account])
@api_view.get
def favourited_by(
request: HttpRequest,
response: HttpResponse,
id: str,
max_id: str | None = None,
since_id: str | None = None,
min_id: str | None = None,
limit: int = 20,
):
) -> ApiResponse[list[schemas.Account]]:
"""
View who favourited a given status.
"""
@ -142,7 +140,7 @@ def favourited_by(
post = get_object_or_404(Post, pk=id)
paginator = MastodonPaginator()
pager = paginator.paginate(
pager: PaginationResult[PostInteraction] = paginator.paginate(
post.interactions.filter(
type=PostInteraction.Types.like,
state__in=PostInteractionStates.group_active(),
@ -152,32 +150,37 @@ def favourited_by(
since_id=since_id,
limit=limit,
)
pager.jsonify_results(lambda r: r.identity.to_mastodon_json(include_counts=False))
headers = {}
if pager.results:
response.headers["Link"] = pager.link_header(
request,
["limit"],
)
return pager.json_results
headers = {"link": pager.link_header(request, ["limit"])}
return ApiResponse(
[
schemas.Account.from_identity(
interaction.identity,
include_counts=False,
)
for interaction in pager.results
],
headers=headers,
)
@api_router.post("/v1/statuses/{id}/reblog", response=schemas.Status)
@identity_required
def reblog_status(request, id: str):
@api_view.post
def reblog_status(request, id: str) -> schemas.Status:
post = get_object_or_404(Post, pk=id)
service = PostService(post)
service.boost_as(request.identity)
interactions = PostInteraction.get_post_interactions([post], request.identity)
return post.to_mastodon_json(interactions=interactions)
return schemas.Status.from_post(post, interactions=interactions)
@api_router.post("/v1/statuses/{id}/unreblog", response=schemas.Status)
@identity_required
def unreblog_status(request, id: str):
@api_view.post
def unreblog_status(request, id: str) -> schemas.Status:
post = get_object_or_404(Post, pk=id)
service = PostService(post)
service.unboost_as(request.identity)
interactions = PostInteraction.get_post_interactions([post], request.identity)
return post.to_mastodon_json(interactions=interactions)
return schemas.Status.from_post(post, interactions=interactions)

View File

@ -1,23 +1,23 @@
from django.http import HttpRequest, HttpResponse, JsonResponse
from django.http import HttpRequest
from activities.models import Post
from activities.services import TimelineService
from api import schemas
from api.decorators import identity_required
from api.pagination import MastodonPaginator
from api.views.base import api_router
from api.pagination import MastodonPaginator, PaginatingApiResponse, PaginationResult
from core.models import Config
from hatchway import ApiError, ApiResponse, api_view
@api_router.get("/v1/timelines/home", response=list[schemas.Status])
@identity_required
@api_view.get
def home(
request: HttpRequest,
response: HttpResponse,
max_id: str | None = None,
since_id: str | None = None,
min_id: str | None = None,
limit: int = 20,
):
) -> ApiResponse[list[schemas.Status]]:
# Grab a paginated result set of instances
paginator = MastodonPaginator()
queryset = TimelineService(request.identity).home()
@ -41,18 +41,16 @@ def home(
since_id=since_id,
limit=limit,
)
# Convert those to the JSON form
pager.jsonify_status_events(identity=request.identity)
# Add the link header if needed
if pager.results:
response.headers["Link"] = pager.link_header(request, ["limit"])
return pager.json_results
return PaginatingApiResponse(
schemas.Status.map_from_timeline_event(pager.results, request.identity),
request=request,
include_params=["limit"],
)
@api_router.get("/v1/timelines/public", response=list[schemas.Status])
@api_view.get
def public(
request: HttpRequest,
response: HttpResponse,
local: bool = False,
remote: bool = False,
only_media: bool = False,
@ -60,9 +58,9 @@ def public(
since_id: str | None = None,
min_id: str | None = None,
limit: int = 20,
):
) -> ApiResponse[list[schemas.Status]]:
if not request.identity and not Config.system.public_timeline:
return JsonResponse({"error": "public timeline is disabled"}, status=422)
raise ApiError(error="public timeline is disabled", status=422)
if local:
queryset = TimelineService(request.identity).local()
@ -74,29 +72,24 @@ def public(
queryset = queryset.filter(attachments__id__isnull=True)
# Grab a paginated result set of instances
paginator = MastodonPaginator()
pager = paginator.paginate(
pager: PaginationResult[Post] = paginator.paginate(
queryset,
min_id=min_id,
max_id=max_id,
since_id=since_id,
limit=limit,
)
# Convert those to the JSON form
pager.jsonify_posts(identity=request.identity)
# Add the link header if needed
if pager.results:
response.headers["Link"] = pager.link_header(
request,
["limit", "local", "remote", "only_media"],
)
return pager.json_results
return PaginatingApiResponse(
schemas.Status.map_from_post(pager.results, request.identity),
request=request,
include_params=["limit", "local", "remote", "only_media"],
)
@api_router.get("/v1/timelines/tag/{hashtag}", response=list[schemas.Status])
@identity_required
@api_view.get
def hashtag(
request: HttpRequest,
response: HttpResponse,
hashtag: str,
local: bool = False,
only_media: bool = False,
@ -104,7 +97,7 @@ def hashtag(
since_id: str | None = None,
min_id: str | None = None,
limit: int = 20,
):
) -> ApiResponse[list[schemas.Status]]:
if limit > 40:
limit = 40
queryset = TimelineService(request.identity).hashtag(hashtag)
@ -114,63 +107,54 @@ def hashtag(
queryset = queryset.filter(attachments__id__isnull=True)
# Grab a paginated result set of instances
paginator = MastodonPaginator()
pager = paginator.paginate(
pager: PaginationResult[Post] = paginator.paginate(
queryset,
min_id=min_id,
max_id=max_id,
since_id=since_id,
limit=limit,
)
# Convert those to the JSON form
pager.jsonify_posts(identity=request.identity)
# Add a link header if we need to
if pager.results:
response.headers["Link"] = pager.link_header(
request,
["limit", "local", "remote", "only_media"],
)
return pager.json_results
return PaginatingApiResponse(
schemas.Status.map_from_post(pager.results, request.identity),
request=request,
include_params=["limit", "local", "remote", "only_media"],
)
@api_router.get("/v1/conversations", response=list[schemas.Status])
@identity_required
@api_view.get
def conversations(
request: HttpRequest,
response: HttpResponse,
max_id: str | None = None,
since_id: str | None = None,
min_id: str | None = None,
limit: int = 20,
):
) -> list[schemas.Status]:
# We don't implement this yet
return []
@api_router.get("/v1/favourites", response=list[schemas.Status])
@identity_required
@api_view.get
def favourites(
request: HttpRequest,
response: HttpResponse,
max_id: str | None = None,
since_id: str | None = None,
min_id: str | None = None,
limit: int = 20,
):
) -> ApiResponse[list[schemas.Status]]:
queryset = TimelineService(request.identity).likes()
paginator = MastodonPaginator()
pager = paginator.paginate(
pager: PaginationResult[Post] = paginator.paginate(
queryset,
min_id=min_id,
max_id=max_id,
since_id=since_id,
limit=limit,
)
# Convert those to the JSON form
pager.jsonify_posts(identity=request.identity)
# Add the link header if needed
if pager.results:
response.headers["Link"] = pager.link_header(request, ["limit"])
return pager.json_results
return PaginatingApiResponse(
schemas.Status.map_from_post(pager.results, request.identity),
request=request,
include_params=["limit"],
)

5
hatchway/__init__.py Normal file
View File

@ -0,0 +1,5 @@
from .http import ApiError, ApiResponse # noqa
from .schema import Field, Schema # noqa
from .types import Body, BodyDirect, Path, Query, QueryOrBody # noqa
from .urls import methods # noqa
from .view import api_view # noqa

10
hatchway/constants.py Normal file
View File

@ -0,0 +1,10 @@
import enum
class InputSource(str, enum.Enum):
path = "path"
query = "query"
body = "body"
body_direct = "body_direct"
query_and_body_direct = "query_and_body_direct"
file = "file"

47
hatchway/http.py Normal file
View File

@ -0,0 +1,47 @@
import json
from typing import Generic, TypeVar
from django.core.serializers.json import DjangoJSONEncoder
from django.http import HttpResponse
T = TypeVar("T")
class ApiResponse(Generic[T], HttpResponse):
"""
A way to return extra information with a response if you want
headers, etc.
"""
def __init__(
self,
data: T,
encoder=DjangoJSONEncoder,
json_dumps_params: dict[str, object] | None = None,
finalize: bool = False,
**kwargs
):
self.data = data
self.encoder = encoder
self.json_dumps_params = json_dumps_params or {}
kwargs.setdefault("content_type", "application/json")
super().__init__(content=b"(unfinalised)", **kwargs)
if finalize:
self.finalize()
def finalize(self):
"""
Converts whatever our current data is into HttpResponse content
"""
# TODO: Automatically call this when we're asked to write output?
self.content = json.dumps(self.data, cls=self.encoder, **self.json_dumps_params)
class ApiError(BaseException):
"""
A handy way to raise an error with JSONable contents
"""
def __init__(self, status: int, error: str):
self.status = status
self.error = error

52
hatchway/schema.py Normal file
View File

@ -0,0 +1,52 @@
from typing import Any
from django.db.models import Manager, QuerySet
from django.db.models.fields.files import FieldFile
from django.template import Variable, VariableDoesNotExist
from pydantic.fields import Field # noqa
from pydantic.main import BaseModel
from pydantic.utils import GetterDict
class DjangoGetterDict(GetterDict):
def __init__(self, obj: Any):
self._obj = obj
def __getitem__(self, key: str) -> Any:
try:
item = getattr(self._obj, key)
except AttributeError:
try:
item = Variable(key).resolve(self._obj)
except VariableDoesNotExist as e:
raise KeyError(key) from e
return self._convert_result(item)
def get(self, key: Any, default: Any = None) -> Any:
try:
return self[key]
except KeyError:
return default
def _convert_result(self, result: Any) -> Any:
if isinstance(result, Manager):
return list(result.all())
elif isinstance(result, getattr(QuerySet, "__origin__", QuerySet)):
return list(result)
if callable(result):
return result()
elif isinstance(result, FieldFile):
if not result:
return None
return result.url
return result
class Schema(BaseModel):
class Config:
orm_mode = True
getter_dict = DjangoGetterDict

View File

@ -0,0 +1,63 @@
from typing import Literal, Optional, Union
from django.core.files import File
from hatchway.http import ApiResponse
from hatchway.types import (
Query,
QueryType,
acceptable_input,
extract_output_type,
extract_signifier,
is_optional,
)
def test_is_optional():
assert is_optional(Optional[int]) == (True, int)
assert is_optional(Union[int, None]) == (True, int)
assert is_optional(Union[None, int]) == (True, int)
assert is_optional(int | None) == (True, int)
assert is_optional(None | int) == (True, int)
assert is_optional(int) == (False, int)
assert is_optional(Query[int]) == (False, Query[int])
def test_extract_signifier():
assert extract_signifier(int) == (None, int)
assert extract_signifier(Query[int]) == (QueryType, int)
assert extract_signifier(Query[Optional[int]]) == ( # type:ignore
QueryType,
Optional[int],
)
assert extract_signifier(Query[int | None]) == ( # type:ignore
QueryType,
Optional[int],
)
assert extract_signifier(Optional[Query[int]]) == (QueryType, Optional[int])
def test_extract_output_type():
assert extract_output_type(int) == int
assert extract_output_type(ApiResponse[int]) == int
assert extract_output_type(ApiResponse[int | str]) == int | str
def test_acceptable_input():
assert acceptable_input(str) is True
assert acceptable_input(int) is True
assert acceptable_input(Query[int]) is True
assert acceptable_input(Optional[int]) is True
assert acceptable_input(int | None) is True
assert acceptable_input(int | str | None) is True
assert acceptable_input(Query[int | None]) is True # type: ignore
assert acceptable_input(File) is True
assert acceptable_input(list[str]) is True
assert acceptable_input(dict[str, int]) is True
assert acceptable_input(Literal["a", "b"]) is True
assert acceptable_input(frozenset) is False
assert acceptable_input(dict[str, frozenset]) is False

244
hatchway/tests/test_view.py Normal file
View File

@ -0,0 +1,244 @@
import json
import pytest
from django.core import files
from django.core.files.uploadedfile import SimpleUploadedFile
from django.http import QueryDict
from django.test import RequestFactory
from django.test.client import MULTIPART_CONTENT
from pydantic import BaseModel
from hatchway import ApiError, Body, QueryOrBody, api_view
from hatchway.view import ApiView
def test_basic_view():
"""
Tests that a view with simple types works correctly
"""
@api_view
def test_view(
request,
a: int,
b: QueryOrBody[int | None] = None,
c: str = "x",
) -> str:
if b is None:
return c * a
else:
return c * (a - b)
# Call it with a few different patterns to verify it's type coercing right
factory = RequestFactory()
# Implicit query param
response = test_view(factory.get("/test/?a=4"))
assert json.loads(response.content) == "xxxx"
# QueryOrBody pulling from query
response = test_view(factory.get("/test/?a=4&b=2"))
assert json.loads(response.content) == "xx"
# QueryOrBody pulling from formdata body
response = test_view(factory.post("/test/?a=4", {"b": "3"}))
assert json.loads(response.content) == "x"
# QueryOrBody pulling from JSON body
response = test_view(
factory.post(
"/test/?a=4", json.dumps({"b": 3}), content_type="application/json"
)
)
assert json.loads(response.content) == "x"
# Implicit Query not pulling from body
with pytest.raises(TypeError):
test_view(factory.post("/test/", {"a": 4, "b": 3}))
def test_body_direct():
"""
Tests that a Pydantic model with BodyDirect gets its fields from the top level
"""
class TestModel(BaseModel):
number: int
name: str
@api_view
def test_view(request, data: TestModel) -> int:
return data.number
factory = RequestFactory()
# formdata version
response = test_view(factory.post("/test/", {"number": "123", "name": "Andrew"}))
assert json.loads(response.content) == 123
# JSON body version
response = test_view(
factory.post(
"/test/",
json.dumps({"number": "123", "name": "Andrew"}),
content_type="application/json",
)
)
assert json.loads(response.content) == 123
def test_list_response():
"""
Tests that a view with a list response type works correctly with both
dicts and pydantic model instances.
"""
class TestModel(BaseModel):
number: int
name: str
@api_view
def test_view_dict(request) -> list[TestModel]:
return [
{"name": "Andrew", "number": 1}, # type:ignore
{"name": "Alice", "number": 0}, # type:ignore
]
@api_view
def test_view_model(request) -> list[TestModel]:
return [TestModel(name="Andrew", number=1), TestModel(name="Alice", number=0)]
response = test_view_dict(RequestFactory().get("/test/"))
assert json.loads(response.content) == [
{"name": "Andrew", "number": 1},
{"name": "Alice", "number": 0},
]
response = test_view_model(RequestFactory().get("/test/"))
assert json.loads(response.content) == [
{"name": "Andrew", "number": 1},
{"name": "Alice", "number": 0},
]
def test_patch_body():
"""
Tests that PATCH also gets its body parsed
"""
@api_view.patch
def test_view(request, a: Body[int]):
return a
factory = RequestFactory()
response = test_view(
factory.patch(
"/test/",
content_type=MULTIPART_CONTENT,
data=factory._encode_data({"a": "42"}, MULTIPART_CONTENT),
)
)
assert json.loads(response.content) == 42
def test_file_body():
"""
Tests that file uploads work right
"""
@api_view.post
def test_view(request, a: Body[int], b: files.File) -> str:
return str(a) + b.read().decode("ascii")
factory = RequestFactory()
uploaded_file = SimpleUploadedFile(
"file.txt",
b"MY FILE IS AMAZING",
content_type="text/plain",
)
response = test_view(
factory.post(
"/test/",
data={"a": 42, "b": uploaded_file},
)
)
assert json.loads(response.content) == "42MY FILE IS AMAZING"
def test_no_response():
"""
Tests that a view with no response type returns the contents verbatim
"""
@api_view
def test_view(request):
return [1, "woooooo"]
response = test_view(RequestFactory().get("/test/"))
assert json.loads(response.content) == [1, "woooooo"]
def test_wrong_method():
"""
Tests that a view with a method limiter works
"""
@api_view.get
def test_view(request):
return "yay"
response = test_view(RequestFactory().get("/test/"))
assert json.loads(response.content) == "yay"
response = test_view(RequestFactory().post("/test/"))
assert response.status_code == 405
def test_api_error():
"""
Tests that ApiError propagates right
"""
@api_view.get
def test_view(request):
raise ApiError(401, "you did a bad thing")
response = test_view(RequestFactory().get("/test/"))
assert json.loads(response.content) == {"error": "you did a bad thing"}
assert response.status_code == 401
def test_unusable_type():
"""
Tests that you get a nice error when you use a type on an input that
Pydantic doesn't understand.
"""
with pytest.raises(ValueError):
@api_view.get
def test_view(request, a: RequestFactory):
pass
def test_get_values():
"""
Tests that ApiView.get_values correctly handles lists
"""
assert ApiView.get_values({"a": 2, "b": [3, 4]}) == {"a": 2, "b": [3, 4]}
assert ApiView.get_values({"a": 2, "b[]": [3, 4]}) == {"a": 2, "b": [3, 4]}
assert ApiView.get_values(QueryDict("a=2&b=3&b=4")) == {"a": "2", "b": ["3", "4"]}
assert ApiView.get_values(QueryDict("a=2&b[]=3&b[]=4")) == {
"a": "2",
"b": ["3", "4"],
}
assert ApiView.get_values(QueryDict("a=2&b=3")) == {"a": "2", "b": "3"}
assert ApiView.get_values(QueryDict("a=2&b[]=3")) == {"a": "2", "b": ["3"]}
assert ApiView.get_values(QueryDict("a[b]=1")) == {"a": {"b": "1"}}
assert ApiView.get_values(QueryDict("a[b]=1&a[c]=2")) == {"a": {"b": "1", "c": "2"}}
assert ApiView.get_values(QueryDict("a[b][c]=1")) == {"a": {"b": {"c": "1"}}}
assert ApiView.get_values(QueryDict("a[b][c][]=1")) == {"a": {"b": {"c": ["1"]}}}
assert ApiView.get_values(QueryDict("a[b][]=1&a[b][]=2")) == {
"a": {"b": ["1", "2"]}
}

145
hatchway/types.py Normal file
View File

@ -0,0 +1,145 @@
from types import NoneType, UnionType
from typing import ( # type: ignore[attr-defined]
Annotated,
Any,
Literal,
Optional,
TypeVar,
Union,
_AnnotatedAlias,
_GenericAlias,
get_args,
get_origin,
)
from django.core import files
from pydantic import BaseModel
from .http import ApiResponse
T = TypeVar("T")
class PathType:
"""
An input pulled from the path (url resolver kwargs)
"""
class QueryType:
"""
An input pulled from the query parameters (request.GET)
"""
class BodyType:
"""
An input pulled from the POST body (request.POST or a JSON body)
"""
class FileType:
"""
An input pulled from the POST body (request.POST or a JSON body)
"""
class BodyDirectType:
"""
A Pydantic model whose keys are all looked for in the top-level
POST data, rather than in a dict under a key named after the input.
"""
class QueryOrBodyType:
"""
An input pulled from either query parameters or post data.
"""
Path = Annotated[T, PathType]
Query = Annotated[T, QueryType]
Body = Annotated[T, BodyType]
File = Annotated[T, FileType]
BodyDirect = Annotated[T, BodyDirectType]
QueryOrBody = Annotated[T, QueryOrBodyType]
def is_optional(annotation) -> tuple[bool, Any]:
"""
If an annotation is Optional or | None, returns (True, internal type).
Returns (False, annotation) otherwise.
"""
if (isinstance(annotation, _GenericAlias) and annotation.__origin__ is Union) or (
isinstance(annotation, UnionType)
):
args = get_args(annotation)
if len(args) > 2:
return False, annotation
if args[0] is NoneType:
return True, args[1]
if args[1] is NoneType:
return True, args[0]
return False, annotation
return False, annotation
def extract_signifier(annotation) -> tuple[Any, Any]:
"""
Given a type annotation, looks to see if it can find a input source
signifier (Path, Query, etc.)
If it can, returns (signifier, annotation_without_signifier)
If not, returns (None, annotation)
"""
our_generics = {
PathType,
QueryType,
BodyType,
FileType,
BodyDirectType,
QueryOrBodyType,
}
# Remove any optional-style wrapper
optional, internal_annotation = is_optional(annotation)
# Is it an annotation?
if isinstance(internal_annotation, _AnnotatedAlias):
args = get_args(internal_annotation)
for arg in args[1:]:
if arg in our_generics:
if optional:
return (arg, Optional[args[0]])
else:
return (arg, args[0])
return None, annotation
def extract_output_type(annotation):
"""
Returns the right response type for a function
"""
# If the type is ApiResponse, we want to pull out its inside
if isinstance(annotation, _GenericAlias):
if get_origin(annotation) == ApiResponse:
return get_args(annotation)[0]
return annotation
def acceptable_input(annotation) -> bool:
"""
Returns if this annotation is something we think we can accept as input
"""
_, inner_type = extract_signifier(annotation)
try:
if issubclass(inner_type, BaseModel):
return True
except TypeError:
pass
if inner_type in [str, int, list, tuple, bool, Any, files.File, type(None)]:
return True
origin = get_origin(inner_type)
if origin == Literal:
return True
if origin in [Union, UnionType, dict, list, tuple]:
return all(acceptable_input(a) for a in get_args(inner_type))
return False

32
hatchway/urls.py Normal file
View File

@ -0,0 +1,32 @@
from collections.abc import Callable
from typing import Any
from django.http import HttpResponseNotAllowed
class Methods:
"""
Allows easy multi-method dispatch to different functions
"""
csrf_exempt = True
def __init__(self, **callables: Callable):
self.callables = {
method.lower(): callable for method, callable in callables.items()
}
unknown_methods = set(self.callables.keys()).difference(
{"get", "post", "patch", "put", "delete"}
)
if unknown_methods:
raise ValueError(f"Cannot route methods: {unknown_methods}")
def __call__(self, request, *args, **kwargs) -> Any:
method = request.method.lower()
if method in self.callables:
return self.callables[method](request, *args, **kwargs)
else:
return HttpResponseNotAllowed(self.callables.keys())
methods = Methods

297
hatchway/view.py Normal file
View File

@ -0,0 +1,297 @@
import json
from collections.abc import Callable
from typing import Any, Optional, get_type_hints
from django.core import files
from django.http import HttpRequest, HttpResponseNotAllowed, QueryDict
from django.http.multipartparser import MultiPartParser
from pydantic import BaseModel, create_model
from .constants import InputSource
from .http import ApiError, ApiResponse
from .types import (
BodyDirectType,
BodyType,
FileType,
PathType,
QueryOrBodyType,
QueryType,
acceptable_input,
extract_output_type,
extract_signifier,
is_optional,
)
class ApiView:
"""
A view 'wrapper' object that replaces the API view for anything further
up the stack.
Unlike Django's class-based views, we don't need an as_view pattern
as we are careful never to write anything per-request to self.
"""
csrf_exempt = True
def __init__(
self,
view: Callable,
input_types: dict[str, Any] | None = None,
output_type: Any = None,
implicit_lists: bool = True,
method: str | None = None,
):
self.view = view
self.implicit_lists = implicit_lists
self.view_name = getattr(view, "__name__", "unknown_view")
self.method = method
# Extract input/output types from view annotations if we need to
self.input_types = input_types
if self.input_types is None:
self.input_types = get_type_hints(view, include_extras=True)
if "return" in self.input_types:
del self.input_types["return"]
self.output_type = output_type
if self.output_type is None:
try:
self.output_type = extract_output_type(
get_type_hints(view, include_extras=True)["return"]
)
except KeyError:
self.output_type = None
self.compile()
@classmethod
def get(cls, view: Callable):
return cls(view=view, method="get")
@classmethod
def post(cls, view: Callable):
return cls(view=view, method="post")
@classmethod
def put(cls, view: Callable):
return cls(view=view, method="put")
@classmethod
def patch(cls, view: Callable):
return cls(view=view, method="patch")
@classmethod
def delete(cls, view: Callable):
return cls(view=view, method="delete")
@classmethod
def sources_for_input(cls, input_type) -> tuple[list[InputSource], Any]:
"""
Given a type that can appear as a request parameter type, returns
what sources it can come from, and what its type is as understood
by Pydantic.
"""
signifier, input_type = extract_signifier(input_type)
if signifier is QueryType:
return ([InputSource.query], input_type)
elif signifier is BodyType:
return ([InputSource.body], input_type)
elif signifier is BodyDirectType:
if not issubclass(input_type, BaseModel):
raise ValueError(
"You cannot use BodyDirect on something that is not a Pydantic model"
)
return ([InputSource.body_direct], input_type)
elif signifier is PathType:
return ([InputSource.path], input_type)
elif (
signifier is FileType
or input_type is files.File
or is_optional(input_type)[1] is files.File
):
return ([InputSource.file], input_type)
elif signifier is QueryOrBodyType:
return ([InputSource.query, InputSource.body], input_type)
# Is it a Pydantic model, which means it's implicitly body?
elif isinstance(input_type, type) and issubclass(input_type, BaseModel):
return ([InputSource.body], input_type)
# Otherwise, we look in the path first and then the query
else:
return ([InputSource.path, InputSource.query], input_type)
@classmethod
def get_values(cls, data, use_square_brackets=True) -> dict[str, Any]:
"""
Given a QueryDict or normal dict, returns data taking into account
lists made by repeated values or by suffixing names with [].
"""
result: dict[str, Any] = {}
for key, value in data.items():
# If it's a query dict with multiple values, make it a list
if isinstance(data, QueryDict):
values = data.getlist(key)
if len(values) > 1:
value = values
# If it is in dict-ish/list-ish syntax, adhere to that
# TODO: Make this better handle badly formed keys
if "[" in key and use_square_brackets:
parts = key.split("[")
target = result
last_key = parts[0]
for part in parts[1:]:
part = part.rstrip("]")
if not part:
target = target.setdefault(last_key, [])
else:
target = target.setdefault(last_key, {})
last_key = part
if isinstance(target, list):
if isinstance(value, list):
target.extend(value)
else:
target.append(value)
else:
target[last_key] = value
else:
result[key] = value
return result
def compile(self):
self.sources: dict[str, list[InputSource]] = {}
amount_from_body = 0
pydantic_model_dict = {}
self.input_files = set()
last_body_type = None
# For each input item, work out where to pull it from
for name, input_type in self.input_types.items():
# Do some basic typechecking to stop things that aren't allowed
if isinstance(input_type, type) and issubclass(input_type, HttpRequest):
continue
if not acceptable_input(input_type):
# Strip away any singifiers for the error
_, inner_type = extract_signifier(input_type)
raise ValueError(
f"Input argument {name} has an unsupported type {inner_type}"
)
sources, pydantic_type = self.sources_for_input(input_type)
self.sources[name] = sources
# Keep count of how many are pulling from the body
if InputSource.body in sources:
amount_from_body += 1
last_body_type = pydantic_type
if InputSource.file in sources:
self.input_files.add(name)
else:
pydantic_model_dict[name] = (Optional[pydantic_type], ...)
# If there is just one thing pulling from the body and it's a BaseModel,
# signify that it's actually pulling from the body keys directly and
# not a sub-dict
if amount_from_body == 1:
for name, sources in self.sources.items():
if (
InputSource.body in sources
and isinstance(last_body_type, type)
and issubclass(last_body_type, BaseModel)
):
self.sources[name] = [
x for x in sources if x != InputSource.body
] + [InputSource.body_direct]
# Turn all the main arguments into Pydantic parsing models
try:
self.input_model = create_model(
f"{self.view_name}_input", **pydantic_model_dict
)
except RuntimeError:
raise ValueError(
f"One or more inputs on view {self.view_name} have a bad configuration"
)
if self.output_type is not None:
self.output_model = create_model(
f"{self.view_name}_output", value=(self.output_type, ...)
)
def __call__(self, request: HttpRequest, *args, **kwargs):
"""
Entrypoint when this is called as a view.
"""
# Do a method check if we have one set
if self.method and self.method.upper() != request.method:
return HttpResponseNotAllowed([self.method])
# For each item we can source, go find it if we can
query_values = self.get_values(request.GET)
body_values = self.get_values(request.POST)
files_values = self.get_values(request.FILES)
# If it's a PUT or PATCH method, work around Django not handling FILES
# or POST on those requests
if request.method in ["PATCH", "PUT"]:
if request.content_type == "multipart/form-data":
POST, FILES = MultiPartParser(
request.META, request, request.upload_handlers, request.encoding
).parse()
body_values = self.get_values(POST)
files_values = self.get_values(FILES)
elif request.content_type == "application/x-www-form-urlencoded":
POST = QueryDict(request.body, encoding=request._encoding)
body_values = self.get_values(POST)
# If there was a JSON body, go load that
if request.content_type == "application/json" and request.body.strip():
body_values.update(self.get_values(json.loads(request.body)))
values = {}
for name, sources in self.sources.items():
for source in sources:
if source == InputSource.path:
if name in kwargs:
values[name] = kwargs[name]
break
elif source == InputSource.query:
if name in query_values:
values[name] = query_values[name]
break
elif source == InputSource.body:
if name in body_values:
values[name] = body_values[name]
break
elif source == InputSource.file:
if name in files_values:
values[name] = files_values[name]
break
elif source == InputSource.body_direct:
values[name] = body_values
break
elif source == InputSource.query_and_body_direct:
values[name] = dict(query_values)
values[name].update(body_values)
break
else:
raise ValueError(f"Unknown source {source}")
else:
values[name] = None
# Give that to the Pydantic model to make it handle stuff
model_instance = self.input_model(**values)
kwargs = {
name: getattr(model_instance, name)
for name in model_instance.__fields__
if values[name] is not None # Trim out missing fields
}
# Add in any files
# TODO: HTTP error if file is not optional
for name in self.input_files:
kwargs[name] = files_values.get(name, None)
# Call the view with those as kwargs
try:
response = self.view(request, **kwargs)
except ApiError as error:
return ApiResponse(
{"error": error.error}, status=error.status, finalize=True
)
# If it's not an ApiResponse, make it one
if not isinstance(response, ApiResponse):
response = ApiResponse(response)
# Get pydantic to coerce the output response
if self.output_type is not None:
response.data = self.output_model(value=response.data).dict()["value"]
elif isinstance(response.data, BaseModel):
response.data = response.data.dict()
response.finalize()
return response
api_view = ApiView

View File

@ -6,7 +6,6 @@ django-cache-url~=3.4.2
django-cors-headers~=3.13.0
django-debug-toolbar~=3.8.1
django-htmx~=1.13.0
django-ninja~=0.19.1
django-oauth-toolkit~=2.2.0
django-storages[google,boto3]~=1.13.1
django~=4.1

View File

@ -3,7 +3,7 @@ from django.contrib import admin as djadmin
from django.urls import include, path, re_path
from activities.views import compose, debug, explore, follows, posts, search, timelines
from api.views import api_router, oauth
from api.views import oauth
from core import views as core
from mediaproxy import views as mediaproxy
from stator import views as stator
@ -296,7 +296,7 @@ urlpatterns = [
path("actor/outbox/", activitypub.EmptyOutbox.as_view()),
path("inbox/", activitypub.Inbox.as_view(), name="shared_inbox"),
# API/Oauth
path("api/", api_router.urls),
path("api/", include("api.urls")),
path("oauth/authorize", oauth.AuthorizationView.as_view()),
path("oauth/token", oauth.TokenView.as_view()),
path("oauth/revoke", oauth.RevokeTokenView.as_view()),

View File

@ -281,7 +281,9 @@ def test_inbound_posts(
InboxMessage.objects.create(message=message)
# Run stator and ensure that made the post
print("prestat")
stator.run_single_cycle_sync()
print("poststat")
post = Post.objects.get(object_uri="https://remote.test/test-post")
assert post.content == "post version one"
assert post.published.day == 13