Check scope on API endpoints

This commit is contained in:
Andrew Godwin 2023-02-19 11:37:02 -07:00
parent b677bf5123
commit 1f3f28e8ff
14 changed files with 108 additions and 49 deletions

View File

@ -1,3 +1,4 @@
from collections.abc import Callable
from functools import wraps from functools import wraps
from django.http import JsonResponse from django.http import JsonResponse
@ -13,10 +14,33 @@ def identity_required(function):
def inner(request, *args, **kwargs): def inner(request, *args, **kwargs):
# They need an identity # They need an identity
if not request.identity: if not request.identity:
return JsonResponse({"error": "identity_token_required"}, status=400) return JsonResponse({"error": "identity_token_required"}, status=401)
return function(request, *args, **kwargs) return function(request, *args, **kwargs)
# This is for the API only # This is for the API only
inner.csrf_exempt = True inner.csrf_exempt = True
return inner return inner
def scope_required(scope: str, requires_identity=True):
"""
Asserts that the token we're using has the provided scope
"""
def decorator(function: Callable):
@wraps(function)
def inner(request, *args, **kwargs):
if not request.token:
return JsonResponse({"error": "identity_token_required"}, status=401)
# They need an identity
if not request.identity and requires_identity:
return JsonResponse({"error": "identity_token_required"}, status=401)
if not request.token.has_scope(scope):
return JsonResponse({"error": "out_of_scope_for_token"}, status=403)
return function(request, *args, **kwargs)
inner.csrf_exempt = True # type:ignore
return inner
return decorator

View File

@ -14,6 +14,7 @@ class ApiTokenMiddleware:
def __call__(self, request): def __call__(self, request):
auth_header = request.headers.get("authorization", None) auth_header = request.headers.get("authorization", None)
request.token = None
if auth_header and auth_header.startswith("Bearer "): if auth_header and auth_header.startswith("Bearer "):
token_value = auth_header[7:] token_value = auth_header[7:]
try: try:
@ -22,6 +23,7 @@ class ApiTokenMiddleware:
return HttpResponse("Invalid Bearer token", status=400) return HttpResponse("Invalid Bearer token", status=400)
request.user = token.user request.user = token.user
request.identity = token.identity request.identity = token.identity
request.token = token
request.session = None request.session = None
response = self.get_response(request) response = self.get_response(request)
return response return response

View File

@ -36,3 +36,12 @@ class Token(models.Model):
created = models.DateTimeField(auto_now_add=True) created = models.DateTimeField(auto_now_add=True)
updated = models.DateTimeField(auto_now=True) updated = models.DateTimeField(auto_now=True)
revoked = models.DateTimeField(blank=True, null=True) revoked = models.DateTimeField(blank=True, null=True)
def has_scope(self, scope: str):
"""
Returns if this token has the given scope.
It's a function so we can do mapping/reduction if needed
"""
# TODO: Support granular scopes the other way?
scope_prefix = scope.split(":")[0]
return (scope in self.scopes) or (scope_prefix in self.scopes)

View File

@ -8,7 +8,7 @@ from hatchway import ApiResponse, QueryOrBody, api_view
from activities.models import Post from activities.models import Post
from activities.services import SearchService from activities.services import SearchService
from api import schemas from api import schemas
from api.decorators import identity_required from api.decorators import scope_required
from api.pagination import MastodonPaginator, PaginatingApiResponse, PaginationResult from api.pagination import MastodonPaginator, PaginatingApiResponse, PaginationResult
from core.models import Config from core.models import Config
from users.models import Identity from users.models import Identity
@ -16,13 +16,13 @@ from users.services import IdentityService
from users.shortcuts import by_handle_or_404 from users.shortcuts import by_handle_or_404
@identity_required @scope_required("read")
@api_view.get @api_view.get
def verify_credentials(request) -> schemas.Account: def verify_credentials(request) -> schemas.Account:
return schemas.Account.from_identity(request.identity, source=True) return schemas.Account.from_identity(request.identity, source=True)
@identity_required @scope_required("write")
@api_view.patch @api_view.patch
def update_credentials( def update_credentials(
request, request,
@ -73,7 +73,7 @@ def update_credentials(
return schemas.Account.from_identity(identity, source=True) return schemas.Account.from_identity(identity, source=True)
@identity_required @scope_required("read")
@api_view.get @api_view.get
def account_relationships(request, id: list[str] | None) -> list[schemas.Relationship]: def account_relationships(request, id: list[str] | None) -> list[schemas.Relationship]:
result = [] result = []
@ -87,7 +87,7 @@ def account_relationships(request, id: list[str] | None) -> list[schemas.Relatio
return result return result
@identity_required @scope_required("read")
@api_view.get @api_view.get
def familiar_followers( def familiar_followers(
request, id: list[str] | None request, id: list[str] | None
@ -114,7 +114,7 @@ def familiar_followers(
return result return result
@identity_required @scope_required("read")
@api_view.get @api_view.get
def accounts_search( def accounts_search(
request, request,
@ -146,8 +146,8 @@ def lookup(request: HttpRequest, acct: str) -> schemas.Account:
return schemas.Account.from_identity(identity) return schemas.Account.from_identity(identity)
@scope_required("read:accounts")
@api_view.get @api_view.get
@identity_required
def account(request, id: str) -> schemas.Account: def account(request, id: str) -> schemas.Account:
identity = get_object_or_404( identity = get_object_or_404(
Identity.objects.exclude(restriction=Identity.Restriction.blocked), Identity.objects.exclude(restriction=Identity.Restriction.blocked),
@ -156,8 +156,8 @@ def account(request, id: str) -> schemas.Account:
return schemas.Account.from_identity(identity) return schemas.Account.from_identity(identity)
@scope_required("read:statuses")
@api_view.get @api_view.get
@identity_required
def account_statuses( def account_statuses(
request: HttpRequest, request: HttpRequest,
id: str, id: str,
@ -218,8 +218,8 @@ def account_statuses(
) )
@scope_required("write:follows")
@api_view.post @api_view.post
@identity_required
def account_follow(request, id: str, reblogs: bool = True) -> schemas.Relationship: def account_follow(request, id: str, reblogs: bool = True) -> schemas.Relationship:
identity = get_object_or_404( identity = get_object_or_404(
Identity.objects.exclude(restriction=Identity.Restriction.blocked), pk=id Identity.objects.exclude(restriction=Identity.Restriction.blocked), pk=id
@ -229,8 +229,8 @@ def account_follow(request, id: str, reblogs: bool = True) -> schemas.Relationsh
return schemas.Relationship.from_identity_pair(identity, request.identity) return schemas.Relationship.from_identity_pair(identity, request.identity)
@scope_required("write:follows")
@api_view.post @api_view.post
@identity_required
def account_unfollow(request, id: str) -> schemas.Relationship: def account_unfollow(request, id: str) -> schemas.Relationship:
identity = get_object_or_404( identity = get_object_or_404(
Identity.objects.exclude(restriction=Identity.Restriction.blocked), pk=id Identity.objects.exclude(restriction=Identity.Restriction.blocked), pk=id
@ -240,8 +240,8 @@ def account_unfollow(request, id: str) -> schemas.Relationship:
return schemas.Relationship.from_identity_pair(identity, request.identity) return schemas.Relationship.from_identity_pair(identity, request.identity)
@scope_required("write:blocks")
@api_view.post @api_view.post
@identity_required
def account_block(request, id: str) -> schemas.Relationship: def account_block(request, id: str) -> schemas.Relationship:
identity = get_object_or_404(Identity, pk=id) identity = get_object_or_404(Identity, pk=id)
service = IdentityService(request.identity) service = IdentityService(request.identity)
@ -249,8 +249,8 @@ def account_block(request, id: str) -> schemas.Relationship:
return schemas.Relationship.from_identity_pair(identity, request.identity) return schemas.Relationship.from_identity_pair(identity, request.identity)
@scope_required("write:blocks")
@api_view.post @api_view.post
@identity_required
def account_unblock(request, id: str) -> schemas.Relationship: def account_unblock(request, id: str) -> schemas.Relationship:
identity = get_object_or_404(Identity, pk=id) identity = get_object_or_404(Identity, pk=id)
service = IdentityService(request.identity) service = IdentityService(request.identity)
@ -258,7 +258,7 @@ def account_unblock(request, id: str) -> schemas.Relationship:
return schemas.Relationship.from_identity_pair(identity, request.identity) return schemas.Relationship.from_identity_pair(identity, request.identity)
@identity_required @scope_required("write:blocks")
@api_view.post @api_view.post
def account_mute( def account_mute(
request, request,
@ -276,7 +276,7 @@ def account_mute(
return schemas.Relationship.from_identity_pair(identity, request.identity) return schemas.Relationship.from_identity_pair(identity, request.identity)
@identity_required @scope_required("write:blocks")
@api_view.post @api_view.post
def account_unmute(request, id: str) -> schemas.Relationship: def account_unmute(request, id: str) -> schemas.Relationship:
identity = get_object_or_404(Identity, pk=id) identity = get_object_or_404(Identity, pk=id)

View File

@ -2,12 +2,12 @@ from django.shortcuts import get_object_or_404
from hatchway import api_view from hatchway import api_view
from api import schemas from api import schemas
from api.decorators import identity_required from api.decorators import scope_required
from users.models import Announcement from users.models import Announcement
from users.services import AnnouncementService from users.services import AnnouncementService
@identity_required @scope_required("read:notifications")
@api_view.get @api_view.get
def announcement_list(request) -> list[schemas.Announcement]: def announcement_list(request) -> list[schemas.Announcement]:
return [ return [
@ -16,7 +16,7 @@ def announcement_list(request) -> list[schemas.Announcement]:
] ]
@identity_required @scope_required("write:notifications")
@api_view.post @api_view.post
def announcement_dismiss(request, pk: str): def announcement_dismiss(request, pk: str):
announcement = get_object_or_404(Announcement, pk=pk) announcement = get_object_or_404(Announcement, pk=pk)

View File

@ -6,10 +6,10 @@ from activities.models import PostAttachment, PostAttachmentStates
from api import schemas from api import schemas
from core.files import blurhash_image, resize_image from core.files import blurhash_image, resize_image
from ..decorators import identity_required from ..decorators import scope_required
@identity_required @scope_required("write:media")
@api_view.post @api_view.post
def upload_media( def upload_media(
request, request,
@ -47,7 +47,7 @@ def upload_media(
return schemas.MediaAttachment.from_post_attachment(attachment) return schemas.MediaAttachment.from_post_attachment(attachment)
@identity_required @scope_required("read:media")
@api_view.get @api_view.get
def get_media( def get_media(
request, request,
@ -59,7 +59,7 @@ def get_media(
return schemas.MediaAttachment.from_post_attachment(attachment) return schemas.MediaAttachment.from_post_attachment(attachment)
@identity_required @scope_required("write:media")
@api_view.put @api_view.put
def update_media( def update_media(
request, request,

View File

@ -4,11 +4,11 @@ from hatchway import ApiResponse, api_view
from activities.models import TimelineEvent from activities.models import TimelineEvent
from activities.services import TimelineService from activities.services import TimelineService
from api import schemas from api import schemas
from api.decorators import identity_required from api.decorators import scope_required
from api.pagination import MastodonPaginator, PaginatingApiResponse, PaginationResult from api.pagination import MastodonPaginator, PaginatingApiResponse, PaginationResult
@identity_required @scope_required("read:notifications")
@api_view.get @api_view.get
def notifications( def notifications(
request: HttpRequest, request: HttpRequest,

View File

@ -3,21 +3,21 @@ from hatchway import Schema, api_view
from activities.models import Post, PostInteraction from activities.models import Post, PostInteraction
from api import schemas from api import schemas
from api.decorators import identity_required from api.decorators import scope_required
class PostVoteSchema(Schema): class PostVoteSchema(Schema):
choices: list[int] choices: list[int]
@identity_required @scope_required("read:statuses")
@api_view.get @api_view.get
def get_poll(request, id: str) -> schemas.Poll: def get_poll(request, id: str) -> schemas.Poll:
post = get_object_or_404(Post, pk=id, type=Post.Types.question) post = get_object_or_404(Post, pk=id, type=Post.Types.question)
return schemas.Poll.from_post(post, identity=request.identity) return schemas.Poll.from_post(post, identity=request.identity)
@identity_required @scope_required("write:statuses")
@api_view.post @api_view.post
def vote_poll(request, id: str, details: PostVoteSchema) -> schemas.Poll: def vote_poll(request, id: str, details: PostVoteSchema) -> schemas.Poll:
post = get_object_or_404(Post, pk=id, type=Post.Types.question) post = get_object_or_404(Post, pk=id, type=Post.Types.question)

View File

@ -5,10 +5,10 @@ from hatchway import Field, api_view
from activities.models import PostInteraction from activities.models import PostInteraction
from activities.services.search import SearchService from activities.services.search import SearchService
from api import schemas from api import schemas
from api.decorators import identity_required from api.decorators import scope_required
@identity_required @scope_required("read")
@api_view.get @api_view.get
def search( def search(
request, request,

View File

@ -15,7 +15,7 @@ from activities.models import (
) )
from activities.services import PostService from activities.services import PostService
from api import schemas from api import schemas
from api.decorators import identity_required from api.decorators import scope_required
from api.pagination import MastodonPaginator, PaginationResult from api.pagination import MastodonPaginator, PaginationResult
from core.models import Config from core.models import Config
@ -72,7 +72,7 @@ def post_for_id(request: HttpRequest, id: str) -> Post:
return get_object_or_404(queryset, pk=id) return get_object_or_404(queryset, pk=id)
@identity_required @scope_required("write:statuses")
@api_view.post @api_view.post
def post_status(request, details: PostStatusSchema) -> schemas.Status: def post_status(request, details: PostStatusSchema) -> schemas.Status:
# Check text length # Check text length
@ -110,7 +110,7 @@ def post_status(request, details: PostStatusSchema) -> schemas.Status:
return schemas.Status.from_post(post, identity=request.identity) return schemas.Status.from_post(post, identity=request.identity)
@identity_required @scope_required("read:statuses")
@api_view.get @api_view.get
def status(request, id: str) -> schemas.Status: def status(request, id: str) -> schemas.Status:
post = post_for_id(request, id) post = post_for_id(request, id)
@ -120,7 +120,7 @@ def status(request, id: str) -> schemas.Status:
) )
@identity_required @scope_required("write:statuses")
@api_view.put @api_view.put
def edit_status(request, id: str, details: EditStatusSchema) -> schemas.Status: def edit_status(request, id: str, details: EditStatusSchema) -> schemas.Status:
post = post_for_id(request, id) post = post_for_id(request, id)
@ -138,7 +138,7 @@ def edit_status(request, id: str, details: EditStatusSchema) -> schemas.Status:
return schemas.Status.from_post(post) return schemas.Status.from_post(post)
@identity_required @scope_required("write:statuses")
@api_view.delete @api_view.delete
def delete_status(request, id: str) -> schemas.Status: def delete_status(request, id: str) -> schemas.Status:
post = post_for_id(request, id) post = post_for_id(request, id)
@ -148,14 +148,14 @@ def delete_status(request, id: str) -> schemas.Status:
return schemas.Status.from_post(post, identity=request.identity) return schemas.Status.from_post(post, identity=request.identity)
@identity_required @scope_required("read:statuses")
@api_view.get @api_view.get
def status_source(request, id: str) -> schemas.StatusSource: def status_source(request, id: str) -> schemas.StatusSource:
post = post_for_id(request, id) post = post_for_id(request, id)
return schemas.StatusSource.from_post(post) return schemas.StatusSource.from_post(post)
@identity_required @scope_required("read:statuses")
@api_view.get @api_view.get
def status_context(request, id: str) -> schemas.Context: def status_context(request, id: str) -> schemas.Context:
post = post_for_id(request, id) post = post_for_id(request, id)
@ -180,7 +180,7 @@ def status_context(request, id: str) -> schemas.Context:
) )
@identity_required @scope_required("write:favourites")
@api_view.post @api_view.post
def favourite_status(request, id: str) -> schemas.Status: def favourite_status(request, id: str) -> schemas.Status:
post = post_for_id(request, id) post = post_for_id(request, id)
@ -192,7 +192,7 @@ def favourite_status(request, id: str) -> schemas.Status:
) )
@identity_required @scope_required("write:favourites")
@api_view.post @api_view.post
def unfavourite_status(request, id: str) -> schemas.Status: def unfavourite_status(request, id: str) -> schemas.Status:
post = post_for_id(request, id) post = post_for_id(request, id)
@ -245,7 +245,7 @@ def favourited_by(
) )
@identity_required @scope_required("write:favourites")
@api_view.post @api_view.post
def reblog_status(request, id: str) -> schemas.Status: def reblog_status(request, id: str) -> schemas.Status:
post = post_for_id(request, id) post = post_for_id(request, id)
@ -257,7 +257,7 @@ def reblog_status(request, id: str) -> schemas.Status:
) )
@identity_required @scope_required("write:favourites")
@api_view.post @api_view.post
def unreblog_status(request, id: str) -> schemas.Status: def unreblog_status(request, id: str) -> schemas.Status:
post = post_for_id(request, id) post = post_for_id(request, id)

View File

@ -4,12 +4,12 @@ from hatchway import ApiError, ApiResponse, api_view
from activities.models import Post from activities.models import Post
from activities.services import TimelineService from activities.services import TimelineService
from api import schemas from api import schemas
from api.decorators import identity_required from api.decorators import scope_required
from api.pagination import MastodonPaginator, PaginatingApiResponse, PaginationResult from api.pagination import MastodonPaginator, PaginatingApiResponse, PaginationResult
from core.models import Config from core.models import Config
@identity_required @scope_required("read:statuses")
@api_view.get @api_view.get
def home( def home(
request: HttpRequest, request: HttpRequest,
@ -86,7 +86,7 @@ def public(
) )
@identity_required @scope_required("read:statuses")
@api_view.get @api_view.get
def hashtag( def hashtag(
request: HttpRequest, request: HttpRequest,
@ -121,7 +121,7 @@ def hashtag(
) )
@identity_required @scope_required("read:conversations")
@api_view.get @api_view.get
def conversations( def conversations(
request: HttpRequest, request: HttpRequest,
@ -134,7 +134,7 @@ def conversations(
return [] return []
@identity_required @scope_required("read:favourites")
@api_view.get @api_view.get
def favourites( def favourites(
request: HttpRequest, request: HttpRequest,

View File

@ -2,10 +2,10 @@ from django.http import HttpRequest
from hatchway import api_view from hatchway import api_view
from api import schemas from api import schemas
from api.decorators import identity_required from api.decorators import scope_required
@identity_required @scope_required("read")
@api_view.get @api_view.get
def trends_tags( def trends_tags(
request: HttpRequest, request: HttpRequest,
@ -16,7 +16,7 @@ def trends_tags(
return [] return []
@identity_required @scope_required("read")
@api_view.get @api_view.get
def trends_statuses( def trends_statuses(
request: HttpRequest, request: HttpRequest,
@ -27,7 +27,7 @@ def trends_statuses(
return [] return []
@identity_required @scope_required("read")
@api_view.get @api_view.get
def trends_links( def trends_links(
request: HttpRequest, request: HttpRequest,

13
tests/api/test_apps.py Normal file
View File

@ -0,0 +1,13 @@
import pytest
@pytest.mark.django_db
def test_create(api_client):
"""
Tests creating an app
"""
response = api_client.post(
"/api/v1/apps", {"client_name": "test", "redirect_uris": ""}
)
assert response.status_code == 200
assert response.json()["name"] == "test"

11
tests/api/test_tokens.py Normal file
View File

@ -0,0 +1,11 @@
import pytest
@pytest.mark.django_db
def test_has_scope(api_token):
"""
Tests has_scope on the Token model
"""
assert api_token.has_scope("read")
assert api_token.has_scope("read:statuses")
assert not api_token.has_scope("destroyearth")