starmallow 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- starmallow/__init__.py +1 -1
- starmallow/applications.py +196 -232
- starmallow/background.py +1 -1
- starmallow/concurrency.py +8 -7
- starmallow/dataclasses.py +11 -10
- starmallow/datastructures.py +1 -1
- starmallow/decorators.py +31 -30
- starmallow/delimited_field.py +37 -15
- starmallow/docs.py +5 -5
- starmallow/endpoint.py +103 -81
- starmallow/endpoints.py +3 -2
- starmallow/exceptions.py +3 -3
- starmallow/ext/marshmallow/openapi.py +11 -14
- starmallow/fields.py +3 -3
- starmallow/generics.py +34 -0
- starmallow/middleware/asyncexitstack.py +1 -2
- starmallow/params.py +20 -21
- starmallow/py.typed +0 -0
- starmallow/request_resolver.py +62 -58
- starmallow/responses.py +5 -4
- starmallow/routing.py +231 -239
- starmallow/schema_generator.py +98 -52
- starmallow/security/api_key.py +10 -10
- starmallow/security/base.py +11 -3
- starmallow/security/http.py +30 -25
- starmallow/security/oauth2.py +47 -47
- starmallow/security/open_id_connect_url.py +6 -6
- starmallow/security/utils.py +2 -5
- starmallow/serializers.py +59 -63
- starmallow/types.py +12 -8
- starmallow/utils.py +108 -68
- starmallow/websockets.py +3 -6
- {starmallow-0.8.0.dist-info → starmallow-0.9.0.dist-info}/METADATA +14 -13
- starmallow-0.9.0.dist-info/RECORD +43 -0
- {starmallow-0.8.0.dist-info → starmallow-0.9.0.dist-info}/WHEEL +1 -1
- starmallow-0.8.0.dist-info/RECORD +0 -41
- {starmallow-0.8.0.dist-info → starmallow-0.9.0.dist-info}/licenses/LICENSE.md +0 -0
starmallow/endpoint.py
CHANGED
@@ -1,31 +1,27 @@
|
|
1
1
|
import inspect
|
2
2
|
import logging
|
3
|
+
from collections.abc import Callable, Iterable, Mapping, Sequence
|
3
4
|
from dataclasses import dataclass, field
|
4
5
|
from typing import (
|
5
6
|
TYPE_CHECKING,
|
7
|
+
Annotated,
|
6
8
|
Any,
|
7
|
-
Callable,
|
8
|
-
Dict,
|
9
|
-
Iterable,
|
10
|
-
List,
|
11
|
-
Mapping,
|
12
9
|
NewType,
|
13
|
-
|
14
|
-
Type,
|
15
|
-
Union,
|
10
|
+
cast,
|
16
11
|
get_args,
|
17
12
|
get_origin,
|
18
13
|
)
|
19
14
|
|
20
15
|
import marshmallow as ma
|
21
16
|
import marshmallow.fields as mf
|
17
|
+
import typing_inspect
|
18
|
+
from marshmallow.types import StrSequenceOrSet
|
22
19
|
from marshmallow.utils import missing as missing_
|
23
20
|
from marshmallow_dataclass2 import class_schema, is_generic_alias_of_dataclass
|
24
21
|
from starlette.background import BackgroundTasks
|
25
22
|
from starlette.requests import HTTPConnection, Request
|
26
23
|
from starlette.responses import Response
|
27
24
|
from starlette.websockets import WebSocket
|
28
|
-
from typing_extensions import Annotated
|
29
25
|
|
30
26
|
from starmallow.params import (
|
31
27
|
Body,
|
@@ -43,20 +39,21 @@ from starmallow.params import (
|
|
43
39
|
from starmallow.responses import JSONResponse
|
44
40
|
from starmallow.security.base import SecurityBaseResolver
|
45
41
|
from starmallow.utils import (
|
42
|
+
MaDataclassProtocol,
|
46
43
|
create_response_model,
|
47
44
|
get_model_field,
|
48
45
|
get_path_param_names,
|
49
46
|
get_typed_return_annotation,
|
50
47
|
get_typed_signature,
|
51
48
|
is_marshmallow_dataclass,
|
52
|
-
|
49
|
+
is_marshmallow_field_or_generic,
|
53
50
|
is_marshmallow_schema,
|
54
51
|
is_optional,
|
55
52
|
lenient_issubclass,
|
56
53
|
)
|
57
54
|
|
58
55
|
if TYPE_CHECKING:
|
59
|
-
from starmallow.routing import APIRoute
|
56
|
+
from starmallow.routing import APIRoute, APIWebSocketRoute
|
60
57
|
|
61
58
|
logger = logging.getLogger(__name__)
|
62
59
|
|
@@ -70,52 +67,52 @@ STARMALLOW_PARAM_TYPES = (
|
|
70
67
|
|
71
68
|
@dataclass
|
72
69
|
class EndpointModel:
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
70
|
+
path: str
|
71
|
+
call: Callable[..., Any]
|
72
|
+
route: 'APIRoute | APIWebSocketRoute'
|
73
|
+
params: dict[ParamType, dict[str, Param]] = field(default_factory=dict)
|
74
|
+
flat_params: dict[ParamType, dict[str, Param]] = field(default_factory=dict)
|
75
|
+
name: str | None = None
|
76
|
+
methods: Sequence[str] | None = None
|
77
|
+
response_model: ma.Schema | type[ma.Schema | MaDataclassProtocol] | mf.Field | None = None
|
78
|
+
response_class: type[Response] = JSONResponse
|
79
|
+
status_code: int | None = None
|
83
80
|
|
84
81
|
@property
|
85
|
-
def path_params(self) ->
|
86
|
-
return self.flat_params.get(ParamType.path)
|
82
|
+
def path_params(self) -> dict[str, Path] | None:
|
83
|
+
return cast(dict[str, Path], self.flat_params.get(ParamType.path))
|
87
84
|
|
88
85
|
@property
|
89
|
-
def query_params(self) ->
|
90
|
-
return self.flat_params.get(ParamType.query)
|
86
|
+
def query_params(self) -> dict[str, Query] | None:
|
87
|
+
return cast(dict[str, Query], self.flat_params.get(ParamType.query))
|
91
88
|
|
92
89
|
@property
|
93
|
-
def header_params(self) ->
|
94
|
-
return self.flat_params.get(ParamType.header)
|
90
|
+
def header_params(self) -> dict[str, Header] | None:
|
91
|
+
return cast(dict[str, Header], self.flat_params.get(ParamType.header))
|
95
92
|
|
96
93
|
@property
|
97
|
-
def cookie_params(self) ->
|
98
|
-
return self.flat_params.get(ParamType.cookie)
|
94
|
+
def cookie_params(self) -> dict[str, Cookie] | None:
|
95
|
+
return cast(dict[str, Cookie], self.flat_params.get(ParamType.cookie))
|
99
96
|
|
100
97
|
@property
|
101
|
-
def body_params(self) ->
|
102
|
-
return self.flat_params.get(ParamType.body)
|
98
|
+
def body_params(self) -> dict[str, Body] | None:
|
99
|
+
return cast(dict[str, Body], self.flat_params.get(ParamType.body))
|
103
100
|
|
104
101
|
@property
|
105
|
-
def form_params(self) ->
|
106
|
-
return self.flat_params.get(ParamType.form)
|
102
|
+
def form_params(self) -> dict[str, Form] | None:
|
103
|
+
return cast(dict[str, Form], self.flat_params.get(ParamType.form))
|
107
104
|
|
108
105
|
@property
|
109
|
-
def non_field_params(self) ->
|
110
|
-
return self.flat_params.get(ParamType.noparam)
|
106
|
+
def non_field_params(self) -> dict[str, NoParam] | None:
|
107
|
+
return cast(dict[str, NoParam], self.flat_params.get(ParamType.noparam))
|
111
108
|
|
112
109
|
@property
|
113
|
-
def resolved_params(self) ->
|
114
|
-
return self.flat_params.get(ParamType.resolved)
|
110
|
+
def resolved_params(self) -> dict[str, ResolvedParam] | None:
|
111
|
+
return cast(dict[str, ResolvedParam], self.flat_params.get(ParamType.resolved))
|
115
112
|
|
116
113
|
@property
|
117
|
-
def security_params(self) ->
|
118
|
-
return self.flat_params.get(ParamType.security)
|
114
|
+
def security_params(self) -> dict[str, Security] | None:
|
115
|
+
return cast(dict[str, Security], self.flat_params.get(ParamType.security))
|
119
116
|
|
120
117
|
|
121
118
|
class SchemaMeta:
|
@@ -129,20 +126,20 @@ class SchemaModel(ma.Schema):
|
|
129
126
|
schema: ma.Schema,
|
130
127
|
load_default: Any = missing_,
|
131
128
|
required: bool = True,
|
132
|
-
metadata:
|
129
|
+
metadata: dict[str, Any] | None = None,
|
133
130
|
**kwargs,
|
134
131
|
) -> None:
|
135
132
|
self.schema = schema
|
136
133
|
self.load_default = load_default
|
137
134
|
self.required = required
|
138
|
-
self.title = metadata.get('title')
|
135
|
+
self.title = metadata.get('title') if metadata else None
|
139
136
|
self.metadata = metadata
|
140
137
|
self.kwargs = kwargs
|
141
138
|
|
142
139
|
if not getattr(schema.Meta, "title", None):
|
143
140
|
if schema.Meta is ma.Schema.Meta:
|
144
141
|
# Don't override global Meta object's title
|
145
|
-
schema.Meta = SchemaMeta(self.title)
|
142
|
+
schema.Meta = SchemaMeta(self.title) # type: ignore
|
146
143
|
else:
|
147
144
|
schema.Meta.title = self.title
|
148
145
|
|
@@ -162,14 +159,11 @@ class SchemaModel(ma.Schema):
|
|
162
159
|
|
163
160
|
def load(
|
164
161
|
self,
|
165
|
-
data:
|
166
|
-
Mapping[str, Any],
|
167
|
-
Iterable[Mapping[str, Any]],
|
168
|
-
],
|
162
|
+
data: Mapping[str, Any] | Iterable[Mapping[str, Any]],
|
169
163
|
*,
|
170
|
-
many:
|
171
|
-
partial:
|
172
|
-
unknown:
|
164
|
+
many: bool | None = None,
|
165
|
+
partial: bool | StrSequenceOrSet | None = None,
|
166
|
+
unknown: str | None = None,
|
173
167
|
) -> Any:
|
174
168
|
if not data and self.load_default:
|
175
169
|
return self.load_default
|
@@ -185,7 +179,7 @@ class EndpointMixin:
|
|
185
179
|
parameter: Param,
|
186
180
|
parameter_name: str,
|
187
181
|
default_value: Any,
|
188
|
-
) ->
|
182
|
+
) -> ma.Schema | mf.Field | None:
|
189
183
|
model = getattr(parameter, 'model', None) or type_annotation
|
190
184
|
if isinstance(model, SchemaModel):
|
191
185
|
return model
|
@@ -244,21 +238,22 @@ class EndpointMixin:
|
|
244
238
|
if is_marshmallow_dataclass(model):
|
245
239
|
model = model.Schema
|
246
240
|
|
247
|
-
if is_generic_alias_of_dataclass(model):
|
248
|
-
model = class_schema(model)
|
241
|
+
if is_generic_alias_of_dataclass(model): # type: ignore
|
242
|
+
model = class_schema(model) # type: ignore
|
249
243
|
|
250
|
-
|
251
|
-
|
244
|
+
mmf = getattr(model, '_marshmallow_field', None)
|
245
|
+
if isinstance(model, NewType) and mmf and issubclass(mmf, mf.Field):
|
246
|
+
return mmf(**kwargs)
|
252
247
|
elif is_marshmallow_schema(model):
|
253
248
|
return SchemaModel(model() if inspect.isclass(model) else model, **kwargs)
|
254
|
-
elif
|
255
|
-
if inspect.isclass(model):
|
249
|
+
elif is_marshmallow_field_or_generic(model):
|
250
|
+
if not isinstance(model, mf.Field): # inspect.isclass(model):
|
256
251
|
model = model()
|
257
252
|
|
258
253
|
if model.load_default is not None and model.load_default != kwargs.get('load_default', ma.missing):
|
259
254
|
logger.warning(
|
260
255
|
f"'{parameter_name}' model and annotation have different 'load_default' values."
|
261
|
-
|
256
|
+
f" {model.load_default} <> {kwargs.get('load_default', ma.missing)}",
|
262
257
|
)
|
263
258
|
|
264
259
|
model.required = kwargs['required']
|
@@ -266,6 +261,11 @@ class EndpointMixin:
|
|
266
261
|
model.metadata.update(kwargs['metadata'])
|
267
262
|
|
268
263
|
return model
|
264
|
+
|
265
|
+
# elif is_marshmallow_field_or_generic(model):
|
266
|
+
# if isinstance(model, mf.Field):
|
267
|
+
# model if isinstance(model, mf.Field) else model()
|
268
|
+
|
269
269
|
else:
|
270
270
|
try:
|
271
271
|
return get_model_field(model, **kwargs)
|
@@ -282,11 +282,19 @@ class EndpointMixin:
|
|
282
282
|
|
283
283
|
return resolved_param
|
284
284
|
|
285
|
+
def get_security_param(self, resolved_param: Security, annotation: Any, path: str) -> Security:
|
286
|
+
if resolved_param.resolver is None:
|
287
|
+
resolved_param.resolver = annotation
|
288
|
+
|
289
|
+
resolved_param.resolver_params = self._get_params(resolved_param.resolver, path=path)
|
290
|
+
|
291
|
+
return resolved_param
|
292
|
+
|
285
293
|
def _get_params(
|
286
294
|
self,
|
287
295
|
func: Callable[..., Any],
|
288
296
|
path: str,
|
289
|
-
) ->
|
297
|
+
) -> dict[ParamType, dict[str, Param]]:
|
290
298
|
path_param_names = get_path_param_names(path)
|
291
299
|
params = {param_type: {} for param_type in ParamType}
|
292
300
|
for name, parameter in get_typed_signature(func).parameters.items():
|
@@ -313,8 +321,7 @@ class EndpointMixin:
|
|
313
321
|
]
|
314
322
|
if starmallow_annotations:
|
315
323
|
assert starmallow_param is inspect._empty, (
|
316
|
-
"Cannot specify `Param` in `Annotated` and default value"
|
317
|
-
f" together for {name!r}"
|
324
|
+
f"Cannot specify `Param` in `Annotated` and default value together for {name!r}"
|
318
325
|
)
|
319
326
|
|
320
327
|
starmallow_param = starmallow_annotations[-1]
|
@@ -325,6 +332,20 @@ class EndpointMixin:
|
|
325
332
|
):
|
326
333
|
default_value = starmallow_param.default
|
327
334
|
|
335
|
+
field_annotations = [
|
336
|
+
arg
|
337
|
+
for arg in annotated_args
|
338
|
+
if (
|
339
|
+
isinstance(arg, mf.Field)
|
340
|
+
or lenient_issubclass(arg, mf.Field)
|
341
|
+
or (
|
342
|
+
isinstance(arg, typing_inspect.typingGenericAlias)
|
343
|
+
and lenient_issubclass(get_origin(arg), mf.Field)
|
344
|
+
)
|
345
|
+
)
|
346
|
+
]
|
347
|
+
if field_annotations:
|
348
|
+
type_annotation = field_annotations[-1]
|
328
349
|
if (
|
329
350
|
# Skip 'self' in APIHTTPEndpoint functions
|
330
351
|
(name == 'self' and '.' in func.__qualname__)
|
@@ -332,11 +353,11 @@ class EndpointMixin:
|
|
332
353
|
):
|
333
354
|
continue
|
334
355
|
elif isinstance(starmallow_param, Security):
|
335
|
-
security_param
|
356
|
+
security_param = self.get_security_param(starmallow_param, type_annotation, path=path)
|
336
357
|
params[ParamType.security][name] = security_param
|
337
358
|
continue
|
338
359
|
elif isinstance(starmallow_param, ResolvedParam):
|
339
|
-
resolved_param
|
360
|
+
resolved_param = self.get_resolved_param(starmallow_param, type_annotation, path=path)
|
340
361
|
|
341
362
|
# Allow `ResolvedParam(HTTPBearer())` - treat as securty param
|
342
363
|
if isinstance(resolved_param.resolver, SecurityBaseResolver):
|
@@ -359,7 +380,7 @@ class EndpointMixin:
|
|
359
380
|
continue
|
360
381
|
|
361
382
|
model = self._get_param_model(type_annotation, starmallow_param, name, default_value)
|
362
|
-
model.name = name
|
383
|
+
model.name = name # type: ignore
|
363
384
|
|
364
385
|
if isinstance(starmallow_param, Param):
|
365
386
|
# Create new field_info with processed model
|
@@ -397,18 +418,18 @@ class EndpointMixin:
|
|
397
418
|
def get_endpoint_model(
|
398
419
|
self,
|
399
420
|
path: str,
|
400
|
-
endpoint: Callable[...,
|
401
|
-
route: 'APIRoute',
|
402
|
-
name:
|
403
|
-
methods:
|
404
|
-
|
405
|
-
status_code:
|
406
|
-
response_model:
|
407
|
-
response_class:
|
421
|
+
endpoint: Callable[..., ma.Schema | mf.Field | Response | None],
|
422
|
+
route: 'APIRoute | APIWebSocketRoute',
|
423
|
+
name: str | None = None,
|
424
|
+
methods: Sequence[str] | None = None,
|
425
|
+
|
426
|
+
status_code: int | None = None,
|
427
|
+
response_model: ma.Schema | type[ma.Schema | MaDataclassProtocol] | None = None,
|
428
|
+
response_class: type[Response] = JSONResponse,
|
408
429
|
) -> EndpointModel:
|
409
430
|
params = self._get_params(endpoint, path)
|
410
431
|
|
411
|
-
response_model = create_response_model(response_model or get_typed_return_annotation(endpoint))
|
432
|
+
response_model = create_response_model(response_model or get_typed_return_annotation(endpoint)) # type: ignore
|
412
433
|
|
413
434
|
return EndpointModel(
|
414
435
|
path=path,
|
@@ -425,9 +446,9 @@ class EndpointMixin:
|
|
425
446
|
|
426
447
|
|
427
448
|
def safe_merge_params(
|
428
|
-
left:
|
429
|
-
right:
|
430
|
-
) ->
|
449
|
+
left: dict[str, Param],
|
450
|
+
right: dict[str, Param],
|
451
|
+
) -> dict[str, Param]:
|
431
452
|
res = left.copy()
|
432
453
|
for name, param in right.items():
|
433
454
|
if name not in left:
|
@@ -439,9 +460,9 @@ def safe_merge_params(
|
|
439
460
|
|
440
461
|
|
441
462
|
def safe_merge_all_params(
|
442
|
-
left:
|
443
|
-
right:
|
444
|
-
) ->
|
463
|
+
left: dict[ParamType, dict[str, Param]],
|
464
|
+
right: dict[ParamType, dict[str, Param]],
|
465
|
+
) -> dict[ParamType, dict[str, Param]]:
|
445
466
|
res = {
|
446
467
|
param_type: safe_merge_params(left[param_type], right[param_type])
|
447
468
|
for param_type in ParamType
|
@@ -451,12 +472,13 @@ def safe_merge_all_params(
|
|
451
472
|
|
452
473
|
|
453
474
|
def flatten_parameters(
|
454
|
-
params:
|
455
|
-
) ->
|
475
|
+
params: dict[ParamType, dict[str, Param]],
|
476
|
+
) -> dict[ParamType, dict[str, Param]]:
|
456
477
|
# flat_params = {param_type: {} for param_type in ParamType}
|
457
478
|
flat_params = params.copy()
|
479
|
+
|
458
480
|
for param in params[ParamType.resolved].values():
|
459
|
-
if not param.resolver_params:
|
481
|
+
if not (isinstance(param, ResolvedParam) and param.resolver_params):
|
460
482
|
continue
|
461
483
|
|
462
484
|
flat_params = safe_merge_all_params(flat_params, param.resolver_params)
|
starmallow/endpoints.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
|
-
from
|
1
|
+
from collections.abc import Collection
|
2
|
+
from typing import Any, ClassVar
|
2
3
|
|
3
4
|
HTTP_METHOD_FUNCS = ("GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "OPTIONS")
|
4
5
|
|
@@ -8,7 +9,7 @@ class APIHTTPEndpoint:
|
|
8
9
|
#: The methods this view is registered for. Uses the same default
|
9
10
|
#: (``["GET", "HEAD", "OPTIONS"]``) as ``route`` and
|
10
11
|
#: ``add_url_rule`` by default.
|
11
|
-
methods: ClassVar[
|
12
|
+
methods: ClassVar[Collection[str] | None] = None
|
12
13
|
|
13
14
|
def __init_subclass__(cls, **kwargs: Any) -> None:
|
14
15
|
|
starmallow/exceptions.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Any
|
1
|
+
from typing import Any
|
2
2
|
|
3
3
|
from starlette.exceptions import HTTPException
|
4
4
|
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, HTTP_500_INTERNAL_SERVER_ERROR
|
@@ -7,7 +7,7 @@ from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, HTTP_500_INTERNAL_SE
|
|
7
7
|
class RequestValidationError(HTTPException):
|
8
8
|
def __init__(
|
9
9
|
self,
|
10
|
-
errors:
|
10
|
+
errors: dict[str, Any | list | dict],
|
11
11
|
) -> None:
|
12
12
|
super().__init__(status_code=HTTP_422_UNPROCESSABLE_ENTITY)
|
13
13
|
self.errors = errors
|
@@ -16,7 +16,7 @@ class RequestValidationError(HTTPException):
|
|
16
16
|
class WebSocketRequestValidationError(HTTPException):
|
17
17
|
def __init__(
|
18
18
|
self,
|
19
|
-
errors:
|
19
|
+
errors: dict[str, Any | list | dict],
|
20
20
|
) -> None:
|
21
21
|
super().__init__(status_code=HTTP_422_UNPROCESSABLE_ENTITY)
|
22
22
|
self.errors = errors
|
@@ -97,7 +97,7 @@ class OpenAPIConverter(ApiSpecOpenAPIConverter):
|
|
97
97
|
spec=spec,
|
98
98
|
)
|
99
99
|
self.add_attribute_function(self.field2title)
|
100
|
-
self.add_attribute_function(self.
|
100
|
+
self.add_attribute_function(self.field2unique_items)
|
101
101
|
self.add_attribute_function(self.field2enum)
|
102
102
|
self.add_attribute_function(self.field2union)
|
103
103
|
|
@@ -187,7 +187,7 @@ class OpenAPIConverter(ApiSpecOpenAPIConverter):
|
|
187
187
|
|
188
188
|
return ret
|
189
189
|
|
190
|
-
def
|
190
|
+
def field2unique_items(self: FieldConverterMixin, field: mf.Field, **kwargs: Any) -> dict:
|
191
191
|
ret = {}
|
192
192
|
|
193
193
|
# If this type isn't directly in the field mapping then check the
|
@@ -203,10 +203,7 @@ class OpenAPIConverter(ApiSpecOpenAPIConverter):
|
|
203
203
|
ret = {}
|
204
204
|
|
205
205
|
if isinstance(field, mf.Enum):
|
206
|
-
if field.by_value
|
207
|
-
choices = [x.value for x in field.enum]
|
208
|
-
else:
|
209
|
-
choices = list(field.enum.__members__)
|
206
|
+
choices = [x.value for x in field.enum] if field.by_value else list(field.enum.__members__)
|
210
207
|
|
211
208
|
if choices:
|
212
209
|
ret['enum'] = choices
|
@@ -242,7 +239,7 @@ class OpenAPIConverter(ApiSpecOpenAPIConverter):
|
|
242
239
|
|
243
240
|
return ret
|
244
241
|
|
245
|
-
#
|
242
|
+
# Override to add 'deprecated' support
|
246
243
|
def _field2parameter(
|
247
244
|
self, field: mf.Field, *, name: str, location: str,
|
248
245
|
):
|
@@ -289,19 +286,19 @@ class OpenAPIConverter(ApiSpecOpenAPIConverter):
|
|
289
286
|
:rtype: dict, a JSON Schema Object
|
290
287
|
"""
|
291
288
|
fields = get_fields(schema)
|
292
|
-
Meta = getattr(schema, "Meta", None)
|
289
|
+
meta: ma.Schema.Meta = getattr(schema, "Meta", None)
|
293
290
|
partial = getattr(schema, "partial", None)
|
294
291
|
|
295
292
|
jsonschema = self.fields2jsonschema(fields, partial=partial)
|
296
293
|
|
297
|
-
if hasattr(
|
298
|
-
jsonschema["title"] =
|
294
|
+
if hasattr(meta, "title"):
|
295
|
+
jsonschema["title"] = meta.title
|
299
296
|
else:
|
300
297
|
jsonschema['title'] = schema.__class__.__name__
|
301
298
|
|
302
|
-
if hasattr(
|
303
|
-
jsonschema["description"] =
|
304
|
-
if hasattr(
|
305
|
-
jsonschema["additionalProperties"] =
|
299
|
+
if hasattr(meta, "description"):
|
300
|
+
jsonschema["description"] = meta.description
|
301
|
+
if hasattr(meta, "unknown") and meta.unknown != ma.EXCLUDE:
|
302
|
+
jsonschema["additionalProperties"] = meta.unknown == ma.INCLUDE
|
306
303
|
|
307
304
|
return jsonschema
|
starmallow/fields.py
CHANGED
@@ -7,19 +7,19 @@ import marshmallow.fields as mf
|
|
7
7
|
from .delimited_field import DelimitedFieldMixin
|
8
8
|
|
9
9
|
|
10
|
-
class DelimitedListUUID(DelimitedFieldMixin, mf.List):
|
10
|
+
class DelimitedListUUID(DelimitedFieldMixin, mf.List): # type: ignore
|
11
11
|
def __init__(self, *, delimiter: str | None = None, **kwargs):
|
12
12
|
self.delimiter = delimiter or self.delimiter
|
13
13
|
super().__init__(mf.UUID(), **kwargs)
|
14
14
|
|
15
15
|
|
16
|
-
class DelimitedListStr(DelimitedFieldMixin, mf.List):
|
16
|
+
class DelimitedListStr(DelimitedFieldMixin, mf.List): # type: ignore
|
17
17
|
def __init__(self, *, delimiter: str | None = None, **kwargs):
|
18
18
|
self.delimiter = delimiter or self.delimiter
|
19
19
|
super().__init__(mf.String(), **kwargs)
|
20
20
|
|
21
21
|
|
22
|
-
class DelimitedListInt(DelimitedFieldMixin, mf.List):
|
22
|
+
class DelimitedListInt(DelimitedFieldMixin, mf.List): # type: ignore
|
23
23
|
def __init__(self, *, delimiter: str | None = None, **kwargs):
|
24
24
|
self.delimiter = delimiter or self.delimiter
|
25
25
|
super().__init__(mf.Integer(), **kwargs)
|
starmallow/generics.py
ADDED
@@ -0,0 +1,34 @@
|
|
1
|
+
import inspect
|
2
|
+
|
3
|
+
from typing_inspect import is_generic_type
|
4
|
+
|
5
|
+
|
6
|
+
def get_orig_class(obj):
|
7
|
+
"""
|
8
|
+
Allows you got get the runtime origin class inside __init__
|
9
|
+
|
10
|
+
Near duplicate of https://github.com/Stewori/pytypes/blob/master/pytypes/type_util.py#L182
|
11
|
+
"""
|
12
|
+
try:
|
13
|
+
return object.__getattribute__(obj, "__orig_class__")
|
14
|
+
except AttributeError:
|
15
|
+
cls = object.__getattribute__(obj, "__class__")
|
16
|
+
if is_generic_type(cls):
|
17
|
+
# Searching from index 1 is sufficient: At 0 is get_orig_class, at 1 is the caller.
|
18
|
+
frame = inspect.currentframe()
|
19
|
+
if frame is None:
|
20
|
+
raise ValueError('Frame does not have a caller') from None
|
21
|
+
|
22
|
+
frame = frame.f_back
|
23
|
+
try:
|
24
|
+
while frame:
|
25
|
+
try:
|
26
|
+
res = frame.f_locals["self"]
|
27
|
+
if res.__origin__ is cls:
|
28
|
+
return res
|
29
|
+
except (KeyError, AttributeError):
|
30
|
+
frame = frame.f_back
|
31
|
+
finally:
|
32
|
+
del frame
|
33
|
+
|
34
|
+
raise
|
@@ -1,5 +1,4 @@
|
|
1
1
|
from contextlib import AsyncExitStack
|
2
|
-
from typing import Optional
|
3
2
|
|
4
3
|
from starlette.types import ASGIApp, Receive, Scope, Send
|
5
4
|
|
@@ -11,7 +10,7 @@ class AsyncExitStackMiddleware:
|
|
11
10
|
|
12
11
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
13
12
|
if AsyncExitStack:
|
14
|
-
dependency_exception:
|
13
|
+
dependency_exception: Exception | None = None
|
15
14
|
async with AsyncExitStack() as stack:
|
16
15
|
scope[self.context_name] = stack
|
17
16
|
try:
|