starmallow 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- starmallow/__init__.py +1 -1
- starmallow/applications.py +196 -232
- starmallow/background.py +1 -1
- starmallow/concurrency.py +8 -7
- starmallow/dataclasses.py +11 -10
- starmallow/datastructures.py +1 -1
- starmallow/decorators.py +31 -30
- starmallow/delimited_field.py +37 -15
- starmallow/docs.py +5 -5
- starmallow/endpoint.py +103 -81
- starmallow/endpoints.py +3 -2
- starmallow/exceptions.py +3 -3
- starmallow/ext/marshmallow/openapi.py +11 -14
- starmallow/fields.py +3 -3
- starmallow/generics.py +34 -0
- starmallow/middleware/asyncexitstack.py +1 -2
- starmallow/params.py +20 -21
- starmallow/py.typed +0 -0
- starmallow/request_resolver.py +62 -58
- starmallow/responses.py +5 -4
- starmallow/routing.py +231 -239
- starmallow/schema_generator.py +98 -52
- starmallow/security/api_key.py +10 -10
- starmallow/security/base.py +11 -3
- starmallow/security/http.py +30 -25
- starmallow/security/oauth2.py +47 -47
- starmallow/security/open_id_connect_url.py +6 -6
- starmallow/security/utils.py +2 -5
- starmallow/serializers.py +59 -63
- starmallow/types.py +12 -8
- starmallow/utils.py +108 -68
- starmallow/websockets.py +3 -6
- {starmallow-0.8.0.dist-info → starmallow-0.9.0.dist-info}/METADATA +14 -13
- starmallow-0.9.0.dist-info/RECORD +43 -0
- {starmallow-0.8.0.dist-info → starmallow-0.9.0.dist-info}/WHEEL +1 -1
- starmallow-0.8.0.dist-info/RECORD +0 -41
- {starmallow-0.8.0.dist-info → starmallow-0.9.0.dist-info}/licenses/LICENSE.md +0 -0
starmallow/schema_generator.py
CHANGED
@@ -4,14 +4,14 @@ import itertools
|
|
4
4
|
import re
|
5
5
|
import warnings
|
6
6
|
from collections import defaultdict
|
7
|
+
from collections.abc import Generator, Mapping
|
7
8
|
from logging import getLogger
|
8
|
-
from typing import Any,
|
9
|
+
from typing import Any, cast
|
9
10
|
|
10
11
|
import marshmallow as ma
|
11
12
|
import marshmallow.fields as mf
|
12
13
|
from apispec import APISpec
|
13
|
-
from apispec.ext.marshmallow import OpenAPIConverter, SchemaResolver
|
14
|
-
# from apispec.ext.marshmallow.openapi import OpenAPIConverter
|
14
|
+
from apispec.ext.marshmallow import OpenAPIConverter, SchemaResolver # type: ignore
|
15
15
|
from starlette.responses import Response
|
16
16
|
from starlette.routing import BaseRoute, Mount, compile_path
|
17
17
|
from starlette.schemas import BaseSchemaGenerator
|
@@ -26,16 +26,17 @@ from starmallow.security.base import SecurityBaseResolver
|
|
26
26
|
from starmallow.utils import (
|
27
27
|
deep_dict_update,
|
28
28
|
dict_safe_add,
|
29
|
-
is_marshmallow_field,
|
30
29
|
status_code_ranges,
|
31
30
|
)
|
32
31
|
|
33
32
|
logger = getLogger(__name__)
|
34
33
|
|
34
|
+
|
35
35
|
class SchemaRegistry(dict):
|
36
36
|
'''
|
37
37
|
Dict that holds all the schemas for each class and lazily resolves them.
|
38
38
|
'''
|
39
|
+
|
39
40
|
def __init__(
|
40
41
|
self,
|
41
42
|
spec: APISpec,
|
@@ -59,7 +60,7 @@ class SchemaRegistry(dict):
|
|
59
60
|
sec_obj = self.security_references.__getitem__(component_id)
|
60
61
|
except KeyError:
|
61
62
|
# Use marshmallow_dataclass to dump itself
|
62
|
-
sec_schema = model.Schema().dump(model)
|
63
|
+
sec_schema = cast(dict, model.Schema().dump(model))
|
63
64
|
sec_schema['type'] = model.type.value
|
64
65
|
|
65
66
|
self.spec.components.security_scheme(component_id=component_id, component=sec_schema)
|
@@ -74,7 +75,7 @@ class SchemaRegistry(dict):
|
|
74
75
|
if isinstance(item, SecurityBaseResolver):
|
75
76
|
return self._get_security_item(item)
|
76
77
|
|
77
|
-
if
|
78
|
+
if isinstance(item, mf.Field):
|
78
79
|
# If marshmallow field, just resolve it here without caching
|
79
80
|
prop = self.converter.field2property(item)
|
80
81
|
return prop
|
@@ -92,7 +93,9 @@ class SchemaRegistry(dict):
|
|
92
93
|
try:
|
93
94
|
schema = self.spec.components.schemas.__getitem__(component_id)
|
94
95
|
except KeyError:
|
95
|
-
self.spec.components.schema(
|
96
|
+
self.spec.components.schema(
|
97
|
+
component_id=component_id, schema=item,
|
98
|
+
)
|
96
99
|
|
97
100
|
schema = self.resolver.resolve_schema_dict(item)
|
98
101
|
super().__setitem__(schema_class, schema)
|
@@ -124,19 +127,24 @@ class SchemaGenerator(BaseSchemaGenerator):
|
|
124
127
|
plugins=[marshmallow_plugin],
|
125
128
|
)
|
126
129
|
|
130
|
+
if marshmallow_plugin.converter is None or marshmallow_plugin.resolver is None:
|
131
|
+
raise ValueError(
|
132
|
+
"Converter and resolver must be initialized before use.",
|
133
|
+
)
|
134
|
+
|
127
135
|
self.converter = marshmallow_plugin.converter
|
128
136
|
self.resolver = marshmallow_plugin.resolver
|
129
137
|
|
130
138
|
# Builtin definitions
|
131
139
|
self.schemas = SchemaRegistry(self.spec, self.converter, self.resolver)
|
132
140
|
|
133
|
-
self.operation_ids:
|
141
|
+
self.operation_ids: set[str] = set()
|
134
142
|
|
135
|
-
def get_endpoints(
|
143
|
+
def get_endpoints( # type: ignore
|
136
144
|
self,
|
137
|
-
routes:
|
145
|
+
routes: list[BaseRoute],
|
138
146
|
base_path: str = "",
|
139
|
-
) ->
|
147
|
+
) -> dict[str, list[EndpointModel]]:
|
140
148
|
"""
|
141
149
|
Given the routes, yields the following information:
|
142
150
|
|
@@ -151,11 +159,11 @@ class SchemaGenerator(BaseSchemaGenerator):
|
|
151
159
|
This allows each path to have multiple responses.
|
152
160
|
"""
|
153
161
|
|
154
|
-
endpoints_info:
|
162
|
+
endpoints_info: dict[str, list[EndpointModel]] = defaultdict(list)
|
155
163
|
|
156
164
|
for route in routes:
|
157
165
|
# path is not defined in BaseRoute, but all implementations have it.
|
158
|
-
_, path, _ = compile_path(base_path + route.path)
|
166
|
+
_, path, _ = compile_path(base_path + route.path) # type: ignore
|
159
167
|
|
160
168
|
if isinstance(route, APIRoute) and route.include_in_schema:
|
161
169
|
if inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint):
|
@@ -166,12 +174,16 @@ class SchemaGenerator(BaseSchemaGenerator):
|
|
166
174
|
endpoints_info[path].append(route.endpoint_model)
|
167
175
|
|
168
176
|
elif isinstance(route, Mount):
|
169
|
-
endpoints_info.update(
|
177
|
+
endpoints_info.update(
|
178
|
+
self.get_endpoints(
|
179
|
+
route.routes, base_path=path,
|
180
|
+
),
|
181
|
+
)
|
170
182
|
|
171
183
|
return endpoints_info
|
172
184
|
|
173
|
-
def _to_parameters(self, *params:
|
174
|
-
for name, field in itertools.chain(*(p.items() for p in params)):
|
185
|
+
def _to_parameters(self, *params: Mapping[str, Query | Path | Header | Cookie] | None) -> Generator[dict[str, Any], None, None]:
|
186
|
+
for name, field in itertools.chain(*(p.items() for p in params if p is not None)):
|
175
187
|
if not field.include_in_schema:
|
176
188
|
continue
|
177
189
|
|
@@ -183,9 +195,13 @@ class SchemaGenerator(BaseSchemaGenerator):
|
|
183
195
|
location=field.in_.name,
|
184
196
|
)
|
185
197
|
|
186
|
-
|
198
|
+
elif field.model is not None:
|
187
199
|
yield self.converter._field2parameter(
|
188
|
-
|
200
|
+
(
|
201
|
+
field.model
|
202
|
+
if isinstance(field.model, mf.Field)
|
203
|
+
else mf.Nested(field.model)
|
204
|
+
),
|
189
205
|
name=name,
|
190
206
|
location=field.in_.name,
|
191
207
|
)
|
@@ -193,7 +209,7 @@ class SchemaGenerator(BaseSchemaGenerator):
|
|
193
209
|
def _add_endpoint_parameters(
|
194
210
|
self,
|
195
211
|
endpoint: EndpointModel,
|
196
|
-
schema:
|
212
|
+
schema: dict,
|
197
213
|
):
|
198
214
|
unique_params = {}
|
199
215
|
for param in self._to_parameters(
|
@@ -207,7 +223,9 @@ class SchemaGenerator(BaseSchemaGenerator):
|
|
207
223
|
# Duplicate parameter, skip. Could be defined as a field and in a schema.
|
208
224
|
continue
|
209
225
|
else:
|
210
|
-
raise ValueError(
|
226
|
+
raise ValueError(
|
227
|
+
f"Duplicate parameter with name {param['name']} and location {param['in']}",
|
228
|
+
)
|
211
229
|
|
212
230
|
unique_params[param['name']] = param
|
213
231
|
|
@@ -216,11 +234,14 @@ class SchemaGenerator(BaseSchemaGenerator):
|
|
216
234
|
def _add_endpoint_body(
|
217
235
|
self,
|
218
236
|
endpoint: EndpointModel,
|
219
|
-
schema:
|
237
|
+
schema: dict,
|
220
238
|
):
|
221
|
-
|
222
|
-
|
223
|
-
|
239
|
+
if not isinstance(endpoint.route, APIRoute):
|
240
|
+
return
|
241
|
+
|
242
|
+
all_body_params: list[tuple[str, Body]] = [
|
243
|
+
*(endpoint.body_params.items() if endpoint.body_params else []),
|
244
|
+
*(endpoint.form_params.items() if endpoint.form_params else []),
|
224
245
|
]
|
225
246
|
schema_by_media_type = {}
|
226
247
|
is_body_required = True
|
@@ -231,8 +252,8 @@ class SchemaGenerator(BaseSchemaGenerator):
|
|
231
252
|
if body_param.include_in_schema:
|
232
253
|
endpoint_schema = self.schemas[body_param.model]
|
233
254
|
|
234
|
-
|
235
|
-
|
255
|
+
if endpoint_schema:
|
256
|
+
schema_by_media_type[body_param.media_type] = {'schema': endpoint_schema}
|
236
257
|
|
237
258
|
if getattr(body_param.model, 'required', True) is False:
|
238
259
|
is_body_required = False
|
@@ -254,8 +275,11 @@ class SchemaGenerator(BaseSchemaGenerator):
|
|
254
275
|
component_by_media_type = defaultdict(new_endpoint_schema)
|
255
276
|
for name, value in all_body_params:
|
256
277
|
media_component = component_by_media_type[value.media_type]
|
257
|
-
endpoint_properties:
|
258
|
-
|
278
|
+
endpoint_properties: dict[
|
279
|
+
str,
|
280
|
+
Any,
|
281
|
+
] = media_component['properties']
|
282
|
+
required_properties: list[Any] = media_component['required']
|
259
283
|
|
260
284
|
if value.include_in_schema:
|
261
285
|
if isinstance(value.model, ma.Schema):
|
@@ -264,7 +288,9 @@ class SchemaGenerator(BaseSchemaGenerator):
|
|
264
288
|
required_properties.append(name)
|
265
289
|
|
266
290
|
elif isinstance(value.model, mf.Field):
|
267
|
-
endpoint_properties[name] = self.converter.field2property(
|
291
|
+
endpoint_properties[name] = self.converter.field2property(
|
292
|
+
value.model,
|
293
|
+
)
|
268
294
|
if value.model.required:
|
269
295
|
required_properties.append(name)
|
270
296
|
|
@@ -273,7 +299,9 @@ class SchemaGenerator(BaseSchemaGenerator):
|
|
273
299
|
|
274
300
|
schema_by_media_type = {}
|
275
301
|
for media_type, component in component_by_media_type.items():
|
276
|
-
self.spec.components.schema(
|
302
|
+
self.spec.components.schema(
|
303
|
+
component_id=component_schema_id, component=component,
|
304
|
+
)
|
277
305
|
|
278
306
|
schema_by_media_type[media_type] = {
|
279
307
|
"schema": {
|
@@ -290,21 +318,26 @@ class SchemaGenerator(BaseSchemaGenerator):
|
|
290
318
|
def _add_security_params(
|
291
319
|
self,
|
292
320
|
endpoint: EndpointModel,
|
293
|
-
schema:
|
321
|
+
schema: dict,
|
294
322
|
):
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
323
|
+
if endpoint.security_params:
|
324
|
+
schema['security'] = [
|
325
|
+
self.schemas[security_param.resolver]
|
326
|
+
for param_name, security_param in endpoint.security_params.items()
|
327
|
+
]
|
299
328
|
|
300
329
|
def _add_endpoint_response(
|
301
330
|
self,
|
302
331
|
endpoint: EndpointModel,
|
303
|
-
schema:
|
332
|
+
schema: dict,
|
304
333
|
):
|
305
334
|
operation_responses = schema.setdefault("responses", {})
|
306
335
|
response_codes = list(operation_responses.keys())
|
307
|
-
main_response = str(
|
336
|
+
main_response = str(
|
337
|
+
endpoint.status_code or (
|
338
|
+
response_codes[0] if response_codes else 200
|
339
|
+
),
|
340
|
+
)
|
308
341
|
|
309
342
|
operation_responses[main_response] = {
|
310
343
|
'content': {
|
@@ -313,17 +346,20 @@ class SchemaGenerator(BaseSchemaGenerator):
|
|
313
346
|
},
|
314
347
|
},
|
315
348
|
}
|
349
|
+
if not isinstance(endpoint.route, APIRoute):
|
350
|
+
return
|
351
|
+
|
316
352
|
if endpoint.route.response_description:
|
317
353
|
operation_responses[main_response]['description'] = endpoint.route.response_description
|
318
354
|
|
319
355
|
# Process additional responses
|
320
356
|
route = endpoint.route
|
321
357
|
if isinstance(route.response_class, DefaultPlaceholder):
|
322
|
-
current_response_class:
|
358
|
+
current_response_class: type[Response] = route.response_class.value
|
323
359
|
else:
|
324
360
|
current_response_class = route.response_class
|
325
361
|
assert current_response_class, "A response class is needed to generate OpenAPI"
|
326
|
-
route_response_media_type:
|
362
|
+
route_response_media_type: str | None = current_response_class.media_type
|
327
363
|
|
328
364
|
if route.responses:
|
329
365
|
for (
|
@@ -342,17 +378,21 @@ class SchemaGenerator(BaseSchemaGenerator):
|
|
342
378
|
status_code_key, {},
|
343
379
|
)
|
344
380
|
field = route.response_fields.get(additional_status_code)
|
345
|
-
additional_field_schema:
|
381
|
+
additional_field_schema: dict[str, Any] | None = None
|
346
382
|
if field:
|
347
|
-
additional_field_schema = self.schemas[field] if field else {
|
383
|
+
additional_field_schema = self.schemas[field] if field else {
|
384
|
+
}
|
348
385
|
media_type = route_response_media_type or "application/json"
|
349
386
|
additional_schema = (
|
350
387
|
process_response.setdefault("content", {})
|
351
388
|
.setdefault(media_type, {})
|
352
389
|
.setdefault("schema", {})
|
353
390
|
)
|
354
|
-
deep_dict_update(
|
355
|
-
|
391
|
+
deep_dict_update(
|
392
|
+
additional_schema,
|
393
|
+
additional_field_schema,
|
394
|
+
)
|
395
|
+
status_text: str | None = status_code_ranges.get(
|
356
396
|
str(additional_status_code).upper(),
|
357
397
|
) or http.client.responses.get(int(additional_status_code))
|
358
398
|
description = (
|
@@ -364,7 +404,7 @@ class SchemaGenerator(BaseSchemaGenerator):
|
|
364
404
|
deep_dict_update(openapi_response, process_response)
|
365
405
|
openapi_response["description"] = description
|
366
406
|
|
367
|
-
def _add_default_error_response(self, schema:
|
407
|
+
def _add_default_error_response(self, schema: dict):
|
368
408
|
dict_safe_add(
|
369
409
|
schema,
|
370
410
|
'responses.422.description',
|
@@ -381,7 +421,7 @@ class SchemaGenerator(BaseSchemaGenerator):
|
|
381
421
|
return route.summary
|
382
422
|
return re.sub(r"(\w)([A-Z])", r"\1 \2", route.name).replace(".", " ").replace("_", " ").title()
|
383
423
|
|
384
|
-
def _get_route_openapi_metadata(self, route: APIRoute) ->
|
424
|
+
def _get_route_openapi_metadata(self, route: APIRoute) -> dict[str, Any]:
|
385
425
|
schema = {}
|
386
426
|
if route.tags:
|
387
427
|
schema["tags"] = route.tags
|
@@ -391,10 +431,11 @@ class SchemaGenerator(BaseSchemaGenerator):
|
|
391
431
|
operation_id = route.operation_id or route.unique_id
|
392
432
|
if operation_id in self.operation_ids:
|
393
433
|
message = (
|
394
|
-
f"Duplicate Operation ID {operation_id} for function "
|
395
|
-
+ f"{route.endpoint.__name__}"
|
434
|
+
f"Duplicate Operation ID {operation_id} for function {route.endpoint.__name__}"
|
396
435
|
)
|
397
|
-
file_name = getattr(
|
436
|
+
file_name = getattr(
|
437
|
+
route.endpoint, "__globals__", {},
|
438
|
+
).get("__file__")
|
398
439
|
if file_name:
|
399
440
|
message += f" at {file_name}"
|
400
441
|
warnings.warn(message)
|
@@ -408,10 +449,14 @@ class SchemaGenerator(BaseSchemaGenerator):
|
|
408
449
|
def get_endpoint_schema(
|
409
450
|
self,
|
410
451
|
endpoint: EndpointModel,
|
411
|
-
) ->
|
452
|
+
) -> dict[str, Any]:
|
412
453
|
'''
|
413
454
|
Generates the endpoint schema
|
414
455
|
'''
|
456
|
+
|
457
|
+
if not isinstance(endpoint.route, APIRoute):
|
458
|
+
return {}
|
459
|
+
|
415
460
|
schema = self._get_route_openapi_metadata(endpoint.route)
|
416
461
|
|
417
462
|
schema.update(self.parse_docstring(endpoint.call))
|
@@ -458,18 +503,19 @@ class SchemaGenerator(BaseSchemaGenerator):
|
|
458
503
|
|
459
504
|
return schema
|
460
505
|
|
461
|
-
def get_operations(self, endpoints:
|
506
|
+
def get_operations(self, endpoints: list[EndpointModel]) -> dict[str, dict[str, Any]]:
|
462
507
|
return {
|
463
508
|
method.lower(): self.get_endpoint_schema(e)
|
464
509
|
for e in endpoints
|
510
|
+
if e.methods
|
465
511
|
for method in e.methods
|
466
512
|
if method != 'HEAD'
|
467
513
|
}
|
468
514
|
|
469
515
|
def get_schema(
|
470
516
|
self,
|
471
|
-
routes:
|
472
|
-
) ->
|
517
|
+
routes: list[BaseRoute],
|
518
|
+
) -> dict[str, Any]:
|
473
519
|
'''
|
474
520
|
Generates the schemas for the specified routes..
|
475
521
|
'''
|
starmallow/security/api_key.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
from enum import Enum
|
2
|
-
from typing import ClassVar
|
2
|
+
from typing import ClassVar
|
3
3
|
|
4
4
|
from marshmallow_dataclass2 import dataclass as ma_dataclass
|
5
5
|
from starlette.requests import Request
|
@@ -28,15 +28,15 @@ class APIKeyQuery(SecurityBaseResolver):
|
|
28
28
|
self,
|
29
29
|
*,
|
30
30
|
name: str,
|
31
|
-
scheme_name:
|
32
|
-
description:
|
31
|
+
scheme_name: str | None = None,
|
32
|
+
description: str | None = None,
|
33
33
|
auto_error: bool = True,
|
34
34
|
):
|
35
35
|
self.model: APIKeyModel = APIKeyModel(in_=APIKeyIn.query, name=name, description=description)
|
36
36
|
self.scheme_name = scheme_name or self.__class__.__name__
|
37
37
|
self.auto_error = auto_error
|
38
38
|
|
39
|
-
async def __call__(self, request: Request) ->
|
39
|
+
async def __call__(self, request: Request) -> str | None:
|
40
40
|
api_key = request.query_params.get(self.model.name)
|
41
41
|
if not api_key:
|
42
42
|
if self.auto_error:
|
@@ -53,15 +53,15 @@ class APIKeyHeader(SecurityBaseResolver):
|
|
53
53
|
self,
|
54
54
|
*,
|
55
55
|
name: str,
|
56
|
-
scheme_name:
|
57
|
-
description:
|
56
|
+
scheme_name: str | None = None,
|
57
|
+
description: str | None = None,
|
58
58
|
auto_error: bool = True,
|
59
59
|
):
|
60
60
|
self.model: APIKeyModel = APIKeyModel(in_=APIKeyIn.header, name=name, description=description)
|
61
61
|
self.scheme_name = scheme_name or self.__class__.__name__
|
62
62
|
self.auto_error = auto_error
|
63
63
|
|
64
|
-
async def __call__(self, request: Request) ->
|
64
|
+
async def __call__(self, request: Request) -> str | None:
|
65
65
|
api_key = request.headers.get(self.model.name)
|
66
66
|
if not api_key:
|
67
67
|
if self.auto_error:
|
@@ -78,15 +78,15 @@ class APIKeyCookie(SecurityBaseResolver):
|
|
78
78
|
self,
|
79
79
|
*,
|
80
80
|
name: str,
|
81
|
-
scheme_name:
|
82
|
-
description:
|
81
|
+
scheme_name: str | None = None,
|
82
|
+
description: str | None = None,
|
83
83
|
auto_error: bool = True,
|
84
84
|
):
|
85
85
|
self.model: APIKeyModel = APIKeyModel(in_=APIKeyIn.cookie, name=name, description=description)
|
86
86
|
self.scheme_name = scheme_name or self.__class__.__name__
|
87
87
|
self.auto_error = auto_error
|
88
88
|
|
89
|
-
async def __call__(self, request: Request) ->
|
89
|
+
async def __call__(self, request: Request) -> str | None:
|
90
90
|
api_key = request.cookies.get(self.model.name)
|
91
91
|
if not api_key:
|
92
92
|
if self.auto_error:
|
starmallow/security/base.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1
1
|
from enum import Enum
|
2
|
-
from typing import Any, ClassVar
|
2
|
+
from typing import Any, ClassVar
|
3
3
|
|
4
4
|
import marshmallow as ma
|
5
5
|
from marshmallow_dataclass2 import dataclass as ma_dataclass
|
6
|
+
from starlette.requests import Request
|
6
7
|
|
7
8
|
|
8
9
|
# Provided by: https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#security-scheme-object
|
@@ -16,11 +17,13 @@ class SecurityTypes(Enum):
|
|
16
17
|
|
17
18
|
@ma_dataclass(frozen=True)
|
18
19
|
class SecurityBase:
|
20
|
+
Schema: ClassVar[type[ma.Schema]]
|
21
|
+
|
19
22
|
type: ClassVar[SecurityTypes]
|
20
|
-
description:
|
23
|
+
description: str | None = None
|
21
24
|
|
22
25
|
@ma.post_dump()
|
23
|
-
def post_dump(self, data:
|
26
|
+
def post_dump(self, data: dict[str, Any], **kwargs):
|
24
27
|
# Remove None values
|
25
28
|
return {
|
26
29
|
key: value
|
@@ -40,3 +43,8 @@ class SecurityBaseResolver:
|
|
40
43
|
) -> None:
|
41
44
|
self.model = model
|
42
45
|
self.schema_name = scheme_name
|
46
|
+
|
47
|
+
async def __call__(self, request: Request) -> Any | None:
|
48
|
+
raise NotImplementedError(
|
49
|
+
f"SecurityBaseResolver.__call__ not implemented for {self.__class__.__name__}",
|
50
|
+
)
|
starmallow/security/http.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
|
+
import asyncio
|
1
2
|
import binascii
|
2
3
|
from base64 import b64decode
|
3
|
-
from typing import ClassVar
|
4
|
+
from typing import ClassVar
|
4
5
|
|
5
6
|
from marshmallow_dataclass2 import dataclass as ma_dataclass
|
6
7
|
from starlette.requests import Request
|
@@ -33,14 +34,14 @@ class HTTPAuthorizationCredentials:
|
|
33
34
|
@ma_dataclass(frozen=True)
|
34
35
|
class HTTPBaseModel(SecurityBase):
|
35
36
|
type: ClassVar[SecurityTypes] = SecurityTypes.http
|
36
|
-
description:
|
37
|
-
scheme: str = None
|
37
|
+
description: str | None = None
|
38
|
+
scheme: str | None = None
|
38
39
|
|
39
40
|
|
40
41
|
@ma_dataclass(frozen=True)
|
41
42
|
class HTTPBearerModel(HTTPBaseModel):
|
42
43
|
scheme: str = "bearer"
|
43
|
-
|
44
|
+
bearer_format: str | None = None
|
44
45
|
|
45
46
|
|
46
47
|
class HTTPBase(SecurityBaseResolver):
|
@@ -49,8 +50,8 @@ class HTTPBase(SecurityBaseResolver):
|
|
49
50
|
self,
|
50
51
|
*,
|
51
52
|
scheme: str,
|
52
|
-
scheme_name:
|
53
|
-
description:
|
53
|
+
scheme_name: str | None = None,
|
54
|
+
description: str | None = None,
|
54
55
|
auto_error: bool = True,
|
55
56
|
) -> None:
|
56
57
|
self.model = HTTPBaseModel(scheme=scheme, description=description)
|
@@ -59,7 +60,7 @@ class HTTPBase(SecurityBaseResolver):
|
|
59
60
|
|
60
61
|
async def __call__(
|
61
62
|
self, request: Request,
|
62
|
-
) ->
|
63
|
+
) -> HTTPAuthorizationCredentials | None:
|
63
64
|
authorization = request.headers.get("Authorization")
|
64
65
|
scheme, credentials = get_authorization_scheme_param(authorization)
|
65
66
|
if not (authorization and scheme and credentials):
|
@@ -77,9 +78,9 @@ class HTTPBasic(HTTPBase):
|
|
77
78
|
def __init__(
|
78
79
|
self,
|
79
80
|
*,
|
80
|
-
scheme_name:
|
81
|
-
realm:
|
82
|
-
description:
|
81
|
+
scheme_name: str | None = None,
|
82
|
+
realm: str | None = None,
|
83
|
+
description: str | None = None,
|
83
84
|
auto_error: bool = True,
|
84
85
|
) -> None:
|
85
86
|
super().__init__(
|
@@ -93,13 +94,14 @@ class HTTPBasic(HTTPBase):
|
|
93
94
|
|
94
95
|
async def __call__( # type: ignore
|
95
96
|
self, request: Request,
|
96
|
-
) ->
|
97
|
+
) -> HTTPBasicCredentials | None:
|
97
98
|
authorization = request.headers.get("Authorization")
|
98
99
|
scheme, param = get_authorization_scheme_param(authorization)
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
100
|
+
unauthorized_headers = (
|
101
|
+
{"WWW-Authenticate": f'Basic realm="{self.realm}"'}
|
102
|
+
if self.realm
|
103
|
+
else {"WWW-Authenticate": "Basic"}
|
104
|
+
)
|
103
105
|
invalid_user_credentials_exc = HTTPException(
|
104
106
|
status_code=HTTP_401_UNAUTHORIZED,
|
105
107
|
detail="Invalid authentication credentials",
|
@@ -115,9 +117,12 @@ class HTTPBasic(HTTPBase):
|
|
115
117
|
else:
|
116
118
|
return None
|
117
119
|
try:
|
118
|
-
|
120
|
+
def decode(param: str) -> str:
|
121
|
+
return b64decode(param).decode("ascii")
|
122
|
+
|
123
|
+
data = await asyncio.to_thread(decode, param)
|
119
124
|
except (ValueError, UnicodeDecodeError, binascii.Error):
|
120
|
-
raise invalid_user_credentials_exc
|
125
|
+
raise invalid_user_credentials_exc from None
|
121
126
|
username, separator, password = data.partition(":")
|
122
127
|
if not separator:
|
123
128
|
raise invalid_user_credentials_exc
|
@@ -129,18 +134,18 @@ class HTTPBearer(HTTPBase):
|
|
129
134
|
def __init__(
|
130
135
|
self,
|
131
136
|
*,
|
132
|
-
|
133
|
-
scheme_name:
|
134
|
-
description:
|
137
|
+
bearer_format: str | None = None,
|
138
|
+
scheme_name: str | None = None,
|
139
|
+
description: str | None = None,
|
135
140
|
auto_error: bool = True,
|
136
141
|
) -> None:
|
137
|
-
self.model = HTTPBearerModel(
|
142
|
+
self.model = HTTPBearerModel(bearer_format=bearer_format, description=description)
|
138
143
|
self.scheme_name = scheme_name or self.__class__.__name__
|
139
144
|
self.auto_error = auto_error
|
140
145
|
|
141
146
|
async def __call__(
|
142
147
|
self, request: Request,
|
143
|
-
) ->
|
148
|
+
) -> HTTPAuthorizationCredentials | None:
|
144
149
|
authorization = request.headers.get("Authorization")
|
145
150
|
scheme, credentials = get_authorization_scheme_param(authorization)
|
146
151
|
if not (authorization and scheme and credentials):
|
@@ -166,8 +171,8 @@ class HTTPDigest(HTTPBase):
|
|
166
171
|
def __init__(
|
167
172
|
self,
|
168
173
|
*,
|
169
|
-
scheme_name:
|
170
|
-
description:
|
174
|
+
scheme_name: str | None = None,
|
175
|
+
description: str | None = None,
|
171
176
|
auto_error: bool = True,
|
172
177
|
) -> None:
|
173
178
|
super().__init__(
|
@@ -179,7 +184,7 @@ class HTTPDigest(HTTPBase):
|
|
179
184
|
|
180
185
|
async def __call__(
|
181
186
|
self, request: Request,
|
182
|
-
) ->
|
187
|
+
) -> HTTPAuthorizationCredentials | None:
|
183
188
|
authorization = request.headers.get("Authorization")
|
184
189
|
scheme, credentials = get_authorization_scheme_param(authorization)
|
185
190
|
if not (authorization and scheme and credentials):
|