anydi 0.32.2__tar.gz → 0.33.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {anydi-0.32.2 → anydi-0.33.0}/PKG-INFO +1 -1
- {anydi-0.32.2 → anydi-0.33.0}/anydi/_container.py +132 -17
- {anydi-0.32.2 → anydi-0.33.0}/anydi/_context.py +43 -23
- {anydi-0.32.2 → anydi-0.33.0}/anydi/_provider.py +1 -1
- {anydi-0.32.2 → anydi-0.33.0}/anydi/_types.py +6 -0
- {anydi-0.32.2 → anydi-0.33.0}/anydi/ext/_utils.py +2 -2
- {anydi-0.32.2 → anydi-0.33.0}/anydi/ext/fastapi.py +1 -1
- {anydi-0.32.2 → anydi-0.33.0}/anydi/ext/faststream.py +1 -1
- {anydi-0.32.2 → anydi-0.33.0}/pyproject.toml +3 -3
- anydi-0.32.2/anydi/_injector.py +0 -94
- {anydi-0.32.2 → anydi-0.33.0}/LICENSE +0 -0
- {anydi-0.32.2 → anydi-0.33.0}/README.md +0 -0
- {anydi-0.32.2 → anydi-0.33.0}/anydi/__init__.py +0 -0
- {anydi-0.32.2 → anydi-0.33.0}/anydi/_logger.py +0 -0
- {anydi-0.32.2 → anydi-0.33.0}/anydi/_module.py +0 -0
- {anydi-0.32.2 → anydi-0.33.0}/anydi/_scanner.py +0 -0
- {anydi-0.32.2 → anydi-0.33.0}/anydi/_utils.py +0 -0
- {anydi-0.32.2 → anydi-0.33.0}/anydi/ext/__init__.py +0 -0
- {anydi-0.32.2 → anydi-0.33.0}/anydi/ext/django/__init__.py +0 -0
- {anydi-0.32.2 → anydi-0.33.0}/anydi/ext/django/_container.py +0 -0
- {anydi-0.32.2 → anydi-0.33.0}/anydi/ext/django/_settings.py +0 -0
- {anydi-0.32.2 → anydi-0.33.0}/anydi/ext/django/_utils.py +0 -0
- {anydi-0.32.2 → anydi-0.33.0}/anydi/ext/django/apps.py +0 -0
- {anydi-0.32.2 → anydi-0.33.0}/anydi/ext/django/middleware.py +0 -0
- {anydi-0.32.2 → anydi-0.33.0}/anydi/ext/django/ninja/__init__.py +0 -0
- {anydi-0.32.2 → anydi-0.33.0}/anydi/ext/django/ninja/_operation.py +0 -0
- {anydi-0.32.2 → anydi-0.33.0}/anydi/ext/django/ninja/_signature.py +0 -0
- {anydi-0.32.2 → anydi-0.33.0}/anydi/ext/pydantic_settings.py +0 -0
- {anydi-0.32.2 → anydi-0.33.0}/anydi/ext/pytest_plugin.py +0 -0
- {anydi-0.32.2 → anydi-0.33.0}/anydi/ext/starlette/__init__.py +0 -0
- {anydi-0.32.2 → anydi-0.33.0}/anydi/ext/starlette/middleware.py +0 -0
- {anydi-0.32.2 → anydi-0.33.0}/anydi/py.typed +0 -0
|
@@ -3,12 +3,14 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
import contextlib
|
|
6
|
+
import functools
|
|
6
7
|
import inspect
|
|
7
8
|
import types
|
|
8
9
|
from collections import defaultdict
|
|
9
|
-
from collections.abc import AsyncIterator,
|
|
10
|
+
from collections.abc import AsyncIterator, Iterable, Iterator, Sequence
|
|
10
11
|
from contextvars import ContextVar
|
|
11
12
|
from typing import Any, Callable, TypeVar, cast, overload
|
|
13
|
+
from weakref import WeakKeyDictionary
|
|
12
14
|
|
|
13
15
|
from typing_extensions import ParamSpec, Self, final
|
|
14
16
|
|
|
@@ -19,11 +21,11 @@ from ._context import (
|
|
|
19
21
|
SingletonContext,
|
|
20
22
|
TransientContext,
|
|
21
23
|
)
|
|
22
|
-
from .
|
|
24
|
+
from ._logger import logger
|
|
23
25
|
from ._module import Module, ModuleRegistry
|
|
24
26
|
from ._provider import Provider
|
|
25
27
|
from ._scanner import Scanner
|
|
26
|
-
from ._types import AnyInterface, Interface, Scope
|
|
28
|
+
from ._types import AnyInterface, Interface, Scope, TestInterface, is_marker
|
|
27
29
|
from ._utils import get_full_qualname, get_typed_parameters, is_builtin_type
|
|
28
30
|
|
|
29
31
|
T = TypeVar("T", bound=Any)
|
|
@@ -47,6 +49,7 @@ class Container:
|
|
|
47
49
|
modules: Sequence[Module | type[Module] | Callable[[Container], None] | str]
|
|
48
50
|
| None = None,
|
|
49
51
|
strict: bool = False,
|
|
52
|
+
testing: bool = False,
|
|
50
53
|
) -> None:
|
|
51
54
|
self._providers: dict[type[Any], Provider] = {}
|
|
52
55
|
self._resource_cache: dict[Scope, list[type[Any]]] = defaultdict(list)
|
|
@@ -57,10 +60,13 @@ class Container:
|
|
|
57
60
|
)
|
|
58
61
|
self._override_instances: dict[type[Any], Any] = {}
|
|
59
62
|
self._strict = strict
|
|
63
|
+
self._testing = testing
|
|
60
64
|
self._unresolved_interfaces: set[type[Any]] = set()
|
|
65
|
+
self._inject_cache: WeakKeyDictionary[
|
|
66
|
+
Callable[..., Any], Callable[..., Any]
|
|
67
|
+
] = WeakKeyDictionary()
|
|
61
68
|
|
|
62
69
|
# Components
|
|
63
|
-
self._injector = Injector(self)
|
|
64
70
|
self._modules = ModuleRegistry(self)
|
|
65
71
|
self._scanner = Scanner(self)
|
|
66
72
|
|
|
@@ -79,6 +85,11 @@ class Container:
|
|
|
79
85
|
"""Check if strict mode is enabled."""
|
|
80
86
|
return self._strict
|
|
81
87
|
|
|
88
|
+
@property
|
|
89
|
+
def testing(self) -> bool:
|
|
90
|
+
"""Check if testing mode is enabled."""
|
|
91
|
+
return self._testing
|
|
92
|
+
|
|
82
93
|
@property
|
|
83
94
|
def providers(self) -> dict[type[Any], Provider]:
|
|
84
95
|
"""Get the registered providers."""
|
|
@@ -202,6 +213,10 @@ class Container:
|
|
|
202
213
|
annotation, parent_scope=provider.scope
|
|
203
214
|
)
|
|
204
215
|
except LookupError:
|
|
216
|
+
# Skip unresolved interfaces in non-strict mode
|
|
217
|
+
if not self.strict and parameter.default is not inspect.Parameter.empty:
|
|
218
|
+
continue
|
|
219
|
+
|
|
205
220
|
if provider.scope not in {"singleton", "transient"}:
|
|
206
221
|
self._unresolved_interfaces.add(provider.interface)
|
|
207
222
|
continue
|
|
@@ -225,7 +240,12 @@ class Container:
|
|
|
225
240
|
scopes = set()
|
|
226
241
|
|
|
227
242
|
for parameter in get_typed_parameters(call):
|
|
228
|
-
|
|
243
|
+
try:
|
|
244
|
+
sub_provider = self._get_or_register_provider(parameter.annotation)
|
|
245
|
+
except LookupError:
|
|
246
|
+
if not self.strict and parameter.default is not inspect.Parameter.empty:
|
|
247
|
+
continue
|
|
248
|
+
raise
|
|
229
249
|
scope = sub_provider.scope
|
|
230
250
|
|
|
231
251
|
if scope == "transient":
|
|
@@ -234,7 +254,7 @@ class Container:
|
|
|
234
254
|
|
|
235
255
|
# If all scopes are found, we can return based on priority order
|
|
236
256
|
if {"transient", "request", "singleton"}.issubset(scopes):
|
|
237
|
-
break
|
|
257
|
+
break # pragma: no cover
|
|
238
258
|
|
|
239
259
|
# Determine scope based on priority
|
|
240
260
|
if "request" in scopes:
|
|
@@ -346,7 +366,32 @@ class Container:
|
|
|
346
366
|
|
|
347
367
|
provider = self._get_or_register_provider(interface)
|
|
348
368
|
scoped_context = self._get_scoped_context(provider.scope)
|
|
349
|
-
|
|
369
|
+
instance, created = scoped_context.get_or_create(provider)
|
|
370
|
+
if self.testing and created:
|
|
371
|
+
self._patch_test_resolver(instance)
|
|
372
|
+
return cast(T, instance)
|
|
373
|
+
|
|
374
|
+
def _patch_test_resolver(self, instance: Any) -> None:
|
|
375
|
+
"""Patch the test resolver for the instance."""
|
|
376
|
+
|
|
377
|
+
def _resolver(_self: Any, _name: str) -> Any:
|
|
378
|
+
try:
|
|
379
|
+
test_interfaces = object.__getattribute__(_self, "__test_interfaces__")
|
|
380
|
+
except AttributeError:
|
|
381
|
+
test_interfaces = {
|
|
382
|
+
name: value.interface
|
|
383
|
+
for name, value in object.__getattribute__(
|
|
384
|
+
_self, "__dict__"
|
|
385
|
+
).items()
|
|
386
|
+
if isinstance(value, TestInterface)
|
|
387
|
+
}
|
|
388
|
+
object.__setattr__(_self, "__test_interfaces__", test_interfaces)
|
|
389
|
+
if _name in test_interfaces:
|
|
390
|
+
return self.resolve(test_interfaces[_name])
|
|
391
|
+
return object.__getattribute__(_self, _name)
|
|
392
|
+
|
|
393
|
+
if hasattr(instance, "__class__") and not is_builtin_type(instance.__class__):
|
|
394
|
+
instance.__class__.__getattribute__ = _resolver
|
|
350
395
|
|
|
351
396
|
@overload
|
|
352
397
|
async def aresolve(self, interface: Interface[T]) -> T: ...
|
|
@@ -361,7 +406,10 @@ class Container:
|
|
|
361
406
|
|
|
362
407
|
provider = self._get_or_register_provider(interface)
|
|
363
408
|
scoped_context = self._get_scoped_context(provider.scope)
|
|
364
|
-
|
|
409
|
+
instance, created = await scoped_context.aget_or_create(provider)
|
|
410
|
+
if self.testing and created:
|
|
411
|
+
self._patch_test_resolver(instance)
|
|
412
|
+
return cast(T, instance)
|
|
365
413
|
|
|
366
414
|
def is_resolved(self, interface: AnyInterface) -> bool:
|
|
367
415
|
"""Check if an instance by interface exists."""
|
|
@@ -424,25 +472,92 @@ class Container:
|
|
|
424
472
|
def inject(self) -> Callable[[Callable[P, T]], Callable[P, T]]: ...
|
|
425
473
|
|
|
426
474
|
def inject(
|
|
427
|
-
self, func: Callable[P, T
|
|
428
|
-
) ->
|
|
429
|
-
Callable[[Callable[P, T | Awaitable[T]]], Callable[P, T | Awaitable[T]]]
|
|
430
|
-
| Callable[P, T | Awaitable[T]]
|
|
431
|
-
):
|
|
475
|
+
self, func: Callable[P, T] | None = None
|
|
476
|
+
) -> Callable[[Callable[P, T]], Callable[P, T]] | Callable[P, T]:
|
|
432
477
|
"""Decorator to inject dependencies into a callable."""
|
|
433
478
|
|
|
434
479
|
def decorator(
|
|
435
|
-
|
|
436
|
-
) -> Callable[P, T
|
|
437
|
-
return self.
|
|
480
|
+
call: Callable[P, T],
|
|
481
|
+
) -> Callable[P, T]:
|
|
482
|
+
return self._inject(call)
|
|
438
483
|
|
|
439
484
|
if func is None:
|
|
440
485
|
return decorator
|
|
441
486
|
return decorator(func)
|
|
442
487
|
|
|
488
|
+
def _inject(
|
|
489
|
+
self,
|
|
490
|
+
call: Callable[P, T],
|
|
491
|
+
) -> Callable[P, T]:
|
|
492
|
+
"""Inject dependencies into a callable."""
|
|
493
|
+
if call in self._inject_cache:
|
|
494
|
+
return cast(Callable[P, T], self._inject_cache[call])
|
|
495
|
+
|
|
496
|
+
injected_params = self._get_injected_params(call)
|
|
497
|
+
|
|
498
|
+
if inspect.iscoroutinefunction(call):
|
|
499
|
+
|
|
500
|
+
@functools.wraps(call)
|
|
501
|
+
async def awrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
502
|
+
for name, annotation in injected_params.items():
|
|
503
|
+
kwargs[name] = await self.aresolve(annotation)
|
|
504
|
+
return cast(T, await call(*args, **kwargs))
|
|
505
|
+
|
|
506
|
+
self._inject_cache[call] = awrapper
|
|
507
|
+
|
|
508
|
+
return awrapper # type: ignore[return-value]
|
|
509
|
+
|
|
510
|
+
@functools.wraps(call)
|
|
511
|
+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
512
|
+
for name, annotation in injected_params.items():
|
|
513
|
+
kwargs[name] = self.resolve(annotation)
|
|
514
|
+
return call(*args, **kwargs)
|
|
515
|
+
|
|
516
|
+
self._inject_cache[call] = wrapper
|
|
517
|
+
|
|
518
|
+
return wrapper
|
|
519
|
+
|
|
520
|
+
def _get_injected_params(self, call: Callable[..., Any]) -> dict[str, Any]:
|
|
521
|
+
"""Get the injected parameters of a callable object."""
|
|
522
|
+
injected_params = {}
|
|
523
|
+
for parameter in get_typed_parameters(call):
|
|
524
|
+
if not is_marker(parameter.default):
|
|
525
|
+
continue
|
|
526
|
+
try:
|
|
527
|
+
self._validate_injected_parameter(call, parameter)
|
|
528
|
+
except LookupError as exc:
|
|
529
|
+
if not self.strict:
|
|
530
|
+
logger.debug(
|
|
531
|
+
f"Cannot validate the `{get_full_qualname(call)}` parameter "
|
|
532
|
+
f"`{parameter.name}` with an annotation of "
|
|
533
|
+
f"`{get_full_qualname(parameter.annotation)} due to being "
|
|
534
|
+
"in non-strict mode. It will be validated at the first call."
|
|
535
|
+
)
|
|
536
|
+
else:
|
|
537
|
+
raise exc
|
|
538
|
+
injected_params[parameter.name] = parameter.annotation
|
|
539
|
+
return injected_params
|
|
540
|
+
|
|
541
|
+
def _validate_injected_parameter(
|
|
542
|
+
self, call: Callable[..., Any], parameter: inspect.Parameter
|
|
543
|
+
) -> None:
|
|
544
|
+
"""Validate an injected parameter."""
|
|
545
|
+
if parameter.annotation is inspect.Parameter.empty:
|
|
546
|
+
raise TypeError(
|
|
547
|
+
f"Missing `{get_full_qualname(call)}` parameter "
|
|
548
|
+
f"`{parameter.name}` annotation."
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
if not self.is_registered(parameter.annotation):
|
|
552
|
+
raise LookupError(
|
|
553
|
+
f"`{get_full_qualname(call)}` has an unknown dependency parameter "
|
|
554
|
+
f"`{parameter.name}` with an annotation of "
|
|
555
|
+
f"`{get_full_qualname(parameter.annotation)}`."
|
|
556
|
+
)
|
|
557
|
+
|
|
443
558
|
def run(self, func: Callable[P, T], /, *args: P.args, **kwargs: P.kwargs) -> T:
|
|
444
559
|
"""Run the given function with injected dependencies."""
|
|
445
|
-
return
|
|
560
|
+
return self._inject(func)(*args, **kwargs)
|
|
446
561
|
|
|
447
562
|
def scan(
|
|
448
563
|
self,
|
|
@@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Callable, ClassVar
|
|
|
9
9
|
from typing_extensions import Self, final
|
|
10
10
|
|
|
11
11
|
from ._provider import CallableKind, Provider
|
|
12
|
-
from ._types import AnyInterface, Scope, is_event_type
|
|
12
|
+
from ._types import AnyInterface, Scope, TestInterface, is_event_type
|
|
13
13
|
from ._utils import get_full_qualname, run_async
|
|
14
14
|
|
|
15
15
|
if TYPE_CHECKING:
|
|
@@ -30,12 +30,12 @@ class ScopedContext(abc.ABC):
|
|
|
30
30
|
self._instances[interface] = instance
|
|
31
31
|
|
|
32
32
|
@abc.abstractmethod
|
|
33
|
-
def
|
|
34
|
-
"""Get an instance of a dependency from the scoped context."""
|
|
33
|
+
def get_or_create(self, provider: Provider) -> tuple[Any, bool]:
|
|
34
|
+
"""Get or create an instance of a dependency from the scoped context."""
|
|
35
35
|
|
|
36
36
|
@abc.abstractmethod
|
|
37
|
-
async def
|
|
38
|
-
"""Get an async instance of a dependency from the scoped context."""
|
|
37
|
+
async def aget_or_create(self, provider: Provider) -> tuple[Any, bool]:
|
|
38
|
+
"""Get or create an async instance of a dependency from the scoped context."""
|
|
39
39
|
|
|
40
40
|
def _create_instance(self, provider: Provider) -> Any:
|
|
41
41
|
"""Create an instance using the provider."""
|
|
@@ -44,12 +44,12 @@ class ScopedContext(abc.ABC):
|
|
|
44
44
|
f"The instance for the coroutine provider `{provider}` cannot be "
|
|
45
45
|
"created in synchronous mode."
|
|
46
46
|
)
|
|
47
|
-
args, kwargs = self.
|
|
47
|
+
args, kwargs = self._get_provided_args(provider)
|
|
48
48
|
return provider.call(*args, **kwargs)
|
|
49
49
|
|
|
50
50
|
async def _acreate_instance(self, provider: Provider) -> Any:
|
|
51
51
|
"""Create an instance asynchronously using the provider."""
|
|
52
|
-
args, kwargs = await self.
|
|
52
|
+
args, kwargs = await self._aget_provided_args(provider)
|
|
53
53
|
if provider.kind == CallableKind.COROUTINE:
|
|
54
54
|
return await provider.call(*args, **kwargs)
|
|
55
55
|
return await run_async(provider.call, *args, **kwargs)
|
|
@@ -78,7 +78,7 @@ class ScopedContext(abc.ABC):
|
|
|
78
78
|
"or set in the scoped context."
|
|
79
79
|
)
|
|
80
80
|
|
|
81
|
-
def
|
|
81
|
+
def _get_provided_args(
|
|
82
82
|
self, provider: Provider
|
|
83
83
|
) -> tuple[list[Any], dict[str, Any]]:
|
|
84
84
|
"""Retrieve the arguments for a provider."""
|
|
@@ -91,14 +91,22 @@ class ScopedContext(abc.ABC):
|
|
|
91
91
|
elif parameter.annotation in self._instances:
|
|
92
92
|
instance = self._instances[parameter.annotation]
|
|
93
93
|
else:
|
|
94
|
-
|
|
94
|
+
try:
|
|
95
|
+
instance = self._resolve_parameter(provider, parameter)
|
|
96
|
+
except LookupError:
|
|
97
|
+
if parameter.default is inspect.Parameter.empty:
|
|
98
|
+
raise
|
|
99
|
+
instance = parameter.default
|
|
100
|
+
else:
|
|
101
|
+
if self.container.testing:
|
|
102
|
+
instance = TestInterface(interface=parameter.annotation)
|
|
95
103
|
if parameter.kind == parameter.POSITIONAL_ONLY:
|
|
96
104
|
args.append(instance)
|
|
97
105
|
else:
|
|
98
106
|
kwargs[parameter.name] = instance
|
|
99
107
|
return args, kwargs
|
|
100
108
|
|
|
101
|
-
async def
|
|
109
|
+
async def _aget_provided_args(
|
|
102
110
|
self, provider: Provider
|
|
103
111
|
) -> tuple[list[Any], dict[str, Any]]:
|
|
104
112
|
"""Asynchronously retrieve the arguments for a provider."""
|
|
@@ -111,7 +119,15 @@ class ScopedContext(abc.ABC):
|
|
|
111
119
|
elif parameter.annotation in self._instances:
|
|
112
120
|
instance = self._instances[parameter.annotation]
|
|
113
121
|
else:
|
|
114
|
-
|
|
122
|
+
try:
|
|
123
|
+
instance = await self._aresolve_parameter(provider, parameter)
|
|
124
|
+
except LookupError:
|
|
125
|
+
if parameter.default is inspect.Parameter.empty:
|
|
126
|
+
raise
|
|
127
|
+
instance = parameter.default
|
|
128
|
+
else:
|
|
129
|
+
if self.container.testing:
|
|
130
|
+
instance = TestInterface(interface=parameter.annotation)
|
|
115
131
|
if parameter.kind == parameter.POSITIONAL_ONLY:
|
|
116
132
|
args.append(instance)
|
|
117
133
|
else:
|
|
@@ -128,7 +144,7 @@ class ResourceScopedContext(ScopedContext):
|
|
|
128
144
|
self._stack = contextlib.ExitStack()
|
|
129
145
|
self._async_stack = contextlib.AsyncExitStack()
|
|
130
146
|
|
|
131
|
-
def
|
|
147
|
+
def get_or_create(self, provider: Provider) -> tuple[Any, bool]:
|
|
132
148
|
"""Get an instance of a dependency from the scoped context."""
|
|
133
149
|
instance = self._instances.get(provider.interface)
|
|
134
150
|
if instance is None:
|
|
@@ -143,9 +159,10 @@ class ResourceScopedContext(ScopedContext):
|
|
|
143
159
|
else:
|
|
144
160
|
instance = self._create_instance(provider)
|
|
145
161
|
self._instances[provider.interface] = instance
|
|
146
|
-
|
|
162
|
+
return instance, True
|
|
163
|
+
return instance, False
|
|
147
164
|
|
|
148
|
-
async def
|
|
165
|
+
async def aget_or_create(self, provider: Provider) -> tuple[Any, bool]:
|
|
149
166
|
"""Get an async instance of a dependency from the scoped context."""
|
|
150
167
|
instance = self._instances.get(provider.interface)
|
|
151
168
|
if instance is None:
|
|
@@ -156,7 +173,8 @@ class ResourceScopedContext(ScopedContext):
|
|
|
156
173
|
else:
|
|
157
174
|
instance = await self._acreate_instance(provider)
|
|
158
175
|
self._instances[provider.interface] = instance
|
|
159
|
-
|
|
176
|
+
return instance, True
|
|
177
|
+
return instance, False
|
|
160
178
|
|
|
161
179
|
def has(self, interface: AnyInterface) -> bool:
|
|
162
180
|
"""Check if the scoped context has an instance of the dependency."""
|
|
@@ -172,7 +190,7 @@ class ResourceScopedContext(ScopedContext):
|
|
|
172
190
|
|
|
173
191
|
def _create_resource(self, provider: Provider) -> Any:
|
|
174
192
|
"""Create a resource using the provider."""
|
|
175
|
-
args, kwargs = self.
|
|
193
|
+
args, kwargs = self._get_provided_args(provider)
|
|
176
194
|
cm = contextlib.contextmanager(provider.call)(*args, **kwargs)
|
|
177
195
|
return self._stack.enter_context(cm)
|
|
178
196
|
|
|
@@ -186,7 +204,7 @@ class ResourceScopedContext(ScopedContext):
|
|
|
186
204
|
|
|
187
205
|
async def _acreate_resource(self, provider: Provider) -> Any:
|
|
188
206
|
"""Create a resource asynchronously using the provider."""
|
|
189
|
-
args, kwargs = await self.
|
|
207
|
+
args, kwargs = await self._aget_provided_args(provider)
|
|
190
208
|
cm = contextlib.asynccontextmanager(provider.call)(*args, **kwargs)
|
|
191
209
|
return await self._async_stack.enter_async_context(cm)
|
|
192
210
|
|
|
@@ -287,10 +305,12 @@ class TransientContext(ScopedContext):
|
|
|
287
305
|
|
|
288
306
|
scope = "transient"
|
|
289
307
|
|
|
290
|
-
def
|
|
291
|
-
"""Get an instance of a dependency from the transient context."""
|
|
292
|
-
return self._create_instance(provider)
|
|
308
|
+
def get_or_create(self, provider: Provider) -> tuple[Any, bool]:
|
|
309
|
+
"""Get or create an instance of a dependency from the transient context."""
|
|
310
|
+
return self._create_instance(provider), True
|
|
293
311
|
|
|
294
|
-
async def
|
|
295
|
-
"""
|
|
296
|
-
|
|
312
|
+
async def aget_or_create(self, provider: Provider) -> tuple[Any, bool]:
|
|
313
|
+
"""
|
|
314
|
+
Get or create an async instance of a dependency from the transient context.
|
|
315
|
+
"""
|
|
316
|
+
return await self._acreate_instance(provider), True
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
|
+
from dataclasses import dataclass
|
|
4
5
|
from typing import Annotated, Any, TypeVar, Union
|
|
5
6
|
|
|
6
7
|
from typing_extensions import Literal, Self, TypeAlias
|
|
@@ -35,3 +36,8 @@ class Event:
|
|
|
35
36
|
def is_event_type(obj: Any) -> bool:
|
|
36
37
|
"""Checks if an object is an event type."""
|
|
37
38
|
return inspect.isclass(obj) and issubclass(obj, Event)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass(frozen=True)
|
|
42
|
+
class TestInterface:
|
|
43
|
+
interface: type[Any]
|
|
@@ -61,7 +61,7 @@ def patch_annotated_parameter(parameter: inspect.Parameter) -> inspect.Parameter
|
|
|
61
61
|
|
|
62
62
|
|
|
63
63
|
def patch_call_parameter(
|
|
64
|
-
call: Callable[..., Any], parameter: inspect.Parameter
|
|
64
|
+
container: Container, call: Callable[..., Any], parameter: inspect.Parameter
|
|
65
65
|
) -> None:
|
|
66
66
|
"""Patch a parameter to inject dependencies using AnyDI."""
|
|
67
67
|
parameter = patch_annotated_parameter(parameter)
|
|
@@ -78,6 +78,6 @@ def patch_call_parameter(
|
|
|
78
78
|
"first call because it is running in non-strict mode."
|
|
79
79
|
)
|
|
80
80
|
else:
|
|
81
|
-
container.
|
|
81
|
+
container._validate_injected_parameter(call, parameter) # noqa
|
|
82
82
|
|
|
83
83
|
parameter.default.interface = parameter.annotation
|
|
@@ -41,7 +41,7 @@ def install(app: FastAPI, container: Container) -> None:
|
|
|
41
41
|
if not call:
|
|
42
42
|
continue # pragma: no cover
|
|
43
43
|
for parameter in get_typed_parameters(call):
|
|
44
|
-
patch_call_parameter(call, parameter
|
|
44
|
+
patch_call_parameter(container, call, parameter)
|
|
45
45
|
|
|
46
46
|
|
|
47
47
|
def get_container(request: Request) -> Container:
|
|
@@ -26,7 +26,7 @@ def install(broker: BrokerUsecase[Any, Any], container: Container) -> None:
|
|
|
26
26
|
for handler in _get_broken_handlers(broker):
|
|
27
27
|
call = handler._original_call # noqa
|
|
28
28
|
for parameter in get_typed_parameters(call):
|
|
29
|
-
patch_call_parameter(call, parameter
|
|
29
|
+
patch_call_parameter(container, call, parameter)
|
|
30
30
|
|
|
31
31
|
|
|
32
32
|
def _get_broken_handlers(broker: BrokerUsecase[Any, Any]) -> list[Any]:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "anydi"
|
|
3
|
-
version = "0.
|
|
3
|
+
version = "0.33.0"
|
|
4
4
|
description = "Dependency Injection library"
|
|
5
5
|
authors = ["Anton Ruhlov <antonruhlov@gmail.com>"]
|
|
6
6
|
license = "MIT"
|
|
@@ -45,8 +45,8 @@ docs = ["mkdocs", "mkdocs-material"]
|
|
|
45
45
|
async = ["anyio"]
|
|
46
46
|
|
|
47
47
|
[tool.poetry.group.dev.dependencies]
|
|
48
|
-
mypy = "^1.
|
|
49
|
-
ruff = "^0.
|
|
48
|
+
mypy = { version = "^1.14.0", extras = ["faster-cache"] }
|
|
49
|
+
ruff = "^0.8.4"
|
|
50
50
|
pytest = "^8.3.1"
|
|
51
51
|
pytest-cov = "^5.0.0"
|
|
52
52
|
fastapi = "^0.100.0"
|
anydi-0.32.2/anydi/_injector.py
DELETED
|
@@ -1,94 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
import inspect
|
|
4
|
-
from collections.abc import Awaitable
|
|
5
|
-
from functools import wraps
|
|
6
|
-
from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, cast
|
|
7
|
-
|
|
8
|
-
from typing_extensions import ParamSpec
|
|
9
|
-
|
|
10
|
-
from ._logger import logger
|
|
11
|
-
from ._types import is_marker
|
|
12
|
-
from ._utils import get_full_qualname, get_typed_parameters
|
|
13
|
-
|
|
14
|
-
if TYPE_CHECKING:
|
|
15
|
-
from ._container import Container
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
T = TypeVar("T", bound=Any)
|
|
19
|
-
P = ParamSpec("P")
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
class Injector:
|
|
23
|
-
def __init__(self, container: Container) -> None:
|
|
24
|
-
self.container = container
|
|
25
|
-
|
|
26
|
-
def inject(
|
|
27
|
-
self,
|
|
28
|
-
call: Callable[P, T | Awaitable[T]],
|
|
29
|
-
) -> Callable[P, T | Awaitable[T]]:
|
|
30
|
-
# Check if the inner callable has already been wrapped
|
|
31
|
-
if hasattr(call, "__inject_wrapper__"):
|
|
32
|
-
return cast(Callable[P, Union[T, Awaitable[T]]], call.__inject_wrapper__)
|
|
33
|
-
|
|
34
|
-
injected_params = self._get_injected_params(call)
|
|
35
|
-
|
|
36
|
-
if inspect.iscoroutinefunction(call):
|
|
37
|
-
|
|
38
|
-
@wraps(call)
|
|
39
|
-
async def awrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
40
|
-
for name, annotation in injected_params.items():
|
|
41
|
-
kwargs[name] = await self.container.aresolve(annotation)
|
|
42
|
-
return cast(T, await call(*args, **kwargs))
|
|
43
|
-
|
|
44
|
-
call.__inject_wrapper__ = awrapper # type: ignore[attr-defined]
|
|
45
|
-
|
|
46
|
-
return awrapper
|
|
47
|
-
|
|
48
|
-
@wraps(call)
|
|
49
|
-
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
50
|
-
for name, annotation in injected_params.items():
|
|
51
|
-
kwargs[name] = self.container.resolve(annotation)
|
|
52
|
-
return cast(T, call(*args, **kwargs))
|
|
53
|
-
|
|
54
|
-
call.__inject_wrapper__ = wrapper # type: ignore[attr-defined]
|
|
55
|
-
|
|
56
|
-
return wrapper
|
|
57
|
-
|
|
58
|
-
def _get_injected_params(self, call: Callable[..., Any]) -> dict[str, Any]:
|
|
59
|
-
"""Get the injected parameters of a callable object."""
|
|
60
|
-
injected_params = {}
|
|
61
|
-
for parameter in get_typed_parameters(call):
|
|
62
|
-
if not is_marker(parameter.default):
|
|
63
|
-
continue
|
|
64
|
-
try:
|
|
65
|
-
self._validate_injected_parameter(call, parameter)
|
|
66
|
-
except LookupError as exc:
|
|
67
|
-
if not self.container.strict:
|
|
68
|
-
logger.debug(
|
|
69
|
-
f"Cannot validate the `{get_full_qualname(call)}` parameter "
|
|
70
|
-
f"`{parameter.name}` with an annotation of "
|
|
71
|
-
f"`{get_full_qualname(parameter.annotation)} due to being "
|
|
72
|
-
"in non-strict mode. It will be validated at the first call."
|
|
73
|
-
)
|
|
74
|
-
else:
|
|
75
|
-
raise exc
|
|
76
|
-
injected_params[parameter.name] = parameter.annotation
|
|
77
|
-
return injected_params
|
|
78
|
-
|
|
79
|
-
def _validate_injected_parameter(
|
|
80
|
-
self, call: Callable[..., Any], parameter: inspect.Parameter
|
|
81
|
-
) -> None:
|
|
82
|
-
"""Validate an injected parameter."""
|
|
83
|
-
if parameter.annotation is inspect.Parameter.empty:
|
|
84
|
-
raise TypeError(
|
|
85
|
-
f"Missing `{get_full_qualname(call)}` parameter "
|
|
86
|
-
f"`{parameter.name}` annotation."
|
|
87
|
-
)
|
|
88
|
-
|
|
89
|
-
if not self.container.is_registered(parameter.annotation):
|
|
90
|
-
raise LookupError(
|
|
91
|
-
f"`{get_full_qualname(call)}` has an unknown dependency parameter "
|
|
92
|
-
f"`{parameter.name}` with an annotation of "
|
|
93
|
-
f"`{get_full_qualname(parameter.annotation)}`."
|
|
94
|
-
)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|