Remove hatchway's internal copy
This commit is contained in:
parent
43ecf19cd1
commit
04ad97c69b
|
@ -1,5 +0,0 @@
|
|||
from .http import ApiError, ApiResponse # noqa
|
||||
from .schema import Field, Schema # noqa
|
||||
from .types import Body, BodyDirect, Path, Query, QueryOrBody # noqa
|
||||
from .urls import methods # noqa
|
||||
from .view import api_view # noqa
|
|
@ -1,10 +0,0 @@
|
|||
import enum
|
||||
|
||||
|
||||
class InputSource(str, enum.Enum):
|
||||
path = "path"
|
||||
query = "query"
|
||||
body = "body"
|
||||
body_direct = "body_direct"
|
||||
query_and_body_direct = "query_and_body_direct"
|
||||
file = "file"
|
|
@ -1,47 +0,0 @@
|
|||
import json
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from django.core.serializers.json import DjangoJSONEncoder
|
||||
from django.http import HttpResponse
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ApiResponse(Generic[T], HttpResponse):
|
||||
"""
|
||||
A way to return extra information with a response if you want
|
||||
headers, etc.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: T,
|
||||
encoder=DjangoJSONEncoder,
|
||||
json_dumps_params: dict[str, object] | None = None,
|
||||
finalize: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
self.data = data
|
||||
self.encoder = encoder
|
||||
self.json_dumps_params = json_dumps_params or {}
|
||||
kwargs.setdefault("content_type", "application/json")
|
||||
super().__init__(content=b"(unfinalised)", **kwargs)
|
||||
if finalize:
|
||||
self.finalize()
|
||||
|
||||
def finalize(self):
|
||||
"""
|
||||
Converts whatever our current data is into HttpResponse content
|
||||
"""
|
||||
# TODO: Automatically call this when we're asked to write output?
|
||||
self.content = json.dumps(self.data, cls=self.encoder, **self.json_dumps_params)
|
||||
|
||||
|
||||
class ApiError(BaseException):
|
||||
"""
|
||||
A handy way to raise an error with JSONable contents
|
||||
"""
|
||||
|
||||
def __init__(self, status: int, error: str):
|
||||
self.status = status
|
||||
self.error = error
|
|
@ -1,52 +0,0 @@
|
|||
from typing import Any
|
||||
|
||||
from django.db.models import Manager, QuerySet
|
||||
from django.db.models.fields.files import FieldFile
|
||||
from django.template import Variable, VariableDoesNotExist
|
||||
from pydantic.fields import Field # noqa
|
||||
from pydantic.main import BaseModel
|
||||
from pydantic.utils import GetterDict
|
||||
|
||||
|
||||
class DjangoGetterDict(GetterDict):
|
||||
def __init__(self, obj: Any):
|
||||
self._obj = obj
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
try:
|
||||
item = getattr(self._obj, key)
|
||||
except AttributeError:
|
||||
try:
|
||||
item = Variable(key).resolve(self._obj)
|
||||
except VariableDoesNotExist as e:
|
||||
raise KeyError(key) from e
|
||||
return self._convert_result(item)
|
||||
|
||||
def get(self, key: Any, default: Any = None) -> Any:
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def _convert_result(self, result: Any) -> Any:
|
||||
if isinstance(result, Manager):
|
||||
return list(result.all())
|
||||
|
||||
elif isinstance(result, getattr(QuerySet, "__origin__", QuerySet)):
|
||||
return list(result)
|
||||
|
||||
if callable(result):
|
||||
return result()
|
||||
|
||||
elif isinstance(result, FieldFile):
|
||||
if not result:
|
||||
return None
|
||||
return result.url
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class Schema(BaseModel):
|
||||
class Config:
|
||||
orm_mode = True
|
||||
getter_dict = DjangoGetterDict
|
|
@ -1,63 +0,0 @@
|
|||
from typing import Literal, Optional, Union
|
||||
|
||||
from django.core.files import File
|
||||
|
||||
from hatchway.http import ApiResponse
|
||||
from hatchway.types import (
|
||||
Query,
|
||||
QueryType,
|
||||
acceptable_input,
|
||||
extract_output_type,
|
||||
extract_signifier,
|
||||
is_optional,
|
||||
)
|
||||
|
||||
|
||||
def test_is_optional():
|
||||
|
||||
assert is_optional(Optional[int]) == (True, int)
|
||||
assert is_optional(Union[int, None]) == (True, int)
|
||||
assert is_optional(Union[None, int]) == (True, int)
|
||||
assert is_optional(int | None) == (True, int)
|
||||
assert is_optional(None | int) == (True, int)
|
||||
assert is_optional(int) == (False, int)
|
||||
assert is_optional(Query[int]) == (False, Query[int])
|
||||
|
||||
|
||||
def test_extract_signifier():
|
||||
|
||||
assert extract_signifier(int) == (None, int)
|
||||
assert extract_signifier(Query[int]) == (QueryType, int)
|
||||
assert extract_signifier(Query[Optional[int]]) == ( # type:ignore
|
||||
QueryType,
|
||||
Optional[int],
|
||||
)
|
||||
assert extract_signifier(Query[int | None]) == ( # type:ignore
|
||||
QueryType,
|
||||
Optional[int],
|
||||
)
|
||||
assert extract_signifier(Optional[Query[int]]) == (QueryType, Optional[int])
|
||||
|
||||
|
||||
def test_extract_output_type():
|
||||
|
||||
assert extract_output_type(int) == int
|
||||
assert extract_output_type(ApiResponse[int]) == int
|
||||
assert extract_output_type(ApiResponse[int | str]) == int | str
|
||||
|
||||
|
||||
def test_acceptable_input():
|
||||
|
||||
assert acceptable_input(str) is True
|
||||
assert acceptable_input(int) is True
|
||||
assert acceptable_input(Query[int]) is True
|
||||
assert acceptable_input(Optional[int]) is True
|
||||
assert acceptable_input(int | None) is True
|
||||
assert acceptable_input(int | str | None) is True
|
||||
assert acceptable_input(Query[int | None]) is True # type: ignore
|
||||
assert acceptable_input(File) is True
|
||||
assert acceptable_input(list[str]) is True
|
||||
assert acceptable_input(dict[str, int]) is True
|
||||
assert acceptable_input(Literal["a", "b"]) is True
|
||||
assert acceptable_input(frozenset) is False
|
||||
assert acceptable_input(dict[str, frozenset]) is False
|
|
@ -1,244 +0,0 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
from django.core import files
|
||||
from django.core.files.uploadedfile import SimpleUploadedFile
|
||||
from django.http import QueryDict
|
||||
from django.test import RequestFactory
|
||||
from django.test.client import MULTIPART_CONTENT
|
||||
from pydantic import BaseModel
|
||||
|
||||
from hatchway import ApiError, Body, QueryOrBody, api_view
|
||||
from hatchway.view import ApiView
|
||||
|
||||
|
||||
def test_basic_view():
|
||||
"""
|
||||
Tests that a view with simple types works correctly
|
||||
"""
|
||||
|
||||
@api_view
|
||||
def test_view(
|
||||
request,
|
||||
a: int,
|
||||
b: QueryOrBody[int | None] = None,
|
||||
c: str = "x",
|
||||
) -> str:
|
||||
if b is None:
|
||||
return c * a
|
||||
else:
|
||||
return c * (a - b)
|
||||
|
||||
# Call it with a few different patterns to verify it's type coercing right
|
||||
factory = RequestFactory()
|
||||
|
||||
# Implicit query param
|
||||
response = test_view(factory.get("/test/?a=4"))
|
||||
assert json.loads(response.content) == "xxxx"
|
||||
|
||||
# QueryOrBody pulling from query
|
||||
response = test_view(factory.get("/test/?a=4&b=2"))
|
||||
assert json.loads(response.content) == "xx"
|
||||
|
||||
# QueryOrBody pulling from formdata body
|
||||
response = test_view(factory.post("/test/?a=4", {"b": "3"}))
|
||||
assert json.loads(response.content) == "x"
|
||||
|
||||
# QueryOrBody pulling from JSON body
|
||||
response = test_view(
|
||||
factory.post(
|
||||
"/test/?a=4", json.dumps({"b": 3}), content_type="application/json"
|
||||
)
|
||||
)
|
||||
assert json.loads(response.content) == "x"
|
||||
|
||||
# Implicit Query not pulling from body
|
||||
with pytest.raises(TypeError):
|
||||
test_view(factory.post("/test/", {"a": 4, "b": 3}))
|
||||
|
||||
|
||||
def test_body_direct():
|
||||
"""
|
||||
Tests that a Pydantic model with BodyDirect gets its fields from the top level
|
||||
"""
|
||||
|
||||
class TestModel(BaseModel):
|
||||
number: int
|
||||
name: str
|
||||
|
||||
@api_view
|
||||
def test_view(request, data: TestModel) -> int:
|
||||
return data.number
|
||||
|
||||
factory = RequestFactory()
|
||||
|
||||
# formdata version
|
||||
response = test_view(factory.post("/test/", {"number": "123", "name": "Andrew"}))
|
||||
assert json.loads(response.content) == 123
|
||||
|
||||
# JSON body version
|
||||
response = test_view(
|
||||
factory.post(
|
||||
"/test/",
|
||||
json.dumps({"number": "123", "name": "Andrew"}),
|
||||
content_type="application/json",
|
||||
)
|
||||
)
|
||||
assert json.loads(response.content) == 123
|
||||
|
||||
|
||||
def test_list_response():
|
||||
"""
|
||||
Tests that a view with a list response type works correctly with both
|
||||
dicts and pydantic model instances.
|
||||
"""
|
||||
|
||||
class TestModel(BaseModel):
|
||||
number: int
|
||||
name: str
|
||||
|
||||
@api_view
|
||||
def test_view_dict(request) -> list[TestModel]:
|
||||
return [
|
||||
{"name": "Andrew", "number": 1}, # type:ignore
|
||||
{"name": "Alice", "number": 0}, # type:ignore
|
||||
]
|
||||
|
||||
@api_view
|
||||
def test_view_model(request) -> list[TestModel]:
|
||||
return [TestModel(name="Andrew", number=1), TestModel(name="Alice", number=0)]
|
||||
|
||||
response = test_view_dict(RequestFactory().get("/test/"))
|
||||
assert json.loads(response.content) == [
|
||||
{"name": "Andrew", "number": 1},
|
||||
{"name": "Alice", "number": 0},
|
||||
]
|
||||
|
||||
response = test_view_model(RequestFactory().get("/test/"))
|
||||
assert json.loads(response.content) == [
|
||||
{"name": "Andrew", "number": 1},
|
||||
{"name": "Alice", "number": 0},
|
||||
]
|
||||
|
||||
|
||||
def test_patch_body():
|
||||
"""
|
||||
Tests that PATCH also gets its body parsed
|
||||
"""
|
||||
|
||||
@api_view.patch
|
||||
def test_view(request, a: Body[int]):
|
||||
return a
|
||||
|
||||
factory = RequestFactory()
|
||||
response = test_view(
|
||||
factory.patch(
|
||||
"/test/",
|
||||
content_type=MULTIPART_CONTENT,
|
||||
data=factory._encode_data({"a": "42"}, MULTIPART_CONTENT),
|
||||
)
|
||||
)
|
||||
assert json.loads(response.content) == 42
|
||||
|
||||
|
||||
def test_file_body():
|
||||
"""
|
||||
Tests that file uploads work right
|
||||
"""
|
||||
|
||||
@api_view.post
|
||||
def test_view(request, a: Body[int], b: files.File) -> str:
|
||||
return str(a) + b.read().decode("ascii")
|
||||
|
||||
factory = RequestFactory()
|
||||
uploaded_file = SimpleUploadedFile(
|
||||
"file.txt",
|
||||
b"MY FILE IS AMAZING",
|
||||
content_type="text/plain",
|
||||
)
|
||||
response = test_view(
|
||||
factory.post(
|
||||
"/test/",
|
||||
data={"a": 42, "b": uploaded_file},
|
||||
)
|
||||
)
|
||||
assert json.loads(response.content) == "42MY FILE IS AMAZING"
|
||||
|
||||
|
||||
def test_no_response():
|
||||
"""
|
||||
Tests that a view with no response type returns the contents verbatim
|
||||
"""
|
||||
|
||||
@api_view
|
||||
def test_view(request):
|
||||
return [1, "woooooo"]
|
||||
|
||||
response = test_view(RequestFactory().get("/test/"))
|
||||
assert json.loads(response.content) == [1, "woooooo"]
|
||||
|
||||
|
||||
def test_wrong_method():
|
||||
"""
|
||||
Tests that a view with a method limiter works
|
||||
"""
|
||||
|
||||
@api_view.get
|
||||
def test_view(request):
|
||||
return "yay"
|
||||
|
||||
response = test_view(RequestFactory().get("/test/"))
|
||||
assert json.loads(response.content) == "yay"
|
||||
|
||||
response = test_view(RequestFactory().post("/test/"))
|
||||
assert response.status_code == 405
|
||||
|
||||
|
||||
def test_api_error():
|
||||
"""
|
||||
Tests that ApiError propagates right
|
||||
"""
|
||||
|
||||
@api_view.get
|
||||
def test_view(request):
|
||||
raise ApiError(401, "you did a bad thing")
|
||||
|
||||
response = test_view(RequestFactory().get("/test/"))
|
||||
assert json.loads(response.content) == {"error": "you did a bad thing"}
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
def test_unusable_type():
|
||||
"""
|
||||
Tests that you get a nice error when you use a type on an input that
|
||||
Pydantic doesn't understand.
|
||||
"""
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
|
||||
@api_view.get
|
||||
def test_view(request, a: RequestFactory):
|
||||
pass
|
||||
|
||||
|
||||
def test_get_values():
|
||||
"""
|
||||
Tests that ApiView.get_values correctly handles lists
|
||||
"""
|
||||
|
||||
assert ApiView.get_values({"a": 2, "b": [3, 4]}) == {"a": 2, "b": [3, 4]}
|
||||
assert ApiView.get_values({"a": 2, "b[]": [3, 4]}) == {"a": 2, "b": [3, 4]}
|
||||
assert ApiView.get_values(QueryDict("a=2&b=3&b=4")) == {"a": "2", "b": ["3", "4"]}
|
||||
assert ApiView.get_values(QueryDict("a=2&b[]=3&b[]=4")) == {
|
||||
"a": "2",
|
||||
"b": ["3", "4"],
|
||||
}
|
||||
assert ApiView.get_values(QueryDict("a=2&b=3")) == {"a": "2", "b": "3"}
|
||||
assert ApiView.get_values(QueryDict("a=2&b[]=3")) == {"a": "2", "b": ["3"]}
|
||||
assert ApiView.get_values(QueryDict("a[b]=1")) == {"a": {"b": "1"}}
|
||||
assert ApiView.get_values(QueryDict("a[b]=1&a[c]=2")) == {"a": {"b": "1", "c": "2"}}
|
||||
assert ApiView.get_values(QueryDict("a[b][c]=1")) == {"a": {"b": {"c": "1"}}}
|
||||
assert ApiView.get_values(QueryDict("a[b][c][]=1")) == {"a": {"b": {"c": ["1"]}}}
|
||||
assert ApiView.get_values(QueryDict("a[b][]=1&a[b][]=2")) == {
|
||||
"a": {"b": ["1", "2"]}
|
||||
}
|
|
@ -1,145 +0,0 @@
|
|||
from types import NoneType, UnionType
|
||||
from typing import ( # type: ignore[attr-defined]
|
||||
Annotated,
|
||||
Any,
|
||||
Literal,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
_AnnotatedAlias,
|
||||
_GenericAlias,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
from django.core import files
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .http import ApiResponse
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class PathType:
|
||||
"""
|
||||
An input pulled from the path (url resolver kwargs)
|
||||
"""
|
||||
|
||||
|
||||
class QueryType:
|
||||
"""
|
||||
An input pulled from the query parameters (request.GET)
|
||||
"""
|
||||
|
||||
|
||||
class BodyType:
|
||||
"""
|
||||
An input pulled from the POST body (request.POST or a JSON body)
|
||||
"""
|
||||
|
||||
|
||||
class FileType:
|
||||
"""
|
||||
An input pulled from the POST body (request.POST or a JSON body)
|
||||
"""
|
||||
|
||||
|
||||
class BodyDirectType:
|
||||
"""
|
||||
A Pydantic model whose keys are all looked for in the top-level
|
||||
POST data, rather than in a dict under a key named after the input.
|
||||
"""
|
||||
|
||||
|
||||
class QueryOrBodyType:
|
||||
"""
|
||||
An input pulled from either query parameters or post data.
|
||||
"""
|
||||
|
||||
|
||||
Path = Annotated[T, PathType]
|
||||
Query = Annotated[T, QueryType]
|
||||
Body = Annotated[T, BodyType]
|
||||
File = Annotated[T, FileType]
|
||||
BodyDirect = Annotated[T, BodyDirectType]
|
||||
QueryOrBody = Annotated[T, QueryOrBodyType]
|
||||
|
||||
|
||||
def is_optional(annotation) -> tuple[bool, Any]:
|
||||
"""
|
||||
If an annotation is Optional or | None, returns (True, internal type).
|
||||
Returns (False, annotation) otherwise.
|
||||
"""
|
||||
if (isinstance(annotation, _GenericAlias) and annotation.__origin__ is Union) or (
|
||||
isinstance(annotation, UnionType)
|
||||
):
|
||||
args = get_args(annotation)
|
||||
if len(args) > 2:
|
||||
return False, annotation
|
||||
if args[0] is NoneType:
|
||||
return True, args[1]
|
||||
if args[1] is NoneType:
|
||||
return True, args[0]
|
||||
return False, annotation
|
||||
return False, annotation
|
||||
|
||||
|
||||
def extract_signifier(annotation) -> tuple[Any, Any]:
|
||||
"""
|
||||
Given a type annotation, looks to see if it can find a input source
|
||||
signifier (Path, Query, etc.)
|
||||
|
||||
If it can, returns (signifier, annotation_without_signifier)
|
||||
If not, returns (None, annotation)
|
||||
"""
|
||||
our_generics = {
|
||||
PathType,
|
||||
QueryType,
|
||||
BodyType,
|
||||
FileType,
|
||||
BodyDirectType,
|
||||
QueryOrBodyType,
|
||||
}
|
||||
# Remove any optional-style wrapper
|
||||
optional, internal_annotation = is_optional(annotation)
|
||||
# Is it an annotation?
|
||||
if isinstance(internal_annotation, _AnnotatedAlias):
|
||||
args = get_args(internal_annotation)
|
||||
for arg in args[1:]:
|
||||
if arg in our_generics:
|
||||
if optional:
|
||||
return (arg, Optional[args[0]])
|
||||
else:
|
||||
return (arg, args[0])
|
||||
return None, annotation
|
||||
|
||||
|
||||
def extract_output_type(annotation):
|
||||
"""
|
||||
Returns the right response type for a function
|
||||
"""
|
||||
# If the type is ApiResponse, we want to pull out its inside
|
||||
if isinstance(annotation, _GenericAlias):
|
||||
if get_origin(annotation) == ApiResponse:
|
||||
return get_args(annotation)[0]
|
||||
return annotation
|
||||
|
||||
|
||||
def acceptable_input(annotation) -> bool:
|
||||
"""
|
||||
Returns if this annotation is something we think we can accept as input
|
||||
"""
|
||||
_, inner_type = extract_signifier(annotation)
|
||||
try:
|
||||
if issubclass(inner_type, BaseModel):
|
||||
return True
|
||||
except TypeError:
|
||||
pass
|
||||
if inner_type in [str, int, list, tuple, bool, Any, files.File, type(None)]:
|
||||
return True
|
||||
origin = get_origin(inner_type)
|
||||
if origin == Literal:
|
||||
return True
|
||||
if origin in [Union, UnionType, dict, list, tuple]:
|
||||
return all(acceptable_input(a) for a in get_args(inner_type))
|
||||
return False
|
|
@ -1,32 +0,0 @@
|
|||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from django.http import HttpResponseNotAllowed
|
||||
|
||||
|
||||
class Methods:
|
||||
"""
|
||||
Allows easy multi-method dispatch to different functions
|
||||
"""
|
||||
|
||||
csrf_exempt = True
|
||||
|
||||
def __init__(self, **callables: Callable):
|
||||
self.callables = {
|
||||
method.lower(): callable for method, callable in callables.items()
|
||||
}
|
||||
unknown_methods = set(self.callables.keys()).difference(
|
||||
{"get", "post", "patch", "put", "delete"}
|
||||
)
|
||||
if unknown_methods:
|
||||
raise ValueError(f"Cannot route methods: {unknown_methods}")
|
||||
|
||||
def __call__(self, request, *args, **kwargs) -> Any:
|
||||
method = request.method.lower()
|
||||
if method in self.callables:
|
||||
return self.callables[method](request, *args, **kwargs)
|
||||
else:
|
||||
return HttpResponseNotAllowed(self.callables.keys())
|
||||
|
||||
|
||||
methods = Methods
|
297
hatchway/view.py
297
hatchway/view.py
|
@ -1,297 +0,0 @@
|
|||
import json
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional, get_type_hints
|
||||
|
||||
from django.core import files
|
||||
from django.http import HttpRequest, HttpResponseNotAllowed, QueryDict
|
||||
from django.http.multipartparser import MultiPartParser
|
||||
from pydantic import BaseModel, create_model
|
||||
|
||||
from .constants import InputSource
|
||||
from .http import ApiError, ApiResponse
|
||||
from .types import (
|
||||
BodyDirectType,
|
||||
BodyType,
|
||||
FileType,
|
||||
PathType,
|
||||
QueryOrBodyType,
|
||||
QueryType,
|
||||
acceptable_input,
|
||||
extract_output_type,
|
||||
extract_signifier,
|
||||
is_optional,
|
||||
)
|
||||
|
||||
|
||||
class ApiView:
|
||||
"""
|
||||
A view 'wrapper' object that replaces the API view for anything further
|
||||
up the stack.
|
||||
|
||||
Unlike Django's class-based views, we don't need an as_view pattern
|
||||
as we are careful never to write anything per-request to self.
|
||||
"""
|
||||
|
||||
csrf_exempt = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
view: Callable,
|
||||
input_types: dict[str, Any] | None = None,
|
||||
output_type: Any = None,
|
||||
implicit_lists: bool = True,
|
||||
method: str | None = None,
|
||||
):
|
||||
self.view = view
|
||||
self.implicit_lists = implicit_lists
|
||||
self.view_name = getattr(view, "__name__", "unknown_view")
|
||||
self.method = method
|
||||
# Extract input/output types from view annotations if we need to
|
||||
self.input_types = input_types
|
||||
if self.input_types is None:
|
||||
self.input_types = get_type_hints(view, include_extras=True)
|
||||
if "return" in self.input_types:
|
||||
del self.input_types["return"]
|
||||
self.output_type = output_type
|
||||
if self.output_type is None:
|
||||
try:
|
||||
self.output_type = extract_output_type(
|
||||
get_type_hints(view, include_extras=True)["return"]
|
||||
)
|
||||
except KeyError:
|
||||
self.output_type = None
|
||||
self.compile()
|
||||
|
||||
@classmethod
|
||||
def get(cls, view: Callable):
|
||||
return cls(view=view, method="get")
|
||||
|
||||
@classmethod
|
||||
def post(cls, view: Callable):
|
||||
return cls(view=view, method="post")
|
||||
|
||||
@classmethod
|
||||
def put(cls, view: Callable):
|
||||
return cls(view=view, method="put")
|
||||
|
||||
@classmethod
|
||||
def patch(cls, view: Callable):
|
||||
return cls(view=view, method="patch")
|
||||
|
||||
@classmethod
|
||||
def delete(cls, view: Callable):
|
||||
return cls(view=view, method="delete")
|
||||
|
||||
@classmethod
|
||||
def sources_for_input(cls, input_type) -> tuple[list[InputSource], Any]:
|
||||
"""
|
||||
Given a type that can appear as a request parameter type, returns
|
||||
what sources it can come from, and what its type is as understood
|
||||
by Pydantic.
|
||||
"""
|
||||
signifier, input_type = extract_signifier(input_type)
|
||||
if signifier is QueryType:
|
||||
return ([InputSource.query], input_type)
|
||||
elif signifier is BodyType:
|
||||
return ([InputSource.body], input_type)
|
||||
elif signifier is BodyDirectType:
|
||||
if not issubclass(input_type, BaseModel):
|
||||
raise ValueError(
|
||||
"You cannot use BodyDirect on something that is not a Pydantic model"
|
||||
)
|
||||
return ([InputSource.body_direct], input_type)
|
||||
elif signifier is PathType:
|
||||
return ([InputSource.path], input_type)
|
||||
elif (
|
||||
signifier is FileType
|
||||
or input_type is files.File
|
||||
or is_optional(input_type)[1] is files.File
|
||||
):
|
||||
return ([InputSource.file], input_type)
|
||||
elif signifier is QueryOrBodyType:
|
||||
return ([InputSource.query, InputSource.body], input_type)
|
||||
# Is it a Pydantic model, which means it's implicitly body?
|
||||
elif isinstance(input_type, type) and issubclass(input_type, BaseModel):
|
||||
return ([InputSource.body], input_type)
|
||||
# Otherwise, we look in the path first and then the query
|
||||
else:
|
||||
return ([InputSource.path, InputSource.query], input_type)
|
||||
|
||||
@classmethod
|
||||
def get_values(cls, data, use_square_brackets=True) -> dict[str, Any]:
|
||||
"""
|
||||
Given a QueryDict or normal dict, returns data taking into account
|
||||
lists made by repeated values or by suffixing names with [].
|
||||
"""
|
||||
result: dict[str, Any] = {}
|
||||
for key, value in data.items():
|
||||
# If it's a query dict with multiple values, make it a list
|
||||
if isinstance(data, QueryDict):
|
||||
values = data.getlist(key)
|
||||
if len(values) > 1:
|
||||
value = values
|
||||
# If it is in dict-ish/list-ish syntax, adhere to that
|
||||
# TODO: Make this better handle badly formed keys
|
||||
if "[" in key and use_square_brackets:
|
||||
parts = key.split("[")
|
||||
target = result
|
||||
last_key = parts[0]
|
||||
for part in parts[1:]:
|
||||
part = part.rstrip("]")
|
||||
if not part:
|
||||
target = target.setdefault(last_key, [])
|
||||
else:
|
||||
target = target.setdefault(last_key, {})
|
||||
last_key = part
|
||||
if isinstance(target, list):
|
||||
if isinstance(value, list):
|
||||
target.extend(value)
|
||||
else:
|
||||
target.append(value)
|
||||
else:
|
||||
target[last_key] = value
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
def compile(self):
|
||||
self.sources: dict[str, list[InputSource]] = {}
|
||||
amount_from_body = 0
|
||||
pydantic_model_dict = {}
|
||||
self.input_files = set()
|
||||
last_body_type = None
|
||||
# For each input item, work out where to pull it from
|
||||
for name, input_type in self.input_types.items():
|
||||
# Do some basic typechecking to stop things that aren't allowed
|
||||
if isinstance(input_type, type) and issubclass(input_type, HttpRequest):
|
||||
continue
|
||||
if not acceptable_input(input_type):
|
||||
# Strip away any singifiers for the error
|
||||
_, inner_type = extract_signifier(input_type)
|
||||
raise ValueError(
|
||||
f"Input argument {name} has an unsupported type {inner_type}"
|
||||
)
|
||||
sources, pydantic_type = self.sources_for_input(input_type)
|
||||
self.sources[name] = sources
|
||||
# Keep count of how many are pulling from the body
|
||||
if InputSource.body in sources:
|
||||
amount_from_body += 1
|
||||
last_body_type = pydantic_type
|
||||
if InputSource.file in sources:
|
||||
self.input_files.add(name)
|
||||
else:
|
||||
pydantic_model_dict[name] = (Optional[pydantic_type], ...)
|
||||
# If there is just one thing pulling from the body and it's a BaseModel,
|
||||
# signify that it's actually pulling from the body keys directly and
|
||||
# not a sub-dict
|
||||
if amount_from_body == 1:
|
||||
for name, sources in self.sources.items():
|
||||
if (
|
||||
InputSource.body in sources
|
||||
and isinstance(last_body_type, type)
|
||||
and issubclass(last_body_type, BaseModel)
|
||||
):
|
||||
self.sources[name] = [
|
||||
x for x in sources if x != InputSource.body
|
||||
] + [InputSource.body_direct]
|
||||
# Turn all the main arguments into Pydantic parsing models
|
||||
try:
|
||||
self.input_model = create_model(
|
||||
f"{self.view_name}_input", **pydantic_model_dict
|
||||
)
|
||||
except RuntimeError:
|
||||
raise ValueError(
|
||||
f"One or more inputs on view {self.view_name} have a bad configuration"
|
||||
)
|
||||
if self.output_type is not None:
|
||||
self.output_model = create_model(
|
||||
f"{self.view_name}_output", value=(self.output_type, ...)
|
||||
)
|
||||
|
||||
def __call__(self, request: HttpRequest, *args, **kwargs):
|
||||
"""
|
||||
Entrypoint when this is called as a view.
|
||||
"""
|
||||
# Do a method check if we have one set
|
||||
if self.method and self.method.upper() != request.method:
|
||||
return HttpResponseNotAllowed([self.method])
|
||||
# For each item we can source, go find it if we can
|
||||
query_values = self.get_values(request.GET)
|
||||
body_values = self.get_values(request.POST)
|
||||
files_values = self.get_values(request.FILES)
|
||||
# If it's a PUT or PATCH method, work around Django not handling FILES
|
||||
# or POST on those requests
|
||||
if request.method in ["PATCH", "PUT"]:
|
||||
if request.content_type == "multipart/form-data":
|
||||
POST, FILES = MultiPartParser(
|
||||
request.META, request, request.upload_handlers, request.encoding
|
||||
).parse()
|
||||
body_values = self.get_values(POST)
|
||||
files_values = self.get_values(FILES)
|
||||
elif request.content_type == "application/x-www-form-urlencoded":
|
||||
POST = QueryDict(request.body, encoding=request._encoding)
|
||||
body_values = self.get_values(POST)
|
||||
# If there was a JSON body, go load that
|
||||
if request.content_type == "application/json" and request.body.strip():
|
||||
body_values.update(self.get_values(json.loads(request.body)))
|
||||
values = {}
|
||||
for name, sources in self.sources.items():
|
||||
for source in sources:
|
||||
if source == InputSource.path:
|
||||
if name in kwargs:
|
||||
values[name] = kwargs[name]
|
||||
break
|
||||
elif source == InputSource.query:
|
||||
if name in query_values:
|
||||
values[name] = query_values[name]
|
||||
break
|
||||
elif source == InputSource.body:
|
||||
if name in body_values:
|
||||
values[name] = body_values[name]
|
||||
break
|
||||
elif source == InputSource.file:
|
||||
if name in files_values:
|
||||
values[name] = files_values[name]
|
||||
break
|
||||
elif source == InputSource.body_direct:
|
||||
values[name] = body_values
|
||||
break
|
||||
elif source == InputSource.query_and_body_direct:
|
||||
values[name] = dict(query_values)
|
||||
values[name].update(body_values)
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"Unknown source {source}")
|
||||
else:
|
||||
values[name] = None
|
||||
# Give that to the Pydantic model to make it handle stuff
|
||||
model_instance = self.input_model(**values)
|
||||
kwargs = {
|
||||
name: getattr(model_instance, name)
|
||||
for name in model_instance.__fields__
|
||||
if values[name] is not None # Trim out missing fields
|
||||
}
|
||||
# Add in any files
|
||||
# TODO: HTTP error if file is not optional
|
||||
for name in self.input_files:
|
||||
kwargs[name] = files_values.get(name, None)
|
||||
# Call the view with those as kwargs
|
||||
try:
|
||||
response = self.view(request, **kwargs)
|
||||
except ApiError as error:
|
||||
return ApiResponse(
|
||||
{"error": error.error}, status=error.status, finalize=True
|
||||
)
|
||||
# If it's not an ApiResponse, make it one
|
||||
if not isinstance(response, ApiResponse):
|
||||
response = ApiResponse(response)
|
||||
# Get pydantic to coerce the output response
|
||||
if self.output_type is not None:
|
||||
response.data = self.output_model(value=response.data).dict()["value"]
|
||||
elif isinstance(response.data, BaseModel):
|
||||
response.data = response.data.dict()
|
||||
response.finalize()
|
||||
return response
|
||||
|
||||
|
||||
api_view = ApiView
|
|
@ -5,6 +5,7 @@ dj_database_url~=1.0.0
|
|||
django-cache-url~=3.4.2
|
||||
django-cors-headers~=3.13.0
|
||||
django-debug-toolbar~=3.8.1
|
||||
django-hatchway~=0.5.0
|
||||
django-htmx~=1.13.0
|
||||
django-oauth-toolkit~=2.2.0
|
||||
django-storages[google,boto3]~=1.13.1
|
||||
|
|
|
@ -196,6 +196,7 @@ INSTALLED_APPS = [
|
|||
"django.contrib.staticfiles",
|
||||
"corsheaders",
|
||||
"django_htmx",
|
||||
"hatchway",
|
||||
"core",
|
||||
"activities",
|
||||
"api",
|
||||
|
|
Loading…
Reference in New Issue