fastapi 0.128.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. fastapi/__init__.py +25 -0
  2. fastapi/__main__.py +3 -0
  3. fastapi/_compat/__init__.py +41 -0
  4. fastapi/_compat/shared.py +206 -0
  5. fastapi/_compat/v2.py +568 -0
  6. fastapi/applications.py +4669 -0
  7. fastapi/background.py +60 -0
  8. fastapi/cli.py +13 -0
  9. fastapi/concurrency.py +41 -0
  10. fastapi/datastructures.py +183 -0
  11. fastapi/dependencies/__init__.py +0 -0
  12. fastapi/dependencies/models.py +193 -0
  13. fastapi/dependencies/utils.py +1021 -0
  14. fastapi/encoders.py +346 -0
  15. fastapi/exception_handlers.py +34 -0
  16. fastapi/exceptions.py +246 -0
  17. fastapi/logger.py +3 -0
  18. fastapi/middleware/__init__.py +1 -0
  19. fastapi/middleware/asyncexitstack.py +18 -0
  20. fastapi/middleware/cors.py +1 -0
  21. fastapi/middleware/gzip.py +1 -0
  22. fastapi/middleware/httpsredirect.py +3 -0
  23. fastapi/middleware/trustedhost.py +3 -0
  24. fastapi/middleware/wsgi.py +1 -0
  25. fastapi/openapi/__init__.py +0 -0
  26. fastapi/openapi/constants.py +3 -0
  27. fastapi/openapi/docs.py +344 -0
  28. fastapi/openapi/models.py +438 -0
  29. fastapi/openapi/utils.py +567 -0
  30. fastapi/param_functions.py +2369 -0
  31. fastapi/params.py +755 -0
  32. fastapi/py.typed +0 -0
  33. fastapi/requests.py +2 -0
  34. fastapi/responses.py +48 -0
  35. fastapi/routing.py +4508 -0
  36. fastapi/security/__init__.py +15 -0
  37. fastapi/security/api_key.py +318 -0
  38. fastapi/security/base.py +6 -0
  39. fastapi/security/http.py +423 -0
  40. fastapi/security/oauth2.py +663 -0
  41. fastapi/security/open_id_connect_url.py +94 -0
  42. fastapi/security/utils.py +10 -0
  43. fastapi/staticfiles.py +1 -0
  44. fastapi/templating.py +1 -0
  45. fastapi/testclient.py +1 -0
  46. fastapi/types.py +11 -0
  47. fastapi/utils.py +164 -0
  48. fastapi/websockets.py +3 -0
  49. fastapi-0.128.0.dist-info/METADATA +645 -0
  50. fastapi-0.128.0.dist-info/RECORD +53 -0
  51. fastapi-0.128.0.dist-info/WHEEL +4 -0
  52. fastapi-0.128.0.dist-info/entry_points.txt +5 -0
  53. fastapi-0.128.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,1021 @@
1
+ import dataclasses
2
+ import inspect
3
+ import sys
4
+ from collections.abc import Coroutine, Mapping, Sequence
5
+ from contextlib import AsyncExitStack, contextmanager
6
+ from copy import copy, deepcopy
7
+ from dataclasses import dataclass
8
+ from typing import (
9
+ Annotated,
10
+ Any,
11
+ Callable,
12
+ ForwardRef,
13
+ Optional,
14
+ Union,
15
+ cast,
16
+ )
17
+
18
+ import anyio
19
+ from fastapi import params
20
+ from fastapi._compat import (
21
+ ModelField,
22
+ RequiredParam,
23
+ Undefined,
24
+ _regenerate_error_with_loc,
25
+ copy_field_info,
26
+ create_body_model,
27
+ evaluate_forwardref,
28
+ field_annotation_is_scalar,
29
+ get_cached_model_fields,
30
+ get_missing_field_error,
31
+ is_bytes_field,
32
+ is_bytes_sequence_field,
33
+ is_scalar_field,
34
+ is_scalar_sequence_field,
35
+ is_sequence_field,
36
+ is_uploadfile_or_nonable_uploadfile_annotation,
37
+ is_uploadfile_sequence_annotation,
38
+ lenient_issubclass,
39
+ sequence_types,
40
+ serialize_sequence_value,
41
+ value_is_sequence,
42
+ )
43
+ from fastapi.background import BackgroundTasks
44
+ from fastapi.concurrency import (
45
+ asynccontextmanager,
46
+ contextmanager_in_threadpool,
47
+ )
48
+ from fastapi.dependencies.models import Dependant
49
+ from fastapi.exceptions import DependencyScopeError
50
+ from fastapi.logger import logger
51
+ from fastapi.security.oauth2 import SecurityScopes
52
+ from fastapi.types import DependencyCacheKey
53
+ from fastapi.utils import create_model_field, get_path_param_names
54
+ from pydantic import BaseModel
55
+ from pydantic.fields import FieldInfo
56
+ from starlette.background import BackgroundTasks as StarletteBackgroundTasks
57
+ from starlette.concurrency import run_in_threadpool
58
+ from starlette.datastructures import (
59
+ FormData,
60
+ Headers,
61
+ ImmutableMultiDict,
62
+ QueryParams,
63
+ UploadFile,
64
+ )
65
+ from starlette.requests import HTTPConnection, Request
66
+ from starlette.responses import Response
67
+ from starlette.websockets import WebSocket
68
+ from typing_extensions import Literal, get_args, get_origin
69
+
70
+ multipart_not_installed_error = (
71
+ 'Form data requires "python-multipart" to be installed. \n'
72
+ 'You can install "python-multipart" with: \n\n'
73
+ "pip install python-multipart\n"
74
+ )
75
+ multipart_incorrect_install_error = (
76
+ 'Form data requires "python-multipart" to be installed. '
77
+ 'It seems you installed "multipart" instead. \n'
78
+ 'You can remove "multipart" with: \n\n'
79
+ "pip uninstall multipart\n\n"
80
+ 'And then install "python-multipart" with: \n\n'
81
+ "pip install python-multipart\n"
82
+ )
83
+
84
+
85
+ def ensure_multipart_is_installed() -> None:
86
+ try:
87
+ from python_multipart import __version__
88
+
89
+ # Import an attribute that can be mocked/deleted in testing
90
+ assert __version__ > "0.0.12"
91
+ except (ImportError, AssertionError):
92
+ try:
93
+ # __version__ is available in both multiparts, and can be mocked
94
+ from multipart import __version__ # type: ignore[no-redef,import-untyped]
95
+
96
+ assert __version__
97
+ try:
98
+ # parse_options_header is only available in the right multipart
99
+ from multipart.multipart import ( # type: ignore[import-untyped]
100
+ parse_options_header,
101
+ )
102
+
103
+ assert parse_options_header
104
+ except ImportError:
105
+ logger.error(multipart_incorrect_install_error)
106
+ raise RuntimeError(multipart_incorrect_install_error) from None
107
+ except ImportError:
108
+ logger.error(multipart_not_installed_error)
109
+ raise RuntimeError(multipart_not_installed_error) from None
110
+
111
+
112
+ def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant:
113
+ assert callable(depends.dependency), (
114
+ "A parameter-less dependency must have a callable dependency"
115
+ )
116
+ own_oauth_scopes: list[str] = []
117
+ if isinstance(depends, params.Security) and depends.scopes:
118
+ own_oauth_scopes.extend(depends.scopes)
119
+ return get_dependant(
120
+ path=path,
121
+ call=depends.dependency,
122
+ scope=depends.scope,
123
+ own_oauth_scopes=own_oauth_scopes,
124
+ )
125
+
126
+
127
+ def get_flat_dependant(
128
+ dependant: Dependant,
129
+ *,
130
+ skip_repeats: bool = False,
131
+ visited: Optional[list[DependencyCacheKey]] = None,
132
+ parent_oauth_scopes: Optional[list[str]] = None,
133
+ ) -> Dependant:
134
+ if visited is None:
135
+ visited = []
136
+ visited.append(dependant.cache_key)
137
+ use_parent_oauth_scopes = (parent_oauth_scopes or []) + (
138
+ dependant.oauth_scopes or []
139
+ )
140
+
141
+ flat_dependant = Dependant(
142
+ path_params=dependant.path_params.copy(),
143
+ query_params=dependant.query_params.copy(),
144
+ header_params=dependant.header_params.copy(),
145
+ cookie_params=dependant.cookie_params.copy(),
146
+ body_params=dependant.body_params.copy(),
147
+ name=dependant.name,
148
+ call=dependant.call,
149
+ request_param_name=dependant.request_param_name,
150
+ websocket_param_name=dependant.websocket_param_name,
151
+ http_connection_param_name=dependant.http_connection_param_name,
152
+ response_param_name=dependant.response_param_name,
153
+ background_tasks_param_name=dependant.background_tasks_param_name,
154
+ security_scopes_param_name=dependant.security_scopes_param_name,
155
+ own_oauth_scopes=dependant.own_oauth_scopes,
156
+ parent_oauth_scopes=use_parent_oauth_scopes,
157
+ use_cache=dependant.use_cache,
158
+ path=dependant.path,
159
+ scope=dependant.scope,
160
+ )
161
+ for sub_dependant in dependant.dependencies:
162
+ if skip_repeats and sub_dependant.cache_key in visited:
163
+ continue
164
+ flat_sub = get_flat_dependant(
165
+ sub_dependant,
166
+ skip_repeats=skip_repeats,
167
+ visited=visited,
168
+ parent_oauth_scopes=flat_dependant.oauth_scopes,
169
+ )
170
+ flat_dependant.dependencies.append(flat_sub)
171
+ flat_dependant.path_params.extend(flat_sub.path_params)
172
+ flat_dependant.query_params.extend(flat_sub.query_params)
173
+ flat_dependant.header_params.extend(flat_sub.header_params)
174
+ flat_dependant.cookie_params.extend(flat_sub.cookie_params)
175
+ flat_dependant.body_params.extend(flat_sub.body_params)
176
+ flat_dependant.dependencies.extend(flat_sub.dependencies)
177
+
178
+ return flat_dependant
179
+
180
+
181
+ def _get_flat_fields_from_params(fields: list[ModelField]) -> list[ModelField]:
182
+ if not fields:
183
+ return fields
184
+ first_field = fields[0]
185
+ if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel):
186
+ fields_to_extract = get_cached_model_fields(first_field.type_)
187
+ return fields_to_extract
188
+ return fields
189
+
190
+
191
+ def get_flat_params(dependant: Dependant) -> list[ModelField]:
192
+ flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
193
+ path_params = _get_flat_fields_from_params(flat_dependant.path_params)
194
+ query_params = _get_flat_fields_from_params(flat_dependant.query_params)
195
+ header_params = _get_flat_fields_from_params(flat_dependant.header_params)
196
+ cookie_params = _get_flat_fields_from_params(flat_dependant.cookie_params)
197
+ return path_params + query_params + header_params + cookie_params
198
+
199
+
200
+ def _get_signature(call: Callable[..., Any]) -> inspect.Signature:
201
+ if sys.version_info >= (3, 10):
202
+ try:
203
+ signature = inspect.signature(call, eval_str=True)
204
+ except NameError:
205
+ # Handle type annotations with if TYPE_CHECKING, not used by FastAPI
206
+ # e.g. dependency return types
207
+ signature = inspect.signature(call)
208
+ else:
209
+ signature = inspect.signature(call)
210
+ return signature
211
+
212
+
213
+ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
214
+ signature = _get_signature(call)
215
+ unwrapped = inspect.unwrap(call)
216
+ globalns = getattr(unwrapped, "__globals__", {})
217
+ typed_params = [
218
+ inspect.Parameter(
219
+ name=param.name,
220
+ kind=param.kind,
221
+ default=param.default,
222
+ annotation=get_typed_annotation(param.annotation, globalns),
223
+ )
224
+ for param in signature.parameters.values()
225
+ ]
226
+ typed_signature = inspect.Signature(typed_params)
227
+ return typed_signature
228
+
229
+
230
+ def get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any:
231
+ if isinstance(annotation, str):
232
+ annotation = ForwardRef(annotation)
233
+ annotation = evaluate_forwardref(annotation, globalns, globalns)
234
+ if annotation is type(None):
235
+ return None
236
+ return annotation
237
+
238
+
239
+ def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
240
+ signature = _get_signature(call)
241
+ unwrapped = inspect.unwrap(call)
242
+ annotation = signature.return_annotation
243
+
244
+ if annotation is inspect.Signature.empty:
245
+ return None
246
+
247
+ globalns = getattr(unwrapped, "__globals__", {})
248
+ return get_typed_annotation(annotation, globalns)
249
+
250
+
251
+ def get_dependant(
252
+ *,
253
+ path: str,
254
+ call: Callable[..., Any],
255
+ name: Optional[str] = None,
256
+ own_oauth_scopes: Optional[list[str]] = None,
257
+ parent_oauth_scopes: Optional[list[str]] = None,
258
+ use_cache: bool = True,
259
+ scope: Union[Literal["function", "request"], None] = None,
260
+ ) -> Dependant:
261
+ dependant = Dependant(
262
+ call=call,
263
+ name=name,
264
+ path=path,
265
+ use_cache=use_cache,
266
+ scope=scope,
267
+ own_oauth_scopes=own_oauth_scopes,
268
+ parent_oauth_scopes=parent_oauth_scopes,
269
+ )
270
+ current_scopes = (parent_oauth_scopes or []) + (own_oauth_scopes or [])
271
+ path_param_names = get_path_param_names(path)
272
+ endpoint_signature = get_typed_signature(call)
273
+ signature_params = endpoint_signature.parameters
274
+ for param_name, param in signature_params.items():
275
+ is_path_param = param_name in path_param_names
276
+ param_details = analyze_param(
277
+ param_name=param_name,
278
+ annotation=param.annotation,
279
+ value=param.default,
280
+ is_path_param=is_path_param,
281
+ )
282
+ if param_details.depends is not None:
283
+ assert param_details.depends.dependency
284
+ if (
285
+ (dependant.is_gen_callable or dependant.is_async_gen_callable)
286
+ and dependant.computed_scope == "request"
287
+ and param_details.depends.scope == "function"
288
+ ):
289
+ assert dependant.call
290
+ raise DependencyScopeError(
291
+ f'The dependency "{dependant.call.__name__}" has a scope of '
292
+ '"request", it cannot depend on dependencies with scope "function".'
293
+ )
294
+ sub_own_oauth_scopes: list[str] = []
295
+ if isinstance(param_details.depends, params.Security):
296
+ if param_details.depends.scopes:
297
+ sub_own_oauth_scopes = list(param_details.depends.scopes)
298
+ sub_dependant = get_dependant(
299
+ path=path,
300
+ call=param_details.depends.dependency,
301
+ name=param_name,
302
+ own_oauth_scopes=sub_own_oauth_scopes,
303
+ parent_oauth_scopes=current_scopes,
304
+ use_cache=param_details.depends.use_cache,
305
+ scope=param_details.depends.scope,
306
+ )
307
+ dependant.dependencies.append(sub_dependant)
308
+ continue
309
+ if add_non_field_param_to_dependency(
310
+ param_name=param_name,
311
+ type_annotation=param_details.type_annotation,
312
+ dependant=dependant,
313
+ ):
314
+ assert param_details.field is None, (
315
+ f"Cannot specify multiple FastAPI annotations for {param_name!r}"
316
+ )
317
+ continue
318
+ assert param_details.field is not None
319
+ if isinstance(param_details.field.field_info, params.Body):
320
+ dependant.body_params.append(param_details.field)
321
+ else:
322
+ add_param_to_fields(field=param_details.field, dependant=dependant)
323
+ return dependant
324
+
325
+
326
+ def add_non_field_param_to_dependency(
327
+ *, param_name: str, type_annotation: Any, dependant: Dependant
328
+ ) -> Optional[bool]:
329
+ if lenient_issubclass(type_annotation, Request):
330
+ dependant.request_param_name = param_name
331
+ return True
332
+ elif lenient_issubclass(type_annotation, WebSocket):
333
+ dependant.websocket_param_name = param_name
334
+ return True
335
+ elif lenient_issubclass(type_annotation, HTTPConnection):
336
+ dependant.http_connection_param_name = param_name
337
+ return True
338
+ elif lenient_issubclass(type_annotation, Response):
339
+ dependant.response_param_name = param_name
340
+ return True
341
+ elif lenient_issubclass(type_annotation, StarletteBackgroundTasks):
342
+ dependant.background_tasks_param_name = param_name
343
+ return True
344
+ elif lenient_issubclass(type_annotation, SecurityScopes):
345
+ dependant.security_scopes_param_name = param_name
346
+ return True
347
+ return None
348
+
349
+
350
+ @dataclass
351
+ class ParamDetails:
352
+ type_annotation: Any
353
+ depends: Optional[params.Depends]
354
+ field: Optional[ModelField]
355
+
356
+
357
+ def analyze_param(
358
+ *,
359
+ param_name: str,
360
+ annotation: Any,
361
+ value: Any,
362
+ is_path_param: bool,
363
+ ) -> ParamDetails:
364
+ field_info = None
365
+ depends = None
366
+ type_annotation: Any = Any
367
+ use_annotation: Any = Any
368
+ if annotation is not inspect.Signature.empty:
369
+ use_annotation = annotation
370
+ type_annotation = annotation
371
+ # Extract Annotated info
372
+ if get_origin(use_annotation) is Annotated:
373
+ annotated_args = get_args(annotation)
374
+ type_annotation = annotated_args[0]
375
+ fastapi_annotations = [
376
+ arg
377
+ for arg in annotated_args[1:]
378
+ if isinstance(arg, (FieldInfo, params.Depends))
379
+ ]
380
+ fastapi_specific_annotations = [
381
+ arg
382
+ for arg in fastapi_annotations
383
+ if isinstance(
384
+ arg,
385
+ (
386
+ params.Param,
387
+ params.Body,
388
+ params.Depends,
389
+ ),
390
+ )
391
+ ]
392
+ if fastapi_specific_annotations:
393
+ fastapi_annotation: Union[FieldInfo, params.Depends, None] = (
394
+ fastapi_specific_annotations[-1]
395
+ )
396
+ else:
397
+ fastapi_annotation = None
398
+ # Set default for Annotated FieldInfo
399
+ if isinstance(fastapi_annotation, FieldInfo):
400
+ # Copy `field_info` because we mutate `field_info.default` below.
401
+ field_info = copy_field_info(
402
+ field_info=fastapi_annotation, # type: ignore[arg-type]
403
+ annotation=use_annotation,
404
+ )
405
+ assert (
406
+ field_info.default == Undefined or field_info.default == RequiredParam
407
+ ), (
408
+ f"`{field_info.__class__.__name__}` default value cannot be set in"
409
+ f" `Annotated` for {param_name!r}. Set the default value with `=` instead."
410
+ )
411
+ if value is not inspect.Signature.empty:
412
+ assert not is_path_param, "Path parameters cannot have default values"
413
+ field_info.default = value
414
+ else:
415
+ field_info.default = RequiredParam
416
+ # Get Annotated Depends
417
+ elif isinstance(fastapi_annotation, params.Depends):
418
+ depends = fastapi_annotation
419
+ # Get Depends from default value
420
+ if isinstance(value, params.Depends):
421
+ assert depends is None, (
422
+ "Cannot specify `Depends` in `Annotated` and default value"
423
+ f" together for {param_name!r}"
424
+ )
425
+ assert field_info is None, (
426
+ "Cannot specify a FastAPI annotation in `Annotated` and `Depends` as a"
427
+ f" default value together for {param_name!r}"
428
+ )
429
+ depends = value
430
+ # Get FieldInfo from default value
431
+ elif isinstance(value, FieldInfo):
432
+ assert field_info is None, (
433
+ "Cannot specify FastAPI annotations in `Annotated` and default value"
434
+ f" together for {param_name!r}"
435
+ )
436
+ field_info = value # type: ignore[assignment]
437
+ if isinstance(field_info, FieldInfo):
438
+ field_info.annotation = type_annotation
439
+
440
+ # Get Depends from type annotation
441
+ if depends is not None and depends.dependency is None:
442
+ # Copy `depends` before mutating it
443
+ depends = copy(depends)
444
+ depends = dataclasses.replace(depends, dependency=type_annotation)
445
+
446
+ # Handle non-param type annotations like Request
447
+ if lenient_issubclass(
448
+ type_annotation,
449
+ (
450
+ Request,
451
+ WebSocket,
452
+ HTTPConnection,
453
+ Response,
454
+ StarletteBackgroundTasks,
455
+ SecurityScopes,
456
+ ),
457
+ ):
458
+ assert depends is None, f"Cannot specify `Depends` for type {type_annotation!r}"
459
+ assert field_info is None, (
460
+ f"Cannot specify FastAPI annotation for type {type_annotation!r}"
461
+ )
462
+ # Handle default assignations, neither field_info nor depends was not found in Annotated nor default value
463
+ elif field_info is None and depends is None:
464
+ default_value = value if value is not inspect.Signature.empty else RequiredParam
465
+ if is_path_param:
466
+ # We might check here that `default_value is RequiredParam`, but the fact is that the same
467
+ # parameter might sometimes be a path parameter and sometimes not. See
468
+ # `tests/test_infer_param_optionality.py` for an example.
469
+ field_info = params.Path(annotation=use_annotation)
470
+ elif is_uploadfile_or_nonable_uploadfile_annotation(
471
+ type_annotation
472
+ ) or is_uploadfile_sequence_annotation(type_annotation):
473
+ field_info = params.File(annotation=use_annotation, default=default_value)
474
+ elif not field_annotation_is_scalar(annotation=type_annotation):
475
+ field_info = params.Body(annotation=use_annotation, default=default_value)
476
+ else:
477
+ field_info = params.Query(annotation=use_annotation, default=default_value)
478
+
479
+ field = None
480
+ # It's a field_info, not a dependency
481
+ if field_info is not None:
482
+ # Handle field_info.in_
483
+ if is_path_param:
484
+ assert isinstance(field_info, params.Path), (
485
+ f"Cannot use `{field_info.__class__.__name__}` for path param"
486
+ f" {param_name!r}"
487
+ )
488
+ elif (
489
+ isinstance(field_info, params.Param)
490
+ and getattr(field_info, "in_", None) is None
491
+ ):
492
+ field_info.in_ = params.ParamTypes.query
493
+ use_annotation_from_field_info = use_annotation
494
+ if isinstance(field_info, params.Form):
495
+ ensure_multipart_is_installed()
496
+ if not field_info.alias and getattr(field_info, "convert_underscores", None):
497
+ alias = param_name.replace("_", "-")
498
+ else:
499
+ alias = field_info.alias or param_name
500
+ field_info.alias = alias
501
+ field = create_model_field(
502
+ name=param_name,
503
+ type_=use_annotation_from_field_info,
504
+ default=field_info.default,
505
+ alias=alias,
506
+ required=field_info.default in (RequiredParam, Undefined),
507
+ field_info=field_info,
508
+ )
509
+ if is_path_param:
510
+ assert is_scalar_field(field=field), (
511
+ "Path params must be of one of the supported types"
512
+ )
513
+ elif isinstance(field_info, params.Query):
514
+ assert (
515
+ is_scalar_field(field)
516
+ or is_scalar_sequence_field(field)
517
+ or (
518
+ lenient_issubclass(field.type_, BaseModel)
519
+ # For Pydantic v1
520
+ and getattr(field, "shape", 1) == 1
521
+ )
522
+ )
523
+
524
+ return ParamDetails(type_annotation=type_annotation, depends=depends, field=field)
525
+
526
+
527
+ def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
528
+ field_info = field.field_info
529
+ field_info_in = getattr(field_info, "in_", None)
530
+ if field_info_in == params.ParamTypes.path:
531
+ dependant.path_params.append(field)
532
+ elif field_info_in == params.ParamTypes.query:
533
+ dependant.query_params.append(field)
534
+ elif field_info_in == params.ParamTypes.header:
535
+ dependant.header_params.append(field)
536
+ else:
537
+ assert field_info_in == params.ParamTypes.cookie, (
538
+ f"non-body parameters must be in path, query, header or cookie: {field.name}"
539
+ )
540
+ dependant.cookie_params.append(field)
541
+
542
+
543
+ async def _solve_generator(
544
+ *, dependant: Dependant, stack: AsyncExitStack, sub_values: dict[str, Any]
545
+ ) -> Any:
546
+ assert dependant.call
547
+ if dependant.is_async_gen_callable:
548
+ cm = asynccontextmanager(dependant.call)(**sub_values)
549
+ elif dependant.is_gen_callable:
550
+ cm = contextmanager_in_threadpool(contextmanager(dependant.call)(**sub_values))
551
+ return await stack.enter_async_context(cm)
552
+
553
+
554
+ @dataclass
555
+ class SolvedDependency:
556
+ values: dict[str, Any]
557
+ errors: list[Any]
558
+ background_tasks: Optional[StarletteBackgroundTasks]
559
+ response: Response
560
+ dependency_cache: dict[DependencyCacheKey, Any]
561
+
562
+
563
+ async def solve_dependencies(
564
+ *,
565
+ request: Union[Request, WebSocket],
566
+ dependant: Dependant,
567
+ body: Optional[Union[dict[str, Any], FormData]] = None,
568
+ background_tasks: Optional[StarletteBackgroundTasks] = None,
569
+ response: Optional[Response] = None,
570
+ dependency_overrides_provider: Optional[Any] = None,
571
+ dependency_cache: Optional[dict[DependencyCacheKey, Any]] = None,
572
+ # TODO: remove this parameter later, no longer used, not removing it yet as some
573
+ # people might be monkey patching this function (although that's not supported)
574
+ async_exit_stack: AsyncExitStack,
575
+ embed_body_fields: bool,
576
+ ) -> SolvedDependency:
577
+ request_astack = request.scope.get("fastapi_inner_astack")
578
+ assert isinstance(request_astack, AsyncExitStack), (
579
+ "fastapi_inner_astack not found in request scope"
580
+ )
581
+ function_astack = request.scope.get("fastapi_function_astack")
582
+ assert isinstance(function_astack, AsyncExitStack), (
583
+ "fastapi_function_astack not found in request scope"
584
+ )
585
+ values: dict[str, Any] = {}
586
+ errors: list[Any] = []
587
+ if response is None:
588
+ response = Response()
589
+ del response.headers["content-length"]
590
+ response.status_code = None # type: ignore
591
+ if dependency_cache is None:
592
+ dependency_cache = {}
593
+ for sub_dependant in dependant.dependencies:
594
+ sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
595
+ call = sub_dependant.call
596
+ use_sub_dependant = sub_dependant
597
+ if (
598
+ dependency_overrides_provider
599
+ and dependency_overrides_provider.dependency_overrides
600
+ ):
601
+ original_call = sub_dependant.call
602
+ call = getattr(
603
+ dependency_overrides_provider, "dependency_overrides", {}
604
+ ).get(original_call, original_call)
605
+ use_path: str = sub_dependant.path # type: ignore
606
+ use_sub_dependant = get_dependant(
607
+ path=use_path,
608
+ call=call,
609
+ name=sub_dependant.name,
610
+ parent_oauth_scopes=sub_dependant.oauth_scopes,
611
+ scope=sub_dependant.scope,
612
+ )
613
+
614
+ solved_result = await solve_dependencies(
615
+ request=request,
616
+ dependant=use_sub_dependant,
617
+ body=body,
618
+ background_tasks=background_tasks,
619
+ response=response,
620
+ dependency_overrides_provider=dependency_overrides_provider,
621
+ dependency_cache=dependency_cache,
622
+ async_exit_stack=async_exit_stack,
623
+ embed_body_fields=embed_body_fields,
624
+ )
625
+ background_tasks = solved_result.background_tasks
626
+ if solved_result.errors:
627
+ errors.extend(solved_result.errors)
628
+ continue
629
+ if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
630
+ solved = dependency_cache[sub_dependant.cache_key]
631
+ elif (
632
+ use_sub_dependant.is_gen_callable or use_sub_dependant.is_async_gen_callable
633
+ ):
634
+ use_astack = request_astack
635
+ if sub_dependant.scope == "function":
636
+ use_astack = function_astack
637
+ solved = await _solve_generator(
638
+ dependant=use_sub_dependant,
639
+ stack=use_astack,
640
+ sub_values=solved_result.values,
641
+ )
642
+ elif use_sub_dependant.is_coroutine_callable:
643
+ solved = await call(**solved_result.values)
644
+ else:
645
+ solved = await run_in_threadpool(call, **solved_result.values)
646
+ if sub_dependant.name is not None:
647
+ values[sub_dependant.name] = solved
648
+ if sub_dependant.cache_key not in dependency_cache:
649
+ dependency_cache[sub_dependant.cache_key] = solved
650
+ path_values, path_errors = request_params_to_args(
651
+ dependant.path_params, request.path_params
652
+ )
653
+ query_values, query_errors = request_params_to_args(
654
+ dependant.query_params, request.query_params
655
+ )
656
+ header_values, header_errors = request_params_to_args(
657
+ dependant.header_params, request.headers
658
+ )
659
+ cookie_values, cookie_errors = request_params_to_args(
660
+ dependant.cookie_params, request.cookies
661
+ )
662
+ values.update(path_values)
663
+ values.update(query_values)
664
+ values.update(header_values)
665
+ values.update(cookie_values)
666
+ errors += path_errors + query_errors + header_errors + cookie_errors
667
+ if dependant.body_params:
668
+ (
669
+ body_values,
670
+ body_errors,
671
+ ) = await request_body_to_args( # body_params checked above
672
+ body_fields=dependant.body_params,
673
+ received_body=body,
674
+ embed_body_fields=embed_body_fields,
675
+ )
676
+ values.update(body_values)
677
+ errors.extend(body_errors)
678
+ if dependant.http_connection_param_name:
679
+ values[dependant.http_connection_param_name] = request
680
+ if dependant.request_param_name and isinstance(request, Request):
681
+ values[dependant.request_param_name] = request
682
+ elif dependant.websocket_param_name and isinstance(request, WebSocket):
683
+ values[dependant.websocket_param_name] = request
684
+ if dependant.background_tasks_param_name:
685
+ if background_tasks is None:
686
+ background_tasks = BackgroundTasks()
687
+ values[dependant.background_tasks_param_name] = background_tasks
688
+ if dependant.response_param_name:
689
+ values[dependant.response_param_name] = response
690
+ if dependant.security_scopes_param_name:
691
+ values[dependant.security_scopes_param_name] = SecurityScopes(
692
+ scopes=dependant.oauth_scopes
693
+ )
694
+ return SolvedDependency(
695
+ values=values,
696
+ errors=errors,
697
+ background_tasks=background_tasks,
698
+ response=response,
699
+ dependency_cache=dependency_cache,
700
+ )
701
+
702
+
703
+ def _validate_value_with_model_field(
704
+ *, field: ModelField, value: Any, values: dict[str, Any], loc: tuple[str, ...]
705
+ ) -> tuple[Any, list[Any]]:
706
+ if value is None:
707
+ if field.required:
708
+ return None, [get_missing_field_error(loc=loc)]
709
+ else:
710
+ return deepcopy(field.default), []
711
+ v_, errors_ = field.validate(value, values, loc=loc)
712
+ if isinstance(errors_, list):
713
+ new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
714
+ return None, new_errors
715
+ else:
716
+ return v_, []
717
+
718
+
719
+ def _get_multidict_value(
720
+ field: ModelField, values: Mapping[str, Any], alias: Union[str, None] = None
721
+ ) -> Any:
722
+ alias = alias or get_validation_alias(field)
723
+ if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)):
724
+ value = values.getlist(alias)
725
+ else:
726
+ value = values.get(alias, None)
727
+ if (
728
+ value is None
729
+ or (
730
+ isinstance(field.field_info, params.Form)
731
+ and isinstance(value, str) # For type checks
732
+ and value == ""
733
+ )
734
+ or (is_sequence_field(field) and len(value) == 0)
735
+ ):
736
+ if field.required:
737
+ return
738
+ else:
739
+ return deepcopy(field.default)
740
+ return value
741
+
742
+
743
+ def request_params_to_args(
744
+ fields: Sequence[ModelField],
745
+ received_params: Union[Mapping[str, Any], QueryParams, Headers],
746
+ ) -> tuple[dict[str, Any], list[Any]]:
747
+ values: dict[str, Any] = {}
748
+ errors: list[dict[str, Any]] = []
749
+
750
+ if not fields:
751
+ return values, errors
752
+
753
+ first_field = fields[0]
754
+ fields_to_extract = fields
755
+ single_not_embedded_field = False
756
+ default_convert_underscores = True
757
+ if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel):
758
+ fields_to_extract = get_cached_model_fields(first_field.type_)
759
+ single_not_embedded_field = True
760
+ # If headers are in a Pydantic model, the way to disable convert_underscores
761
+ # would be with Header(convert_underscores=False) at the Pydantic model level
762
+ default_convert_underscores = getattr(
763
+ first_field.field_info, "convert_underscores", True
764
+ )
765
+
766
+ params_to_process: dict[str, Any] = {}
767
+
768
+ processed_keys = set()
769
+
770
+ for field in fields_to_extract:
771
+ alias = None
772
+ if isinstance(received_params, Headers):
773
+ # Handle fields extracted from a Pydantic Model for a header, each field
774
+ # doesn't have a FieldInfo of type Header with the default convert_underscores=True
775
+ convert_underscores = getattr(
776
+ field.field_info, "convert_underscores", default_convert_underscores
777
+ )
778
+ if convert_underscores:
779
+ alias = get_validation_alias(field)
780
+ if alias == field.name:
781
+ alias = alias.replace("_", "-")
782
+ value = _get_multidict_value(field, received_params, alias=alias)
783
+ if value is not None:
784
+ params_to_process[get_validation_alias(field)] = value
785
+ processed_keys.add(alias or get_validation_alias(field))
786
+
787
+ for key in received_params.keys():
788
+ if key not in processed_keys:
789
+ if hasattr(received_params, "getlist"):
790
+ value = received_params.getlist(key)
791
+ if isinstance(value, list) and (len(value) == 1):
792
+ params_to_process[key] = value[0]
793
+ else:
794
+ params_to_process[key] = value
795
+ else:
796
+ params_to_process[key] = received_params.get(key)
797
+
798
+ if single_not_embedded_field:
799
+ field_info = first_field.field_info
800
+ assert isinstance(field_info, params.Param), (
801
+ "Params must be subclasses of Param"
802
+ )
803
+ loc: tuple[str, ...] = (field_info.in_.value,)
804
+ v_, errors_ = _validate_value_with_model_field(
805
+ field=first_field, value=params_to_process, values=values, loc=loc
806
+ )
807
+ return {first_field.name: v_}, errors_
808
+
809
+ for field in fields:
810
+ value = _get_multidict_value(field, received_params)
811
+ field_info = field.field_info
812
+ assert isinstance(field_info, params.Param), (
813
+ "Params must be subclasses of Param"
814
+ )
815
+ loc = (field_info.in_.value, get_validation_alias(field))
816
+ v_, errors_ = _validate_value_with_model_field(
817
+ field=field, value=value, values=values, loc=loc
818
+ )
819
+ if errors_:
820
+ errors.extend(errors_)
821
+ else:
822
+ values[field.name] = v_
823
+ return values, errors
824
+
825
+
826
+ def is_union_of_base_models(field_type: Any) -> bool:
827
+ """Check if field type is a Union where all members are BaseModel subclasses."""
828
+ from fastapi.types import UnionType
829
+
830
+ origin = get_origin(field_type)
831
+
832
+ # Check if it's a Union type (covers both typing.Union and types.UnionType in Python 3.10+)
833
+ if origin is not Union and origin is not UnionType:
834
+ return False
835
+
836
+ union_args = get_args(field_type)
837
+
838
+ for arg in union_args:
839
+ if not lenient_issubclass(arg, BaseModel):
840
+ return False
841
+
842
+ return True
843
+
844
+
845
+ def _should_embed_body_fields(fields: list[ModelField]) -> bool:
846
+ if not fields:
847
+ return False
848
+ # More than one dependency could have the same field, it would show up as multiple
849
+ # fields but it's the same one, so count them by name
850
+ body_param_names_set = {field.name for field in fields}
851
+ # A top level field has to be a single field, not multiple
852
+ if len(body_param_names_set) > 1:
853
+ return True
854
+ first_field = fields[0]
855
+ # If it explicitly specifies it is embedded, it has to be embedded
856
+ if getattr(first_field.field_info, "embed", None):
857
+ return True
858
+ # If it's a Form (or File) field, it has to be a BaseModel (or a union of BaseModels) to be top level
859
+ # otherwise it has to be embedded, so that the key value pair can be extracted
860
+ if (
861
+ isinstance(first_field.field_info, params.Form)
862
+ and not lenient_issubclass(first_field.type_, BaseModel)
863
+ and not is_union_of_base_models(first_field.type_)
864
+ ):
865
+ return True
866
+ return False
867
+
868
+
869
+ async def _extract_form_body(
870
+ body_fields: list[ModelField],
871
+ received_body: FormData,
872
+ ) -> dict[str, Any]:
873
+ values = {}
874
+
875
+ for field in body_fields:
876
+ value = _get_multidict_value(field, received_body)
877
+ field_info = field.field_info
878
+ if (
879
+ isinstance(field_info, params.File)
880
+ and is_bytes_field(field)
881
+ and isinstance(value, UploadFile)
882
+ ):
883
+ value = await value.read()
884
+ elif (
885
+ is_bytes_sequence_field(field)
886
+ and isinstance(field_info, params.File)
887
+ and value_is_sequence(value)
888
+ ):
889
+ # For types
890
+ assert isinstance(value, sequence_types)
891
+ results: list[Union[bytes, str]] = []
892
+
893
+ async def process_fn(
894
+ fn: Callable[[], Coroutine[Any, Any, Any]],
895
+ ) -> None:
896
+ result = await fn()
897
+ results.append(result) # noqa: B023
898
+
899
+ async with anyio.create_task_group() as tg:
900
+ for sub_value in value:
901
+ tg.start_soon(process_fn, sub_value.read)
902
+ value = serialize_sequence_value(field=field, value=results)
903
+ if value is not None:
904
+ values[get_validation_alias(field)] = value
905
+ field_aliases = {get_validation_alias(field) for field in body_fields}
906
+ for key in received_body.keys():
907
+ if key not in field_aliases:
908
+ param_values = received_body.getlist(key)
909
+ if len(param_values) == 1:
910
+ values[key] = param_values[0]
911
+ else:
912
+ values[key] = param_values
913
+ return values
914
+
915
+
916
+ async def request_body_to_args(
917
+ body_fields: list[ModelField],
918
+ received_body: Optional[Union[dict[str, Any], FormData]],
919
+ embed_body_fields: bool,
920
+ ) -> tuple[dict[str, Any], list[dict[str, Any]]]:
921
+ values: dict[str, Any] = {}
922
+ errors: list[dict[str, Any]] = []
923
+ assert body_fields, "request_body_to_args() should be called with fields"
924
+ single_not_embedded_field = len(body_fields) == 1 and not embed_body_fields
925
+ first_field = body_fields[0]
926
+ body_to_process = received_body
927
+
928
+ fields_to_extract: list[ModelField] = body_fields
929
+
930
+ if (
931
+ single_not_embedded_field
932
+ and lenient_issubclass(first_field.type_, BaseModel)
933
+ and isinstance(received_body, FormData)
934
+ ):
935
+ fields_to_extract = get_cached_model_fields(first_field.type_)
936
+
937
+ if isinstance(received_body, FormData):
938
+ body_to_process = await _extract_form_body(fields_to_extract, received_body)
939
+
940
+ if single_not_embedded_field:
941
+ loc: tuple[str, ...] = ("body",)
942
+ v_, errors_ = _validate_value_with_model_field(
943
+ field=first_field, value=body_to_process, values=values, loc=loc
944
+ )
945
+ return {first_field.name: v_}, errors_
946
+ for field in body_fields:
947
+ loc = ("body", get_validation_alias(field))
948
+ value: Optional[Any] = None
949
+ if body_to_process is not None:
950
+ try:
951
+ value = body_to_process.get(get_validation_alias(field))
952
+ # If the received body is a list, not a dict
953
+ except AttributeError:
954
+ errors.append(get_missing_field_error(loc))
955
+ continue
956
+ v_, errors_ = _validate_value_with_model_field(
957
+ field=field, value=value, values=values, loc=loc
958
+ )
959
+ if errors_:
960
+ errors.extend(errors_)
961
+ else:
962
+ values[field.name] = v_
963
+ return values, errors
964
+
965
+
966
+ def get_body_field(
967
+ *, flat_dependant: Dependant, name: str, embed_body_fields: bool
968
+ ) -> Optional[ModelField]:
969
+ """
970
+ Get a ModelField representing the request body for a path operation, combining
971
+ all body parameters into a single field if necessary.
972
+
973
+ Used to check if it's form data (with `isinstance(body_field, params.Form)`)
974
+ or JSON and to generate the JSON Schema for a request body.
975
+
976
+ This is **not** used to validate/parse the request body, that's done with each
977
+ individual body parameter.
978
+ """
979
+ if not flat_dependant.body_params:
980
+ return None
981
+ first_param = flat_dependant.body_params[0]
982
+ if not embed_body_fields:
983
+ return first_param
984
+ model_name = "Body_" + name
985
+ BodyModel = create_body_model(
986
+ fields=flat_dependant.body_params, model_name=model_name
987
+ )
988
+ required = any(True for f in flat_dependant.body_params if f.required)
989
+ BodyFieldInfo_kwargs: dict[str, Any] = {
990
+ "annotation": BodyModel,
991
+ "alias": "body",
992
+ }
993
+ if not required:
994
+ BodyFieldInfo_kwargs["default"] = None
995
+ if any(isinstance(f.field_info, params.File) for f in flat_dependant.body_params):
996
+ BodyFieldInfo: type[params.Body] = params.File
997
+ elif any(isinstance(f.field_info, params.Form) for f in flat_dependant.body_params):
998
+ BodyFieldInfo = params.Form
999
+ else:
1000
+ BodyFieldInfo = params.Body
1001
+
1002
+ body_param_media_types = [
1003
+ f.field_info.media_type
1004
+ for f in flat_dependant.body_params
1005
+ if isinstance(f.field_info, params.Body)
1006
+ ]
1007
+ if len(set(body_param_media_types)) == 1:
1008
+ BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0]
1009
+ final_field = create_model_field(
1010
+ name="body",
1011
+ type_=BodyModel,
1012
+ required=required,
1013
+ alias="body",
1014
+ field_info=BodyFieldInfo(**BodyFieldInfo_kwargs),
1015
+ )
1016
+ return final_field
1017
+
1018
+
1019
+ def get_validation_alias(field: ModelField) -> str:
1020
+ va = getattr(field, "validation_alias", None)
1021
+ return va or field.alias