starmallow 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- starmallow/__init__.py +1 -1
- starmallow/applications.py +196 -232
- starmallow/background.py +1 -1
- starmallow/concurrency.py +8 -7
- starmallow/dataclasses.py +11 -10
- starmallow/datastructures.py +1 -1
- starmallow/decorators.py +31 -30
- starmallow/delimited_field.py +37 -15
- starmallow/docs.py +5 -5
- starmallow/endpoint.py +105 -79
- starmallow/endpoints.py +3 -2
- starmallow/exceptions.py +3 -3
- starmallow/ext/marshmallow/openapi.py +13 -16
- starmallow/fields.py +3 -3
- starmallow/generics.py +34 -0
- starmallow/middleware/asyncexitstack.py +1 -2
- starmallow/params.py +20 -21
- starmallow/py.typed +0 -0
- starmallow/request_resolver.py +62 -58
- starmallow/responses.py +5 -4
- starmallow/routing.py +231 -239
- starmallow/schema_generator.py +98 -52
- starmallow/security/api_key.py +11 -11
- starmallow/security/base.py +12 -4
- starmallow/security/http.py +31 -26
- starmallow/security/oauth2.py +48 -48
- starmallow/security/open_id_connect_url.py +7 -7
- starmallow/security/utils.py +2 -5
- starmallow/serializers.py +59 -63
- starmallow/types.py +12 -8
- starmallow/utils.py +114 -70
- starmallow/websockets.py +3 -6
- {starmallow-0.7.0.dist-info → starmallow-0.9.0.dist-info}/METADATA +17 -16
- starmallow-0.9.0.dist-info/RECORD +43 -0
- {starmallow-0.7.0.dist-info → starmallow-0.9.0.dist-info}/WHEEL +1 -1
- starmallow/union_field.py +0 -86
- starmallow-0.7.0.dist-info/RECORD +0 -42
- {starmallow-0.7.0.dist-info → starmallow-0.9.0.dist-info}/licenses/LICENSE.md +0 -0
starmallow/endpoint.py
CHANGED
@@ -1,30 +1,27 @@
|
|
1
1
|
import inspect
|
2
2
|
import logging
|
3
|
+
from collections.abc import Callable, Iterable, Mapping, Sequence
|
3
4
|
from dataclasses import dataclass, field
|
4
5
|
from typing import (
|
5
6
|
TYPE_CHECKING,
|
7
|
+
Annotated,
|
6
8
|
Any,
|
7
|
-
Callable,
|
8
|
-
Dict,
|
9
|
-
Iterable,
|
10
|
-
List,
|
11
|
-
Mapping,
|
12
9
|
NewType,
|
13
|
-
|
14
|
-
Type,
|
15
|
-
Union,
|
10
|
+
cast,
|
16
11
|
get_args,
|
17
12
|
get_origin,
|
18
13
|
)
|
19
14
|
|
20
15
|
import marshmallow as ma
|
21
16
|
import marshmallow.fields as mf
|
17
|
+
import typing_inspect
|
18
|
+
from marshmallow.types import StrSequenceOrSet
|
22
19
|
from marshmallow.utils import missing as missing_
|
20
|
+
from marshmallow_dataclass2 import class_schema, is_generic_alias_of_dataclass
|
23
21
|
from starlette.background import BackgroundTasks
|
24
22
|
from starlette.requests import HTTPConnection, Request
|
25
23
|
from starlette.responses import Response
|
26
24
|
from starlette.websockets import WebSocket
|
27
|
-
from typing_extensions import Annotated
|
28
25
|
|
29
26
|
from starmallow.params import (
|
30
27
|
Body,
|
@@ -42,20 +39,21 @@ from starmallow.params import (
|
|
42
39
|
from starmallow.responses import JSONResponse
|
43
40
|
from starmallow.security.base import SecurityBaseResolver
|
44
41
|
from starmallow.utils import (
|
42
|
+
MaDataclassProtocol,
|
45
43
|
create_response_model,
|
46
44
|
get_model_field,
|
47
45
|
get_path_param_names,
|
48
46
|
get_typed_return_annotation,
|
49
47
|
get_typed_signature,
|
50
48
|
is_marshmallow_dataclass,
|
51
|
-
|
49
|
+
is_marshmallow_field_or_generic,
|
52
50
|
is_marshmallow_schema,
|
53
51
|
is_optional,
|
54
52
|
lenient_issubclass,
|
55
53
|
)
|
56
54
|
|
57
55
|
if TYPE_CHECKING:
|
58
|
-
from starmallow.routing import APIRoute
|
56
|
+
from starmallow.routing import APIRoute, APIWebSocketRoute
|
59
57
|
|
60
58
|
logger = logging.getLogger(__name__)
|
61
59
|
|
@@ -69,52 +67,52 @@ STARMALLOW_PARAM_TYPES = (
|
|
69
67
|
|
70
68
|
@dataclass
|
71
69
|
class EndpointModel:
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
70
|
+
path: str
|
71
|
+
call: Callable[..., Any]
|
72
|
+
route: 'APIRoute | APIWebSocketRoute'
|
73
|
+
params: dict[ParamType, dict[str, Param]] = field(default_factory=dict)
|
74
|
+
flat_params: dict[ParamType, dict[str, Param]] = field(default_factory=dict)
|
75
|
+
name: str | None = None
|
76
|
+
methods: Sequence[str] | None = None
|
77
|
+
response_model: ma.Schema | type[ma.Schema | MaDataclassProtocol] | mf.Field | None = None
|
78
|
+
response_class: type[Response] = JSONResponse
|
79
|
+
status_code: int | None = None
|
82
80
|
|
83
81
|
@property
|
84
|
-
def path_params(self) ->
|
85
|
-
return self.flat_params.get(ParamType.path)
|
82
|
+
def path_params(self) -> dict[str, Path] | None:
|
83
|
+
return cast(dict[str, Path], self.flat_params.get(ParamType.path))
|
86
84
|
|
87
85
|
@property
|
88
|
-
def query_params(self) ->
|
89
|
-
return self.flat_params.get(ParamType.query)
|
86
|
+
def query_params(self) -> dict[str, Query] | None:
|
87
|
+
return cast(dict[str, Query], self.flat_params.get(ParamType.query))
|
90
88
|
|
91
89
|
@property
|
92
|
-
def header_params(self) ->
|
93
|
-
return self.flat_params.get(ParamType.header)
|
90
|
+
def header_params(self) -> dict[str, Header] | None:
|
91
|
+
return cast(dict[str, Header], self.flat_params.get(ParamType.header))
|
94
92
|
|
95
93
|
@property
|
96
|
-
def cookie_params(self) ->
|
97
|
-
return self.flat_params.get(ParamType.cookie)
|
94
|
+
def cookie_params(self) -> dict[str, Cookie] | None:
|
95
|
+
return cast(dict[str, Cookie], self.flat_params.get(ParamType.cookie))
|
98
96
|
|
99
97
|
@property
|
100
|
-
def body_params(self) ->
|
101
|
-
return self.flat_params.get(ParamType.body)
|
98
|
+
def body_params(self) -> dict[str, Body] | None:
|
99
|
+
return cast(dict[str, Body], self.flat_params.get(ParamType.body))
|
102
100
|
|
103
101
|
@property
|
104
|
-
def form_params(self) ->
|
105
|
-
return self.flat_params.get(ParamType.form)
|
102
|
+
def form_params(self) -> dict[str, Form] | None:
|
103
|
+
return cast(dict[str, Form], self.flat_params.get(ParamType.form))
|
106
104
|
|
107
105
|
@property
|
108
|
-
def non_field_params(self) ->
|
109
|
-
return self.flat_params.get(ParamType.noparam)
|
106
|
+
def non_field_params(self) -> dict[str, NoParam] | None:
|
107
|
+
return cast(dict[str, NoParam], self.flat_params.get(ParamType.noparam))
|
110
108
|
|
111
109
|
@property
|
112
|
-
def resolved_params(self) ->
|
113
|
-
return self.flat_params.get(ParamType.resolved)
|
110
|
+
def resolved_params(self) -> dict[str, ResolvedParam] | None:
|
111
|
+
return cast(dict[str, ResolvedParam], self.flat_params.get(ParamType.resolved))
|
114
112
|
|
115
113
|
@property
|
116
|
-
def security_params(self) ->
|
117
|
-
return self.flat_params.get(ParamType.security)
|
114
|
+
def security_params(self) -> dict[str, Security] | None:
|
115
|
+
return cast(dict[str, Security], self.flat_params.get(ParamType.security))
|
118
116
|
|
119
117
|
|
120
118
|
class SchemaMeta:
|
@@ -128,20 +126,20 @@ class SchemaModel(ma.Schema):
|
|
128
126
|
schema: ma.Schema,
|
129
127
|
load_default: Any = missing_,
|
130
128
|
required: bool = True,
|
131
|
-
metadata:
|
129
|
+
metadata: dict[str, Any] | None = None,
|
132
130
|
**kwargs,
|
133
131
|
) -> None:
|
134
132
|
self.schema = schema
|
135
133
|
self.load_default = load_default
|
136
134
|
self.required = required
|
137
|
-
self.title = metadata.get('title')
|
135
|
+
self.title = metadata.get('title') if metadata else None
|
138
136
|
self.metadata = metadata
|
139
137
|
self.kwargs = kwargs
|
140
138
|
|
141
139
|
if not getattr(schema.Meta, "title", None):
|
142
140
|
if schema.Meta is ma.Schema.Meta:
|
143
141
|
# Don't override global Meta object's title
|
144
|
-
schema.Meta = SchemaMeta(self.title)
|
142
|
+
schema.Meta = SchemaMeta(self.title) # type: ignore
|
145
143
|
else:
|
146
144
|
schema.Meta.title = self.title
|
147
145
|
|
@@ -161,14 +159,11 @@ class SchemaModel(ma.Schema):
|
|
161
159
|
|
162
160
|
def load(
|
163
161
|
self,
|
164
|
-
data:
|
165
|
-
Mapping[str, Any],
|
166
|
-
Iterable[Mapping[str, Any]],
|
167
|
-
],
|
162
|
+
data: Mapping[str, Any] | Iterable[Mapping[str, Any]],
|
168
163
|
*,
|
169
|
-
many:
|
170
|
-
partial:
|
171
|
-
unknown:
|
164
|
+
many: bool | None = None,
|
165
|
+
partial: bool | StrSequenceOrSet | None = None,
|
166
|
+
unknown: str | None = None,
|
172
167
|
) -> Any:
|
173
168
|
if not data and self.load_default:
|
174
169
|
return self.load_default
|
@@ -184,7 +179,7 @@ class EndpointMixin:
|
|
184
179
|
parameter: Param,
|
185
180
|
parameter_name: str,
|
186
181
|
default_value: Any,
|
187
|
-
) ->
|
182
|
+
) -> ma.Schema | mf.Field | None:
|
188
183
|
model = getattr(parameter, 'model', None) or type_annotation
|
189
184
|
if isinstance(model, SchemaModel):
|
190
185
|
return model
|
@@ -243,18 +238,22 @@ class EndpointMixin:
|
|
243
238
|
if is_marshmallow_dataclass(model):
|
244
239
|
model = model.Schema
|
245
240
|
|
246
|
-
if
|
247
|
-
|
241
|
+
if is_generic_alias_of_dataclass(model): # type: ignore
|
242
|
+
model = class_schema(model) # type: ignore
|
243
|
+
|
244
|
+
mmf = getattr(model, '_marshmallow_field', None)
|
245
|
+
if isinstance(model, NewType) and mmf and issubclass(mmf, mf.Field):
|
246
|
+
return mmf(**kwargs)
|
248
247
|
elif is_marshmallow_schema(model):
|
249
248
|
return SchemaModel(model() if inspect.isclass(model) else model, **kwargs)
|
250
|
-
elif
|
251
|
-
if inspect.isclass(model):
|
249
|
+
elif is_marshmallow_field_or_generic(model):
|
250
|
+
if not isinstance(model, mf.Field): # inspect.isclass(model):
|
252
251
|
model = model()
|
253
252
|
|
254
253
|
if model.load_default is not None and model.load_default != kwargs.get('load_default', ma.missing):
|
255
254
|
logger.warning(
|
256
255
|
f"'{parameter_name}' model and annotation have different 'load_default' values."
|
257
|
-
|
256
|
+
f" {model.load_default} <> {kwargs.get('load_default', ma.missing)}",
|
258
257
|
)
|
259
258
|
|
260
259
|
model.required = kwargs['required']
|
@@ -262,6 +261,11 @@ class EndpointMixin:
|
|
262
261
|
model.metadata.update(kwargs['metadata'])
|
263
262
|
|
264
263
|
return model
|
264
|
+
|
265
|
+
# elif is_marshmallow_field_or_generic(model):
|
266
|
+
# if isinstance(model, mf.Field):
|
267
|
+
# model if isinstance(model, mf.Field) else model()
|
268
|
+
|
265
269
|
else:
|
266
270
|
try:
|
267
271
|
return get_model_field(model, **kwargs)
|
@@ -278,11 +282,19 @@ class EndpointMixin:
|
|
278
282
|
|
279
283
|
return resolved_param
|
280
284
|
|
285
|
+
def get_security_param(self, resolved_param: Security, annotation: Any, path: str) -> Security:
|
286
|
+
if resolved_param.resolver is None:
|
287
|
+
resolved_param.resolver = annotation
|
288
|
+
|
289
|
+
resolved_param.resolver_params = self._get_params(resolved_param.resolver, path=path)
|
290
|
+
|
291
|
+
return resolved_param
|
292
|
+
|
281
293
|
def _get_params(
|
282
294
|
self,
|
283
295
|
func: Callable[..., Any],
|
284
296
|
path: str,
|
285
|
-
) ->
|
297
|
+
) -> dict[ParamType, dict[str, Param]]:
|
286
298
|
path_param_names = get_path_param_names(path)
|
287
299
|
params = {param_type: {} for param_type in ParamType}
|
288
300
|
for name, parameter in get_typed_signature(func).parameters.items():
|
@@ -309,8 +321,7 @@ class EndpointMixin:
|
|
309
321
|
]
|
310
322
|
if starmallow_annotations:
|
311
323
|
assert starmallow_param is inspect._empty, (
|
312
|
-
"Cannot specify `Param` in `Annotated` and default value"
|
313
|
-
f" together for {name!r}"
|
324
|
+
f"Cannot specify `Param` in `Annotated` and default value together for {name!r}"
|
314
325
|
)
|
315
326
|
|
316
327
|
starmallow_param = starmallow_annotations[-1]
|
@@ -321,6 +332,20 @@ class EndpointMixin:
|
|
321
332
|
):
|
322
333
|
default_value = starmallow_param.default
|
323
334
|
|
335
|
+
field_annotations = [
|
336
|
+
arg
|
337
|
+
for arg in annotated_args
|
338
|
+
if (
|
339
|
+
isinstance(arg, mf.Field)
|
340
|
+
or lenient_issubclass(arg, mf.Field)
|
341
|
+
or (
|
342
|
+
isinstance(arg, typing_inspect.typingGenericAlias)
|
343
|
+
and lenient_issubclass(get_origin(arg), mf.Field)
|
344
|
+
)
|
345
|
+
)
|
346
|
+
]
|
347
|
+
if field_annotations:
|
348
|
+
type_annotation = field_annotations[-1]
|
324
349
|
if (
|
325
350
|
# Skip 'self' in APIHTTPEndpoint functions
|
326
351
|
(name == 'self' and '.' in func.__qualname__)
|
@@ -328,11 +353,11 @@ class EndpointMixin:
|
|
328
353
|
):
|
329
354
|
continue
|
330
355
|
elif isinstance(starmallow_param, Security):
|
331
|
-
security_param
|
356
|
+
security_param = self.get_security_param(starmallow_param, type_annotation, path=path)
|
332
357
|
params[ParamType.security][name] = security_param
|
333
358
|
continue
|
334
359
|
elif isinstance(starmallow_param, ResolvedParam):
|
335
|
-
resolved_param
|
360
|
+
resolved_param = self.get_resolved_param(starmallow_param, type_annotation, path=path)
|
336
361
|
|
337
362
|
# Allow `ResolvedParam(HTTPBearer())` - treat as securty param
|
338
363
|
if isinstance(resolved_param.resolver, SecurityBaseResolver):
|
@@ -355,7 +380,7 @@ class EndpointMixin:
|
|
355
380
|
continue
|
356
381
|
|
357
382
|
model = self._get_param_model(type_annotation, starmallow_param, name, default_value)
|
358
|
-
model.name = name
|
383
|
+
model.name = name # type: ignore
|
359
384
|
|
360
385
|
if isinstance(starmallow_param, Param):
|
361
386
|
# Create new field_info with processed model
|
@@ -393,18 +418,18 @@ class EndpointMixin:
|
|
393
418
|
def get_endpoint_model(
|
394
419
|
self,
|
395
420
|
path: str,
|
396
|
-
endpoint: Callable[...,
|
397
|
-
route: 'APIRoute',
|
398
|
-
name:
|
399
|
-
methods:
|
400
|
-
|
401
|
-
status_code:
|
402
|
-
response_model:
|
403
|
-
response_class:
|
421
|
+
endpoint: Callable[..., ma.Schema | mf.Field | Response | None],
|
422
|
+
route: 'APIRoute | APIWebSocketRoute',
|
423
|
+
name: str | None = None,
|
424
|
+
methods: Sequence[str] | None = None,
|
425
|
+
|
426
|
+
status_code: int | None = None,
|
427
|
+
response_model: ma.Schema | type[ma.Schema | MaDataclassProtocol] | None = None,
|
428
|
+
response_class: type[Response] = JSONResponse,
|
404
429
|
) -> EndpointModel:
|
405
430
|
params = self._get_params(endpoint, path)
|
406
431
|
|
407
|
-
response_model = create_response_model(response_model or get_typed_return_annotation(endpoint))
|
432
|
+
response_model = create_response_model(response_model or get_typed_return_annotation(endpoint)) # type: ignore
|
408
433
|
|
409
434
|
return EndpointModel(
|
410
435
|
path=path,
|
@@ -421,9 +446,9 @@ class EndpointMixin:
|
|
421
446
|
|
422
447
|
|
423
448
|
def safe_merge_params(
|
424
|
-
left:
|
425
|
-
right:
|
426
|
-
) ->
|
449
|
+
left: dict[str, Param],
|
450
|
+
right: dict[str, Param],
|
451
|
+
) -> dict[str, Param]:
|
427
452
|
res = left.copy()
|
428
453
|
for name, param in right.items():
|
429
454
|
if name not in left:
|
@@ -435,9 +460,9 @@ def safe_merge_params(
|
|
435
460
|
|
436
461
|
|
437
462
|
def safe_merge_all_params(
|
438
|
-
left:
|
439
|
-
right:
|
440
|
-
) ->
|
463
|
+
left: dict[ParamType, dict[str, Param]],
|
464
|
+
right: dict[ParamType, dict[str, Param]],
|
465
|
+
) -> dict[ParamType, dict[str, Param]]:
|
441
466
|
res = {
|
442
467
|
param_type: safe_merge_params(left[param_type], right[param_type])
|
443
468
|
for param_type in ParamType
|
@@ -447,12 +472,13 @@ def safe_merge_all_params(
|
|
447
472
|
|
448
473
|
|
449
474
|
def flatten_parameters(
|
450
|
-
params:
|
451
|
-
) ->
|
475
|
+
params: dict[ParamType, dict[str, Param]],
|
476
|
+
) -> dict[ParamType, dict[str, Param]]:
|
452
477
|
# flat_params = {param_type: {} for param_type in ParamType}
|
453
478
|
flat_params = params.copy()
|
479
|
+
|
454
480
|
for param in params[ParamType.resolved].values():
|
455
|
-
if not param.resolver_params:
|
481
|
+
if not (isinstance(param, ResolvedParam) and param.resolver_params):
|
456
482
|
continue
|
457
483
|
|
458
484
|
flat_params = safe_merge_all_params(flat_params, param.resolver_params)
|
starmallow/endpoints.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
|
-
from
|
1
|
+
from collections.abc import Collection
|
2
|
+
from typing import Any, ClassVar
|
2
3
|
|
3
4
|
HTTP_METHOD_FUNCS = ("GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "OPTIONS")
|
4
5
|
|
@@ -8,7 +9,7 @@ class APIHTTPEndpoint:
|
|
8
9
|
#: The methods this view is registered for. Uses the same default
|
9
10
|
#: (``["GET", "HEAD", "OPTIONS"]``) as ``route`` and
|
10
11
|
#: ``add_url_rule`` by default.
|
11
|
-
methods: ClassVar[
|
12
|
+
methods: ClassVar[Collection[str] | None] = None
|
12
13
|
|
13
14
|
def __init_subclass__(cls, **kwargs: Any) -> None:
|
14
15
|
|
starmallow/exceptions.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Any
|
1
|
+
from typing import Any
|
2
2
|
|
3
3
|
from starlette.exceptions import HTTPException
|
4
4
|
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, HTTP_500_INTERNAL_SERVER_ERROR
|
@@ -7,7 +7,7 @@ from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, HTTP_500_INTERNAL_SE
|
|
7
7
|
class RequestValidationError(HTTPException):
|
8
8
|
def __init__(
|
9
9
|
self,
|
10
|
-
errors:
|
10
|
+
errors: dict[str, Any | list | dict],
|
11
11
|
) -> None:
|
12
12
|
super().__init__(status_code=HTTP_422_UNPROCESSABLE_ENTITY)
|
13
13
|
self.errors = errors
|
@@ -16,7 +16,7 @@ class RequestValidationError(HTTPException):
|
|
16
16
|
class WebSocketRequestValidationError(HTTPException):
|
17
17
|
def __init__(
|
18
18
|
self,
|
19
|
-
errors:
|
19
|
+
errors: dict[str, Any | list | dict],
|
20
20
|
) -> None:
|
21
21
|
super().__init__(status_code=HTTP_422_UNPROCESSABLE_ENTITY)
|
22
22
|
self.errors = errors
|
@@ -2,7 +2,7 @@ from typing import Any
|
|
2
2
|
|
3
3
|
import marshmallow as ma
|
4
4
|
import marshmallow.fields as mf
|
5
|
-
import
|
5
|
+
import marshmallow_dataclass2.collection_field as collection_field
|
6
6
|
from apispec import APISpec
|
7
7
|
from apispec.ext.marshmallow.common import get_fields
|
8
8
|
from apispec.ext.marshmallow.field_converter import (
|
@@ -12,9 +12,9 @@ from apispec.ext.marshmallow.field_converter import (
|
|
12
12
|
)
|
13
13
|
from apispec.ext.marshmallow.openapi import OpenAPIConverter as ApiSpecOpenAPIConverter
|
14
14
|
from marshmallow.utils import is_collection
|
15
|
+
from marshmallow_dataclass2.union_field import Union as UnionField
|
15
16
|
from packaging.version import Version
|
16
17
|
|
17
|
-
from starmallow.union_field import Union as UnionField
|
18
18
|
from starmallow.utils import MARSHMALLOW_ITERABLES
|
19
19
|
|
20
20
|
# marshmallow field => (JSON Schema type, format)
|
@@ -97,7 +97,7 @@ class OpenAPIConverter(ApiSpecOpenAPIConverter):
|
|
97
97
|
spec=spec,
|
98
98
|
)
|
99
99
|
self.add_attribute_function(self.field2title)
|
100
|
-
self.add_attribute_function(self.
|
100
|
+
self.add_attribute_function(self.field2unique_items)
|
101
101
|
self.add_attribute_function(self.field2enum)
|
102
102
|
self.add_attribute_function(self.field2union)
|
103
103
|
|
@@ -187,7 +187,7 @@ class OpenAPIConverter(ApiSpecOpenAPIConverter):
|
|
187
187
|
|
188
188
|
return ret
|
189
189
|
|
190
|
-
def
|
190
|
+
def field2unique_items(self: FieldConverterMixin, field: mf.Field, **kwargs: Any) -> dict:
|
191
191
|
ret = {}
|
192
192
|
|
193
193
|
# If this type isn't directly in the field mapping then check the
|
@@ -203,10 +203,7 @@ class OpenAPIConverter(ApiSpecOpenAPIConverter):
|
|
203
203
|
ret = {}
|
204
204
|
|
205
205
|
if isinstance(field, mf.Enum):
|
206
|
-
if field.by_value
|
207
|
-
choices = [x.value for x in field.enum]
|
208
|
-
else:
|
209
|
-
choices = list(field.enum.__members__)
|
206
|
+
choices = [x.value for x in field.enum] if field.by_value else list(field.enum.__members__)
|
210
207
|
|
211
208
|
if choices:
|
212
209
|
ret['enum'] = choices
|
@@ -242,7 +239,7 @@ class OpenAPIConverter(ApiSpecOpenAPIConverter):
|
|
242
239
|
|
243
240
|
return ret
|
244
241
|
|
245
|
-
#
|
242
|
+
# Override to add 'deprecated' support
|
246
243
|
def _field2parameter(
|
247
244
|
self, field: mf.Field, *, name: str, location: str,
|
248
245
|
):
|
@@ -289,19 +286,19 @@ class OpenAPIConverter(ApiSpecOpenAPIConverter):
|
|
289
286
|
:rtype: dict, a JSON Schema Object
|
290
287
|
"""
|
291
288
|
fields = get_fields(schema)
|
292
|
-
Meta = getattr(schema, "Meta", None)
|
289
|
+
meta: ma.Schema.Meta = getattr(schema, "Meta", None)
|
293
290
|
partial = getattr(schema, "partial", None)
|
294
291
|
|
295
292
|
jsonschema = self.fields2jsonschema(fields, partial=partial)
|
296
293
|
|
297
|
-
if hasattr(
|
298
|
-
jsonschema["title"] =
|
294
|
+
if hasattr(meta, "title"):
|
295
|
+
jsonschema["title"] = meta.title
|
299
296
|
else:
|
300
297
|
jsonschema['title'] = schema.__class__.__name__
|
301
298
|
|
302
|
-
if hasattr(
|
303
|
-
jsonschema["description"] =
|
304
|
-
if hasattr(
|
305
|
-
jsonschema["additionalProperties"] =
|
299
|
+
if hasattr(meta, "description"):
|
300
|
+
jsonschema["description"] = meta.description
|
301
|
+
if hasattr(meta, "unknown") and meta.unknown != ma.EXCLUDE:
|
302
|
+
jsonschema["additionalProperties"] = meta.unknown == ma.INCLUDE
|
306
303
|
|
307
304
|
return jsonschema
|
starmallow/fields.py
CHANGED
@@ -7,19 +7,19 @@ import marshmallow.fields as mf
|
|
7
7
|
from .delimited_field import DelimitedFieldMixin
|
8
8
|
|
9
9
|
|
10
|
-
class DelimitedListUUID(DelimitedFieldMixin, mf.List):
|
10
|
+
class DelimitedListUUID(DelimitedFieldMixin, mf.List): # type: ignore
|
11
11
|
def __init__(self, *, delimiter: str | None = None, **kwargs):
|
12
12
|
self.delimiter = delimiter or self.delimiter
|
13
13
|
super().__init__(mf.UUID(), **kwargs)
|
14
14
|
|
15
15
|
|
16
|
-
class DelimitedListStr(DelimitedFieldMixin, mf.List):
|
16
|
+
class DelimitedListStr(DelimitedFieldMixin, mf.List): # type: ignore
|
17
17
|
def __init__(self, *, delimiter: str | None = None, **kwargs):
|
18
18
|
self.delimiter = delimiter or self.delimiter
|
19
19
|
super().__init__(mf.String(), **kwargs)
|
20
20
|
|
21
21
|
|
22
|
-
class DelimitedListInt(DelimitedFieldMixin, mf.List):
|
22
|
+
class DelimitedListInt(DelimitedFieldMixin, mf.List): # type: ignore
|
23
23
|
def __init__(self, *, delimiter: str | None = None, **kwargs):
|
24
24
|
self.delimiter = delimiter or self.delimiter
|
25
25
|
super().__init__(mf.Integer(), **kwargs)
|
starmallow/generics.py
ADDED
@@ -0,0 +1,34 @@
|
|
1
|
+
import inspect
|
2
|
+
|
3
|
+
from typing_inspect import is_generic_type
|
4
|
+
|
5
|
+
|
6
|
+
def get_orig_class(obj):
|
7
|
+
"""
|
8
|
+
Allows you got get the runtime origin class inside __init__
|
9
|
+
|
10
|
+
Near duplicate of https://github.com/Stewori/pytypes/blob/master/pytypes/type_util.py#L182
|
11
|
+
"""
|
12
|
+
try:
|
13
|
+
return object.__getattribute__(obj, "__orig_class__")
|
14
|
+
except AttributeError:
|
15
|
+
cls = object.__getattribute__(obj, "__class__")
|
16
|
+
if is_generic_type(cls):
|
17
|
+
# Searching from index 1 is sufficient: At 0 is get_orig_class, at 1 is the caller.
|
18
|
+
frame = inspect.currentframe()
|
19
|
+
if frame is None:
|
20
|
+
raise ValueError('Frame does not have a caller') from None
|
21
|
+
|
22
|
+
frame = frame.f_back
|
23
|
+
try:
|
24
|
+
while frame:
|
25
|
+
try:
|
26
|
+
res = frame.f_locals["self"]
|
27
|
+
if res.__origin__ is cls:
|
28
|
+
return res
|
29
|
+
except (KeyError, AttributeError):
|
30
|
+
frame = frame.f_back
|
31
|
+
finally:
|
32
|
+
del frame
|
33
|
+
|
34
|
+
raise
|
@@ -1,5 +1,4 @@
|
|
1
1
|
from contextlib import AsyncExitStack
|
2
|
-
from typing import Optional
|
3
2
|
|
4
3
|
from starlette.types import ASGIApp, Receive, Scope, Send
|
5
4
|
|
@@ -11,7 +10,7 @@ class AsyncExitStackMiddleware:
|
|
11
10
|
|
12
11
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
13
12
|
if AsyncExitStack:
|
14
|
-
dependency_exception:
|
13
|
+
dependency_exception: Exception | None = None
|
15
14
|
async with AsyncExitStack() as stack:
|
16
15
|
scope[self.context_name] = stack
|
17
16
|
try:
|