starmallow 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
starmallow/endpoint.py CHANGED
@@ -1,31 +1,27 @@
1
1
  import inspect
2
2
  import logging
3
+ from collections.abc import Callable, Iterable, Mapping, Sequence
3
4
  from dataclasses import dataclass, field
4
5
  from typing import (
5
6
  TYPE_CHECKING,
7
+ Annotated,
6
8
  Any,
7
- Callable,
8
- Dict,
9
- Iterable,
10
- List,
11
- Mapping,
12
9
  NewType,
13
- Optional,
14
- Type,
15
- Union,
10
+ cast,
16
11
  get_args,
17
12
  get_origin,
18
13
  )
19
14
 
20
15
  import marshmallow as ma
21
16
  import marshmallow.fields as mf
17
+ import typing_inspect
18
+ from marshmallow.types import StrSequenceOrSet
22
19
  from marshmallow.utils import missing as missing_
23
20
  from marshmallow_dataclass2 import class_schema, is_generic_alias_of_dataclass
24
21
  from starlette.background import BackgroundTasks
25
22
  from starlette.requests import HTTPConnection, Request
26
23
  from starlette.responses import Response
27
24
  from starlette.websockets import WebSocket
28
- from typing_extensions import Annotated
29
25
 
30
26
  from starmallow.params import (
31
27
  Body,
@@ -43,20 +39,21 @@ from starmallow.params import (
43
39
  from starmallow.responses import JSONResponse
44
40
  from starmallow.security.base import SecurityBaseResolver
45
41
  from starmallow.utils import (
42
+ MaDataclassProtocol,
46
43
  create_response_model,
47
44
  get_model_field,
48
45
  get_path_param_names,
49
46
  get_typed_return_annotation,
50
47
  get_typed_signature,
51
48
  is_marshmallow_dataclass,
52
- is_marshmallow_field,
49
+ is_marshmallow_field_or_generic,
53
50
  is_marshmallow_schema,
54
51
  is_optional,
55
52
  lenient_issubclass,
56
53
  )
57
54
 
58
55
  if TYPE_CHECKING:
59
- from starmallow.routing import APIRoute
56
+ from starmallow.routing import APIRoute, APIWebSocketRoute
60
57
 
61
58
  logger = logging.getLogger(__name__)
62
59
 
@@ -70,52 +67,52 @@ STARMALLOW_PARAM_TYPES = (
70
67
 
71
68
  @dataclass
72
69
  class EndpointModel:
73
- params: Optional[Dict[ParamType, Dict[str, Param]]] = field(default_factory=dict)
74
- flat_params: Optional[Dict[ParamType, Dict[str, Param]]] = field(default_factory=dict)
75
- name: Optional[str] = None
76
- path: Optional[str] = None
77
- methods: Optional[List[str]] = None
78
- call: Optional[Callable[..., Any]] = None
79
- response_model: Optional[ma.Schema | mf.Field] = None
80
- response_class: Type[Response] = JSONResponse
81
- status_code: Optional[int] = None
82
- route: 'APIRoute' = None
70
+ path: str
71
+ call: Callable[..., Any]
72
+ route: 'APIRoute | APIWebSocketRoute'
73
+ params: dict[ParamType, dict[str, Param]] = field(default_factory=dict)
74
+ flat_params: dict[ParamType, dict[str, Param]] = field(default_factory=dict)
75
+ name: str | None = None
76
+ methods: Sequence[str] | None = None
77
+ response_model: ma.Schema | type[ma.Schema | MaDataclassProtocol] | mf.Field | None = None
78
+ response_class: type[Response] = JSONResponse
79
+ status_code: int | None = None
83
80
 
84
81
  @property
85
- def path_params(self) -> Dict[str, Path] | None:
86
- return self.flat_params.get(ParamType.path)
82
+ def path_params(self) -> dict[str, Path] | None:
83
+ return cast(dict[str, Path], self.flat_params.get(ParamType.path))
87
84
 
88
85
  @property
89
- def query_params(self) -> Dict[str, Query] | None:
90
- return self.flat_params.get(ParamType.query)
86
+ def query_params(self) -> dict[str, Query] | None:
87
+ return cast(dict[str, Query], self.flat_params.get(ParamType.query))
91
88
 
92
89
  @property
93
- def header_params(self) -> Dict[str, Header] | None:
94
- return self.flat_params.get(ParamType.header)
90
+ def header_params(self) -> dict[str, Header] | None:
91
+ return cast(dict[str, Header], self.flat_params.get(ParamType.header))
95
92
 
96
93
  @property
97
- def cookie_params(self) -> Dict[str, Cookie] | None:
98
- return self.flat_params.get(ParamType.cookie)
94
+ def cookie_params(self) -> dict[str, Cookie] | None:
95
+ return cast(dict[str, Cookie], self.flat_params.get(ParamType.cookie))
99
96
 
100
97
  @property
101
- def body_params(self) -> Dict[str, Body] | None:
102
- return self.flat_params.get(ParamType.body)
98
+ def body_params(self) -> dict[str, Body] | None:
99
+ return cast(dict[str, Body], self.flat_params.get(ParamType.body))
103
100
 
104
101
  @property
105
- def form_params(self) -> Dict[str, Form] | None:
106
- return self.flat_params.get(ParamType.form)
102
+ def form_params(self) -> dict[str, Form] | None:
103
+ return cast(dict[str, Form], self.flat_params.get(ParamType.form))
107
104
 
108
105
  @property
109
- def non_field_params(self) -> Dict[str, NoParam] | None:
110
- return self.flat_params.get(ParamType.noparam)
106
+ def non_field_params(self) -> dict[str, NoParam] | None:
107
+ return cast(dict[str, NoParam], self.flat_params.get(ParamType.noparam))
111
108
 
112
109
  @property
113
- def resolved_params(self) -> Dict[str, ResolvedParam] | None:
114
- return self.flat_params.get(ParamType.resolved)
110
+ def resolved_params(self) -> dict[str, ResolvedParam] | None:
111
+ return cast(dict[str, ResolvedParam], self.flat_params.get(ParamType.resolved))
115
112
 
116
113
  @property
117
- def security_params(self) -> Dict[str, Security] | None:
118
- return self.flat_params.get(ParamType.security)
114
+ def security_params(self) -> dict[str, Security] | None:
115
+ return cast(dict[str, Security], self.flat_params.get(ParamType.security))
119
116
 
120
117
 
121
118
  class SchemaMeta:
@@ -129,20 +126,20 @@ class SchemaModel(ma.Schema):
129
126
  schema: ma.Schema,
130
127
  load_default: Any = missing_,
131
128
  required: bool = True,
132
- metadata: Dict[str, Any] = None,
129
+ metadata: dict[str, Any] | None = None,
133
130
  **kwargs,
134
131
  ) -> None:
135
132
  self.schema = schema
136
133
  self.load_default = load_default
137
134
  self.required = required
138
- self.title = metadata.get('title')
135
+ self.title = metadata.get('title') if metadata else None
139
136
  self.metadata = metadata
140
137
  self.kwargs = kwargs
141
138
 
142
139
  if not getattr(schema.Meta, "title", None):
143
140
  if schema.Meta is ma.Schema.Meta:
144
141
  # Don't override global Meta object's title
145
- schema.Meta = SchemaMeta(self.title)
142
+ schema.Meta = SchemaMeta(self.title) # type: ignore
146
143
  else:
147
144
  schema.Meta.title = self.title
148
145
 
@@ -162,14 +159,11 @@ class SchemaModel(ma.Schema):
162
159
 
163
160
  def load(
164
161
  self,
165
- data: Union[
166
- Mapping[str, Any],
167
- Iterable[Mapping[str, Any]],
168
- ],
162
+ data: Mapping[str, Any] | Iterable[Mapping[str, Any]],
169
163
  *,
170
- many: Optional[bool] = None,
171
- partial: Optional[bool] = None,
172
- unknown: Optional[str] = None,
164
+ many: bool | None = None,
165
+ partial: bool | StrSequenceOrSet | None = None,
166
+ unknown: str | None = None,
173
167
  ) -> Any:
174
168
  if not data and self.load_default:
175
169
  return self.load_default
@@ -185,7 +179,7 @@ class EndpointMixin:
185
179
  parameter: Param,
186
180
  parameter_name: str,
187
181
  default_value: Any,
188
- ) -> Union[ma.Schema, mf.Field]:
182
+ ) -> ma.Schema | mf.Field | None:
189
183
  model = getattr(parameter, 'model', None) or type_annotation
190
184
  if isinstance(model, SchemaModel):
191
185
  return model
@@ -244,21 +238,22 @@ class EndpointMixin:
244
238
  if is_marshmallow_dataclass(model):
245
239
  model = model.Schema
246
240
 
247
- if is_generic_alias_of_dataclass(model):
248
- model = class_schema(model)
241
+ if is_generic_alias_of_dataclass(model): # type: ignore
242
+ model = class_schema(model) # type: ignore
249
243
 
250
- if isinstance(model, NewType) and getattr(model, '_marshmallow_field', None):
251
- return model._marshmallow_field(**kwargs)
244
+ mmf = getattr(model, '_marshmallow_field', None)
245
+ if isinstance(model, NewType) and mmf and issubclass(mmf, mf.Field):
246
+ return mmf(**kwargs)
252
247
  elif is_marshmallow_schema(model):
253
248
  return SchemaModel(model() if inspect.isclass(model) else model, **kwargs)
254
- elif is_marshmallow_field(model):
255
- if inspect.isclass(model):
249
+ elif is_marshmallow_field_or_generic(model):
250
+ if not isinstance(model, mf.Field): # inspect.isclass(model):
256
251
  model = model()
257
252
 
258
253
  if model.load_default is not None and model.load_default != kwargs.get('load_default', ma.missing):
259
254
  logger.warning(
260
255
  f"'{parameter_name}' model and annotation have different 'load_default' values."
261
- + f" {model.load_default} <> {kwargs.get('load_default', ma.missing)}",
256
+ f" {model.load_default} <> {kwargs.get('load_default', ma.missing)}",
262
257
  )
263
258
 
264
259
  model.required = kwargs['required']
@@ -266,6 +261,11 @@ class EndpointMixin:
266
261
  model.metadata.update(kwargs['metadata'])
267
262
 
268
263
  return model
264
+
265
+ # elif is_marshmallow_field_or_generic(model):
266
+ # if isinstance(model, mf.Field):
267
+ # model if isinstance(model, mf.Field) else model()
268
+
269
269
  else:
270
270
  try:
271
271
  return get_model_field(model, **kwargs)
@@ -282,11 +282,19 @@ class EndpointMixin:
282
282
 
283
283
  return resolved_param
284
284
 
285
+ def get_security_param(self, resolved_param: Security, annotation: Any, path: str) -> Security:
286
+ if resolved_param.resolver is None:
287
+ resolved_param.resolver = annotation
288
+
289
+ resolved_param.resolver_params = self._get_params(resolved_param.resolver, path=path)
290
+
291
+ return resolved_param
292
+
285
293
  def _get_params(
286
294
  self,
287
295
  func: Callable[..., Any],
288
296
  path: str,
289
- ) -> Dict[ParamType, Dict[str, Param]]:
297
+ ) -> dict[ParamType, dict[str, Param]]:
290
298
  path_param_names = get_path_param_names(path)
291
299
  params = {param_type: {} for param_type in ParamType}
292
300
  for name, parameter in get_typed_signature(func).parameters.items():
@@ -313,8 +321,7 @@ class EndpointMixin:
313
321
  ]
314
322
  if starmallow_annotations:
315
323
  assert starmallow_param is inspect._empty, (
316
- "Cannot specify `Param` in `Annotated` and default value"
317
- f" together for {name!r}"
324
+ f"Cannot specify `Param` in `Annotated` and default value together for {name!r}"
318
325
  )
319
326
 
320
327
  starmallow_param = starmallow_annotations[-1]
@@ -325,6 +332,20 @@ class EndpointMixin:
325
332
  ):
326
333
  default_value = starmallow_param.default
327
334
 
335
+ field_annotations = [
336
+ arg
337
+ for arg in annotated_args
338
+ if (
339
+ isinstance(arg, mf.Field)
340
+ or lenient_issubclass(arg, mf.Field)
341
+ or (
342
+ isinstance(arg, typing_inspect.typingGenericAlias)
343
+ and lenient_issubclass(get_origin(arg), mf.Field)
344
+ )
345
+ )
346
+ ]
347
+ if field_annotations:
348
+ type_annotation = field_annotations[-1]
328
349
  if (
329
350
  # Skip 'self' in APIHTTPEndpoint functions
330
351
  (name == 'self' and '.' in func.__qualname__)
@@ -332,11 +353,11 @@ class EndpointMixin:
332
353
  ):
333
354
  continue
334
355
  elif isinstance(starmallow_param, Security):
335
- security_param: Security = self.get_resolved_param(starmallow_param, type_annotation, path=path)
356
+ security_param = self.get_security_param(starmallow_param, type_annotation, path=path)
336
357
  params[ParamType.security][name] = security_param
337
358
  continue
338
359
  elif isinstance(starmallow_param, ResolvedParam):
339
- resolved_param: ResolvedParam = self.get_resolved_param(starmallow_param, type_annotation, path=path)
360
+ resolved_param = self.get_resolved_param(starmallow_param, type_annotation, path=path)
340
361
 
341
362
  # Allow `ResolvedParam(HTTPBearer())` - treat as securty param
342
363
  if isinstance(resolved_param.resolver, SecurityBaseResolver):
@@ -359,7 +380,7 @@ class EndpointMixin:
359
380
  continue
360
381
 
361
382
  model = self._get_param_model(type_annotation, starmallow_param, name, default_value)
362
- model.name = name
383
+ model.name = name # type: ignore
363
384
 
364
385
  if isinstance(starmallow_param, Param):
365
386
  # Create new field_info with processed model
@@ -397,18 +418,18 @@ class EndpointMixin:
397
418
  def get_endpoint_model(
398
419
  self,
399
420
  path: str,
400
- endpoint: Callable[..., Any],
401
- route: 'APIRoute',
402
- name: Optional[str] = None,
403
- methods: Optional[List[str]] = None,
404
-
405
- status_code: Optional[int] = None,
406
- response_model: Optional[ma.Schema] = None,
407
- response_class: Type[Response] = JSONResponse,
421
+ endpoint: Callable[..., ma.Schema | mf.Field | Response | None],
422
+ route: 'APIRoute | APIWebSocketRoute',
423
+ name: str | None = None,
424
+ methods: Sequence[str] | None = None,
425
+
426
+ status_code: int | None = None,
427
+ response_model: ma.Schema | type[ma.Schema | MaDataclassProtocol] | None = None,
428
+ response_class: type[Response] = JSONResponse,
408
429
  ) -> EndpointModel:
409
430
  params = self._get_params(endpoint, path)
410
431
 
411
- response_model = create_response_model(response_model or get_typed_return_annotation(endpoint))
432
+ response_model = create_response_model(response_model or get_typed_return_annotation(endpoint)) # type: ignore
412
433
 
413
434
  return EndpointModel(
414
435
  path=path,
@@ -425,9 +446,9 @@ class EndpointMixin:
425
446
 
426
447
 
427
448
  def safe_merge_params(
428
- left: Dict[str, Param],
429
- right: Dict[str, Param],
430
- ) -> Dict[str, Param]:
449
+ left: dict[str, Param],
450
+ right: dict[str, Param],
451
+ ) -> dict[str, Param]:
431
452
  res = left.copy()
432
453
  for name, param in right.items():
433
454
  if name not in left:
@@ -439,9 +460,9 @@ def safe_merge_params(
439
460
 
440
461
 
441
462
  def safe_merge_all_params(
442
- left: Dict[ParamType, Dict[str, Param]],
443
- right: Dict[ParamType, Dict[str, Param]],
444
- ) -> Dict[ParamType, Dict[str, Param]]:
463
+ left: dict[ParamType, dict[str, Param]],
464
+ right: dict[ParamType, dict[str, Param]],
465
+ ) -> dict[ParamType, dict[str, Param]]:
445
466
  res = {
446
467
  param_type: safe_merge_params(left[param_type], right[param_type])
447
468
  for param_type in ParamType
@@ -451,12 +472,13 @@ def safe_merge_all_params(
451
472
 
452
473
 
453
474
  def flatten_parameters(
454
- params: Dict[ParamType, Dict[str, Param]],
455
- ) -> Dict[ParamType, Dict[str, Param]]:
475
+ params: dict[ParamType, dict[str, Param]],
476
+ ) -> dict[ParamType, dict[str, Param]]:
456
477
  # flat_params = {param_type: {} for param_type in ParamType}
457
478
  flat_params = params.copy()
479
+
458
480
  for param in params[ParamType.resolved].values():
459
- if not param.resolver_params:
481
+ if not (isinstance(param, ResolvedParam) and param.resolver_params):
460
482
  continue
461
483
 
462
484
  flat_params = safe_merge_all_params(flat_params, param.resolver_params)
starmallow/endpoints.py CHANGED
@@ -1,4 +1,5 @@
1
- from typing import Any, ClassVar, Collection, Optional
1
+ from collections.abc import Collection
2
+ from typing import Any, ClassVar
2
3
 
3
4
  HTTP_METHOD_FUNCS = ("GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "OPTIONS")
4
5
 
@@ -8,7 +9,7 @@ class APIHTTPEndpoint:
8
9
  #: The methods this view is registered for. Uses the same default
9
10
  #: (``["GET", "HEAD", "OPTIONS"]``) as ``route`` and
10
11
  #: ``add_url_rule`` by default.
11
- methods: ClassVar[Optional[Collection[str]]] = None
12
+ methods: ClassVar[Collection[str] | None] = None
12
13
 
13
14
  def __init_subclass__(cls, **kwargs: Any) -> None:
14
15
 
starmallow/exceptions.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, List, Union
1
+ from typing import Any
2
2
 
3
3
  from starlette.exceptions import HTTPException
4
4
  from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, HTTP_500_INTERNAL_SERVER_ERROR
@@ -7,7 +7,7 @@ from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, HTTP_500_INTERNAL_SE
7
7
  class RequestValidationError(HTTPException):
8
8
  def __init__(
9
9
  self,
10
- errors: Dict[str, Union[Any, List, Dict]],
10
+ errors: dict[str, Any | list | dict],
11
11
  ) -> None:
12
12
  super().__init__(status_code=HTTP_422_UNPROCESSABLE_ENTITY)
13
13
  self.errors = errors
@@ -16,7 +16,7 @@ class RequestValidationError(HTTPException):
16
16
  class WebSocketRequestValidationError(HTTPException):
17
17
  def __init__(
18
18
  self,
19
- errors: Dict[str, Union[Any, List, Dict]],
19
+ errors: dict[str, Any | list | dict],
20
20
  ) -> None:
21
21
  super().__init__(status_code=HTTP_422_UNPROCESSABLE_ENTITY)
22
22
  self.errors = errors
@@ -97,7 +97,7 @@ class OpenAPIConverter(ApiSpecOpenAPIConverter):
97
97
  spec=spec,
98
98
  )
99
99
  self.add_attribute_function(self.field2title)
100
- self.add_attribute_function(self.field2uniqueItems)
100
+ self.add_attribute_function(self.field2unique_items)
101
101
  self.add_attribute_function(self.field2enum)
102
102
  self.add_attribute_function(self.field2union)
103
103
 
@@ -187,7 +187,7 @@ class OpenAPIConverter(ApiSpecOpenAPIConverter):
187
187
 
188
188
  return ret
189
189
 
190
- def field2uniqueItems(self: FieldConverterMixin, field: mf.Field, **kwargs: Any) -> dict:
190
+ def field2unique_items(self: FieldConverterMixin, field: mf.Field, **kwargs: Any) -> dict:
191
191
  ret = {}
192
192
 
193
193
  # If this type isn't directly in the field mapping then check the
@@ -203,10 +203,7 @@ class OpenAPIConverter(ApiSpecOpenAPIConverter):
203
203
  ret = {}
204
204
 
205
205
  if isinstance(field, mf.Enum):
206
- if field.by_value:
207
- choices = [x.value for x in field.enum]
208
- else:
209
- choices = list(field.enum.__members__)
206
+ choices = [x.value for x in field.enum] if field.by_value else list(field.enum.__members__)
210
207
 
211
208
  if choices:
212
209
  ret['enum'] = choices
@@ -242,7 +239,7 @@ class OpenAPIConverter(ApiSpecOpenAPIConverter):
242
239
 
243
240
  return ret
244
241
 
245
- # Overrice to add 'deprecated' support
242
+ # Override to add 'deprecated' support
246
243
  def _field2parameter(
247
244
  self, field: mf.Field, *, name: str, location: str,
248
245
  ):
@@ -289,19 +286,19 @@ class OpenAPIConverter(ApiSpecOpenAPIConverter):
289
286
  :rtype: dict, a JSON Schema Object
290
287
  """
291
288
  fields = get_fields(schema)
292
- Meta = getattr(schema, "Meta", None)
289
+ meta: ma.Schema.Meta = getattr(schema, "Meta", None)
293
290
  partial = getattr(schema, "partial", None)
294
291
 
295
292
  jsonschema = self.fields2jsonschema(fields, partial=partial)
296
293
 
297
- if hasattr(Meta, "title"):
298
- jsonschema["title"] = Meta.title
294
+ if hasattr(meta, "title"):
295
+ jsonschema["title"] = meta.title
299
296
  else:
300
297
  jsonschema['title'] = schema.__class__.__name__
301
298
 
302
- if hasattr(Meta, "description"):
303
- jsonschema["description"] = Meta.description
304
- if hasattr(Meta, "unknown") and Meta.unknown != ma.EXCLUDE:
305
- jsonschema["additionalProperties"] = Meta.unknown == ma.INCLUDE
299
+ if hasattr(meta, "description"):
300
+ jsonschema["description"] = meta.description
301
+ if hasattr(meta, "unknown") and meta.unknown != ma.EXCLUDE:
302
+ jsonschema["additionalProperties"] = meta.unknown == ma.INCLUDE
306
303
 
307
304
  return jsonschema
starmallow/fields.py CHANGED
@@ -7,19 +7,19 @@ import marshmallow.fields as mf
7
7
  from .delimited_field import DelimitedFieldMixin
8
8
 
9
9
 
10
- class DelimitedListUUID(DelimitedFieldMixin, mf.List):
10
+ class DelimitedListUUID(DelimitedFieldMixin, mf.List): # type: ignore
11
11
  def __init__(self, *, delimiter: str | None = None, **kwargs):
12
12
  self.delimiter = delimiter or self.delimiter
13
13
  super().__init__(mf.UUID(), **kwargs)
14
14
 
15
15
 
16
- class DelimitedListStr(DelimitedFieldMixin, mf.List):
16
+ class DelimitedListStr(DelimitedFieldMixin, mf.List): # type: ignore
17
17
  def __init__(self, *, delimiter: str | None = None, **kwargs):
18
18
  self.delimiter = delimiter or self.delimiter
19
19
  super().__init__(mf.String(), **kwargs)
20
20
 
21
21
 
22
- class DelimitedListInt(DelimitedFieldMixin, mf.List):
22
+ class DelimitedListInt(DelimitedFieldMixin, mf.List): # type: ignore
23
23
  def __init__(self, *, delimiter: str | None = None, **kwargs):
24
24
  self.delimiter = delimiter or self.delimiter
25
25
  super().__init__(mf.Integer(), **kwargs)
starmallow/generics.py ADDED
@@ -0,0 +1,34 @@
1
+ import inspect
2
+
3
+ from typing_inspect import is_generic_type
4
+
5
+
6
+ def get_orig_class(obj):
7
+ """
8
+ Allows you got get the runtime origin class inside __init__
9
+
10
+ Near duplicate of https://github.com/Stewori/pytypes/blob/master/pytypes/type_util.py#L182
11
+ """
12
+ try:
13
+ return object.__getattribute__(obj, "__orig_class__")
14
+ except AttributeError:
15
+ cls = object.__getattribute__(obj, "__class__")
16
+ if is_generic_type(cls):
17
+ # Searching from index 1 is sufficient: At 0 is get_orig_class, at 1 is the caller.
18
+ frame = inspect.currentframe()
19
+ if frame is None:
20
+ raise ValueError('Frame does not have a caller') from None
21
+
22
+ frame = frame.f_back
23
+ try:
24
+ while frame:
25
+ try:
26
+ res = frame.f_locals["self"]
27
+ if res.__origin__ is cls:
28
+ return res
29
+ except (KeyError, AttributeError):
30
+ frame = frame.f_back
31
+ finally:
32
+ del frame
33
+
34
+ raise
@@ -1,5 +1,4 @@
1
1
  from contextlib import AsyncExitStack
2
- from typing import Optional
3
2
 
4
3
  from starlette.types import ASGIApp, Receive, Scope, Send
5
4
 
@@ -11,7 +10,7 @@ class AsyncExitStackMiddleware:
11
10
 
12
11
  async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
13
12
  if AsyncExitStack:
14
- dependency_exception: Optional[Exception] = None
13
+ dependency_exception: Exception | None = None
15
14
  async with AsyncExitStack() as stack:
16
15
  scope[self.context_name] = stack
17
16
  try: