starmallow 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,14 +4,14 @@ import itertools
4
4
  import re
5
5
  import warnings
6
6
  from collections import defaultdict
7
+ from collections.abc import Generator, Mapping
7
8
  from logging import getLogger
8
- from typing import Any, Dict, Generator, List, Optional, Sequence, Set, Tuple, Type
9
+ from typing import Any, cast
9
10
 
10
11
  import marshmallow as ma
11
12
  import marshmallow.fields as mf
12
13
  from apispec import APISpec
13
- from apispec.ext.marshmallow import OpenAPIConverter, SchemaResolver
14
- # from apispec.ext.marshmallow.openapi import OpenAPIConverter
14
+ from apispec.ext.marshmallow import OpenAPIConverter, SchemaResolver # type: ignore
15
15
  from starlette.responses import Response
16
16
  from starlette.routing import BaseRoute, Mount, compile_path
17
17
  from starlette.schemas import BaseSchemaGenerator
@@ -26,16 +26,17 @@ from starmallow.security.base import SecurityBaseResolver
26
26
  from starmallow.utils import (
27
27
  deep_dict_update,
28
28
  dict_safe_add,
29
- is_marshmallow_field,
30
29
  status_code_ranges,
31
30
  )
32
31
 
33
32
  logger = getLogger(__name__)
34
33
 
34
+
35
35
  class SchemaRegistry(dict):
36
36
  '''
37
37
  Dict that holds all the schemas for each class and lazily resolves them.
38
38
  '''
39
+
39
40
  def __init__(
40
41
  self,
41
42
  spec: APISpec,
@@ -59,7 +60,7 @@ class SchemaRegistry(dict):
59
60
  sec_obj = self.security_references.__getitem__(component_id)
60
61
  except KeyError:
61
62
  # Use marshmallow_dataclass to dump itself
62
- sec_schema = model.Schema().dump(model)
63
+ sec_schema = cast(dict, model.Schema().dump(model))
63
64
  sec_schema['type'] = model.type.value
64
65
 
65
66
  self.spec.components.security_scheme(component_id=component_id, component=sec_schema)
@@ -74,7 +75,7 @@ class SchemaRegistry(dict):
74
75
  if isinstance(item, SecurityBaseResolver):
75
76
  return self._get_security_item(item)
76
77
 
77
- if is_marshmallow_field(item):
78
+ if isinstance(item, mf.Field):
78
79
  # If marshmallow field, just resolve it here without caching
79
80
  prop = self.converter.field2property(item)
80
81
  return prop
@@ -92,7 +93,9 @@ class SchemaRegistry(dict):
92
93
  try:
93
94
  schema = self.spec.components.schemas.__getitem__(component_id)
94
95
  except KeyError:
95
- self.spec.components.schema(component_id=component_id, schema=item)
96
+ self.spec.components.schema(
97
+ component_id=component_id, schema=item,
98
+ )
96
99
 
97
100
  schema = self.resolver.resolve_schema_dict(item)
98
101
  super().__setitem__(schema_class, schema)
@@ -124,19 +127,24 @@ class SchemaGenerator(BaseSchemaGenerator):
124
127
  plugins=[marshmallow_plugin],
125
128
  )
126
129
 
130
+ if marshmallow_plugin.converter is None or marshmallow_plugin.resolver is None:
131
+ raise ValueError(
132
+ "Converter and resolver must be initialized before use.",
133
+ )
134
+
127
135
  self.converter = marshmallow_plugin.converter
128
136
  self.resolver = marshmallow_plugin.resolver
129
137
 
130
138
  # Builtin definitions
131
139
  self.schemas = SchemaRegistry(self.spec, self.converter, self.resolver)
132
140
 
133
- self.operation_ids: Set[str] = set()
141
+ self.operation_ids: set[str] = set()
134
142
 
135
- def get_endpoints(
143
+ def get_endpoints( # type: ignore
136
144
  self,
137
- routes: List[BaseRoute],
145
+ routes: list[BaseRoute],
138
146
  base_path: str = "",
139
- ) -> Dict[str, Sequence[EndpointModel]]:
147
+ ) -> dict[str, list[EndpointModel]]:
140
148
  """
141
149
  Given the routes, yields the following information:
142
150
 
@@ -151,11 +159,11 @@ class SchemaGenerator(BaseSchemaGenerator):
151
159
  This allows each path to have multiple responses.
152
160
  """
153
161
 
154
- endpoints_info: Dict[str, Sequence[APIRoute]] = defaultdict(list)
162
+ endpoints_info: dict[str, list[EndpointModel]] = defaultdict(list)
155
163
 
156
164
  for route in routes:
157
165
  # path is not defined in BaseRoute, but all implementations have it.
158
- _, path, _ = compile_path(base_path + route.path)
166
+ _, path, _ = compile_path(base_path + route.path) # type: ignore
159
167
 
160
168
  if isinstance(route, APIRoute) and route.include_in_schema:
161
169
  if inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint):
@@ -166,12 +174,16 @@ class SchemaGenerator(BaseSchemaGenerator):
166
174
  endpoints_info[path].append(route.endpoint_model)
167
175
 
168
176
  elif isinstance(route, Mount):
169
- endpoints_info.update(self.get_endpoints(route.routes, base_path=path))
177
+ endpoints_info.update(
178
+ self.get_endpoints(
179
+ route.routes, base_path=path,
180
+ ),
181
+ )
170
182
 
171
183
  return endpoints_info
172
184
 
173
- def _to_parameters(self, *params: Dict[str, Query | Path | Header | Cookie]) -> Generator[Dict[str, Any], None, None]:
174
- for name, field in itertools.chain(*(p.items() for p in params)):
185
+ def _to_parameters(self, *params: Mapping[str, Query | Path | Header | Cookie] | None) -> Generator[dict[str, Any], None, None]:
186
+ for name, field in itertools.chain(*(p.items() for p in params if p is not None)):
175
187
  if not field.include_in_schema:
176
188
  continue
177
189
 
@@ -183,9 +195,13 @@ class SchemaGenerator(BaseSchemaGenerator):
183
195
  location=field.in_.name,
184
196
  )
185
197
 
186
- else:
198
+ elif field.model is not None:
187
199
  yield self.converter._field2parameter(
188
- field.model,
200
+ (
201
+ field.model
202
+ if isinstance(field.model, mf.Field)
203
+ else mf.Nested(field.model)
204
+ ),
189
205
  name=name,
190
206
  location=field.in_.name,
191
207
  )
@@ -193,7 +209,7 @@ class SchemaGenerator(BaseSchemaGenerator):
193
209
  def _add_endpoint_parameters(
194
210
  self,
195
211
  endpoint: EndpointModel,
196
- schema: Dict,
212
+ schema: dict,
197
213
  ):
198
214
  unique_params = {}
199
215
  for param in self._to_parameters(
@@ -207,7 +223,9 @@ class SchemaGenerator(BaseSchemaGenerator):
207
223
  # Duplicate parameter, skip. Could be defined as a field and in a schema.
208
224
  continue
209
225
  else:
210
- raise ValueError(f"Duplicate parameter with name {param['name']} and location {param['in']}")
226
+ raise ValueError(
227
+ f"Duplicate parameter with name {param['name']} and location {param['in']}",
228
+ )
211
229
 
212
230
  unique_params[param['name']] = param
213
231
 
@@ -216,11 +234,14 @@ class SchemaGenerator(BaseSchemaGenerator):
216
234
  def _add_endpoint_body(
217
235
  self,
218
236
  endpoint: EndpointModel,
219
- schema: Dict,
237
+ schema: dict,
220
238
  ):
221
- all_body_params: List[Tuple[str, Body]] = [
222
- *endpoint.body_params.items(),
223
- *endpoint.form_params.items(),
239
+ if not isinstance(endpoint.route, APIRoute):
240
+ return
241
+
242
+ all_body_params: list[tuple[str, Body]] = [
243
+ *(endpoint.body_params.items() if endpoint.body_params else []),
244
+ *(endpoint.form_params.items() if endpoint.form_params else []),
224
245
  ]
225
246
  schema_by_media_type = {}
226
247
  is_body_required = True
@@ -231,8 +252,8 @@ class SchemaGenerator(BaseSchemaGenerator):
231
252
  if body_param.include_in_schema:
232
253
  endpoint_schema = self.schemas[body_param.model]
233
254
 
234
- if endpoint_schema:
235
- schema_by_media_type[body_param.media_type] = {'schema': endpoint_schema}
255
+ if endpoint_schema:
256
+ schema_by_media_type[body_param.media_type] = {'schema': endpoint_schema}
236
257
 
237
258
  if getattr(body_param.model, 'required', True) is False:
238
259
  is_body_required = False
@@ -254,8 +275,11 @@ class SchemaGenerator(BaseSchemaGenerator):
254
275
  component_by_media_type = defaultdict(new_endpoint_schema)
255
276
  for name, value in all_body_params:
256
277
  media_component = component_by_media_type[value.media_type]
257
- endpoint_properties: Dict[str, Any] = media_component['properties']
258
- required_properties: List[Any] = media_component['required']
278
+ endpoint_properties: dict[
279
+ str,
280
+ Any,
281
+ ] = media_component['properties']
282
+ required_properties: list[Any] = media_component['required']
259
283
 
260
284
  if value.include_in_schema:
261
285
  if isinstance(value.model, ma.Schema):
@@ -264,7 +288,9 @@ class SchemaGenerator(BaseSchemaGenerator):
264
288
  required_properties.append(name)
265
289
 
266
290
  elif isinstance(value.model, mf.Field):
267
- endpoint_properties[name] = self.converter.field2property(value.model)
291
+ endpoint_properties[name] = self.converter.field2property(
292
+ value.model,
293
+ )
268
294
  if value.model.required:
269
295
  required_properties.append(name)
270
296
 
@@ -273,7 +299,9 @@ class SchemaGenerator(BaseSchemaGenerator):
273
299
 
274
300
  schema_by_media_type = {}
275
301
  for media_type, component in component_by_media_type.items():
276
- self.spec.components.schema(component_id=component_schema_id, component=component)
302
+ self.spec.components.schema(
303
+ component_id=component_schema_id, component=component,
304
+ )
277
305
 
278
306
  schema_by_media_type[media_type] = {
279
307
  "schema": {
@@ -290,21 +318,26 @@ class SchemaGenerator(BaseSchemaGenerator):
290
318
  def _add_security_params(
291
319
  self,
292
320
  endpoint: EndpointModel,
293
- schema: Dict,
321
+ schema: dict,
294
322
  ):
295
- schema['security'] = [
296
- self.schemas[security_param.resolver]
297
- for param_name, security_param in endpoint.security_params.items()
298
- ]
323
+ if endpoint.security_params:
324
+ schema['security'] = [
325
+ self.schemas[security_param.resolver]
326
+ for param_name, security_param in endpoint.security_params.items()
327
+ ]
299
328
 
300
329
  def _add_endpoint_response(
301
330
  self,
302
331
  endpoint: EndpointModel,
303
- schema: Dict,
332
+ schema: dict,
304
333
  ):
305
334
  operation_responses = schema.setdefault("responses", {})
306
335
  response_codes = list(operation_responses.keys())
307
- main_response = str(endpoint.status_code or (response_codes[0] if response_codes else 200))
336
+ main_response = str(
337
+ endpoint.status_code or (
338
+ response_codes[0] if response_codes else 200
339
+ ),
340
+ )
308
341
 
309
342
  operation_responses[main_response] = {
310
343
  'content': {
@@ -313,17 +346,20 @@ class SchemaGenerator(BaseSchemaGenerator):
313
346
  },
314
347
  },
315
348
  }
349
+ if not isinstance(endpoint.route, APIRoute):
350
+ return
351
+
316
352
  if endpoint.route.response_description:
317
353
  operation_responses[main_response]['description'] = endpoint.route.response_description
318
354
 
319
355
  # Process additional responses
320
356
  route = endpoint.route
321
357
  if isinstance(route.response_class, DefaultPlaceholder):
322
- current_response_class: Type[Response] = route.response_class.value
358
+ current_response_class: type[Response] = route.response_class.value
323
359
  else:
324
360
  current_response_class = route.response_class
325
361
  assert current_response_class, "A response class is needed to generate OpenAPI"
326
- route_response_media_type: Optional[str] = current_response_class.media_type
362
+ route_response_media_type: str | None = current_response_class.media_type
327
363
 
328
364
  if route.responses:
329
365
  for (
@@ -342,17 +378,21 @@ class SchemaGenerator(BaseSchemaGenerator):
342
378
  status_code_key, {},
343
379
  )
344
380
  field = route.response_fields.get(additional_status_code)
345
- additional_field_schema: Optional[Dict[str, Any]] = None
381
+ additional_field_schema: dict[str, Any] | None = None
346
382
  if field:
347
- additional_field_schema = self.schemas[field] if field else {}
383
+ additional_field_schema = self.schemas[field] if field else {
384
+ }
348
385
  media_type = route_response_media_type or "application/json"
349
386
  additional_schema = (
350
387
  process_response.setdefault("content", {})
351
388
  .setdefault(media_type, {})
352
389
  .setdefault("schema", {})
353
390
  )
354
- deep_dict_update(additional_schema, additional_field_schema)
355
- status_text: Optional[str] = status_code_ranges.get(
391
+ deep_dict_update(
392
+ additional_schema,
393
+ additional_field_schema,
394
+ )
395
+ status_text: str | None = status_code_ranges.get(
356
396
  str(additional_status_code).upper(),
357
397
  ) or http.client.responses.get(int(additional_status_code))
358
398
  description = (
@@ -364,7 +404,7 @@ class SchemaGenerator(BaseSchemaGenerator):
364
404
  deep_dict_update(openapi_response, process_response)
365
405
  openapi_response["description"] = description
366
406
 
367
- def _add_default_error_response(self, schema: Dict):
407
+ def _add_default_error_response(self, schema: dict):
368
408
  dict_safe_add(
369
409
  schema,
370
410
  'responses.422.description',
@@ -381,7 +421,7 @@ class SchemaGenerator(BaseSchemaGenerator):
381
421
  return route.summary
382
422
  return re.sub(r"(\w)([A-Z])", r"\1 \2", route.name).replace(".", " ").replace("_", " ").title()
383
423
 
384
- def _get_route_openapi_metadata(self, route: APIRoute) -> Dict[str, Any]:
424
+ def _get_route_openapi_metadata(self, route: APIRoute) -> dict[str, Any]:
385
425
  schema = {}
386
426
  if route.tags:
387
427
  schema["tags"] = route.tags
@@ -391,10 +431,11 @@ class SchemaGenerator(BaseSchemaGenerator):
391
431
  operation_id = route.operation_id or route.unique_id
392
432
  if operation_id in self.operation_ids:
393
433
  message = (
394
- f"Duplicate Operation ID {operation_id} for function "
395
- + f"{route.endpoint.__name__}"
434
+ f"Duplicate Operation ID {operation_id} for function {route.endpoint.__name__}"
396
435
  )
397
- file_name = getattr(route.endpoint, "__globals__", {}).get("__file__")
436
+ file_name = getattr(
437
+ route.endpoint, "__globals__", {},
438
+ ).get("__file__")
398
439
  if file_name:
399
440
  message += f" at {file_name}"
400
441
  warnings.warn(message)
@@ -408,10 +449,14 @@ class SchemaGenerator(BaseSchemaGenerator):
408
449
  def get_endpoint_schema(
409
450
  self,
410
451
  endpoint: EndpointModel,
411
- ) -> Dict[str, Any]:
452
+ ) -> dict[str, Any]:
412
453
  '''
413
454
  Generates the endpoint schema
414
455
  '''
456
+
457
+ if not isinstance(endpoint.route, APIRoute):
458
+ return {}
459
+
415
460
  schema = self._get_route_openapi_metadata(endpoint.route)
416
461
 
417
462
  schema.update(self.parse_docstring(endpoint.call))
@@ -458,18 +503,19 @@ class SchemaGenerator(BaseSchemaGenerator):
458
503
 
459
504
  return schema
460
505
 
461
- def get_operations(self, endpoints: List[EndpointModel]):
506
+ def get_operations(self, endpoints: list[EndpointModel]) -> dict[str, dict[str, Any]]:
462
507
  return {
463
508
  method.lower(): self.get_endpoint_schema(e)
464
509
  for e in endpoints
510
+ if e.methods
465
511
  for method in e.methods
466
512
  if method != 'HEAD'
467
513
  }
468
514
 
469
515
  def get_schema(
470
516
  self,
471
- routes: List[APIRoute],
472
- ) -> Dict[str, Any]:
517
+ routes: list[BaseRoute],
518
+ ) -> dict[str, Any]:
473
519
  '''
474
520
  Generates the schemas for the specified routes..
475
521
  '''
@@ -1,7 +1,7 @@
1
1
  from enum import Enum
2
- from typing import ClassVar, Optional
2
+ from typing import ClassVar
3
3
 
4
- from marshmallow_dataclass import dataclass as ma_dataclass
4
+ from marshmallow_dataclass2 import dataclass as ma_dataclass
5
5
  from starlette.requests import Request
6
6
  from starlette.status import HTTP_403_FORBIDDEN
7
7
 
@@ -28,15 +28,15 @@ class APIKeyQuery(SecurityBaseResolver):
28
28
  self,
29
29
  *,
30
30
  name: str,
31
- scheme_name: Optional[str] = None,
32
- description: Optional[str] = None,
31
+ scheme_name: str | None = None,
32
+ description: str | None = None,
33
33
  auto_error: bool = True,
34
34
  ):
35
35
  self.model: APIKeyModel = APIKeyModel(in_=APIKeyIn.query, name=name, description=description)
36
36
  self.scheme_name = scheme_name or self.__class__.__name__
37
37
  self.auto_error = auto_error
38
38
 
39
- async def __call__(self, request: Request) -> Optional[str]:
39
+ async def __call__(self, request: Request) -> str | None:
40
40
  api_key = request.query_params.get(self.model.name)
41
41
  if not api_key:
42
42
  if self.auto_error:
@@ -53,15 +53,15 @@ class APIKeyHeader(SecurityBaseResolver):
53
53
  self,
54
54
  *,
55
55
  name: str,
56
- scheme_name: Optional[str] = None,
57
- description: Optional[str] = None,
56
+ scheme_name: str | None = None,
57
+ description: str | None = None,
58
58
  auto_error: bool = True,
59
59
  ):
60
60
  self.model: APIKeyModel = APIKeyModel(in_=APIKeyIn.header, name=name, description=description)
61
61
  self.scheme_name = scheme_name or self.__class__.__name__
62
62
  self.auto_error = auto_error
63
63
 
64
- async def __call__(self, request: Request) -> Optional[str]:
64
+ async def __call__(self, request: Request) -> str | None:
65
65
  api_key = request.headers.get(self.model.name)
66
66
  if not api_key:
67
67
  if self.auto_error:
@@ -78,15 +78,15 @@ class APIKeyCookie(SecurityBaseResolver):
78
78
  self,
79
79
  *,
80
80
  name: str,
81
- scheme_name: Optional[str] = None,
82
- description: Optional[str] = None,
81
+ scheme_name: str | None = None,
82
+ description: str | None = None,
83
83
  auto_error: bool = True,
84
84
  ):
85
85
  self.model: APIKeyModel = APIKeyModel(in_=APIKeyIn.cookie, name=name, description=description)
86
86
  self.scheme_name = scheme_name or self.__class__.__name__
87
87
  self.auto_error = auto_error
88
88
 
89
- async def __call__(self, request: Request) -> Optional[str]:
89
+ async def __call__(self, request: Request) -> str | None:
90
90
  api_key = request.cookies.get(self.model.name)
91
91
  if not api_key:
92
92
  if self.auto_error:
@@ -1,8 +1,9 @@
1
1
  from enum import Enum
2
- from typing import Any, ClassVar, Dict, Optional
2
+ from typing import Any, ClassVar
3
3
 
4
4
  import marshmallow as ma
5
- from marshmallow_dataclass import dataclass as ma_dataclass
5
+ from marshmallow_dataclass2 import dataclass as ma_dataclass
6
+ from starlette.requests import Request
6
7
 
7
8
 
8
9
  # Provided by: https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#security-scheme-object
@@ -16,11 +17,13 @@ class SecurityTypes(Enum):
16
17
 
17
18
  @ma_dataclass(frozen=True)
18
19
  class SecurityBase:
20
+ Schema: ClassVar[type[ma.Schema]]
21
+
19
22
  type: ClassVar[SecurityTypes]
20
- description: Optional[str] = None
23
+ description: str | None = None
21
24
 
22
25
  @ma.post_dump()
23
- def post_dump(self, data: Dict[str, Any], **kwargs):
26
+ def post_dump(self, data: dict[str, Any], **kwargs):
24
27
  # Remove None values
25
28
  return {
26
29
  key: value
@@ -40,3 +43,8 @@ class SecurityBaseResolver:
40
43
  ) -> None:
41
44
  self.model = model
42
45
  self.schema_name = scheme_name
46
+
47
+ async def __call__(self, request: Request) -> Any | None:
48
+ raise NotImplementedError(
49
+ f"SecurityBaseResolver.__call__ not implemented for {self.__class__.__name__}",
50
+ )
@@ -1,8 +1,9 @@
1
+ import asyncio
1
2
  import binascii
2
3
  from base64 import b64decode
3
- from typing import ClassVar, Optional
4
+ from typing import ClassVar
4
5
 
5
- from marshmallow_dataclass import dataclass as ma_dataclass
6
+ from marshmallow_dataclass2 import dataclass as ma_dataclass
6
7
  from starlette.requests import Request
7
8
  from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
8
9
 
@@ -33,14 +34,14 @@ class HTTPAuthorizationCredentials:
33
34
  @ma_dataclass(frozen=True)
34
35
  class HTTPBaseModel(SecurityBase):
35
36
  type: ClassVar[SecurityTypes] = SecurityTypes.http
36
- description: Optional[str] = None
37
- scheme: str = None
37
+ description: str | None = None
38
+ scheme: str | None = None
38
39
 
39
40
 
40
41
  @ma_dataclass(frozen=True)
41
42
  class HTTPBearerModel(HTTPBaseModel):
42
43
  scheme: str = "bearer"
43
- bearerFormat: Optional[str] = None
44
+ bearer_format: str | None = None
44
45
 
45
46
 
46
47
  class HTTPBase(SecurityBaseResolver):
@@ -49,8 +50,8 @@ class HTTPBase(SecurityBaseResolver):
49
50
  self,
50
51
  *,
51
52
  scheme: str,
52
- scheme_name: Optional[str] = None,
53
- description: Optional[str] = None,
53
+ scheme_name: str | None = None,
54
+ description: str | None = None,
54
55
  auto_error: bool = True,
55
56
  ) -> None:
56
57
  self.model = HTTPBaseModel(scheme=scheme, description=description)
@@ -59,7 +60,7 @@ class HTTPBase(SecurityBaseResolver):
59
60
 
60
61
  async def __call__(
61
62
  self, request: Request,
62
- ) -> Optional[HTTPAuthorizationCredentials]:
63
+ ) -> HTTPAuthorizationCredentials | None:
63
64
  authorization = request.headers.get("Authorization")
64
65
  scheme, credentials = get_authorization_scheme_param(authorization)
65
66
  if not (authorization and scheme and credentials):
@@ -77,9 +78,9 @@ class HTTPBasic(HTTPBase):
77
78
  def __init__(
78
79
  self,
79
80
  *,
80
- scheme_name: Optional[str] = None,
81
- realm: Optional[str] = None,
82
- description: Optional[str] = None,
81
+ scheme_name: str | None = None,
82
+ realm: str | None = None,
83
+ description: str | None = None,
83
84
  auto_error: bool = True,
84
85
  ) -> None:
85
86
  super().__init__(
@@ -93,13 +94,14 @@ class HTTPBasic(HTTPBase):
93
94
 
94
95
  async def __call__( # type: ignore
95
96
  self, request: Request,
96
- ) -> Optional[HTTPBasicCredentials]:
97
+ ) -> HTTPBasicCredentials | None:
97
98
  authorization = request.headers.get("Authorization")
98
99
  scheme, param = get_authorization_scheme_param(authorization)
99
- if self.realm:
100
- unauthorized_headers = {"WWW-Authenticate": f'Basic realm="{self.realm}"'}
101
- else:
102
- unauthorized_headers = {"WWW-Authenticate": "Basic"}
100
+ unauthorized_headers = (
101
+ {"WWW-Authenticate": f'Basic realm="{self.realm}"'}
102
+ if self.realm
103
+ else {"WWW-Authenticate": "Basic"}
104
+ )
103
105
  invalid_user_credentials_exc = HTTPException(
104
106
  status_code=HTTP_401_UNAUTHORIZED,
105
107
  detail="Invalid authentication credentials",
@@ -115,9 +117,12 @@ class HTTPBasic(HTTPBase):
115
117
  else:
116
118
  return None
117
119
  try:
118
- data = b64decode(param).decode("ascii")
120
+ def decode(param: str) -> str:
121
+ return b64decode(param).decode("ascii")
122
+
123
+ data = await asyncio.to_thread(decode, param)
119
124
  except (ValueError, UnicodeDecodeError, binascii.Error):
120
- raise invalid_user_credentials_exc
125
+ raise invalid_user_credentials_exc from None
121
126
  username, separator, password = data.partition(":")
122
127
  if not separator:
123
128
  raise invalid_user_credentials_exc
@@ -129,18 +134,18 @@ class HTTPBearer(HTTPBase):
129
134
  def __init__(
130
135
  self,
131
136
  *,
132
- bearerFormat: Optional[str] = None,
133
- scheme_name: Optional[str] = None,
134
- description: Optional[str] = None,
137
+ bearer_format: str | None = None,
138
+ scheme_name: str | None = None,
139
+ description: str | None = None,
135
140
  auto_error: bool = True,
136
141
  ) -> None:
137
- self.model = HTTPBearerModel(bearerFormat=bearerFormat, description=description)
142
+ self.model = HTTPBearerModel(bearer_format=bearer_format, description=description)
138
143
  self.scheme_name = scheme_name or self.__class__.__name__
139
144
  self.auto_error = auto_error
140
145
 
141
146
  async def __call__(
142
147
  self, request: Request,
143
- ) -> Optional[HTTPAuthorizationCredentials]:
148
+ ) -> HTTPAuthorizationCredentials | None:
144
149
  authorization = request.headers.get("Authorization")
145
150
  scheme, credentials = get_authorization_scheme_param(authorization)
146
151
  if not (authorization and scheme and credentials):
@@ -166,8 +171,8 @@ class HTTPDigest(HTTPBase):
166
171
  def __init__(
167
172
  self,
168
173
  *,
169
- scheme_name: Optional[str] = None,
170
- description: Optional[str] = None,
174
+ scheme_name: str | None = None,
175
+ description: str | None = None,
171
176
  auto_error: bool = True,
172
177
  ) -> None:
173
178
  super().__init__(
@@ -179,7 +184,7 @@ class HTTPDigest(HTTPBase):
179
184
 
180
185
  async def __call__(
181
186
  self, request: Request,
182
- ) -> Optional[HTTPAuthorizationCredentials]:
187
+ ) -> HTTPAuthorizationCredentials | None:
183
188
  authorization = request.headers.get("Authorization")
184
189
  scheme, credentials = get_authorization_scheme_param(authorization)
185
190
  if not (authorization and scheme and credentials):