fastapi 0.128.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastapi/__init__.py +25 -0
- fastapi/__main__.py +3 -0
- fastapi/_compat/__init__.py +41 -0
- fastapi/_compat/shared.py +206 -0
- fastapi/_compat/v2.py +568 -0
- fastapi/applications.py +4669 -0
- fastapi/background.py +60 -0
- fastapi/cli.py +13 -0
- fastapi/concurrency.py +41 -0
- fastapi/datastructures.py +183 -0
- fastapi/dependencies/__init__.py +0 -0
- fastapi/dependencies/models.py +193 -0
- fastapi/dependencies/utils.py +1021 -0
- fastapi/encoders.py +346 -0
- fastapi/exception_handlers.py +34 -0
- fastapi/exceptions.py +246 -0
- fastapi/logger.py +3 -0
- fastapi/middleware/__init__.py +1 -0
- fastapi/middleware/asyncexitstack.py +18 -0
- fastapi/middleware/cors.py +1 -0
- fastapi/middleware/gzip.py +1 -0
- fastapi/middleware/httpsredirect.py +3 -0
- fastapi/middleware/trustedhost.py +3 -0
- fastapi/middleware/wsgi.py +1 -0
- fastapi/openapi/__init__.py +0 -0
- fastapi/openapi/constants.py +3 -0
- fastapi/openapi/docs.py +344 -0
- fastapi/openapi/models.py +438 -0
- fastapi/openapi/utils.py +567 -0
- fastapi/param_functions.py +2369 -0
- fastapi/params.py +755 -0
- fastapi/py.typed +0 -0
- fastapi/requests.py +2 -0
- fastapi/responses.py +48 -0
- fastapi/routing.py +4508 -0
- fastapi/security/__init__.py +15 -0
- fastapi/security/api_key.py +318 -0
- fastapi/security/base.py +6 -0
- fastapi/security/http.py +423 -0
- fastapi/security/oauth2.py +663 -0
- fastapi/security/open_id_connect_url.py +94 -0
- fastapi/security/utils.py +10 -0
- fastapi/staticfiles.py +1 -0
- fastapi/templating.py +1 -0
- fastapi/testclient.py +1 -0
- fastapi/types.py +11 -0
- fastapi/utils.py +164 -0
- fastapi/websockets.py +3 -0
- fastapi-0.128.0.dist-info/METADATA +645 -0
- fastapi-0.128.0.dist-info/RECORD +53 -0
- fastapi-0.128.0.dist-info/WHEEL +4 -0
- fastapi-0.128.0.dist-info/entry_points.txt +5 -0
- fastapi-0.128.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,1021 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
import inspect
|
|
3
|
+
import sys
|
|
4
|
+
from collections.abc import Coroutine, Mapping, Sequence
|
|
5
|
+
from contextlib import AsyncExitStack, contextmanager
|
|
6
|
+
from copy import copy, deepcopy
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from typing import (
|
|
9
|
+
Annotated,
|
|
10
|
+
Any,
|
|
11
|
+
Callable,
|
|
12
|
+
ForwardRef,
|
|
13
|
+
Optional,
|
|
14
|
+
Union,
|
|
15
|
+
cast,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
import anyio
|
|
19
|
+
from fastapi import params
|
|
20
|
+
from fastapi._compat import (
|
|
21
|
+
ModelField,
|
|
22
|
+
RequiredParam,
|
|
23
|
+
Undefined,
|
|
24
|
+
_regenerate_error_with_loc,
|
|
25
|
+
copy_field_info,
|
|
26
|
+
create_body_model,
|
|
27
|
+
evaluate_forwardref,
|
|
28
|
+
field_annotation_is_scalar,
|
|
29
|
+
get_cached_model_fields,
|
|
30
|
+
get_missing_field_error,
|
|
31
|
+
is_bytes_field,
|
|
32
|
+
is_bytes_sequence_field,
|
|
33
|
+
is_scalar_field,
|
|
34
|
+
is_scalar_sequence_field,
|
|
35
|
+
is_sequence_field,
|
|
36
|
+
is_uploadfile_or_nonable_uploadfile_annotation,
|
|
37
|
+
is_uploadfile_sequence_annotation,
|
|
38
|
+
lenient_issubclass,
|
|
39
|
+
sequence_types,
|
|
40
|
+
serialize_sequence_value,
|
|
41
|
+
value_is_sequence,
|
|
42
|
+
)
|
|
43
|
+
from fastapi.background import BackgroundTasks
|
|
44
|
+
from fastapi.concurrency import (
|
|
45
|
+
asynccontextmanager,
|
|
46
|
+
contextmanager_in_threadpool,
|
|
47
|
+
)
|
|
48
|
+
from fastapi.dependencies.models import Dependant
|
|
49
|
+
from fastapi.exceptions import DependencyScopeError
|
|
50
|
+
from fastapi.logger import logger
|
|
51
|
+
from fastapi.security.oauth2 import SecurityScopes
|
|
52
|
+
from fastapi.types import DependencyCacheKey
|
|
53
|
+
from fastapi.utils import create_model_field, get_path_param_names
|
|
54
|
+
from pydantic import BaseModel
|
|
55
|
+
from pydantic.fields import FieldInfo
|
|
56
|
+
from starlette.background import BackgroundTasks as StarletteBackgroundTasks
|
|
57
|
+
from starlette.concurrency import run_in_threadpool
|
|
58
|
+
from starlette.datastructures import (
|
|
59
|
+
FormData,
|
|
60
|
+
Headers,
|
|
61
|
+
ImmutableMultiDict,
|
|
62
|
+
QueryParams,
|
|
63
|
+
UploadFile,
|
|
64
|
+
)
|
|
65
|
+
from starlette.requests import HTTPConnection, Request
|
|
66
|
+
from starlette.responses import Response
|
|
67
|
+
from starlette.websockets import WebSocket
|
|
68
|
+
from typing_extensions import Literal, get_args, get_origin
|
|
69
|
+
|
|
70
|
+
multipart_not_installed_error = (
|
|
71
|
+
'Form data requires "python-multipart" to be installed. \n'
|
|
72
|
+
'You can install "python-multipart" with: \n\n'
|
|
73
|
+
"pip install python-multipart\n"
|
|
74
|
+
)
|
|
75
|
+
multipart_incorrect_install_error = (
|
|
76
|
+
'Form data requires "python-multipart" to be installed. '
|
|
77
|
+
'It seems you installed "multipart" instead. \n'
|
|
78
|
+
'You can remove "multipart" with: \n\n'
|
|
79
|
+
"pip uninstall multipart\n\n"
|
|
80
|
+
'And then install "python-multipart" with: \n\n'
|
|
81
|
+
"pip install python-multipart\n"
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def ensure_multipart_is_installed() -> None:
|
|
86
|
+
try:
|
|
87
|
+
from python_multipart import __version__
|
|
88
|
+
|
|
89
|
+
# Import an attribute that can be mocked/deleted in testing
|
|
90
|
+
assert __version__ > "0.0.12"
|
|
91
|
+
except (ImportError, AssertionError):
|
|
92
|
+
try:
|
|
93
|
+
# __version__ is available in both multiparts, and can be mocked
|
|
94
|
+
from multipart import __version__ # type: ignore[no-redef,import-untyped]
|
|
95
|
+
|
|
96
|
+
assert __version__
|
|
97
|
+
try:
|
|
98
|
+
# parse_options_header is only available in the right multipart
|
|
99
|
+
from multipart.multipart import ( # type: ignore[import-untyped]
|
|
100
|
+
parse_options_header,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
assert parse_options_header
|
|
104
|
+
except ImportError:
|
|
105
|
+
logger.error(multipart_incorrect_install_error)
|
|
106
|
+
raise RuntimeError(multipart_incorrect_install_error) from None
|
|
107
|
+
except ImportError:
|
|
108
|
+
logger.error(multipart_not_installed_error)
|
|
109
|
+
raise RuntimeError(multipart_not_installed_error) from None
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant:
|
|
113
|
+
assert callable(depends.dependency), (
|
|
114
|
+
"A parameter-less dependency must have a callable dependency"
|
|
115
|
+
)
|
|
116
|
+
own_oauth_scopes: list[str] = []
|
|
117
|
+
if isinstance(depends, params.Security) and depends.scopes:
|
|
118
|
+
own_oauth_scopes.extend(depends.scopes)
|
|
119
|
+
return get_dependant(
|
|
120
|
+
path=path,
|
|
121
|
+
call=depends.dependency,
|
|
122
|
+
scope=depends.scope,
|
|
123
|
+
own_oauth_scopes=own_oauth_scopes,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def get_flat_dependant(
|
|
128
|
+
dependant: Dependant,
|
|
129
|
+
*,
|
|
130
|
+
skip_repeats: bool = False,
|
|
131
|
+
visited: Optional[list[DependencyCacheKey]] = None,
|
|
132
|
+
parent_oauth_scopes: Optional[list[str]] = None,
|
|
133
|
+
) -> Dependant:
|
|
134
|
+
if visited is None:
|
|
135
|
+
visited = []
|
|
136
|
+
visited.append(dependant.cache_key)
|
|
137
|
+
use_parent_oauth_scopes = (parent_oauth_scopes or []) + (
|
|
138
|
+
dependant.oauth_scopes or []
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
flat_dependant = Dependant(
|
|
142
|
+
path_params=dependant.path_params.copy(),
|
|
143
|
+
query_params=dependant.query_params.copy(),
|
|
144
|
+
header_params=dependant.header_params.copy(),
|
|
145
|
+
cookie_params=dependant.cookie_params.copy(),
|
|
146
|
+
body_params=dependant.body_params.copy(),
|
|
147
|
+
name=dependant.name,
|
|
148
|
+
call=dependant.call,
|
|
149
|
+
request_param_name=dependant.request_param_name,
|
|
150
|
+
websocket_param_name=dependant.websocket_param_name,
|
|
151
|
+
http_connection_param_name=dependant.http_connection_param_name,
|
|
152
|
+
response_param_name=dependant.response_param_name,
|
|
153
|
+
background_tasks_param_name=dependant.background_tasks_param_name,
|
|
154
|
+
security_scopes_param_name=dependant.security_scopes_param_name,
|
|
155
|
+
own_oauth_scopes=dependant.own_oauth_scopes,
|
|
156
|
+
parent_oauth_scopes=use_parent_oauth_scopes,
|
|
157
|
+
use_cache=dependant.use_cache,
|
|
158
|
+
path=dependant.path,
|
|
159
|
+
scope=dependant.scope,
|
|
160
|
+
)
|
|
161
|
+
for sub_dependant in dependant.dependencies:
|
|
162
|
+
if skip_repeats and sub_dependant.cache_key in visited:
|
|
163
|
+
continue
|
|
164
|
+
flat_sub = get_flat_dependant(
|
|
165
|
+
sub_dependant,
|
|
166
|
+
skip_repeats=skip_repeats,
|
|
167
|
+
visited=visited,
|
|
168
|
+
parent_oauth_scopes=flat_dependant.oauth_scopes,
|
|
169
|
+
)
|
|
170
|
+
flat_dependant.dependencies.append(flat_sub)
|
|
171
|
+
flat_dependant.path_params.extend(flat_sub.path_params)
|
|
172
|
+
flat_dependant.query_params.extend(flat_sub.query_params)
|
|
173
|
+
flat_dependant.header_params.extend(flat_sub.header_params)
|
|
174
|
+
flat_dependant.cookie_params.extend(flat_sub.cookie_params)
|
|
175
|
+
flat_dependant.body_params.extend(flat_sub.body_params)
|
|
176
|
+
flat_dependant.dependencies.extend(flat_sub.dependencies)
|
|
177
|
+
|
|
178
|
+
return flat_dependant
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def _get_flat_fields_from_params(fields: list[ModelField]) -> list[ModelField]:
|
|
182
|
+
if not fields:
|
|
183
|
+
return fields
|
|
184
|
+
first_field = fields[0]
|
|
185
|
+
if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel):
|
|
186
|
+
fields_to_extract = get_cached_model_fields(first_field.type_)
|
|
187
|
+
return fields_to_extract
|
|
188
|
+
return fields
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def get_flat_params(dependant: Dependant) -> list[ModelField]:
|
|
192
|
+
flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
|
|
193
|
+
path_params = _get_flat_fields_from_params(flat_dependant.path_params)
|
|
194
|
+
query_params = _get_flat_fields_from_params(flat_dependant.query_params)
|
|
195
|
+
header_params = _get_flat_fields_from_params(flat_dependant.header_params)
|
|
196
|
+
cookie_params = _get_flat_fields_from_params(flat_dependant.cookie_params)
|
|
197
|
+
return path_params + query_params + header_params + cookie_params
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def _get_signature(call: Callable[..., Any]) -> inspect.Signature:
|
|
201
|
+
if sys.version_info >= (3, 10):
|
|
202
|
+
try:
|
|
203
|
+
signature = inspect.signature(call, eval_str=True)
|
|
204
|
+
except NameError:
|
|
205
|
+
# Handle type annotations with if TYPE_CHECKING, not used by FastAPI
|
|
206
|
+
# e.g. dependency return types
|
|
207
|
+
signature = inspect.signature(call)
|
|
208
|
+
else:
|
|
209
|
+
signature = inspect.signature(call)
|
|
210
|
+
return signature
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
|
214
|
+
signature = _get_signature(call)
|
|
215
|
+
unwrapped = inspect.unwrap(call)
|
|
216
|
+
globalns = getattr(unwrapped, "__globals__", {})
|
|
217
|
+
typed_params = [
|
|
218
|
+
inspect.Parameter(
|
|
219
|
+
name=param.name,
|
|
220
|
+
kind=param.kind,
|
|
221
|
+
default=param.default,
|
|
222
|
+
annotation=get_typed_annotation(param.annotation, globalns),
|
|
223
|
+
)
|
|
224
|
+
for param in signature.parameters.values()
|
|
225
|
+
]
|
|
226
|
+
typed_signature = inspect.Signature(typed_params)
|
|
227
|
+
return typed_signature
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any:
|
|
231
|
+
if isinstance(annotation, str):
|
|
232
|
+
annotation = ForwardRef(annotation)
|
|
233
|
+
annotation = evaluate_forwardref(annotation, globalns, globalns)
|
|
234
|
+
if annotation is type(None):
|
|
235
|
+
return None
|
|
236
|
+
return annotation
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
|
|
240
|
+
signature = _get_signature(call)
|
|
241
|
+
unwrapped = inspect.unwrap(call)
|
|
242
|
+
annotation = signature.return_annotation
|
|
243
|
+
|
|
244
|
+
if annotation is inspect.Signature.empty:
|
|
245
|
+
return None
|
|
246
|
+
|
|
247
|
+
globalns = getattr(unwrapped, "__globals__", {})
|
|
248
|
+
return get_typed_annotation(annotation, globalns)
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def get_dependant(
|
|
252
|
+
*,
|
|
253
|
+
path: str,
|
|
254
|
+
call: Callable[..., Any],
|
|
255
|
+
name: Optional[str] = None,
|
|
256
|
+
own_oauth_scopes: Optional[list[str]] = None,
|
|
257
|
+
parent_oauth_scopes: Optional[list[str]] = None,
|
|
258
|
+
use_cache: bool = True,
|
|
259
|
+
scope: Union[Literal["function", "request"], None] = None,
|
|
260
|
+
) -> Dependant:
|
|
261
|
+
dependant = Dependant(
|
|
262
|
+
call=call,
|
|
263
|
+
name=name,
|
|
264
|
+
path=path,
|
|
265
|
+
use_cache=use_cache,
|
|
266
|
+
scope=scope,
|
|
267
|
+
own_oauth_scopes=own_oauth_scopes,
|
|
268
|
+
parent_oauth_scopes=parent_oauth_scopes,
|
|
269
|
+
)
|
|
270
|
+
current_scopes = (parent_oauth_scopes or []) + (own_oauth_scopes or [])
|
|
271
|
+
path_param_names = get_path_param_names(path)
|
|
272
|
+
endpoint_signature = get_typed_signature(call)
|
|
273
|
+
signature_params = endpoint_signature.parameters
|
|
274
|
+
for param_name, param in signature_params.items():
|
|
275
|
+
is_path_param = param_name in path_param_names
|
|
276
|
+
param_details = analyze_param(
|
|
277
|
+
param_name=param_name,
|
|
278
|
+
annotation=param.annotation,
|
|
279
|
+
value=param.default,
|
|
280
|
+
is_path_param=is_path_param,
|
|
281
|
+
)
|
|
282
|
+
if param_details.depends is not None:
|
|
283
|
+
assert param_details.depends.dependency
|
|
284
|
+
if (
|
|
285
|
+
(dependant.is_gen_callable or dependant.is_async_gen_callable)
|
|
286
|
+
and dependant.computed_scope == "request"
|
|
287
|
+
and param_details.depends.scope == "function"
|
|
288
|
+
):
|
|
289
|
+
assert dependant.call
|
|
290
|
+
raise DependencyScopeError(
|
|
291
|
+
f'The dependency "{dependant.call.__name__}" has a scope of '
|
|
292
|
+
'"request", it cannot depend on dependencies with scope "function".'
|
|
293
|
+
)
|
|
294
|
+
sub_own_oauth_scopes: list[str] = []
|
|
295
|
+
if isinstance(param_details.depends, params.Security):
|
|
296
|
+
if param_details.depends.scopes:
|
|
297
|
+
sub_own_oauth_scopes = list(param_details.depends.scopes)
|
|
298
|
+
sub_dependant = get_dependant(
|
|
299
|
+
path=path,
|
|
300
|
+
call=param_details.depends.dependency,
|
|
301
|
+
name=param_name,
|
|
302
|
+
own_oauth_scopes=sub_own_oauth_scopes,
|
|
303
|
+
parent_oauth_scopes=current_scopes,
|
|
304
|
+
use_cache=param_details.depends.use_cache,
|
|
305
|
+
scope=param_details.depends.scope,
|
|
306
|
+
)
|
|
307
|
+
dependant.dependencies.append(sub_dependant)
|
|
308
|
+
continue
|
|
309
|
+
if add_non_field_param_to_dependency(
|
|
310
|
+
param_name=param_name,
|
|
311
|
+
type_annotation=param_details.type_annotation,
|
|
312
|
+
dependant=dependant,
|
|
313
|
+
):
|
|
314
|
+
assert param_details.field is None, (
|
|
315
|
+
f"Cannot specify multiple FastAPI annotations for {param_name!r}"
|
|
316
|
+
)
|
|
317
|
+
continue
|
|
318
|
+
assert param_details.field is not None
|
|
319
|
+
if isinstance(param_details.field.field_info, params.Body):
|
|
320
|
+
dependant.body_params.append(param_details.field)
|
|
321
|
+
else:
|
|
322
|
+
add_param_to_fields(field=param_details.field, dependant=dependant)
|
|
323
|
+
return dependant
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def add_non_field_param_to_dependency(
|
|
327
|
+
*, param_name: str, type_annotation: Any, dependant: Dependant
|
|
328
|
+
) -> Optional[bool]:
|
|
329
|
+
if lenient_issubclass(type_annotation, Request):
|
|
330
|
+
dependant.request_param_name = param_name
|
|
331
|
+
return True
|
|
332
|
+
elif lenient_issubclass(type_annotation, WebSocket):
|
|
333
|
+
dependant.websocket_param_name = param_name
|
|
334
|
+
return True
|
|
335
|
+
elif lenient_issubclass(type_annotation, HTTPConnection):
|
|
336
|
+
dependant.http_connection_param_name = param_name
|
|
337
|
+
return True
|
|
338
|
+
elif lenient_issubclass(type_annotation, Response):
|
|
339
|
+
dependant.response_param_name = param_name
|
|
340
|
+
return True
|
|
341
|
+
elif lenient_issubclass(type_annotation, StarletteBackgroundTasks):
|
|
342
|
+
dependant.background_tasks_param_name = param_name
|
|
343
|
+
return True
|
|
344
|
+
elif lenient_issubclass(type_annotation, SecurityScopes):
|
|
345
|
+
dependant.security_scopes_param_name = param_name
|
|
346
|
+
return True
|
|
347
|
+
return None
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
@dataclass
|
|
351
|
+
class ParamDetails:
|
|
352
|
+
type_annotation: Any
|
|
353
|
+
depends: Optional[params.Depends]
|
|
354
|
+
field: Optional[ModelField]
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
def analyze_param(
|
|
358
|
+
*,
|
|
359
|
+
param_name: str,
|
|
360
|
+
annotation: Any,
|
|
361
|
+
value: Any,
|
|
362
|
+
is_path_param: bool,
|
|
363
|
+
) -> ParamDetails:
|
|
364
|
+
field_info = None
|
|
365
|
+
depends = None
|
|
366
|
+
type_annotation: Any = Any
|
|
367
|
+
use_annotation: Any = Any
|
|
368
|
+
if annotation is not inspect.Signature.empty:
|
|
369
|
+
use_annotation = annotation
|
|
370
|
+
type_annotation = annotation
|
|
371
|
+
# Extract Annotated info
|
|
372
|
+
if get_origin(use_annotation) is Annotated:
|
|
373
|
+
annotated_args = get_args(annotation)
|
|
374
|
+
type_annotation = annotated_args[0]
|
|
375
|
+
fastapi_annotations = [
|
|
376
|
+
arg
|
|
377
|
+
for arg in annotated_args[1:]
|
|
378
|
+
if isinstance(arg, (FieldInfo, params.Depends))
|
|
379
|
+
]
|
|
380
|
+
fastapi_specific_annotations = [
|
|
381
|
+
arg
|
|
382
|
+
for arg in fastapi_annotations
|
|
383
|
+
if isinstance(
|
|
384
|
+
arg,
|
|
385
|
+
(
|
|
386
|
+
params.Param,
|
|
387
|
+
params.Body,
|
|
388
|
+
params.Depends,
|
|
389
|
+
),
|
|
390
|
+
)
|
|
391
|
+
]
|
|
392
|
+
if fastapi_specific_annotations:
|
|
393
|
+
fastapi_annotation: Union[FieldInfo, params.Depends, None] = (
|
|
394
|
+
fastapi_specific_annotations[-1]
|
|
395
|
+
)
|
|
396
|
+
else:
|
|
397
|
+
fastapi_annotation = None
|
|
398
|
+
# Set default for Annotated FieldInfo
|
|
399
|
+
if isinstance(fastapi_annotation, FieldInfo):
|
|
400
|
+
# Copy `field_info` because we mutate `field_info.default` below.
|
|
401
|
+
field_info = copy_field_info(
|
|
402
|
+
field_info=fastapi_annotation, # type: ignore[arg-type]
|
|
403
|
+
annotation=use_annotation,
|
|
404
|
+
)
|
|
405
|
+
assert (
|
|
406
|
+
field_info.default == Undefined or field_info.default == RequiredParam
|
|
407
|
+
), (
|
|
408
|
+
f"`{field_info.__class__.__name__}` default value cannot be set in"
|
|
409
|
+
f" `Annotated` for {param_name!r}. Set the default value with `=` instead."
|
|
410
|
+
)
|
|
411
|
+
if value is not inspect.Signature.empty:
|
|
412
|
+
assert not is_path_param, "Path parameters cannot have default values"
|
|
413
|
+
field_info.default = value
|
|
414
|
+
else:
|
|
415
|
+
field_info.default = RequiredParam
|
|
416
|
+
# Get Annotated Depends
|
|
417
|
+
elif isinstance(fastapi_annotation, params.Depends):
|
|
418
|
+
depends = fastapi_annotation
|
|
419
|
+
# Get Depends from default value
|
|
420
|
+
if isinstance(value, params.Depends):
|
|
421
|
+
assert depends is None, (
|
|
422
|
+
"Cannot specify `Depends` in `Annotated` and default value"
|
|
423
|
+
f" together for {param_name!r}"
|
|
424
|
+
)
|
|
425
|
+
assert field_info is None, (
|
|
426
|
+
"Cannot specify a FastAPI annotation in `Annotated` and `Depends` as a"
|
|
427
|
+
f" default value together for {param_name!r}"
|
|
428
|
+
)
|
|
429
|
+
depends = value
|
|
430
|
+
# Get FieldInfo from default value
|
|
431
|
+
elif isinstance(value, FieldInfo):
|
|
432
|
+
assert field_info is None, (
|
|
433
|
+
"Cannot specify FastAPI annotations in `Annotated` and default value"
|
|
434
|
+
f" together for {param_name!r}"
|
|
435
|
+
)
|
|
436
|
+
field_info = value # type: ignore[assignment]
|
|
437
|
+
if isinstance(field_info, FieldInfo):
|
|
438
|
+
field_info.annotation = type_annotation
|
|
439
|
+
|
|
440
|
+
# Get Depends from type annotation
|
|
441
|
+
if depends is not None and depends.dependency is None:
|
|
442
|
+
# Copy `depends` before mutating it
|
|
443
|
+
depends = copy(depends)
|
|
444
|
+
depends = dataclasses.replace(depends, dependency=type_annotation)
|
|
445
|
+
|
|
446
|
+
# Handle non-param type annotations like Request
|
|
447
|
+
if lenient_issubclass(
|
|
448
|
+
type_annotation,
|
|
449
|
+
(
|
|
450
|
+
Request,
|
|
451
|
+
WebSocket,
|
|
452
|
+
HTTPConnection,
|
|
453
|
+
Response,
|
|
454
|
+
StarletteBackgroundTasks,
|
|
455
|
+
SecurityScopes,
|
|
456
|
+
),
|
|
457
|
+
):
|
|
458
|
+
assert depends is None, f"Cannot specify `Depends` for type {type_annotation!r}"
|
|
459
|
+
assert field_info is None, (
|
|
460
|
+
f"Cannot specify FastAPI annotation for type {type_annotation!r}"
|
|
461
|
+
)
|
|
462
|
+
# Handle default assignations, neither field_info nor depends was not found in Annotated nor default value
|
|
463
|
+
elif field_info is None and depends is None:
|
|
464
|
+
default_value = value if value is not inspect.Signature.empty else RequiredParam
|
|
465
|
+
if is_path_param:
|
|
466
|
+
# We might check here that `default_value is RequiredParam`, but the fact is that the same
|
|
467
|
+
# parameter might sometimes be a path parameter and sometimes not. See
|
|
468
|
+
# `tests/test_infer_param_optionality.py` for an example.
|
|
469
|
+
field_info = params.Path(annotation=use_annotation)
|
|
470
|
+
elif is_uploadfile_or_nonable_uploadfile_annotation(
|
|
471
|
+
type_annotation
|
|
472
|
+
) or is_uploadfile_sequence_annotation(type_annotation):
|
|
473
|
+
field_info = params.File(annotation=use_annotation, default=default_value)
|
|
474
|
+
elif not field_annotation_is_scalar(annotation=type_annotation):
|
|
475
|
+
field_info = params.Body(annotation=use_annotation, default=default_value)
|
|
476
|
+
else:
|
|
477
|
+
field_info = params.Query(annotation=use_annotation, default=default_value)
|
|
478
|
+
|
|
479
|
+
field = None
|
|
480
|
+
# It's a field_info, not a dependency
|
|
481
|
+
if field_info is not None:
|
|
482
|
+
# Handle field_info.in_
|
|
483
|
+
if is_path_param:
|
|
484
|
+
assert isinstance(field_info, params.Path), (
|
|
485
|
+
f"Cannot use `{field_info.__class__.__name__}` for path param"
|
|
486
|
+
f" {param_name!r}"
|
|
487
|
+
)
|
|
488
|
+
elif (
|
|
489
|
+
isinstance(field_info, params.Param)
|
|
490
|
+
and getattr(field_info, "in_", None) is None
|
|
491
|
+
):
|
|
492
|
+
field_info.in_ = params.ParamTypes.query
|
|
493
|
+
use_annotation_from_field_info = use_annotation
|
|
494
|
+
if isinstance(field_info, params.Form):
|
|
495
|
+
ensure_multipart_is_installed()
|
|
496
|
+
if not field_info.alias and getattr(field_info, "convert_underscores", None):
|
|
497
|
+
alias = param_name.replace("_", "-")
|
|
498
|
+
else:
|
|
499
|
+
alias = field_info.alias or param_name
|
|
500
|
+
field_info.alias = alias
|
|
501
|
+
field = create_model_field(
|
|
502
|
+
name=param_name,
|
|
503
|
+
type_=use_annotation_from_field_info,
|
|
504
|
+
default=field_info.default,
|
|
505
|
+
alias=alias,
|
|
506
|
+
required=field_info.default in (RequiredParam, Undefined),
|
|
507
|
+
field_info=field_info,
|
|
508
|
+
)
|
|
509
|
+
if is_path_param:
|
|
510
|
+
assert is_scalar_field(field=field), (
|
|
511
|
+
"Path params must be of one of the supported types"
|
|
512
|
+
)
|
|
513
|
+
elif isinstance(field_info, params.Query):
|
|
514
|
+
assert (
|
|
515
|
+
is_scalar_field(field)
|
|
516
|
+
or is_scalar_sequence_field(field)
|
|
517
|
+
or (
|
|
518
|
+
lenient_issubclass(field.type_, BaseModel)
|
|
519
|
+
# For Pydantic v1
|
|
520
|
+
and getattr(field, "shape", 1) == 1
|
|
521
|
+
)
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
return ParamDetails(type_annotation=type_annotation, depends=depends, field=field)
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
|
|
528
|
+
field_info = field.field_info
|
|
529
|
+
field_info_in = getattr(field_info, "in_", None)
|
|
530
|
+
if field_info_in == params.ParamTypes.path:
|
|
531
|
+
dependant.path_params.append(field)
|
|
532
|
+
elif field_info_in == params.ParamTypes.query:
|
|
533
|
+
dependant.query_params.append(field)
|
|
534
|
+
elif field_info_in == params.ParamTypes.header:
|
|
535
|
+
dependant.header_params.append(field)
|
|
536
|
+
else:
|
|
537
|
+
assert field_info_in == params.ParamTypes.cookie, (
|
|
538
|
+
f"non-body parameters must be in path, query, header or cookie: {field.name}"
|
|
539
|
+
)
|
|
540
|
+
dependant.cookie_params.append(field)
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
async def _solve_generator(
|
|
544
|
+
*, dependant: Dependant, stack: AsyncExitStack, sub_values: dict[str, Any]
|
|
545
|
+
) -> Any:
|
|
546
|
+
assert dependant.call
|
|
547
|
+
if dependant.is_async_gen_callable:
|
|
548
|
+
cm = asynccontextmanager(dependant.call)(**sub_values)
|
|
549
|
+
elif dependant.is_gen_callable:
|
|
550
|
+
cm = contextmanager_in_threadpool(contextmanager(dependant.call)(**sub_values))
|
|
551
|
+
return await stack.enter_async_context(cm)
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
@dataclass
|
|
555
|
+
class SolvedDependency:
|
|
556
|
+
values: dict[str, Any]
|
|
557
|
+
errors: list[Any]
|
|
558
|
+
background_tasks: Optional[StarletteBackgroundTasks]
|
|
559
|
+
response: Response
|
|
560
|
+
dependency_cache: dict[DependencyCacheKey, Any]
|
|
561
|
+
|
|
562
|
+
|
|
563
|
+
async def solve_dependencies(
|
|
564
|
+
*,
|
|
565
|
+
request: Union[Request, WebSocket],
|
|
566
|
+
dependant: Dependant,
|
|
567
|
+
body: Optional[Union[dict[str, Any], FormData]] = None,
|
|
568
|
+
background_tasks: Optional[StarletteBackgroundTasks] = None,
|
|
569
|
+
response: Optional[Response] = None,
|
|
570
|
+
dependency_overrides_provider: Optional[Any] = None,
|
|
571
|
+
dependency_cache: Optional[dict[DependencyCacheKey, Any]] = None,
|
|
572
|
+
# TODO: remove this parameter later, no longer used, not removing it yet as some
|
|
573
|
+
# people might be monkey patching this function (although that's not supported)
|
|
574
|
+
async_exit_stack: AsyncExitStack,
|
|
575
|
+
embed_body_fields: bool,
|
|
576
|
+
) -> SolvedDependency:
|
|
577
|
+
request_astack = request.scope.get("fastapi_inner_astack")
|
|
578
|
+
assert isinstance(request_astack, AsyncExitStack), (
|
|
579
|
+
"fastapi_inner_astack not found in request scope"
|
|
580
|
+
)
|
|
581
|
+
function_astack = request.scope.get("fastapi_function_astack")
|
|
582
|
+
assert isinstance(function_astack, AsyncExitStack), (
|
|
583
|
+
"fastapi_function_astack not found in request scope"
|
|
584
|
+
)
|
|
585
|
+
values: dict[str, Any] = {}
|
|
586
|
+
errors: list[Any] = []
|
|
587
|
+
if response is None:
|
|
588
|
+
response = Response()
|
|
589
|
+
del response.headers["content-length"]
|
|
590
|
+
response.status_code = None # type: ignore
|
|
591
|
+
if dependency_cache is None:
|
|
592
|
+
dependency_cache = {}
|
|
593
|
+
for sub_dependant in dependant.dependencies:
|
|
594
|
+
sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
|
|
595
|
+
call = sub_dependant.call
|
|
596
|
+
use_sub_dependant = sub_dependant
|
|
597
|
+
if (
|
|
598
|
+
dependency_overrides_provider
|
|
599
|
+
and dependency_overrides_provider.dependency_overrides
|
|
600
|
+
):
|
|
601
|
+
original_call = sub_dependant.call
|
|
602
|
+
call = getattr(
|
|
603
|
+
dependency_overrides_provider, "dependency_overrides", {}
|
|
604
|
+
).get(original_call, original_call)
|
|
605
|
+
use_path: str = sub_dependant.path # type: ignore
|
|
606
|
+
use_sub_dependant = get_dependant(
|
|
607
|
+
path=use_path,
|
|
608
|
+
call=call,
|
|
609
|
+
name=sub_dependant.name,
|
|
610
|
+
parent_oauth_scopes=sub_dependant.oauth_scopes,
|
|
611
|
+
scope=sub_dependant.scope,
|
|
612
|
+
)
|
|
613
|
+
|
|
614
|
+
solved_result = await solve_dependencies(
|
|
615
|
+
request=request,
|
|
616
|
+
dependant=use_sub_dependant,
|
|
617
|
+
body=body,
|
|
618
|
+
background_tasks=background_tasks,
|
|
619
|
+
response=response,
|
|
620
|
+
dependency_overrides_provider=dependency_overrides_provider,
|
|
621
|
+
dependency_cache=dependency_cache,
|
|
622
|
+
async_exit_stack=async_exit_stack,
|
|
623
|
+
embed_body_fields=embed_body_fields,
|
|
624
|
+
)
|
|
625
|
+
background_tasks = solved_result.background_tasks
|
|
626
|
+
if solved_result.errors:
|
|
627
|
+
errors.extend(solved_result.errors)
|
|
628
|
+
continue
|
|
629
|
+
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
|
|
630
|
+
solved = dependency_cache[sub_dependant.cache_key]
|
|
631
|
+
elif (
|
|
632
|
+
use_sub_dependant.is_gen_callable or use_sub_dependant.is_async_gen_callable
|
|
633
|
+
):
|
|
634
|
+
use_astack = request_astack
|
|
635
|
+
if sub_dependant.scope == "function":
|
|
636
|
+
use_astack = function_astack
|
|
637
|
+
solved = await _solve_generator(
|
|
638
|
+
dependant=use_sub_dependant,
|
|
639
|
+
stack=use_astack,
|
|
640
|
+
sub_values=solved_result.values,
|
|
641
|
+
)
|
|
642
|
+
elif use_sub_dependant.is_coroutine_callable:
|
|
643
|
+
solved = await call(**solved_result.values)
|
|
644
|
+
else:
|
|
645
|
+
solved = await run_in_threadpool(call, **solved_result.values)
|
|
646
|
+
if sub_dependant.name is not None:
|
|
647
|
+
values[sub_dependant.name] = solved
|
|
648
|
+
if sub_dependant.cache_key not in dependency_cache:
|
|
649
|
+
dependency_cache[sub_dependant.cache_key] = solved
|
|
650
|
+
path_values, path_errors = request_params_to_args(
|
|
651
|
+
dependant.path_params, request.path_params
|
|
652
|
+
)
|
|
653
|
+
query_values, query_errors = request_params_to_args(
|
|
654
|
+
dependant.query_params, request.query_params
|
|
655
|
+
)
|
|
656
|
+
header_values, header_errors = request_params_to_args(
|
|
657
|
+
dependant.header_params, request.headers
|
|
658
|
+
)
|
|
659
|
+
cookie_values, cookie_errors = request_params_to_args(
|
|
660
|
+
dependant.cookie_params, request.cookies
|
|
661
|
+
)
|
|
662
|
+
values.update(path_values)
|
|
663
|
+
values.update(query_values)
|
|
664
|
+
values.update(header_values)
|
|
665
|
+
values.update(cookie_values)
|
|
666
|
+
errors += path_errors + query_errors + header_errors + cookie_errors
|
|
667
|
+
if dependant.body_params:
|
|
668
|
+
(
|
|
669
|
+
body_values,
|
|
670
|
+
body_errors,
|
|
671
|
+
) = await request_body_to_args( # body_params checked above
|
|
672
|
+
body_fields=dependant.body_params,
|
|
673
|
+
received_body=body,
|
|
674
|
+
embed_body_fields=embed_body_fields,
|
|
675
|
+
)
|
|
676
|
+
values.update(body_values)
|
|
677
|
+
errors.extend(body_errors)
|
|
678
|
+
if dependant.http_connection_param_name:
|
|
679
|
+
values[dependant.http_connection_param_name] = request
|
|
680
|
+
if dependant.request_param_name and isinstance(request, Request):
|
|
681
|
+
values[dependant.request_param_name] = request
|
|
682
|
+
elif dependant.websocket_param_name and isinstance(request, WebSocket):
|
|
683
|
+
values[dependant.websocket_param_name] = request
|
|
684
|
+
if dependant.background_tasks_param_name:
|
|
685
|
+
if background_tasks is None:
|
|
686
|
+
background_tasks = BackgroundTasks()
|
|
687
|
+
values[dependant.background_tasks_param_name] = background_tasks
|
|
688
|
+
if dependant.response_param_name:
|
|
689
|
+
values[dependant.response_param_name] = response
|
|
690
|
+
if dependant.security_scopes_param_name:
|
|
691
|
+
values[dependant.security_scopes_param_name] = SecurityScopes(
|
|
692
|
+
scopes=dependant.oauth_scopes
|
|
693
|
+
)
|
|
694
|
+
return SolvedDependency(
|
|
695
|
+
values=values,
|
|
696
|
+
errors=errors,
|
|
697
|
+
background_tasks=background_tasks,
|
|
698
|
+
response=response,
|
|
699
|
+
dependency_cache=dependency_cache,
|
|
700
|
+
)
|
|
701
|
+
|
|
702
|
+
|
|
703
|
+
def _validate_value_with_model_field(
|
|
704
|
+
*, field: ModelField, value: Any, values: dict[str, Any], loc: tuple[str, ...]
|
|
705
|
+
) -> tuple[Any, list[Any]]:
|
|
706
|
+
if value is None:
|
|
707
|
+
if field.required:
|
|
708
|
+
return None, [get_missing_field_error(loc=loc)]
|
|
709
|
+
else:
|
|
710
|
+
return deepcopy(field.default), []
|
|
711
|
+
v_, errors_ = field.validate(value, values, loc=loc)
|
|
712
|
+
if isinstance(errors_, list):
|
|
713
|
+
new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
|
|
714
|
+
return None, new_errors
|
|
715
|
+
else:
|
|
716
|
+
return v_, []
|
|
717
|
+
|
|
718
|
+
|
|
719
|
+
def _get_multidict_value(
|
|
720
|
+
field: ModelField, values: Mapping[str, Any], alias: Union[str, None] = None
|
|
721
|
+
) -> Any:
|
|
722
|
+
alias = alias or get_validation_alias(field)
|
|
723
|
+
if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)):
|
|
724
|
+
value = values.getlist(alias)
|
|
725
|
+
else:
|
|
726
|
+
value = values.get(alias, None)
|
|
727
|
+
if (
|
|
728
|
+
value is None
|
|
729
|
+
or (
|
|
730
|
+
isinstance(field.field_info, params.Form)
|
|
731
|
+
and isinstance(value, str) # For type checks
|
|
732
|
+
and value == ""
|
|
733
|
+
)
|
|
734
|
+
or (is_sequence_field(field) and len(value) == 0)
|
|
735
|
+
):
|
|
736
|
+
if field.required:
|
|
737
|
+
return
|
|
738
|
+
else:
|
|
739
|
+
return deepcopy(field.default)
|
|
740
|
+
return value
|
|
741
|
+
|
|
742
|
+
|
|
743
|
+
def request_params_to_args(
|
|
744
|
+
fields: Sequence[ModelField],
|
|
745
|
+
received_params: Union[Mapping[str, Any], QueryParams, Headers],
|
|
746
|
+
) -> tuple[dict[str, Any], list[Any]]:
|
|
747
|
+
values: dict[str, Any] = {}
|
|
748
|
+
errors: list[dict[str, Any]] = []
|
|
749
|
+
|
|
750
|
+
if not fields:
|
|
751
|
+
return values, errors
|
|
752
|
+
|
|
753
|
+
first_field = fields[0]
|
|
754
|
+
fields_to_extract = fields
|
|
755
|
+
single_not_embedded_field = False
|
|
756
|
+
default_convert_underscores = True
|
|
757
|
+
if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel):
|
|
758
|
+
fields_to_extract = get_cached_model_fields(first_field.type_)
|
|
759
|
+
single_not_embedded_field = True
|
|
760
|
+
# If headers are in a Pydantic model, the way to disable convert_underscores
|
|
761
|
+
# would be with Header(convert_underscores=False) at the Pydantic model level
|
|
762
|
+
default_convert_underscores = getattr(
|
|
763
|
+
first_field.field_info, "convert_underscores", True
|
|
764
|
+
)
|
|
765
|
+
|
|
766
|
+
params_to_process: dict[str, Any] = {}
|
|
767
|
+
|
|
768
|
+
processed_keys = set()
|
|
769
|
+
|
|
770
|
+
for field in fields_to_extract:
|
|
771
|
+
alias = None
|
|
772
|
+
if isinstance(received_params, Headers):
|
|
773
|
+
# Handle fields extracted from a Pydantic Model for a header, each field
|
|
774
|
+
# doesn't have a FieldInfo of type Header with the default convert_underscores=True
|
|
775
|
+
convert_underscores = getattr(
|
|
776
|
+
field.field_info, "convert_underscores", default_convert_underscores
|
|
777
|
+
)
|
|
778
|
+
if convert_underscores:
|
|
779
|
+
alias = get_validation_alias(field)
|
|
780
|
+
if alias == field.name:
|
|
781
|
+
alias = alias.replace("_", "-")
|
|
782
|
+
value = _get_multidict_value(field, received_params, alias=alias)
|
|
783
|
+
if value is not None:
|
|
784
|
+
params_to_process[get_validation_alias(field)] = value
|
|
785
|
+
processed_keys.add(alias or get_validation_alias(field))
|
|
786
|
+
|
|
787
|
+
for key in received_params.keys():
|
|
788
|
+
if key not in processed_keys:
|
|
789
|
+
if hasattr(received_params, "getlist"):
|
|
790
|
+
value = received_params.getlist(key)
|
|
791
|
+
if isinstance(value, list) and (len(value) == 1):
|
|
792
|
+
params_to_process[key] = value[0]
|
|
793
|
+
else:
|
|
794
|
+
params_to_process[key] = value
|
|
795
|
+
else:
|
|
796
|
+
params_to_process[key] = received_params.get(key)
|
|
797
|
+
|
|
798
|
+
if single_not_embedded_field:
|
|
799
|
+
field_info = first_field.field_info
|
|
800
|
+
assert isinstance(field_info, params.Param), (
|
|
801
|
+
"Params must be subclasses of Param"
|
|
802
|
+
)
|
|
803
|
+
loc: tuple[str, ...] = (field_info.in_.value,)
|
|
804
|
+
v_, errors_ = _validate_value_with_model_field(
|
|
805
|
+
field=first_field, value=params_to_process, values=values, loc=loc
|
|
806
|
+
)
|
|
807
|
+
return {first_field.name: v_}, errors_
|
|
808
|
+
|
|
809
|
+
for field in fields:
|
|
810
|
+
value = _get_multidict_value(field, received_params)
|
|
811
|
+
field_info = field.field_info
|
|
812
|
+
assert isinstance(field_info, params.Param), (
|
|
813
|
+
"Params must be subclasses of Param"
|
|
814
|
+
)
|
|
815
|
+
loc = (field_info.in_.value, get_validation_alias(field))
|
|
816
|
+
v_, errors_ = _validate_value_with_model_field(
|
|
817
|
+
field=field, value=value, values=values, loc=loc
|
|
818
|
+
)
|
|
819
|
+
if errors_:
|
|
820
|
+
errors.extend(errors_)
|
|
821
|
+
else:
|
|
822
|
+
values[field.name] = v_
|
|
823
|
+
return values, errors
|
|
824
|
+
|
|
825
|
+
|
|
826
|
+
def is_union_of_base_models(field_type: Any) -> bool:
|
|
827
|
+
"""Check if field type is a Union where all members are BaseModel subclasses."""
|
|
828
|
+
from fastapi.types import UnionType
|
|
829
|
+
|
|
830
|
+
origin = get_origin(field_type)
|
|
831
|
+
|
|
832
|
+
# Check if it's a Union type (covers both typing.Union and types.UnionType in Python 3.10+)
|
|
833
|
+
if origin is not Union and origin is not UnionType:
|
|
834
|
+
return False
|
|
835
|
+
|
|
836
|
+
union_args = get_args(field_type)
|
|
837
|
+
|
|
838
|
+
for arg in union_args:
|
|
839
|
+
if not lenient_issubclass(arg, BaseModel):
|
|
840
|
+
return False
|
|
841
|
+
|
|
842
|
+
return True
|
|
843
|
+
|
|
844
|
+
|
|
845
|
+
def _should_embed_body_fields(fields: list[ModelField]) -> bool:
|
|
846
|
+
if not fields:
|
|
847
|
+
return False
|
|
848
|
+
# More than one dependency could have the same field, it would show up as multiple
|
|
849
|
+
# fields but it's the same one, so count them by name
|
|
850
|
+
body_param_names_set = {field.name for field in fields}
|
|
851
|
+
# A top level field has to be a single field, not multiple
|
|
852
|
+
if len(body_param_names_set) > 1:
|
|
853
|
+
return True
|
|
854
|
+
first_field = fields[0]
|
|
855
|
+
# If it explicitly specifies it is embedded, it has to be embedded
|
|
856
|
+
if getattr(first_field.field_info, "embed", None):
|
|
857
|
+
return True
|
|
858
|
+
# If it's a Form (or File) field, it has to be a BaseModel (or a union of BaseModels) to be top level
|
|
859
|
+
# otherwise it has to be embedded, so that the key value pair can be extracted
|
|
860
|
+
if (
|
|
861
|
+
isinstance(first_field.field_info, params.Form)
|
|
862
|
+
and not lenient_issubclass(first_field.type_, BaseModel)
|
|
863
|
+
and not is_union_of_base_models(first_field.type_)
|
|
864
|
+
):
|
|
865
|
+
return True
|
|
866
|
+
return False
|
|
867
|
+
|
|
868
|
+
|
|
869
|
+
async def _extract_form_body(
|
|
870
|
+
body_fields: list[ModelField],
|
|
871
|
+
received_body: FormData,
|
|
872
|
+
) -> dict[str, Any]:
|
|
873
|
+
values = {}
|
|
874
|
+
|
|
875
|
+
for field in body_fields:
|
|
876
|
+
value = _get_multidict_value(field, received_body)
|
|
877
|
+
field_info = field.field_info
|
|
878
|
+
if (
|
|
879
|
+
isinstance(field_info, params.File)
|
|
880
|
+
and is_bytes_field(field)
|
|
881
|
+
and isinstance(value, UploadFile)
|
|
882
|
+
):
|
|
883
|
+
value = await value.read()
|
|
884
|
+
elif (
|
|
885
|
+
is_bytes_sequence_field(field)
|
|
886
|
+
and isinstance(field_info, params.File)
|
|
887
|
+
and value_is_sequence(value)
|
|
888
|
+
):
|
|
889
|
+
# For types
|
|
890
|
+
assert isinstance(value, sequence_types)
|
|
891
|
+
results: list[Union[bytes, str]] = []
|
|
892
|
+
|
|
893
|
+
async def process_fn(
|
|
894
|
+
fn: Callable[[], Coroutine[Any, Any, Any]],
|
|
895
|
+
) -> None:
|
|
896
|
+
result = await fn()
|
|
897
|
+
results.append(result) # noqa: B023
|
|
898
|
+
|
|
899
|
+
async with anyio.create_task_group() as tg:
|
|
900
|
+
for sub_value in value:
|
|
901
|
+
tg.start_soon(process_fn, sub_value.read)
|
|
902
|
+
value = serialize_sequence_value(field=field, value=results)
|
|
903
|
+
if value is not None:
|
|
904
|
+
values[get_validation_alias(field)] = value
|
|
905
|
+
field_aliases = {get_validation_alias(field) for field in body_fields}
|
|
906
|
+
for key in received_body.keys():
|
|
907
|
+
if key not in field_aliases:
|
|
908
|
+
param_values = received_body.getlist(key)
|
|
909
|
+
if len(param_values) == 1:
|
|
910
|
+
values[key] = param_values[0]
|
|
911
|
+
else:
|
|
912
|
+
values[key] = param_values
|
|
913
|
+
return values
|
|
914
|
+
|
|
915
|
+
|
|
916
|
+
async def request_body_to_args(
|
|
917
|
+
body_fields: list[ModelField],
|
|
918
|
+
received_body: Optional[Union[dict[str, Any], FormData]],
|
|
919
|
+
embed_body_fields: bool,
|
|
920
|
+
) -> tuple[dict[str, Any], list[dict[str, Any]]]:
|
|
921
|
+
values: dict[str, Any] = {}
|
|
922
|
+
errors: list[dict[str, Any]] = []
|
|
923
|
+
assert body_fields, "request_body_to_args() should be called with fields"
|
|
924
|
+
single_not_embedded_field = len(body_fields) == 1 and not embed_body_fields
|
|
925
|
+
first_field = body_fields[0]
|
|
926
|
+
body_to_process = received_body
|
|
927
|
+
|
|
928
|
+
fields_to_extract: list[ModelField] = body_fields
|
|
929
|
+
|
|
930
|
+
if (
|
|
931
|
+
single_not_embedded_field
|
|
932
|
+
and lenient_issubclass(first_field.type_, BaseModel)
|
|
933
|
+
and isinstance(received_body, FormData)
|
|
934
|
+
):
|
|
935
|
+
fields_to_extract = get_cached_model_fields(first_field.type_)
|
|
936
|
+
|
|
937
|
+
if isinstance(received_body, FormData):
|
|
938
|
+
body_to_process = await _extract_form_body(fields_to_extract, received_body)
|
|
939
|
+
|
|
940
|
+
if single_not_embedded_field:
|
|
941
|
+
loc: tuple[str, ...] = ("body",)
|
|
942
|
+
v_, errors_ = _validate_value_with_model_field(
|
|
943
|
+
field=first_field, value=body_to_process, values=values, loc=loc
|
|
944
|
+
)
|
|
945
|
+
return {first_field.name: v_}, errors_
|
|
946
|
+
for field in body_fields:
|
|
947
|
+
loc = ("body", get_validation_alias(field))
|
|
948
|
+
value: Optional[Any] = None
|
|
949
|
+
if body_to_process is not None:
|
|
950
|
+
try:
|
|
951
|
+
value = body_to_process.get(get_validation_alias(field))
|
|
952
|
+
# If the received body is a list, not a dict
|
|
953
|
+
except AttributeError:
|
|
954
|
+
errors.append(get_missing_field_error(loc))
|
|
955
|
+
continue
|
|
956
|
+
v_, errors_ = _validate_value_with_model_field(
|
|
957
|
+
field=field, value=value, values=values, loc=loc
|
|
958
|
+
)
|
|
959
|
+
if errors_:
|
|
960
|
+
errors.extend(errors_)
|
|
961
|
+
else:
|
|
962
|
+
values[field.name] = v_
|
|
963
|
+
return values, errors
|
|
964
|
+
|
|
965
|
+
|
|
966
|
+
def get_body_field(
|
|
967
|
+
*, flat_dependant: Dependant, name: str, embed_body_fields: bool
|
|
968
|
+
) -> Optional[ModelField]:
|
|
969
|
+
"""
|
|
970
|
+
Get a ModelField representing the request body for a path operation, combining
|
|
971
|
+
all body parameters into a single field if necessary.
|
|
972
|
+
|
|
973
|
+
Used to check if it's form data (with `isinstance(body_field, params.Form)`)
|
|
974
|
+
or JSON and to generate the JSON Schema for a request body.
|
|
975
|
+
|
|
976
|
+
This is **not** used to validate/parse the request body, that's done with each
|
|
977
|
+
individual body parameter.
|
|
978
|
+
"""
|
|
979
|
+
if not flat_dependant.body_params:
|
|
980
|
+
return None
|
|
981
|
+
first_param = flat_dependant.body_params[0]
|
|
982
|
+
if not embed_body_fields:
|
|
983
|
+
return first_param
|
|
984
|
+
model_name = "Body_" + name
|
|
985
|
+
BodyModel = create_body_model(
|
|
986
|
+
fields=flat_dependant.body_params, model_name=model_name
|
|
987
|
+
)
|
|
988
|
+
required = any(True for f in flat_dependant.body_params if f.required)
|
|
989
|
+
BodyFieldInfo_kwargs: dict[str, Any] = {
|
|
990
|
+
"annotation": BodyModel,
|
|
991
|
+
"alias": "body",
|
|
992
|
+
}
|
|
993
|
+
if not required:
|
|
994
|
+
BodyFieldInfo_kwargs["default"] = None
|
|
995
|
+
if any(isinstance(f.field_info, params.File) for f in flat_dependant.body_params):
|
|
996
|
+
BodyFieldInfo: type[params.Body] = params.File
|
|
997
|
+
elif any(isinstance(f.field_info, params.Form) for f in flat_dependant.body_params):
|
|
998
|
+
BodyFieldInfo = params.Form
|
|
999
|
+
else:
|
|
1000
|
+
BodyFieldInfo = params.Body
|
|
1001
|
+
|
|
1002
|
+
body_param_media_types = [
|
|
1003
|
+
f.field_info.media_type
|
|
1004
|
+
for f in flat_dependant.body_params
|
|
1005
|
+
if isinstance(f.field_info, params.Body)
|
|
1006
|
+
]
|
|
1007
|
+
if len(set(body_param_media_types)) == 1:
|
|
1008
|
+
BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0]
|
|
1009
|
+
final_field = create_model_field(
|
|
1010
|
+
name="body",
|
|
1011
|
+
type_=BodyModel,
|
|
1012
|
+
required=required,
|
|
1013
|
+
alias="body",
|
|
1014
|
+
field_info=BodyFieldInfo(**BodyFieldInfo_kwargs),
|
|
1015
|
+
)
|
|
1016
|
+
return final_field
|
|
1017
|
+
|
|
1018
|
+
|
|
1019
|
+
def get_validation_alias(field: ModelField) -> str:
|
|
1020
|
+
va = getattr(field, "validation_alias", None)
|
|
1021
|
+
return va or field.alias
|