Check scope on API endpoints
This commit is contained in:
parent
b677bf5123
commit
1f3f28e8ff
|
@ -1,3 +1,4 @@
|
|||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
|
||||
from django.http import JsonResponse
|
||||
|
@ -13,10 +14,33 @@ def identity_required(function):
|
|||
def inner(request, *args, **kwargs):
|
||||
# They need an identity
|
||||
if not request.identity:
|
||||
return JsonResponse({"error": "identity_token_required"}, status=400)
|
||||
return JsonResponse({"error": "identity_token_required"}, status=401)
|
||||
return function(request, *args, **kwargs)
|
||||
|
||||
# This is for the API only
|
||||
inner.csrf_exempt = True
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def scope_required(scope: str, requires_identity=True):
|
||||
"""
|
||||
Asserts that the token we're using has the provided scope
|
||||
"""
|
||||
|
||||
def decorator(function: Callable):
|
||||
@wraps(function)
|
||||
def inner(request, *args, **kwargs):
|
||||
if not request.token:
|
||||
return JsonResponse({"error": "identity_token_required"}, status=401)
|
||||
# They need an identity
|
||||
if not request.identity and requires_identity:
|
||||
return JsonResponse({"error": "identity_token_required"}, status=401)
|
||||
if not request.token.has_scope(scope):
|
||||
return JsonResponse({"error": "out_of_scope_for_token"}, status=403)
|
||||
return function(request, *args, **kwargs)
|
||||
|
||||
inner.csrf_exempt = True # type:ignore
|
||||
return inner
|
||||
|
||||
return decorator
|
||||
|
|
|
@ -14,6 +14,7 @@ class ApiTokenMiddleware:
|
|||
|
||||
def __call__(self, request):
|
||||
auth_header = request.headers.get("authorization", None)
|
||||
request.token = None
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
token_value = auth_header[7:]
|
||||
try:
|
||||
|
@ -22,6 +23,7 @@ class ApiTokenMiddleware:
|
|||
return HttpResponse("Invalid Bearer token", status=400)
|
||||
request.user = token.user
|
||||
request.identity = token.identity
|
||||
request.token = token
|
||||
request.session = None
|
||||
response = self.get_response(request)
|
||||
return response
|
||||
|
|
|
@ -36,3 +36,12 @@ class Token(models.Model):
|
|||
created = models.DateTimeField(auto_now_add=True)
|
||||
updated = models.DateTimeField(auto_now=True)
|
||||
revoked = models.DateTimeField(blank=True, null=True)
|
||||
|
||||
def has_scope(self, scope: str):
|
||||
"""
|
||||
Returns if this token has the given scope.
|
||||
It's a function so we can do mapping/reduction if needed
|
||||
"""
|
||||
# TODO: Support granular scopes the other way?
|
||||
scope_prefix = scope.split(":")[0]
|
||||
return (scope in self.scopes) or (scope_prefix in self.scopes)
|
||||
|
|
|
@ -8,7 +8,7 @@ from hatchway import ApiResponse, QueryOrBody, api_view
|
|||
from activities.models import Post
|
||||
from activities.services import SearchService
|
||||
from api import schemas
|
||||
from api.decorators import identity_required
|
||||
from api.decorators import scope_required
|
||||
from api.pagination import MastodonPaginator, PaginatingApiResponse, PaginationResult
|
||||
from core.models import Config
|
||||
from users.models import Identity
|
||||
|
@ -16,13 +16,13 @@ from users.services import IdentityService
|
|||
from users.shortcuts import by_handle_or_404
|
||||
|
||||
|
||||
@identity_required
|
||||
@scope_required("read")
|
||||
@api_view.get
|
||||
def verify_credentials(request) -> schemas.Account:
|
||||
return schemas.Account.from_identity(request.identity, source=True)
|
||||
|
||||
|
||||
@identity_required
|
||||
@scope_required("write")
|
||||
@api_view.patch
|
||||
def update_credentials(
|
||||
request,
|
||||
|
@ -73,7 +73,7 @@ def update_credentials(
|
|||
return schemas.Account.from_identity(identity, source=True)
|
||||
|
||||
|
||||
@identity_required
|
||||
@scope_required("read")
|
||||
@api_view.get
|
||||
def account_relationships(request, id: list[str] | None) -> list[schemas.Relationship]:
|
||||
result = []
|
||||
|
@ -87,7 +87,7 @@ def account_relationships(request, id: list[str] | None) -> list[schemas.Relatio
|
|||
return result
|
||||
|
||||
|
||||
@identity_required
|
||||
@scope_required("read")
|
||||
@api_view.get
|
||||
def familiar_followers(
|
||||
request, id: list[str] | None
|
||||
|
@ -114,7 +114,7 @@ def familiar_followers(
|
|||
return result
|
||||
|
||||
|
||||
@identity_required
|
||||
@scope_required("read")
|
||||
@api_view.get
|
||||
def accounts_search(
|
||||
request,
|
||||
|
@ -146,8 +146,8 @@ def lookup(request: HttpRequest, acct: str) -> schemas.Account:
|
|||
return schemas.Account.from_identity(identity)
|
||||
|
||||
|
||||
@scope_required("read:accounts")
|
||||
@api_view.get
|
||||
@identity_required
|
||||
def account(request, id: str) -> schemas.Account:
|
||||
identity = get_object_or_404(
|
||||
Identity.objects.exclude(restriction=Identity.Restriction.blocked),
|
||||
|
@ -156,8 +156,8 @@ def account(request, id: str) -> schemas.Account:
|
|||
return schemas.Account.from_identity(identity)
|
||||
|
||||
|
||||
@scope_required("read:statuses")
|
||||
@api_view.get
|
||||
@identity_required
|
||||
def account_statuses(
|
||||
request: HttpRequest,
|
||||
id: str,
|
||||
|
@ -218,8 +218,8 @@ def account_statuses(
|
|||
)
|
||||
|
||||
|
||||
@scope_required("write:follows")
|
||||
@api_view.post
|
||||
@identity_required
|
||||
def account_follow(request, id: str, reblogs: bool = True) -> schemas.Relationship:
|
||||
identity = get_object_or_404(
|
||||
Identity.objects.exclude(restriction=Identity.Restriction.blocked), pk=id
|
||||
|
@ -229,8 +229,8 @@ def account_follow(request, id: str, reblogs: bool = True) -> schemas.Relationsh
|
|||
return schemas.Relationship.from_identity_pair(identity, request.identity)
|
||||
|
||||
|
||||
@scope_required("write:follows")
|
||||
@api_view.post
|
||||
@identity_required
|
||||
def account_unfollow(request, id: str) -> schemas.Relationship:
|
||||
identity = get_object_or_404(
|
||||
Identity.objects.exclude(restriction=Identity.Restriction.blocked), pk=id
|
||||
|
@ -240,8 +240,8 @@ def account_unfollow(request, id: str) -> schemas.Relationship:
|
|||
return schemas.Relationship.from_identity_pair(identity, request.identity)
|
||||
|
||||
|
||||
@scope_required("write:blocks")
|
||||
@api_view.post
|
||||
@identity_required
|
||||
def account_block(request, id: str) -> schemas.Relationship:
|
||||
identity = get_object_or_404(Identity, pk=id)
|
||||
service = IdentityService(request.identity)
|
||||
|
@ -249,8 +249,8 @@ def account_block(request, id: str) -> schemas.Relationship:
|
|||
return schemas.Relationship.from_identity_pair(identity, request.identity)
|
||||
|
||||
|
||||
@scope_required("write:blocks")
|
||||
@api_view.post
|
||||
@identity_required
|
||||
def account_unblock(request, id: str) -> schemas.Relationship:
|
||||
identity = get_object_or_404(Identity, pk=id)
|
||||
service = IdentityService(request.identity)
|
||||
|
@ -258,7 +258,7 @@ def account_unblock(request, id: str) -> schemas.Relationship:
|
|||
return schemas.Relationship.from_identity_pair(identity, request.identity)
|
||||
|
||||
|
||||
@identity_required
|
||||
@scope_required("write:blocks")
|
||||
@api_view.post
|
||||
def account_mute(
|
||||
request,
|
||||
|
@ -276,7 +276,7 @@ def account_mute(
|
|||
return schemas.Relationship.from_identity_pair(identity, request.identity)
|
||||
|
||||
|
||||
@identity_required
|
||||
@scope_required("write:blocks")
|
||||
@api_view.post
|
||||
def account_unmute(request, id: str) -> schemas.Relationship:
|
||||
identity = get_object_or_404(Identity, pk=id)
|
||||
|
|
|
@ -2,12 +2,12 @@ from django.shortcuts import get_object_or_404
|
|||
from hatchway import api_view
|
||||
|
||||
from api import schemas
|
||||
from api.decorators import identity_required
|
||||
from api.decorators import scope_required
|
||||
from users.models import Announcement
|
||||
from users.services import AnnouncementService
|
||||
|
||||
|
||||
@identity_required
|
||||
@scope_required("read:notifications")
|
||||
@api_view.get
|
||||
def announcement_list(request) -> list[schemas.Announcement]:
|
||||
return [
|
||||
|
@ -16,7 +16,7 @@ def announcement_list(request) -> list[schemas.Announcement]:
|
|||
]
|
||||
|
||||
|
||||
@identity_required
|
||||
@scope_required("write:notifications")
|
||||
@api_view.post
|
||||
def announcement_dismiss(request, pk: str):
|
||||
announcement = get_object_or_404(Announcement, pk=pk)
|
||||
|
|
|
@ -6,10 +6,10 @@ from activities.models import PostAttachment, PostAttachmentStates
|
|||
from api import schemas
|
||||
from core.files import blurhash_image, resize_image
|
||||
|
||||
from ..decorators import identity_required
|
||||
from ..decorators import scope_required
|
||||
|
||||
|
||||
@identity_required
|
||||
@scope_required("write:media")
|
||||
@api_view.post
|
||||
def upload_media(
|
||||
request,
|
||||
|
@ -47,7 +47,7 @@ def upload_media(
|
|||
return schemas.MediaAttachment.from_post_attachment(attachment)
|
||||
|
||||
|
||||
@identity_required
|
||||
@scope_required("read:media")
|
||||
@api_view.get
|
||||
def get_media(
|
||||
request,
|
||||
|
@ -59,7 +59,7 @@ def get_media(
|
|||
return schemas.MediaAttachment.from_post_attachment(attachment)
|
||||
|
||||
|
||||
@identity_required
|
||||
@scope_required("write:media")
|
||||
@api_view.put
|
||||
def update_media(
|
||||
request,
|
||||
|
|
|
@ -4,11 +4,11 @@ from hatchway import ApiResponse, api_view
|
|||
from activities.models import TimelineEvent
|
||||
from activities.services import TimelineService
|
||||
from api import schemas
|
||||
from api.decorators import identity_required
|
||||
from api.decorators import scope_required
|
||||
from api.pagination import MastodonPaginator, PaginatingApiResponse, PaginationResult
|
||||
|
||||
|
||||
@identity_required
|
||||
@scope_required("read:notifications")
|
||||
@api_view.get
|
||||
def notifications(
|
||||
request: HttpRequest,
|
||||
|
|
|
@ -3,21 +3,21 @@ from hatchway import Schema, api_view
|
|||
|
||||
from activities.models import Post, PostInteraction
|
||||
from api import schemas
|
||||
from api.decorators import identity_required
|
||||
from api.decorators import scope_required
|
||||
|
||||
|
||||
class PostVoteSchema(Schema):
|
||||
choices: list[int]
|
||||
|
||||
|
||||
@identity_required
|
||||
@scope_required("read:statuses")
|
||||
@api_view.get
|
||||
def get_poll(request, id: str) -> schemas.Poll:
|
||||
post = get_object_or_404(Post, pk=id, type=Post.Types.question)
|
||||
return schemas.Poll.from_post(post, identity=request.identity)
|
||||
|
||||
|
||||
@identity_required
|
||||
@scope_required("write:statuses")
|
||||
@api_view.post
|
||||
def vote_poll(request, id: str, details: PostVoteSchema) -> schemas.Poll:
|
||||
post = get_object_or_404(Post, pk=id, type=Post.Types.question)
|
||||
|
|
|
@ -5,10 +5,10 @@ from hatchway import Field, api_view
|
|||
from activities.models import PostInteraction
|
||||
from activities.services.search import SearchService
|
||||
from api import schemas
|
||||
from api.decorators import identity_required
|
||||
from api.decorators import scope_required
|
||||
|
||||
|
||||
@identity_required
|
||||
@scope_required("read")
|
||||
@api_view.get
|
||||
def search(
|
||||
request,
|
||||
|
|
|
@ -15,7 +15,7 @@ from activities.models import (
|
|||
)
|
||||
from activities.services import PostService
|
||||
from api import schemas
|
||||
from api.decorators import identity_required
|
||||
from api.decorators import scope_required
|
||||
from api.pagination import MastodonPaginator, PaginationResult
|
||||
from core.models import Config
|
||||
|
||||
|
@ -72,7 +72,7 @@ def post_for_id(request: HttpRequest, id: str) -> Post:
|
|||
return get_object_or_404(queryset, pk=id)
|
||||
|
||||
|
||||
@identity_required
|
||||
@scope_required("write:statuses")
|
||||
@api_view.post
|
||||
def post_status(request, details: PostStatusSchema) -> schemas.Status:
|
||||
# Check text length
|
||||
|
@ -110,7 +110,7 @@ def post_status(request, details: PostStatusSchema) -> schemas.Status:
|
|||
return schemas.Status.from_post(post, identity=request.identity)
|
||||
|
||||
|
||||
@identity_required
|
||||
@scope_required("read:statuses")
|
||||
@api_view.get
|
||||
def status(request, id: str) -> schemas.Status:
|
||||
post = post_for_id(request, id)
|
||||
|
@ -120,7 +120,7 @@ def status(request, id: str) -> schemas.Status:
|
|||
)
|
||||
|
||||
|
||||
@identity_required
|
||||
@scope_required("write:statuses")
|
||||
@api_view.put
|
||||
def edit_status(request, id: str, details: EditStatusSchema) -> schemas.Status:
|
||||
post = post_for_id(request, id)
|
||||
|
@ -138,7 +138,7 @@ def edit_status(request, id: str, details: EditStatusSchema) -> schemas.Status:
|
|||
return schemas.Status.from_post(post)
|
||||
|
||||
|
||||
@identity_required
|
||||
@scope_required("write:statuses")
|
||||
@api_view.delete
|
||||
def delete_status(request, id: str) -> schemas.Status:
|
||||
post = post_for_id(request, id)
|
||||
|
@ -148,14 +148,14 @@ def delete_status(request, id: str) -> schemas.Status:
|
|||
return schemas.Status.from_post(post, identity=request.identity)
|
||||
|
||||
|
||||
@identity_required
|
||||
@scope_required("read:statuses")
|
||||
@api_view.get
|
||||
def status_source(request, id: str) -> schemas.StatusSource:
|
||||
post = post_for_id(request, id)
|
||||
return schemas.StatusSource.from_post(post)
|
||||
|
||||
|
||||
@identity_required
|
||||
@scope_required("read:statuses")
|
||||
@api_view.get
|
||||
def status_context(request, id: str) -> schemas.Context:
|
||||
post = post_for_id(request, id)
|
||||
|
@ -180,7 +180,7 @@ def status_context(request, id: str) -> schemas.Context:
|
|||
)
|
||||
|
||||
|
||||
@identity_required
|
||||
@scope_required("write:favourites")
|
||||
@api_view.post
|
||||
def favourite_status(request, id: str) -> schemas.Status:
|
||||
post = post_for_id(request, id)
|
||||
|
@ -192,7 +192,7 @@ def favourite_status(request, id: str) -> schemas.Status:
|
|||
)
|
||||
|
||||
|
||||
@identity_required
|
||||
@scope_required("write:favourites")
|
||||
@api_view.post
|
||||
def unfavourite_status(request, id: str) -> schemas.Status:
|
||||
post = post_for_id(request, id)
|
||||
|
@ -245,7 +245,7 @@ def favourited_by(
|
|||
)
|
||||
|
||||
|
||||
@identity_required
|
||||
@scope_required("write:favourites")
|
||||
@api_view.post
|
||||
def reblog_status(request, id: str) -> schemas.Status:
|
||||
post = post_for_id(request, id)
|
||||
|
@ -257,7 +257,7 @@ def reblog_status(request, id: str) -> schemas.Status:
|
|||
)
|
||||
|
||||
|
||||
@identity_required
|
||||
@scope_required("write:favourites")
|
||||
@api_view.post
|
||||
def unreblog_status(request, id: str) -> schemas.Status:
|
||||
post = post_for_id(request, id)
|
||||
|
|
|
@ -4,12 +4,12 @@ from hatchway import ApiError, ApiResponse, api_view
|
|||
from activities.models import Post
|
||||
from activities.services import TimelineService
|
||||
from api import schemas
|
||||
from api.decorators import identity_required
|
||||
from api.decorators import scope_required
|
||||
from api.pagination import MastodonPaginator, PaginatingApiResponse, PaginationResult
|
||||
from core.models import Config
|
||||
|
||||
|
||||
@identity_required
|
||||
@scope_required("read:statuses")
|
||||
@api_view.get
|
||||
def home(
|
||||
request: HttpRequest,
|
||||
|
@ -86,7 +86,7 @@ def public(
|
|||
)
|
||||
|
||||
|
||||
@identity_required
|
||||
@scope_required("read:statuses")
|
||||
@api_view.get
|
||||
def hashtag(
|
||||
request: HttpRequest,
|
||||
|
@ -121,7 +121,7 @@ def hashtag(
|
|||
)
|
||||
|
||||
|
||||
@identity_required
|
||||
@scope_required("read:conversations")
|
||||
@api_view.get
|
||||
def conversations(
|
||||
request: HttpRequest,
|
||||
|
@ -134,7 +134,7 @@ def conversations(
|
|||
return []
|
||||
|
||||
|
||||
@identity_required
|
||||
@scope_required("read:favourites")
|
||||
@api_view.get
|
||||
def favourites(
|
||||
request: HttpRequest,
|
||||
|
|
|
@ -2,10 +2,10 @@ from django.http import HttpRequest
|
|||
from hatchway import api_view
|
||||
|
||||
from api import schemas
|
||||
from api.decorators import identity_required
|
||||
from api.decorators import scope_required
|
||||
|
||||
|
||||
@identity_required
|
||||
@scope_required("read")
|
||||
@api_view.get
|
||||
def trends_tags(
|
||||
request: HttpRequest,
|
||||
|
@ -16,7 +16,7 @@ def trends_tags(
|
|||
return []
|
||||
|
||||
|
||||
@identity_required
|
||||
@scope_required("read")
|
||||
@api_view.get
|
||||
def trends_statuses(
|
||||
request: HttpRequest,
|
||||
|
@ -27,7 +27,7 @@ def trends_statuses(
|
|||
return []
|
||||
|
||||
|
||||
@identity_required
|
||||
@scope_required("read")
|
||||
@api_view.get
|
||||
def trends_links(
|
||||
request: HttpRequest,
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_create(api_client):
|
||||
"""
|
||||
Tests creating an app
|
||||
"""
|
||||
response = api_client.post(
|
||||
"/api/v1/apps", {"client_name": "test", "redirect_uris": ""}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == "test"
|
|
@ -0,0 +1,11 @@
|
|||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_has_scope(api_token):
|
||||
"""
|
||||
Tests has_scope on the Token model
|
||||
"""
|
||||
assert api_token.has_scope("read")
|
||||
assert api_token.has_scope("read:statuses")
|
||||
assert not api_token.has_scope("destroyearth")
|
Loading…
Reference in New Issue