starmallow 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
starmallow/endpoint.py CHANGED
@@ -1,30 +1,27 @@
1
1
  import inspect
2
2
  import logging
3
+ from collections.abc import Callable, Iterable, Mapping, Sequence
3
4
  from dataclasses import dataclass, field
4
5
  from typing import (
5
6
  TYPE_CHECKING,
7
+ Annotated,
6
8
  Any,
7
- Callable,
8
- Dict,
9
- Iterable,
10
- List,
11
- Mapping,
12
9
  NewType,
13
- Optional,
14
- Type,
15
- Union,
10
+ cast,
16
11
  get_args,
17
12
  get_origin,
18
13
  )
19
14
 
20
15
  import marshmallow as ma
21
16
  import marshmallow.fields as mf
17
+ import typing_inspect
18
+ from marshmallow.types import StrSequenceOrSet
22
19
  from marshmallow.utils import missing as missing_
20
+ from marshmallow_dataclass2 import class_schema, is_generic_alias_of_dataclass
23
21
  from starlette.background import BackgroundTasks
24
22
  from starlette.requests import HTTPConnection, Request
25
23
  from starlette.responses import Response
26
24
  from starlette.websockets import WebSocket
27
- from typing_extensions import Annotated
28
25
 
29
26
  from starmallow.params import (
30
27
  Body,
@@ -42,20 +39,21 @@ from starmallow.params import (
42
39
  from starmallow.responses import JSONResponse
43
40
  from starmallow.security.base import SecurityBaseResolver
44
41
  from starmallow.utils import (
42
+ MaDataclassProtocol,
45
43
  create_response_model,
46
44
  get_model_field,
47
45
  get_path_param_names,
48
46
  get_typed_return_annotation,
49
47
  get_typed_signature,
50
48
  is_marshmallow_dataclass,
51
- is_marshmallow_field,
49
+ is_marshmallow_field_or_generic,
52
50
  is_marshmallow_schema,
53
51
  is_optional,
54
52
  lenient_issubclass,
55
53
  )
56
54
 
57
55
  if TYPE_CHECKING:
58
- from starmallow.routing import APIRoute
56
+ from starmallow.routing import APIRoute, APIWebSocketRoute
59
57
 
60
58
  logger = logging.getLogger(__name__)
61
59
 
@@ -69,52 +67,52 @@ STARMALLOW_PARAM_TYPES = (
69
67
 
70
68
  @dataclass
71
69
  class EndpointModel:
72
- params: Optional[Dict[ParamType, Dict[str, Param]]] = field(default_factory=dict)
73
- flat_params: Optional[Dict[ParamType, Dict[str, Param]]] = field(default_factory=dict)
74
- name: Optional[str] = None
75
- path: Optional[str] = None
76
- methods: Optional[List[str]] = None
77
- call: Optional[Callable[..., Any]] = None
78
- response_model: Optional[ma.Schema | mf.Field] = None
79
- response_class: Type[Response] = JSONResponse
80
- status_code: Optional[int] = None
81
- route: 'APIRoute' = None
70
+ path: str
71
+ call: Callable[..., Any]
72
+ route: 'APIRoute | APIWebSocketRoute'
73
+ params: dict[ParamType, dict[str, Param]] = field(default_factory=dict)
74
+ flat_params: dict[ParamType, dict[str, Param]] = field(default_factory=dict)
75
+ name: str | None = None
76
+ methods: Sequence[str] | None = None
77
+ response_model: ma.Schema | type[ma.Schema | MaDataclassProtocol] | mf.Field | None = None
78
+ response_class: type[Response] = JSONResponse
79
+ status_code: int | None = None
82
80
 
83
81
  @property
84
- def path_params(self) -> Dict[str, Path] | None:
85
- return self.flat_params.get(ParamType.path)
82
+ def path_params(self) -> dict[str, Path] | None:
83
+ return cast(dict[str, Path], self.flat_params.get(ParamType.path))
86
84
 
87
85
  @property
88
- def query_params(self) -> Dict[str, Query] | None:
89
- return self.flat_params.get(ParamType.query)
86
+ def query_params(self) -> dict[str, Query] | None:
87
+ return cast(dict[str, Query], self.flat_params.get(ParamType.query))
90
88
 
91
89
  @property
92
- def header_params(self) -> Dict[str, Header] | None:
93
- return self.flat_params.get(ParamType.header)
90
+ def header_params(self) -> dict[str, Header] | None:
91
+ return cast(dict[str, Header], self.flat_params.get(ParamType.header))
94
92
 
95
93
  @property
96
- def cookie_params(self) -> Dict[str, Cookie] | None:
97
- return self.flat_params.get(ParamType.cookie)
94
+ def cookie_params(self) -> dict[str, Cookie] | None:
95
+ return cast(dict[str, Cookie], self.flat_params.get(ParamType.cookie))
98
96
 
99
97
  @property
100
- def body_params(self) -> Dict[str, Body] | None:
101
- return self.flat_params.get(ParamType.body)
98
+ def body_params(self) -> dict[str, Body] | None:
99
+ return cast(dict[str, Body], self.flat_params.get(ParamType.body))
102
100
 
103
101
  @property
104
- def form_params(self) -> Dict[str, Form] | None:
105
- return self.flat_params.get(ParamType.form)
102
+ def form_params(self) -> dict[str, Form] | None:
103
+ return cast(dict[str, Form], self.flat_params.get(ParamType.form))
106
104
 
107
105
  @property
108
- def non_field_params(self) -> Dict[str, NoParam] | None:
109
- return self.flat_params.get(ParamType.noparam)
106
+ def non_field_params(self) -> dict[str, NoParam] | None:
107
+ return cast(dict[str, NoParam], self.flat_params.get(ParamType.noparam))
110
108
 
111
109
  @property
112
- def resolved_params(self) -> Dict[str, ResolvedParam] | None:
113
- return self.flat_params.get(ParamType.resolved)
110
+ def resolved_params(self) -> dict[str, ResolvedParam] | None:
111
+ return cast(dict[str, ResolvedParam], self.flat_params.get(ParamType.resolved))
114
112
 
115
113
  @property
116
- def security_params(self) -> Dict[str, Security] | None:
117
- return self.flat_params.get(ParamType.security)
114
+ def security_params(self) -> dict[str, Security] | None:
115
+ return cast(dict[str, Security], self.flat_params.get(ParamType.security))
118
116
 
119
117
 
120
118
  class SchemaMeta:
@@ -128,20 +126,20 @@ class SchemaModel(ma.Schema):
128
126
  schema: ma.Schema,
129
127
  load_default: Any = missing_,
130
128
  required: bool = True,
131
- metadata: Dict[str, Any] = None,
129
+ metadata: dict[str, Any] | None = None,
132
130
  **kwargs,
133
131
  ) -> None:
134
132
  self.schema = schema
135
133
  self.load_default = load_default
136
134
  self.required = required
137
- self.title = metadata.get('title')
135
+ self.title = metadata.get('title') if metadata else None
138
136
  self.metadata = metadata
139
137
  self.kwargs = kwargs
140
138
 
141
139
  if not getattr(schema.Meta, "title", None):
142
140
  if schema.Meta is ma.Schema.Meta:
143
141
  # Don't override global Meta object's title
144
- schema.Meta = SchemaMeta(self.title)
142
+ schema.Meta = SchemaMeta(self.title) # type: ignore
145
143
  else:
146
144
  schema.Meta.title = self.title
147
145
 
@@ -161,14 +159,11 @@ class SchemaModel(ma.Schema):
161
159
 
162
160
  def load(
163
161
  self,
164
- data: Union[
165
- Mapping[str, Any],
166
- Iterable[Mapping[str, Any]],
167
- ],
162
+ data: Mapping[str, Any] | Iterable[Mapping[str, Any]],
168
163
  *,
169
- many: Optional[bool] = None,
170
- partial: Optional[bool] = None,
171
- unknown: Optional[str] = None,
164
+ many: bool | None = None,
165
+ partial: bool | StrSequenceOrSet | None = None,
166
+ unknown: str | None = None,
172
167
  ) -> Any:
173
168
  if not data and self.load_default:
174
169
  return self.load_default
@@ -184,7 +179,7 @@ class EndpointMixin:
184
179
  parameter: Param,
185
180
  parameter_name: str,
186
181
  default_value: Any,
187
- ) -> Union[ma.Schema, mf.Field]:
182
+ ) -> ma.Schema | mf.Field | None:
188
183
  model = getattr(parameter, 'model', None) or type_annotation
189
184
  if isinstance(model, SchemaModel):
190
185
  return model
@@ -243,18 +238,22 @@ class EndpointMixin:
243
238
  if is_marshmallow_dataclass(model):
244
239
  model = model.Schema
245
240
 
246
- if isinstance(model, NewType) and getattr(model, '_marshmallow_field', None):
247
- return model._marshmallow_field(**kwargs)
241
+ if is_generic_alias_of_dataclass(model): # type: ignore
242
+ model = class_schema(model) # type: ignore
243
+
244
+ mmf = getattr(model, '_marshmallow_field', None)
245
+ if isinstance(model, NewType) and mmf and issubclass(mmf, mf.Field):
246
+ return mmf(**kwargs)
248
247
  elif is_marshmallow_schema(model):
249
248
  return SchemaModel(model() if inspect.isclass(model) else model, **kwargs)
250
- elif is_marshmallow_field(model):
251
- if inspect.isclass(model):
249
+ elif is_marshmallow_field_or_generic(model):
250
+ if not isinstance(model, mf.Field): # inspect.isclass(model):
252
251
  model = model()
253
252
 
254
253
  if model.load_default is not None and model.load_default != kwargs.get('load_default', ma.missing):
255
254
  logger.warning(
256
255
  f"'{parameter_name}' model and annotation have different 'load_default' values."
257
- + f" {model.load_default} <> {kwargs.get('load_default', ma.missing)}",
256
+ f" {model.load_default} <> {kwargs.get('load_default', ma.missing)}",
258
257
  )
259
258
 
260
259
  model.required = kwargs['required']
@@ -262,6 +261,11 @@ class EndpointMixin:
262
261
  model.metadata.update(kwargs['metadata'])
263
262
 
264
263
  return model
264
+
265
+ # elif is_marshmallow_field_or_generic(model):
266
+ # if isinstance(model, mf.Field):
267
+ # model if isinstance(model, mf.Field) else model()
268
+
265
269
  else:
266
270
  try:
267
271
  return get_model_field(model, **kwargs)
@@ -278,11 +282,19 @@ class EndpointMixin:
278
282
 
279
283
  return resolved_param
280
284
 
285
+ def get_security_param(self, resolved_param: Security, annotation: Any, path: str) -> Security:
286
+ if resolved_param.resolver is None:
287
+ resolved_param.resolver = annotation
288
+
289
+ resolved_param.resolver_params = self._get_params(resolved_param.resolver, path=path)
290
+
291
+ return resolved_param
292
+
281
293
  def _get_params(
282
294
  self,
283
295
  func: Callable[..., Any],
284
296
  path: str,
285
- ) -> Dict[ParamType, Dict[str, Param]]:
297
+ ) -> dict[ParamType, dict[str, Param]]:
286
298
  path_param_names = get_path_param_names(path)
287
299
  params = {param_type: {} for param_type in ParamType}
288
300
  for name, parameter in get_typed_signature(func).parameters.items():
@@ -309,8 +321,7 @@ class EndpointMixin:
309
321
  ]
310
322
  if starmallow_annotations:
311
323
  assert starmallow_param is inspect._empty, (
312
- "Cannot specify `Param` in `Annotated` and default value"
313
- f" together for {name!r}"
324
+ f"Cannot specify `Param` in `Annotated` and default value together for {name!r}"
314
325
  )
315
326
 
316
327
  starmallow_param = starmallow_annotations[-1]
@@ -321,6 +332,20 @@ class EndpointMixin:
321
332
  ):
322
333
  default_value = starmallow_param.default
323
334
 
335
+ field_annotations = [
336
+ arg
337
+ for arg in annotated_args
338
+ if (
339
+ isinstance(arg, mf.Field)
340
+ or lenient_issubclass(arg, mf.Field)
341
+ or (
342
+ isinstance(arg, typing_inspect.typingGenericAlias)
343
+ and lenient_issubclass(get_origin(arg), mf.Field)
344
+ )
345
+ )
346
+ ]
347
+ if field_annotations:
348
+ type_annotation = field_annotations[-1]
324
349
  if (
325
350
  # Skip 'self' in APIHTTPEndpoint functions
326
351
  (name == 'self' and '.' in func.__qualname__)
@@ -328,11 +353,11 @@ class EndpointMixin:
328
353
  ):
329
354
  continue
330
355
  elif isinstance(starmallow_param, Security):
331
- security_param: Security = self.get_resolved_param(starmallow_param, type_annotation, path=path)
356
+ security_param = self.get_security_param(starmallow_param, type_annotation, path=path)
332
357
  params[ParamType.security][name] = security_param
333
358
  continue
334
359
  elif isinstance(starmallow_param, ResolvedParam):
335
- resolved_param: ResolvedParam = self.get_resolved_param(starmallow_param, type_annotation, path=path)
360
+ resolved_param = self.get_resolved_param(starmallow_param, type_annotation, path=path)
336
361
 
337
362
  # Allow `ResolvedParam(HTTPBearer())` - treat as securty param
338
363
  if isinstance(resolved_param.resolver, SecurityBaseResolver):
@@ -355,7 +380,7 @@ class EndpointMixin:
355
380
  continue
356
381
 
357
382
  model = self._get_param_model(type_annotation, starmallow_param, name, default_value)
358
- model.name = name
383
+ model.name = name # type: ignore
359
384
 
360
385
  if isinstance(starmallow_param, Param):
361
386
  # Create new field_info with processed model
@@ -393,18 +418,18 @@ class EndpointMixin:
393
418
  def get_endpoint_model(
394
419
  self,
395
420
  path: str,
396
- endpoint: Callable[..., Any],
397
- route: 'APIRoute',
398
- name: Optional[str] = None,
399
- methods: Optional[List[str]] = None,
400
-
401
- status_code: Optional[int] = None,
402
- response_model: Optional[ma.Schema] = None,
403
- response_class: Type[Response] = JSONResponse,
421
+ endpoint: Callable[..., ma.Schema | mf.Field | Response | None],
422
+ route: 'APIRoute | APIWebSocketRoute',
423
+ name: str | None = None,
424
+ methods: Sequence[str] | None = None,
425
+
426
+ status_code: int | None = None,
427
+ response_model: ma.Schema | type[ma.Schema | MaDataclassProtocol] | None = None,
428
+ response_class: type[Response] = JSONResponse,
404
429
  ) -> EndpointModel:
405
430
  params = self._get_params(endpoint, path)
406
431
 
407
- response_model = create_response_model(response_model or get_typed_return_annotation(endpoint))
432
+ response_model = create_response_model(response_model or get_typed_return_annotation(endpoint)) # type: ignore
408
433
 
409
434
  return EndpointModel(
410
435
  path=path,
@@ -421,9 +446,9 @@ class EndpointMixin:
421
446
 
422
447
 
423
448
  def safe_merge_params(
424
- left: Dict[str, Param],
425
- right: Dict[str, Param],
426
- ) -> Dict[str, Param]:
449
+ left: dict[str, Param],
450
+ right: dict[str, Param],
451
+ ) -> dict[str, Param]:
427
452
  res = left.copy()
428
453
  for name, param in right.items():
429
454
  if name not in left:
@@ -435,9 +460,9 @@ def safe_merge_params(
435
460
 
436
461
 
437
462
  def safe_merge_all_params(
438
- left: Dict[ParamType, Dict[str, Param]],
439
- right: Dict[ParamType, Dict[str, Param]],
440
- ) -> Dict[ParamType, Dict[str, Param]]:
463
+ left: dict[ParamType, dict[str, Param]],
464
+ right: dict[ParamType, dict[str, Param]],
465
+ ) -> dict[ParamType, dict[str, Param]]:
441
466
  res = {
442
467
  param_type: safe_merge_params(left[param_type], right[param_type])
443
468
  for param_type in ParamType
@@ -447,12 +472,13 @@ def safe_merge_all_params(
447
472
 
448
473
 
449
474
  def flatten_parameters(
450
- params: Dict[ParamType, Dict[str, Param]],
451
- ) -> Dict[ParamType, Dict[str, Param]]:
475
+ params: dict[ParamType, dict[str, Param]],
476
+ ) -> dict[ParamType, dict[str, Param]]:
452
477
  # flat_params = {param_type: {} for param_type in ParamType}
453
478
  flat_params = params.copy()
479
+
454
480
  for param in params[ParamType.resolved].values():
455
- if not param.resolver_params:
481
+ if not (isinstance(param, ResolvedParam) and param.resolver_params):
456
482
  continue
457
483
 
458
484
  flat_params = safe_merge_all_params(flat_params, param.resolver_params)
starmallow/endpoints.py CHANGED
@@ -1,4 +1,5 @@
1
- from typing import Any, ClassVar, Collection, Optional
1
+ from collections.abc import Collection
2
+ from typing import Any, ClassVar
2
3
 
3
4
  HTTP_METHOD_FUNCS = ("GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "OPTIONS")
4
5
 
@@ -8,7 +9,7 @@ class APIHTTPEndpoint:
8
9
  #: The methods this view is registered for. Uses the same default
9
10
  #: (``["GET", "HEAD", "OPTIONS"]``) as ``route`` and
10
11
  #: ``add_url_rule`` by default.
11
- methods: ClassVar[Optional[Collection[str]]] = None
12
+ methods: ClassVar[Collection[str] | None] = None
12
13
 
13
14
  def __init_subclass__(cls, **kwargs: Any) -> None:
14
15
 
starmallow/exceptions.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, List, Union
1
+ from typing import Any
2
2
 
3
3
  from starlette.exceptions import HTTPException
4
4
  from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, HTTP_500_INTERNAL_SERVER_ERROR
@@ -7,7 +7,7 @@ from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, HTTP_500_INTERNAL_SE
7
7
  class RequestValidationError(HTTPException):
8
8
  def __init__(
9
9
  self,
10
- errors: Dict[str, Union[Any, List, Dict]],
10
+ errors: dict[str, Any | list | dict],
11
11
  ) -> None:
12
12
  super().__init__(status_code=HTTP_422_UNPROCESSABLE_ENTITY)
13
13
  self.errors = errors
@@ -16,7 +16,7 @@ class RequestValidationError(HTTPException):
16
16
  class WebSocketRequestValidationError(HTTPException):
17
17
  def __init__(
18
18
  self,
19
- errors: Dict[str, Union[Any, List, Dict]],
19
+ errors: dict[str, Any | list | dict],
20
20
  ) -> None:
21
21
  super().__init__(status_code=HTTP_422_UNPROCESSABLE_ENTITY)
22
22
  self.errors = errors
@@ -2,7 +2,7 @@ from typing import Any
2
2
 
3
3
  import marshmallow as ma
4
4
  import marshmallow.fields as mf
5
- import marshmallow_dataclass.collection_field as collection_field
5
+ import marshmallow_dataclass2.collection_field as collection_field
6
6
  from apispec import APISpec
7
7
  from apispec.ext.marshmallow.common import get_fields
8
8
  from apispec.ext.marshmallow.field_converter import (
@@ -12,9 +12,9 @@ from apispec.ext.marshmallow.field_converter import (
12
12
  )
13
13
  from apispec.ext.marshmallow.openapi import OpenAPIConverter as ApiSpecOpenAPIConverter
14
14
  from marshmallow.utils import is_collection
15
+ from marshmallow_dataclass2.union_field import Union as UnionField
15
16
  from packaging.version import Version
16
17
 
17
- from starmallow.union_field import Union as UnionField
18
18
  from starmallow.utils import MARSHMALLOW_ITERABLES
19
19
 
20
20
  # marshmallow field => (JSON Schema type, format)
@@ -97,7 +97,7 @@ class OpenAPIConverter(ApiSpecOpenAPIConverter):
97
97
  spec=spec,
98
98
  )
99
99
  self.add_attribute_function(self.field2title)
100
- self.add_attribute_function(self.field2uniqueItems)
100
+ self.add_attribute_function(self.field2unique_items)
101
101
  self.add_attribute_function(self.field2enum)
102
102
  self.add_attribute_function(self.field2union)
103
103
 
@@ -187,7 +187,7 @@ class OpenAPIConverter(ApiSpecOpenAPIConverter):
187
187
 
188
188
  return ret
189
189
 
190
- def field2uniqueItems(self: FieldConverterMixin, field: mf.Field, **kwargs: Any) -> dict:
190
+ def field2unique_items(self: FieldConverterMixin, field: mf.Field, **kwargs: Any) -> dict:
191
191
  ret = {}
192
192
 
193
193
  # If this type isn't directly in the field mapping then check the
@@ -203,10 +203,7 @@ class OpenAPIConverter(ApiSpecOpenAPIConverter):
203
203
  ret = {}
204
204
 
205
205
  if isinstance(field, mf.Enum):
206
- if field.by_value:
207
- choices = [x.value for x in field.enum]
208
- else:
209
- choices = list(field.enum.__members__)
206
+ choices = [x.value for x in field.enum] if field.by_value else list(field.enum.__members__)
210
207
 
211
208
  if choices:
212
209
  ret['enum'] = choices
@@ -242,7 +239,7 @@ class OpenAPIConverter(ApiSpecOpenAPIConverter):
242
239
 
243
240
  return ret
244
241
 
245
- # Overrice to add 'deprecated' support
242
+ # Override to add 'deprecated' support
246
243
  def _field2parameter(
247
244
  self, field: mf.Field, *, name: str, location: str,
248
245
  ):
@@ -289,19 +286,19 @@ class OpenAPIConverter(ApiSpecOpenAPIConverter):
289
286
  :rtype: dict, a JSON Schema Object
290
287
  """
291
288
  fields = get_fields(schema)
292
- Meta = getattr(schema, "Meta", None)
289
+ meta: ma.Schema.Meta = getattr(schema, "Meta", None)
293
290
  partial = getattr(schema, "partial", None)
294
291
 
295
292
  jsonschema = self.fields2jsonschema(fields, partial=partial)
296
293
 
297
- if hasattr(Meta, "title"):
298
- jsonschema["title"] = Meta.title
294
+ if hasattr(meta, "title"):
295
+ jsonschema["title"] = meta.title
299
296
  else:
300
297
  jsonschema['title'] = schema.__class__.__name__
301
298
 
302
- if hasattr(Meta, "description"):
303
- jsonschema["description"] = Meta.description
304
- if hasattr(Meta, "unknown") and Meta.unknown != ma.EXCLUDE:
305
- jsonschema["additionalProperties"] = Meta.unknown == ma.INCLUDE
299
+ if hasattr(meta, "description"):
300
+ jsonschema["description"] = meta.description
301
+ if hasattr(meta, "unknown") and meta.unknown != ma.EXCLUDE:
302
+ jsonschema["additionalProperties"] = meta.unknown == ma.INCLUDE
306
303
 
307
304
  return jsonschema
starmallow/fields.py CHANGED
@@ -7,19 +7,19 @@ import marshmallow.fields as mf
7
7
  from .delimited_field import DelimitedFieldMixin
8
8
 
9
9
 
10
- class DelimitedListUUID(DelimitedFieldMixin, mf.List):
10
+ class DelimitedListUUID(DelimitedFieldMixin, mf.List): # type: ignore
11
11
  def __init__(self, *, delimiter: str | None = None, **kwargs):
12
12
  self.delimiter = delimiter or self.delimiter
13
13
  super().__init__(mf.UUID(), **kwargs)
14
14
 
15
15
 
16
- class DelimitedListStr(DelimitedFieldMixin, mf.List):
16
+ class DelimitedListStr(DelimitedFieldMixin, mf.List): # type: ignore
17
17
  def __init__(self, *, delimiter: str | None = None, **kwargs):
18
18
  self.delimiter = delimiter or self.delimiter
19
19
  super().__init__(mf.String(), **kwargs)
20
20
 
21
21
 
22
- class DelimitedListInt(DelimitedFieldMixin, mf.List):
22
+ class DelimitedListInt(DelimitedFieldMixin, mf.List): # type: ignore
23
23
  def __init__(self, *, delimiter: str | None = None, **kwargs):
24
24
  self.delimiter = delimiter or self.delimiter
25
25
  super().__init__(mf.Integer(), **kwargs)
starmallow/generics.py ADDED
@@ -0,0 +1,34 @@
1
+ import inspect
2
+
3
+ from typing_inspect import is_generic_type
4
+
5
+
6
+ def get_orig_class(obj):
7
+ """
8
+ Allows you got get the runtime origin class inside __init__
9
+
10
+ Near duplicate of https://github.com/Stewori/pytypes/blob/master/pytypes/type_util.py#L182
11
+ """
12
+ try:
13
+ return object.__getattribute__(obj, "__orig_class__")
14
+ except AttributeError:
15
+ cls = object.__getattribute__(obj, "__class__")
16
+ if is_generic_type(cls):
17
+ # Searching from index 1 is sufficient: At 0 is get_orig_class, at 1 is the caller.
18
+ frame = inspect.currentframe()
19
+ if frame is None:
20
+ raise ValueError('Frame does not have a caller') from None
21
+
22
+ frame = frame.f_back
23
+ try:
24
+ while frame:
25
+ try:
26
+ res = frame.f_locals["self"]
27
+ if res.__origin__ is cls:
28
+ return res
29
+ except (KeyError, AttributeError):
30
+ frame = frame.f_back
31
+ finally:
32
+ del frame
33
+
34
+ raise
@@ -1,5 +1,4 @@
1
1
  from contextlib import AsyncExitStack
2
- from typing import Optional
3
2
 
4
3
  from starlette.types import ASGIApp, Receive, Scope, Send
5
4
 
@@ -11,7 +10,7 @@ class AsyncExitStackMiddleware:
11
10
 
12
11
  async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
13
12
  if AsyncExitStack:
14
- dependency_exception: Optional[Exception] = None
13
+ dependency_exception: Exception | None = None
15
14
  async with AsyncExitStack() as stack:
16
15
  scope[self.context_name] = stack
17
16
  try: