muffin-rest 11.0.1__py3-none-any.whl → 12.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
muffin_rest/__init__.py CHANGED
@@ -17,24 +17,24 @@ Api = API
17
17
 
18
18
  __all__ = (
19
19
  "API",
20
- "Api",
21
- "RESTHandler",
22
20
  "APIError",
23
- "PWRESTHandler",
21
+ "Api",
22
+ "MongoFilter",
23
+ "MongoFilters",
24
+ "MongoRESTHandler",
25
+ "MongoSort",
26
+ "MongoSorting",
24
27
  "PWFilter",
25
28
  "PWFilters",
29
+ "PWRESTHandler",
26
30
  "PWSort",
27
31
  "PWSorting",
28
- "SARESTHandler",
32
+ "RESTHandler",
29
33
  "SAFilter",
30
34
  "SAFilters",
35
+ "SARESTHandler",
31
36
  "SASort",
32
37
  "SASorting",
33
- "MongoRESTHandler",
34
- "MongoFilter",
35
- "MongoFilters",
36
- "MongoSort",
37
- "MongoSorting",
38
38
  )
39
39
 
40
40
  # Support Peewee ORM
muffin_rest/api.py CHANGED
@@ -4,7 +4,7 @@ from __future__ import annotations
4
4
 
5
5
  import dataclasses as dc
6
6
  from pathlib import Path
7
- from typing import TYPE_CHECKING, Any, Callable, Optional, Union, overload
7
+ from typing import TYPE_CHECKING, Any, Callable, overload
8
8
 
9
9
  from http_router import Router
10
10
  from muffin.utils import TV, to_awaitable
@@ -27,11 +27,11 @@ class API:
27
27
 
28
28
  def __init__(
29
29
  self,
30
- app: Optional[muffin.Application] = None,
30
+ app: muffin.Application | None = None,
31
31
  prefix: str = "",
32
32
  *,
33
33
  openapi: bool = True,
34
- servers: Optional[list] = None,
34
+ servers: list | None = None,
35
35
  **openapi_info,
36
36
  ):
37
37
  """Post initialize the API if we have an application already."""
@@ -66,8 +66,8 @@ class API:
66
66
  app: muffin.Application,
67
67
  *,
68
68
  prefix: str = "",
69
- openapi: Optional[bool] = None,
70
- servers: Optional[list] = None,
69
+ openapi: bool | None = None,
70
+ servers: list | None = None,
71
71
  **openapi_info,
72
72
  ):
73
73
  """Initialize the API."""
@@ -100,16 +100,12 @@ class API:
100
100
  self.router.route("/openapi.json")(openapi_json)
101
101
 
102
102
  @overload
103
- def route(self, obj: str, *paths: str, **params) -> Callable[[TV], TV]:
104
- ...
103
+ def route(self, obj: str, *paths: str, **params) -> Callable[[TV], TV]: ...
105
104
 
106
105
  @overload
107
- def route(self, obj: TVHandler, *paths: str, **params) -> TVHandler:
108
- ...
106
+ def route(self, obj: TVHandler, *paths: str, **params) -> TVHandler: ...
109
107
 
110
- def route(
111
- self, obj: Union[str, TVHandler], *paths: str, **params
112
- ) -> Union[Callable[[TV], TV], TVHandler]:
108
+ def route(self, obj: str | TVHandler, *paths: str, **params) -> Callable[[TV], TV] | TVHandler:
113
109
  """Route an endpoint by the API."""
114
110
  from .handler import RESTBase
115
111
 
muffin_rest/errors.py CHANGED
@@ -1,9 +1,10 @@
1
1
  """Helpers to raise API errors as JSON responses."""
2
+
2
3
  from __future__ import annotations
3
4
 
4
5
  import json
5
6
  from http import HTTPStatus
6
- from typing import TYPE_CHECKING, Optional
7
+ from typing import TYPE_CHECKING
7
8
 
8
9
  from muffin import ResponseError
9
10
 
@@ -16,7 +17,7 @@ class APIError(ResponseError):
16
17
 
17
18
  def __init__(
18
19
  self,
19
- content: Optional[TJSON] = None,
20
+ content: TJSON | None = None,
20
21
  *,
21
22
  status_code: int = HTTPStatus.BAD_REQUEST.value,
22
23
  **json_data,
muffin_rest/filters.py CHANGED
@@ -3,7 +3,7 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import operator
6
- from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Mapping, Optional # py39
6
+ from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Mapping
7
7
 
8
8
  import marshmallow as ma
9
9
  from asgi_tools._compat import json_loads
@@ -59,8 +59,8 @@ class Filter(Mutate):
59
59
  name: str,
60
60
  *,
61
61
  field: Any = None,
62
- schema_field: Optional[ma.fields.Field] = None,
63
- operator: Optional[str] = None,
62
+ schema_field: ma.fields.Field | None = None,
63
+ operator: str | None = None,
64
64
  **meta,
65
65
  ):
66
66
  """Initialize filter.
@@ -75,7 +75,7 @@ class Filter(Mutate):
75
75
  self.schema_field = schema_field or self.schema_field
76
76
  self.default_operator = operator or self.default_operator
77
77
 
78
- async def apply(self, collection: Any, data: Optional[Mapping] = None):
78
+ async def apply(self, collection: Any, data: Mapping | None = None):
79
79
  """Filter given collection."""
80
80
  if not data:
81
81
  return None, collection
muffin_rest/handler.py CHANGED
@@ -9,9 +9,7 @@ from typing import (
9
9
  Iterable,
10
10
  Literal,
11
11
  Mapping,
12
- Optional,
13
12
  Sequence,
14
- Union,
15
13
  cast,
16
14
  overload,
17
15
  )
@@ -31,7 +29,7 @@ from muffin_rest.types import TSchemaRes
31
29
 
32
30
  from .errors import HandlerNotBindedError
33
31
  from .options import RESTOptions
34
- from .types import TVCollection, TVData, TVResource
32
+ from .types import TVCollection, TVResource
35
33
 
36
34
 
37
35
  class RESTHandlerMeta(HandlerMeta):
@@ -39,7 +37,7 @@ class RESTHandlerMeta(HandlerMeta):
39
37
 
40
38
  def __new__(mcs, name, bases, params):
41
39
  """Prepare options for the handler."""
42
- kls = cast(type["RESTBase"], super().__new__(mcs, name, bases, params))
40
+ kls = cast("type[RESTBase]", super().__new__(mcs, name, bases, params))
43
41
  kls.meta = kls.meta_class(kls)
44
42
 
45
43
  if getattr(kls.meta, kls.meta_class.base_property, None) is not None:
@@ -58,22 +56,22 @@ class RESTBase(Generic[TVResource], Handler, metaclass=RESTHandlerMeta):
58
56
 
59
57
  meta: RESTOptions
60
58
  meta_class: type[RESTOptions] = RESTOptions
61
- _api: Optional[API] = None
59
+ _api: API | None = None
62
60
 
63
- filters: Optional[dict[str, Any]] = None
64
- sorting: Optional[dict[str, Any]] = None
61
+ filters: dict[str, Any] | None = None
62
+ sorting: dict[str, Any] | None = None
65
63
 
66
64
  class Meta:
67
65
  """Tune the handler."""
68
66
 
69
67
  # Resource filters
70
- filters: Sequence[Union[str, tuple[str, str], Filter]] = ()
68
+ filters: Sequence[str | tuple[str, str] | Filter] = ()
71
69
 
72
70
  # Define allowed resource sorting params
73
- sorting: Sequence[Union[str, tuple[str, dict], Sort]] = ()
71
+ sorting: Sequence[str | tuple[str, dict] | Sort] = ()
74
72
 
75
73
  # Serialize/Deserialize Schema class
76
- Schema: Optional[type[ma.Schema]] = None
74
+ Schema: type[ma.Schema] | None = None
77
75
 
78
76
  @classmethod
79
77
  def __route__(cls, router, *paths, **params):
@@ -84,9 +82,7 @@ class RESTBase(Generic[TVResource], Handler, metaclass=RESTHandlerMeta):
84
82
 
85
83
  else:
86
84
  router.bind(cls, f"/{ cls.meta.name }", methods=methods, **params)
87
- router.bind(
88
- cls, f"/{ cls.meta.name }/{{{ cls.meta.name_id }}}", methods=methods, **params
89
- )
85
+ router.bind(cls, f"/{ cls.meta.name }/{{pk}}", methods=methods, **params)
90
86
 
91
87
  for _, method in inspect.getmembers(cls, lambda m: hasattr(m, "__route__")):
92
88
  paths, methods = method.__route__
@@ -94,7 +90,7 @@ class RESTBase(Generic[TVResource], Handler, metaclass=RESTHandlerMeta):
94
90
 
95
91
  return cls
96
92
 
97
- async def __call__(self, request: Request, *, method_name: Optional[str] = None, **_) -> Any:
93
+ async def __call__(self, request: Request, *, method_name: str | None = None, **_) -> Any:
98
94
  """Dispatch the given request by HTTP method."""
99
95
  self.auth = await self.authorize(request)
100
96
 
@@ -151,7 +147,7 @@ class RESTBase(Generic[TVResource], Handler, metaclass=RESTHandlerMeta):
151
147
 
152
148
  async def prepare_resource(self, request: Request) -> Any:
153
149
  """Load a resource."""
154
- return request["path_params"].get(self.meta.name_id)
150
+ return request["path_params"].get("pk")
155
151
 
156
152
  async def filter(self, request: Request, collection: TVCollection) -> tuple[TVCollection, Any]:
157
153
  """Filter the collection."""
@@ -190,7 +186,7 @@ class RESTBase(Generic[TVResource], Handler, metaclass=RESTHandlerMeta):
190
186
  @abc.abstractmethod
191
187
  async def paginate(
192
188
  self, request: Request, *, limit: int = 0, offset: int = 0
193
- ) -> tuple[Any, Optional[int]]:
189
+ ) -> tuple[Any, int | None]:
194
190
  """Paginate the results."""
195
191
  raise NotImplementedError
196
192
 
@@ -215,7 +211,7 @@ class RESTBase(Generic[TVResource], Handler, metaclass=RESTHandlerMeta):
215
211
  # Parse data
216
212
  # -----------
217
213
  def get_schema(
218
- self, request: Request, *, resource: Optional[TVResource] = None, **schema_options
214
+ self, request: Request, *, resource: TVResource | None = None, **schema_options
219
215
  ) -> ma.Schema:
220
216
  """Initialize marshmallow schema for serialization/deserialization."""
221
217
  query = request.url.query
@@ -232,34 +228,37 @@ class RESTBase(Generic[TVResource], Handler, metaclass=RESTHandlerMeta):
232
228
 
233
229
  return data
234
230
 
235
- async def load(
236
- self, request: Request, resource: Optional[TVResource] = None, **schema_options
237
- ) -> TVData[TVResource]:
231
+ async def load(self, request: Request, resource: TVResource | None = None, **schema_options):
238
232
  """Load data from request and create/update a resource."""
239
233
  schema = self.get_schema(request, resource=resource, **schema_options)
240
- data = cast(Union[Mapping, list], await self.load_data(request))
241
- return cast(TVData[TVResource], await load_data(data, schema, partial=resource is not None))
234
+ data = cast("Mapping | list", await self.load_data(request))
235
+ return cast(
236
+ "TVResource | list[TVResource]",
237
+ await load_data(data, schema, partial=resource is not None),
238
+ )
242
239
 
243
240
  @overload
244
241
  async def dump( # type: ignore[misc]
245
- self, request, data: TVData, *, many: Literal[True]
242
+ self, request, data: TVResource | Iterable[TVResource], *, many: Literal[True]
246
243
  ) -> list[TSchemaRes]: ...
247
244
 
248
245
  @overload
249
- async def dump(self, request, data: TVData, *, many: bool = False) -> TSchemaRes: ...
246
+ async def dump(
247
+ self, request, data: TVResource | Iterable[TVResource], *, many: bool = False
248
+ ) -> TSchemaRes: ...
250
249
 
251
250
  async def dump(
252
251
  self,
253
252
  request: Request,
254
- data: Union[TVResource, Iterable[TVResource]],
253
+ data: TVResource | Iterable[TVResource],
255
254
  *,
256
255
  many: bool = False,
257
- ) -> Union[TSchemaRes, list[TSchemaRes]]:
256
+ ) -> TSchemaRes | list[TSchemaRes]:
258
257
  """Serialize the given response."""
259
258
  schema = self.get_schema(request)
260
259
  return schema.dump(data, many=many)
261
260
 
262
- async def get(self, request: Request, *, resource: Optional[TVResource] = None) -> ResponseJSON:
261
+ async def get(self, request: Request, *, resource: TVResource | None = None) -> ResponseJSON:
263
262
  """Get a resource or a collection of resources.
264
263
 
265
264
  Specify a path param to load a resource.
@@ -269,11 +268,9 @@ class RESTBase(Generic[TVResource], Handler, metaclass=RESTHandlerMeta):
269
268
  if resource
270
269
  else self.dump(request, data=self.collection, many=True)
271
270
  )
272
- return ResponseJSON(res)
271
+ return ResponseJSON(res) # type: ignore[type-var]
273
272
 
274
- async def post(
275
- self, request: Request, *, resource: Optional[TVResource] = None
276
- ) -> ResponseJSON:
273
+ async def post(self, request: Request, *, resource: TVResource | None = None) -> ResponseJSON:
277
274
  """Create a resource.
278
275
 
279
276
  The method accepts a single resource's data or a list of resources to create.
@@ -283,19 +280,19 @@ class RESTBase(Generic[TVResource], Handler, metaclass=RESTHandlerMeta):
283
280
  if many:
284
281
  data = await self.save_many(request, data, update=resource is not None)
285
282
  else:
286
- data = await self.save(request, cast(TVResource, data), update=resource is not None)
283
+ data = await self.save(request, cast("TVResource", data), update=resource is not None)
287
284
 
288
285
  res = await self.dump(request, data, many=many)
289
286
  return ResponseJSON(res)
290
287
 
291
- async def put(self, request: Request, *, resource: Optional[TVResource] = None) -> ResponseJSON:
288
+ async def put(self, request: Request, *, resource: TVResource | None = None) -> ResponseJSON:
292
289
  """Update a resource."""
293
290
  if resource is None:
294
291
  raise APIError.NOT_FOUND()
295
292
 
296
293
  return await self.post(request, resource=resource)
297
294
 
298
- async def delete(self, request: Request, resource: Optional[TVResource] = None):
295
+ async def delete(self, request: Request, resource: TVResource | None = None):
299
296
  """Delete a resource."""
300
297
  if resource is None:
301
298
  raise APIError.NOT_FOUND()
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import TYPE_CHECKING, Optional, Union
3
+ from typing import TYPE_CHECKING
4
4
 
5
5
  from marshmallow import Schema, ValidationError
6
6
 
@@ -10,7 +10,7 @@ if TYPE_CHECKING:
10
10
  from collections.abc import Mapping
11
11
 
12
12
 
13
- async def load_data(data: Union[Mapping, list], schema: Optional[Schema] = None, **params):
13
+ async def load_data(data: Mapping | list, schema: Schema | None = None, **params):
14
14
  if schema is None:
15
15
  return data
16
16
 
@@ -1,7 +1,8 @@
1
1
  """Mongo DB support."""
2
+
2
3
  from __future__ import annotations
3
4
 
4
- from typing import TYPE_CHECKING, Optional, cast
5
+ from typing import TYPE_CHECKING, cast
5
6
 
6
7
  import bson
7
8
  from bson.errors import InvalidId
@@ -28,7 +29,7 @@ class MongoRESTOptions(RESTOptions):
28
29
  sorting_cls: type[MongoSorting] = MongoSorting
29
30
  schema_base: type[MongoSchema] = MongoSchema
30
31
 
31
- aggregate: Optional[list] = None # Support aggregation. Set to pipeline.
32
+ aggregate: list | None = None # Support aggregation. Set to pipeline.
32
33
  collection_id: str = "_id"
33
34
  collection: motor.AsyncIOMotorCollection
34
35
 
@@ -57,7 +58,7 @@ class MongoRESTHandler(RESTHandler[TVResource]):
57
58
 
58
59
  async def paginate(
59
60
  self, _: Request, *, limit: int = 0, offset: int = 0
60
- ) -> tuple[motor.AsyncIOMotorCursor, Optional[int]]:
61
+ ) -> tuple[motor.AsyncIOMotorCursor, int | None]:
61
62
  """Paginate collection."""
62
63
  if self.meta.aggregate:
63
64
  pipeline_all = [*self.meta.aggregate, {"$skip": offset}, {"$limit": limit}]
@@ -68,14 +69,14 @@ class MongoRESTHandler(RESTHandler[TVResource]):
68
69
  counts = list(self.collection.aggregate(pipeline_num))
69
70
  return (
70
71
  self.collection.aggregate(pipeline_all),
71
- counts and counts[0]["total"] or 0, # type: ignore[]
72
+ (counts and counts[0]["total"]) or 0, # type: ignore[]
72
73
  )
73
74
  total = None
74
75
  if self.meta.limit_total:
75
76
  total = await self.collection.count()
76
77
  return self.collection.skip(offset).limit(limit), total
77
78
 
78
- async def get(self, request, *, resource: Optional[TVResource] = None):
79
+ async def get(self, request, *, resource: TVResource | None = None):
79
80
  """Get resource or collection of resources."""
80
81
  if resource:
81
82
  return await self.dump(request, resource)
@@ -83,9 +84,9 @@ class MongoRESTHandler(RESTHandler[TVResource]):
83
84
  docs = await self.collection.to_list(None)
84
85
  return await self.dump(request, docs, many=True)
85
86
 
86
- async def prepare_resource(self, request: Request) -> Optional[TVResource]:
87
+ async def prepare_resource(self, request: Request) -> TVResource | None:
87
88
  """Load a resource."""
88
- pk = request["path_params"].get(self.meta.name_id)
89
+ pk = request["path_params"].get("pk")
89
90
  if not pk:
90
91
  return None
91
92
 
@@ -97,7 +98,7 @@ class MongoRESTHandler(RESTHandler[TVResource]):
97
98
  raise APIError.NOT_FOUND() from exc
98
99
 
99
100
  def get_schema(
100
- self, request: Request, resource: Optional[TVResource] = None, **schema_options
101
+ self, request: Request, resource: TVResource | None = None, **schema_options
101
102
  ) -> ma.Schema:
102
103
  """Initialize marshmallow schema for serialization/deserialization."""
103
104
  return super().get_schema(request, instance=resource, **schema_options)
@@ -115,10 +116,12 @@ class MongoRESTHandler(RESTHandler[TVResource]):
115
116
 
116
117
  return resource
117
118
 
118
- async def delete(self, request: Request, resource: Optional[TVResource] = None):
119
+ async def delete(self, request: Request, resource: TVResource | None = None):
119
120
  """Remove the given resource(s)."""
120
121
  meta = self.meta
121
- oids = [resource[meta.collection_id]] if resource else cast(list[str], await request.data())
122
+ oids = (
123
+ [resource[meta.collection_id]] if resource else cast("list[str]", await request.data())
124
+ )
122
125
  if not oids:
123
126
  raise APIError.NOT_FOUND()
124
127
 
@@ -2,7 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from typing import TYPE_CHECKING, Awaitable, Union
5
+ from typing import TYPE_CHECKING, Awaitable
6
6
 
7
7
  if TYPE_CHECKING:
8
8
  from motor import motor_asyncio as motor
@@ -58,7 +58,7 @@ class MongoChain:
58
58
 
59
59
  def find(
60
60
  self,
61
- query: Union[list, dict, None] = None,
61
+ query: list | dict | None = None,
62
62
  projection=None,
63
63
  ) -> MongoChain:
64
64
  """Store filters in self."""
@@ -68,17 +68,17 @@ class MongoChain:
68
68
 
69
69
  def find_one(
70
70
  self,
71
- query: Union[list, dict, None] = None,
71
+ query: list | dict | None = None,
72
72
  projection=None,
73
73
  ) -> Awaitable:
74
74
  """Apply filters and return cursor."""
75
75
  query = self.__update__(query)
76
- query = query and {"$and": query} or {}
76
+ query = (query and {"$and": query}) or {}
77
77
  return self.collection.find_one(query, projection=projection)
78
78
 
79
79
  def count(self) -> Awaitable[int]:
80
80
  """Count documents."""
81
- query = self.query and {"$and": self.query} or {}
81
+ query = (self.query and {"$and": self.query}) or {}
82
82
  return self.collection.count_documents(query)
83
83
 
84
84
  def aggregate(self, pipeline, **kwargs):
@@ -121,7 +121,7 @@ class MongoChain:
121
121
 
122
122
  def __iter__(self):
123
123
  """Iterate by self collection."""
124
- query = self.query and {"$and": self.query} or {}
124
+ query = (self.query and {"$and": self.query}) or {}
125
125
  if self.sorting:
126
126
  return self.collection.find(query, self.projection).sort(self.sorting)
127
127
 
@@ -130,7 +130,7 @@ class MongoChain:
130
130
  def __getattr__(self, name):
131
131
  """Proxy any attributes except find to self.collection."""
132
132
  if name in self.CURSOR_METHODS:
133
- query = self.query and {"$and": self.query} or {}
133
+ query = (self.query and {"$and": self.query}) or {}
134
134
  cursor = self.collection.find(query, self.projection)
135
135
  if self.sorting:
136
136
  cursor = cursor.sort(self.sorting)
muffin_rest/openapi.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Create openapi schema from the given API."""
2
+
2
3
  from __future__ import annotations
3
4
 
4
5
  import inspect
@@ -88,7 +89,7 @@ def route_to_spec(route: Route, spec: APISpec, tags: dict) -> dict:
88
89
  for param in route.params:
89
90
  results["parameters"].append({"in": "path", "name": param})
90
91
 
91
- target = cast(Callable, route.target)
92
+ target = cast("Callable", route.target)
92
93
  if isinstance(target, partial):
93
94
  target = target.func
94
95
 
@@ -198,9 +199,7 @@ class OpenAPIMixin:
198
199
  schema_ref = {"$ref": f"#/components/schemas/{ meta.Schema.__name__ }"}
199
200
  for method in route_to_methods(route):
200
201
  operations[method] = {"tags": [tags[cls]]}
201
- is_resource_route = isinstance(route, DynamicRoute) and route.params.get(
202
- meta.name_id,
203
- )
202
+ is_resource_route = isinstance(route, DynamicRoute) and route.params.get("pk")
204
203
 
205
204
  if method == "get" and not is_resource_route:
206
205
  operations[method]["parameters"] = []
muffin_rest/options.py CHANGED
@@ -12,7 +12,7 @@ class RESTOptions:
12
12
  """Handler Options."""
13
13
 
14
14
  name: str = ""
15
- name_id: str = "id"
15
+ pk: str = "id"
16
16
  base_property: str = "name"
17
17
 
18
18
  # Pagination
@@ -4,7 +4,7 @@ from __future__ import annotations
4
4
 
5
5
  import operator
6
6
  from functools import reduce
7
- from typing import TYPE_CHECKING, ClassVar, Union, cast
7
+ from typing import TYPE_CHECKING, ClassVar, cast
8
8
 
9
9
  from peewee import ColumnBase, Field, ModelSelect
10
10
 
@@ -15,6 +15,8 @@ from .utils import get_model_field_by_name
15
15
  if TYPE_CHECKING:
16
16
  from muffin_rest.types import TFilterValue
17
17
 
18
+ from . import PWRESTHandler
19
+
18
20
 
19
21
  class PWFilter(Filter):
20
22
  """Support Peewee."""
@@ -41,7 +43,9 @@ class PWFilter(Filter):
41
43
  """Apply the filters to Peewee QuerySet.."""
42
44
  column = self.field
43
45
  if isinstance(column, ColumnBase):
44
- collection = cast(ModelSelect, collection.where(*[op(column, val) for op, val in ops]))
46
+ collection = cast(
47
+ "ModelSelect", collection.where(*[op(column, val) for op, val in ops])
48
+ )
45
49
  return collection
46
50
 
47
51
 
@@ -50,11 +54,10 @@ class PWFilters(Filters):
50
54
 
51
55
  MUTATE_CLASS: type[PWFilter] = PWFilter
52
56
 
53
- def convert(self, obj: Union[str, Field, PWFilter], **meta):
57
+ def convert(self, obj: str | Field | PWFilter, **meta):
54
58
  """Convert params to filters."""
55
- from . import PWRESTHandler
56
59
 
57
- handler = cast(PWRESTHandler, self.handler)
60
+ handler = cast("PWRESTHandler", self.handler)
58
61
  if isinstance(obj, PWFilter):
59
62
  return obj
60
63
 
@@ -2,10 +2,9 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload
5
+ from typing import TYPE_CHECKING, Any, cast, overload
6
6
 
7
7
  import marshmallow as ma
8
- import peewee as pw
9
8
  from apispec.ext.marshmallow import MarshmallowPlugin
10
9
  from marshmallow_peewee import ForeignKey
11
10
  from peewee_aio.model import AIOModel, AIOModelSelect
@@ -19,6 +18,7 @@ from .schemas import EnumField
19
18
  from .types import TVModel
20
19
 
21
20
  if TYPE_CHECKING:
21
+ import peewee as pw
22
22
  from muffin import Request
23
23
  from peewee_aio.types import TVAIOModel
24
24
 
@@ -34,7 +34,7 @@ class PWRESTBase(RESTBase[TVModel], PeeweeOpenAPIMixin):
34
34
 
35
35
  if TYPE_CHECKING:
36
36
  resource: TVModel
37
- collection: Union[AIOModelSelect, pw.ModelSelect]
37
+ collection: AIOModelSelect | pw.ModelSelect
38
38
 
39
39
  meta: PWRESTOptions
40
40
  meta_class: type[PWRESTOptions] = PWRESTOptions
@@ -56,9 +56,9 @@ class PWRESTBase(RESTBase[TVModel], PeeweeOpenAPIMixin):
56
56
  """Initialize Peeewee QuerySet for a binded to the resource model."""
57
57
  return self.meta.model.select()
58
58
 
59
- async def prepare_resource(self, request: Request) -> Optional[TVModel]:
59
+ async def prepare_resource(self, request: Request) -> TVModel | None:
60
60
  """Load a resource."""
61
- pk = request["path_params"].get(self.meta.name_id)
61
+ pk = request["path_params"].get("pk")
62
62
  if not pk:
63
63
  return None
64
64
 
@@ -89,7 +89,7 @@ class PWRESTBase(RESTBase[TVModel], PeeweeOpenAPIMixin):
89
89
  async def paginate(self, _: Request, *, limit: int = 0, offset: int = 0): # type: ignore[override]
90
90
  """Paginate the collection."""
91
91
  if self.meta.limit_total:
92
- cqs = cast(pw.ModelSelect, self.collection.order_by())
92
+ cqs = cast("pw.ModelSelect", self.collection.order_by())
93
93
  if cqs._group_by: # type: ignore[misc]
94
94
  cqs._returning = cqs._group_by # type: ignore[misc]
95
95
  cqs._having = None # type: ignore[misc]
@@ -101,7 +101,7 @@ class PWRESTBase(RESTBase[TVModel], PeeweeOpenAPIMixin):
101
101
 
102
102
  return self.collection.offset(offset).limit(limit), count
103
103
 
104
- async def get(self, request, *, resource: Optional[TVModel] = None) -> Any:
104
+ async def get(self, request, *, resource: TVModel | None = None) -> Any:
105
105
  """Get resource or collection of resources."""
106
106
  if resource:
107
107
  return await self.dump(request, resource)
@@ -120,7 +120,7 @@ class PWRESTBase(RESTBase[TVModel], PeeweeOpenAPIMixin):
120
120
 
121
121
  return resource
122
122
 
123
- async def remove(self, request: Request, resource: Optional[TVModel] = None):
123
+ async def remove(self, request: Request, resource: TVModel | None = None):
124
124
  """Remove the given resource."""
125
125
  meta = self.meta
126
126
  if resource:
@@ -131,7 +131,7 @@ class PWRESTBase(RESTBase[TVModel], PeeweeOpenAPIMixin):
131
131
  if not data:
132
132
  return
133
133
 
134
- model_pk = cast(pw.Field, meta.model_pk)
134
+ model_pk = cast("pw.Field", meta.model_pk)
135
135
  resources = await meta.manager.fetchall(self.collection.where(model_pk << data)) # type: ignore[]
136
136
 
137
137
  if not resources:
@@ -147,11 +147,11 @@ class PWRESTBase(RESTBase[TVModel], PeeweeOpenAPIMixin):
147
147
 
148
148
  return resource.get_id() if resource else [r.get_id() for r in resources]
149
149
 
150
- async def delete(self, request: Request, resource: Optional[TVModel] = None): # type: ignore[override]
150
+ async def delete(self, request: Request, resource: TVModel | None = None): # type: ignore[override]
151
151
  return await self.remove(request, resource)
152
152
 
153
153
  def get_schema(
154
- self, request: Request, *, resource: Optional[TVModel] = None, **schema_options
154
+ self, request: Request, *, resource: TVModel | None = None, **schema_options
155
155
  ) -> ma.Schema:
156
156
  """Initialize marshmallow schema for serialization/deserialization."""
157
157
  return super().get_schema(request, instance=resource, **schema_options)
@@ -22,7 +22,7 @@ class PeeweeOpenAPIMixin(OpenAPIMixin):
22
22
  def openapi(cls, route: Route, spec: APISpec, tags: dict) -> dict:
23
23
  """Get openapi specs for the endpoint."""
24
24
  operations = super(PeeweeOpenAPIMixin, cls).openapi(route, spec, tags)
25
- is_resource_route = getattr(route, "params", {}).get(cls.meta.name_id)
25
+ is_resource_route = getattr(route, "params", {}).get("pk")
26
26
  if not is_resource_route and "delete" in operations:
27
27
  operations["delete"].setdefault("parameters", [])
28
28
  operations["delete"]["requestBody"] = {
@@ -13,5 +13,3 @@ def build_field(field, opts, **params):
13
13
 
14
14
 
15
15
  DefaultConverter.register(URLField, ma.fields.Url)
16
-
17
- # ruff: noqa: ARG001, ARG002