anydi 0.40.0__py3-none-any.whl → 0.42.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- anydi/__init__.py +8 -13
- anydi/_async.py +50 -0
- anydi/_container.py +59 -371
- anydi/_context.py +2 -2
- anydi/_decorators.py +80 -0
- anydi/_module.py +76 -0
- anydi/_provider.py +81 -0
- anydi/_scan.py +110 -0
- anydi/_scope.py +9 -0
- anydi/{_utils.py → _typing.py} +49 -69
- anydi/ext/_utils.py +4 -4
- anydi/ext/django/_utils.py +3 -3
- anydi/ext/django/apps.py +0 -3
- anydi/ext/django/ninja/__init__.py +3 -3
- anydi/ext/django/ninja/_operation.py +1 -1
- anydi/ext/django/ninja/_signature.py +1 -1
- anydi/ext/fastapi.py +2 -2
- anydi/ext/faststream.py +2 -2
- anydi/ext/pytest_plugin.py +7 -3
- anydi/ext/starlette/middleware.py +1 -1
- anydi/testing.py +172 -0
- {anydi-0.40.0.dist-info → anydi-0.42.0.dist-info}/METADATA +2 -2
- anydi-0.42.0.dist-info/RECORD +34 -0
- anydi/_types.py +0 -145
- anydi-0.40.0.dist-info/RECORD +0 -28
- {anydi-0.40.0.dist-info → anydi-0.42.0.dist-info}/WHEEL +0 -0
- {anydi-0.40.0.dist-info → anydi-0.42.0.dist-info}/entry_points.txt +0 -0
- {anydi-0.40.0.dist-info → anydi-0.42.0.dist-info}/licenses/LICENSE +0 -0
anydi/_container.py
CHANGED
|
@@ -4,110 +4,74 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
import contextlib
|
|
6
6
|
import functools
|
|
7
|
-
import importlib
|
|
8
7
|
import inspect
|
|
9
8
|
import logging
|
|
10
|
-
import pkgutil
|
|
11
9
|
import types
|
|
12
10
|
import uuid
|
|
13
11
|
from collections import defaultdict
|
|
14
|
-
from collections.abc import AsyncIterator, Iterable, Iterator
|
|
12
|
+
from collections.abc import AsyncIterator, Iterable, Iterator
|
|
15
13
|
from contextvars import ContextVar
|
|
16
|
-
from types import ModuleType
|
|
17
14
|
from typing import Annotated, Any, Callable, TypeVar, cast, overload
|
|
18
15
|
|
|
19
|
-
from typing_extensions import
|
|
16
|
+
from typing_extensions import ParamSpec, Self, get_args, get_origin
|
|
20
17
|
|
|
18
|
+
from ._async import run_sync
|
|
21
19
|
from ._context import InstanceContext
|
|
22
|
-
from .
|
|
23
|
-
|
|
24
|
-
AnyInterface,
|
|
25
|
-
Event,
|
|
26
|
-
InjectableDecoratorArgs,
|
|
27
|
-
InstanceProxy,
|
|
20
|
+
from ._module import ModuleDef, ModuleRegistrar
|
|
21
|
+
from ._provider import (
|
|
28
22
|
Provider,
|
|
29
|
-
|
|
30
|
-
ProviderDecoratorArgs,
|
|
23
|
+
ProviderDef,
|
|
31
24
|
ProviderKind,
|
|
32
|
-
ScannedDependency,
|
|
33
|
-
Scope,
|
|
34
|
-
is_event_type,
|
|
35
|
-
is_marker,
|
|
36
25
|
)
|
|
37
|
-
from .
|
|
38
|
-
|
|
26
|
+
from ._scan import PackageOrIterable, Scanner
|
|
27
|
+
from ._scope import ALLOWED_SCOPES, Scope
|
|
28
|
+
from ._typing import (
|
|
29
|
+
NOT_SET,
|
|
30
|
+
Event,
|
|
39
31
|
get_typed_annotation,
|
|
40
32
|
get_typed_parameters,
|
|
41
|
-
import_string,
|
|
42
33
|
is_async_context_manager,
|
|
43
34
|
is_builtin_type,
|
|
44
35
|
is_context_manager,
|
|
36
|
+
is_event_type,
|
|
45
37
|
is_iterator_type,
|
|
38
|
+
is_marker,
|
|
46
39
|
is_none_type,
|
|
47
|
-
|
|
40
|
+
type_repr,
|
|
48
41
|
)
|
|
49
42
|
|
|
50
43
|
T = TypeVar("T", bound=Any)
|
|
51
|
-
M = TypeVar("M", bound="Module")
|
|
52
44
|
P = ParamSpec("P")
|
|
53
45
|
|
|
54
|
-
ALLOWED_SCOPES: dict[Scope, list[Scope]] = {
|
|
55
|
-
"singleton": ["singleton"],
|
|
56
|
-
"request": ["request", "singleton"],
|
|
57
|
-
"transient": ["transient", "request", "singleton"],
|
|
58
|
-
}
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
class ModuleMeta(type):
|
|
62
|
-
"""A metaclass used for the Module base class."""
|
|
63
|
-
|
|
64
|
-
def __new__(cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any]) -> Any:
|
|
65
|
-
attrs["providers"] = [
|
|
66
|
-
(name, getattr(value, "__provider__"))
|
|
67
|
-
for name, value in attrs.items()
|
|
68
|
-
if hasattr(value, "__provider__")
|
|
69
|
-
]
|
|
70
|
-
return super().__new__(cls, name, bases, attrs)
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
class Module(metaclass=ModuleMeta):
|
|
74
|
-
"""A base class for defining AnyDI modules."""
|
|
75
|
-
|
|
76
|
-
providers: list[tuple[str, ProviderDecoratorArgs]]
|
|
77
46
|
|
|
78
|
-
def configure(self, container: Container) -> None:
|
|
79
|
-
"""Configure the AnyDI container with providers and their dependencies."""
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
@final
|
|
83
47
|
class Container:
|
|
84
48
|
"""AnyDI is a dependency injection container."""
|
|
85
49
|
|
|
86
50
|
def __init__(
|
|
87
51
|
self,
|
|
88
52
|
*,
|
|
89
|
-
providers:
|
|
90
|
-
modules:
|
|
91
|
-
| None = None,
|
|
53
|
+
providers: Iterable[ProviderDef] | None = None,
|
|
54
|
+
modules: Iterable[ModuleDef] | None = None,
|
|
92
55
|
strict: bool = False,
|
|
93
56
|
default_scope: Scope = "transient",
|
|
94
|
-
testing: bool = False,
|
|
95
57
|
logger: logging.Logger | None = None,
|
|
96
58
|
) -> None:
|
|
97
59
|
self._providers: dict[Any, Provider] = {}
|
|
98
60
|
self._strict = strict
|
|
99
61
|
self._default_scope: Scope = default_scope
|
|
100
|
-
self._testing = testing
|
|
101
62
|
self._logger = logger or logging.getLogger(__name__)
|
|
102
63
|
self._resources: dict[str, list[Any]] = defaultdict(list)
|
|
103
64
|
self._singleton_context = InstanceContext()
|
|
104
65
|
self._request_context_var: ContextVar[InstanceContext | None] = ContextVar(
|
|
105
66
|
"request_context", default=None
|
|
106
67
|
)
|
|
107
|
-
self._override_instances: dict[Any, Any] = {}
|
|
108
68
|
self._unresolved_interfaces: set[Any] = set()
|
|
109
69
|
self._inject_cache: dict[Callable[..., Any], Callable[..., Any]] = {}
|
|
110
70
|
|
|
71
|
+
# Components
|
|
72
|
+
self._modules = ModuleRegistrar(self)
|
|
73
|
+
self._scanner = Scanner(self)
|
|
74
|
+
|
|
111
75
|
# Register providers
|
|
112
76
|
providers = providers or []
|
|
113
77
|
for provider in providers:
|
|
@@ -136,11 +100,6 @@ class Container:
|
|
|
136
100
|
"""Get the default scope."""
|
|
137
101
|
return self._default_scope
|
|
138
102
|
|
|
139
|
-
@property
|
|
140
|
-
def testing(self) -> bool:
|
|
141
|
-
"""Check if testing mode is enabled."""
|
|
142
|
-
return self._testing
|
|
143
|
-
|
|
144
103
|
@property
|
|
145
104
|
def providers(self) -> dict[type[Any], Provider]:
|
|
146
105
|
"""Get the registered providers."""
|
|
@@ -258,7 +217,7 @@ class Container:
|
|
|
258
217
|
|
|
259
218
|
def register(
|
|
260
219
|
self,
|
|
261
|
-
interface:
|
|
220
|
+
interface: Any,
|
|
262
221
|
call: Callable[..., Any],
|
|
263
222
|
*,
|
|
264
223
|
scope: Scope,
|
|
@@ -267,16 +226,15 @@ class Container:
|
|
|
267
226
|
"""Register a provider for the specified interface."""
|
|
268
227
|
return self._register_provider(call, scope, interface, override)
|
|
269
228
|
|
|
270
|
-
def is_registered(self, interface:
|
|
229
|
+
def is_registered(self, interface: Any) -> bool:
|
|
271
230
|
"""Check if a provider is registered for the specified interface."""
|
|
272
231
|
return interface in self._providers
|
|
273
232
|
|
|
274
|
-
def unregister(self, interface:
|
|
233
|
+
def unregister(self, interface: Any) -> None:
|
|
275
234
|
"""Unregister a provider by interface."""
|
|
276
235
|
if not self.is_registered(interface):
|
|
277
236
|
raise LookupError(
|
|
278
|
-
"The provider interface "
|
|
279
|
-
f"`{get_full_qualname(interface)}` not registered."
|
|
237
|
+
f"The provider interface `{type_repr(interface)}` not registered."
|
|
280
238
|
)
|
|
281
239
|
|
|
282
240
|
provider = self._get_provider(interface)
|
|
@@ -314,7 +272,7 @@ class Container:
|
|
|
314
272
|
**defaults: Any,
|
|
315
273
|
) -> Provider:
|
|
316
274
|
"""Register a provider with the specified scope."""
|
|
317
|
-
name =
|
|
275
|
+
name = type_repr(call)
|
|
318
276
|
kind = ProviderKind.from_call(call)
|
|
319
277
|
detected_scope = scope
|
|
320
278
|
|
|
@@ -359,8 +317,7 @@ class Container:
|
|
|
359
317
|
# Check for existing provider
|
|
360
318
|
if interface in self._providers and not override:
|
|
361
319
|
raise LookupError(
|
|
362
|
-
f"The provider interface `{
|
|
363
|
-
"already registered."
|
|
320
|
+
f"The provider interface `{type_repr(interface)}` already registered."
|
|
364
321
|
)
|
|
365
322
|
|
|
366
323
|
unresolved_parameter = None
|
|
@@ -421,7 +378,7 @@ class Container:
|
|
|
421
378
|
else:
|
|
422
379
|
raise LookupError(
|
|
423
380
|
f"The provider `{name}` depends on `{unresolved_parameter.name}` "
|
|
424
|
-
f"of type `{
|
|
381
|
+
f"of type `{type_repr(unresolved_parameter.annotation)}`, "
|
|
425
382
|
"which has not been registered or set. To resolve this, ensure "
|
|
426
383
|
f"that `{unresolved_parameter.name}` is registered before "
|
|
427
384
|
f"attempting to use it."
|
|
@@ -464,19 +421,19 @@ class Container:
|
|
|
464
421
|
"with a transient scope, which is not allowed."
|
|
465
422
|
)
|
|
466
423
|
|
|
467
|
-
def _get_provider(self, interface:
|
|
424
|
+
def _get_provider(self, interface: Any) -> Provider:
|
|
468
425
|
"""Get provider by interface."""
|
|
469
426
|
try:
|
|
470
427
|
return self._providers[interface]
|
|
471
428
|
except KeyError as exc:
|
|
472
429
|
raise LookupError(
|
|
473
|
-
f"The provider interface for `{
|
|
430
|
+
f"The provider interface for `{type_repr(interface)}` has "
|
|
474
431
|
"not been registered. Please ensure that the provider interface is "
|
|
475
432
|
"properly registered before attempting to use it."
|
|
476
433
|
) from exc
|
|
477
434
|
|
|
478
435
|
def _get_or_register_provider(
|
|
479
|
-
self, interface:
|
|
436
|
+
self, interface: Any, parent_scope: Scope | None, /, **defaults: Any
|
|
480
437
|
) -> Provider:
|
|
481
438
|
"""Get or register a provider by interface."""
|
|
482
439
|
try:
|
|
@@ -551,7 +508,7 @@ class Container:
|
|
|
551
508
|
"""Create an instance by interface asynchronously."""
|
|
552
509
|
return await self._aresolve_or_create(interface, True, **defaults)
|
|
553
510
|
|
|
554
|
-
def is_resolved(self, interface:
|
|
511
|
+
def is_resolved(self, interface: Any) -> bool:
|
|
555
512
|
"""Check if an instance by interface exists."""
|
|
556
513
|
try:
|
|
557
514
|
provider = self._get_provider(interface)
|
|
@@ -562,7 +519,7 @@ class Container:
|
|
|
562
519
|
context = self._get_instance_context(provider.scope)
|
|
563
520
|
return interface in context
|
|
564
521
|
|
|
565
|
-
def release(self, interface:
|
|
522
|
+
def release(self, interface: Any) -> None:
|
|
566
523
|
"""Release an instance by interface."""
|
|
567
524
|
provider = self._get_provider(interface)
|
|
568
525
|
if provider.scope == "transient":
|
|
@@ -597,9 +554,6 @@ class Container:
|
|
|
597
554
|
else self._create_instance(provider, context, **defaults)
|
|
598
555
|
)
|
|
599
556
|
|
|
600
|
-
if self.testing:
|
|
601
|
-
instance = self._patch_test_resolver(provider.interface, instance)
|
|
602
|
-
|
|
603
557
|
return cast(T, instance)
|
|
604
558
|
|
|
605
559
|
async def _aresolve_or_create(
|
|
@@ -618,9 +572,6 @@ class Container:
|
|
|
618
572
|
else await self._acreate_instance(provider, context, **defaults)
|
|
619
573
|
)
|
|
620
574
|
|
|
621
|
-
if self.testing:
|
|
622
|
-
instance = self._patch_test_resolver(provider.interface, instance)
|
|
623
|
-
|
|
624
575
|
return cast(T, instance)
|
|
625
576
|
|
|
626
577
|
def _get_or_create_instance(
|
|
@@ -695,9 +646,9 @@ class Container:
|
|
|
695
646
|
cm = contextlib.contextmanager(provider.call)(**provider_kwargs)
|
|
696
647
|
return context.enter(cm)
|
|
697
648
|
|
|
698
|
-
return await
|
|
649
|
+
return await run_sync(_create)
|
|
699
650
|
|
|
700
|
-
instance = await
|
|
651
|
+
instance = await run_sync(provider.call, **provider_kwargs)
|
|
701
652
|
if (
|
|
702
653
|
context is not None
|
|
703
654
|
and provider.is_class
|
|
@@ -712,7 +663,7 @@ class Container:
|
|
|
712
663
|
"""Retrieve the arguments for a provider."""
|
|
713
664
|
provided_kwargs = {}
|
|
714
665
|
for parameter in provider.parameters:
|
|
715
|
-
instance = self._get_provider_instance(
|
|
666
|
+
instance, _ = self._get_provider_instance(
|
|
716
667
|
provider, parameter, context, **defaults
|
|
717
668
|
)
|
|
718
669
|
provided_kwargs[parameter.name] = instance
|
|
@@ -725,12 +676,12 @@ class Container:
|
|
|
725
676
|
context: InstanceContext | None,
|
|
726
677
|
/,
|
|
727
678
|
**defaults: Any,
|
|
728
|
-
) -> Any:
|
|
679
|
+
) -> tuple[Any, bool]:
|
|
729
680
|
"""Retrieve an instance of a dependency from the scoped context."""
|
|
730
681
|
|
|
731
682
|
# Try to get instance from defaults
|
|
732
683
|
if parameter.name in defaults:
|
|
733
|
-
return defaults[parameter.name]
|
|
684
|
+
return defaults[parameter.name], True
|
|
734
685
|
|
|
735
686
|
# Try to get instance from context
|
|
736
687
|
elif context and parameter.annotation in context:
|
|
@@ -743,12 +694,8 @@ class Container:
|
|
|
743
694
|
except LookupError:
|
|
744
695
|
if parameter.default is inspect.Parameter.empty:
|
|
745
696
|
raise
|
|
746
|
-
return parameter.default
|
|
747
|
-
|
|
748
|
-
# Wrap the instance in a proxy for testing
|
|
749
|
-
if self.testing:
|
|
750
|
-
return InstanceProxy(instance, interface=parameter.annotation)
|
|
751
|
-
return instance
|
|
697
|
+
return parameter.default, True
|
|
698
|
+
return instance, False
|
|
752
699
|
|
|
753
700
|
async def _aget_provided_kwargs(
|
|
754
701
|
self, provider: Provider, context: InstanceContext | None, /, **defaults: Any
|
|
@@ -756,7 +703,7 @@ class Container:
|
|
|
756
703
|
"""Asynchronously retrieve the arguments for a provider."""
|
|
757
704
|
provided_kwargs = {}
|
|
758
705
|
for parameter in provider.parameters:
|
|
759
|
-
instance = await self._aget_provider_instance(
|
|
706
|
+
instance, _ = await self._aget_provider_instance(
|
|
760
707
|
provider, parameter, context, **defaults
|
|
761
708
|
)
|
|
762
709
|
provided_kwargs[parameter.name] = instance
|
|
@@ -769,12 +716,12 @@ class Container:
|
|
|
769
716
|
context: InstanceContext | None,
|
|
770
717
|
/,
|
|
771
718
|
**defaults: Any,
|
|
772
|
-
) -> Any:
|
|
719
|
+
) -> tuple[Any, bool]:
|
|
773
720
|
"""Asynchronously retrieve an instance of a dependency from the context."""
|
|
774
721
|
|
|
775
722
|
# Try to get instance from defaults
|
|
776
723
|
if parameter.name in defaults:
|
|
777
|
-
return defaults[parameter.name]
|
|
724
|
+
return defaults[parameter.name], True
|
|
778
725
|
|
|
779
726
|
# Try to get instance from context
|
|
780
727
|
elif context and parameter.annotation in context:
|
|
@@ -787,12 +734,8 @@ class Container:
|
|
|
787
734
|
except LookupError:
|
|
788
735
|
if parameter.default is inspect.Parameter.empty:
|
|
789
736
|
raise
|
|
790
|
-
return parameter.default
|
|
791
|
-
|
|
792
|
-
# Wrap the instance in a proxy for testing
|
|
793
|
-
if self.testing:
|
|
794
|
-
return InstanceProxy(instance, interface=parameter.annotation)
|
|
795
|
-
return instance
|
|
737
|
+
return parameter.default, True
|
|
738
|
+
return instance, False
|
|
796
739
|
|
|
797
740
|
def _resolve_parameter(
|
|
798
741
|
self, provider: Provider, parameter: inspect.Parameter
|
|
@@ -813,83 +756,11 @@ class Container:
|
|
|
813
756
|
if parameter.annotation in self._unresolved_interfaces:
|
|
814
757
|
raise LookupError(
|
|
815
758
|
f"You are attempting to get the parameter `{parameter.name}` with the "
|
|
816
|
-
f"annotation `{
|
|
817
|
-
f"dependency into `{
|
|
759
|
+
f"annotation `{type_repr(parameter.annotation)}` as a "
|
|
760
|
+
f"dependency into `{type_repr(provider.call)}` which is "
|
|
818
761
|
"not registered or set in the scoped context."
|
|
819
762
|
)
|
|
820
763
|
|
|
821
|
-
@contextlib.contextmanager
|
|
822
|
-
def override(self, interface: AnyInterface, instance: Any) -> Iterator[None]:
|
|
823
|
-
"""
|
|
824
|
-
Override the provider for the specified interface with a specific instance.
|
|
825
|
-
"""
|
|
826
|
-
if not self.testing:
|
|
827
|
-
raise RuntimeError(
|
|
828
|
-
"The `override` method can only be used in testing mode."
|
|
829
|
-
)
|
|
830
|
-
if not self.is_registered(interface) and self.strict:
|
|
831
|
-
raise LookupError(
|
|
832
|
-
f"The provider interface `{get_full_qualname(interface)}` "
|
|
833
|
-
"not registered."
|
|
834
|
-
)
|
|
835
|
-
self._override_instances[interface] = instance
|
|
836
|
-
try:
|
|
837
|
-
yield
|
|
838
|
-
finally:
|
|
839
|
-
self._override_instances.pop(interface, None)
|
|
840
|
-
|
|
841
|
-
############################
|
|
842
|
-
# Testing Methods
|
|
843
|
-
############################
|
|
844
|
-
|
|
845
|
-
def _patch_test_resolver(self, interface: type[Any], instance: Any) -> Any:
|
|
846
|
-
"""Patch the test resolver for the instance."""
|
|
847
|
-
if interface in self._override_instances:
|
|
848
|
-
return self._override_instances[interface]
|
|
849
|
-
|
|
850
|
-
if not hasattr(instance, "__dict__") or hasattr(
|
|
851
|
-
instance, "__resolver_getter__"
|
|
852
|
-
):
|
|
853
|
-
return instance
|
|
854
|
-
|
|
855
|
-
wrapped = {
|
|
856
|
-
name: value.interface
|
|
857
|
-
for name, value in instance.__dict__.items()
|
|
858
|
-
if isinstance(value, InstanceProxy)
|
|
859
|
-
}
|
|
860
|
-
|
|
861
|
-
def __resolver_getter__(name: str) -> Any:
|
|
862
|
-
if name in wrapped:
|
|
863
|
-
_interface = wrapped[name]
|
|
864
|
-
# Resolve the dependency if it's wrapped
|
|
865
|
-
return self.resolve(_interface)
|
|
866
|
-
raise LookupError
|
|
867
|
-
|
|
868
|
-
# Attach the resolver getter to the instance
|
|
869
|
-
instance.__resolver_getter__ = __resolver_getter__
|
|
870
|
-
|
|
871
|
-
if not hasattr(instance.__class__, "__getattribute_patched__"):
|
|
872
|
-
|
|
873
|
-
def __getattribute__(_self: Any, name: str) -> Any:
|
|
874
|
-
# Skip the resolver getter
|
|
875
|
-
if name in {"__resolver_getter__", "__class__"}:
|
|
876
|
-
return object.__getattribute__(_self, name)
|
|
877
|
-
|
|
878
|
-
if hasattr(_self, "__resolver_getter__"):
|
|
879
|
-
try:
|
|
880
|
-
return _self.__resolver_getter__(name)
|
|
881
|
-
except LookupError:
|
|
882
|
-
pass
|
|
883
|
-
|
|
884
|
-
# Fall back to default behavior
|
|
885
|
-
return object.__getattribute__(_self, name)
|
|
886
|
-
|
|
887
|
-
# Apply the patched resolver if wrapped attributes exist
|
|
888
|
-
instance.__class__.__getattribute__ = __getattribute__
|
|
889
|
-
instance.__class__.__getattribute_patched__ = True
|
|
890
|
-
|
|
891
|
-
return instance
|
|
892
|
-
|
|
893
764
|
############################
|
|
894
765
|
# Injector Methods
|
|
895
766
|
############################
|
|
@@ -923,6 +794,9 @@ class Container:
|
|
|
923
794
|
|
|
924
795
|
injected_params = self._get_injected_params(call)
|
|
925
796
|
|
|
797
|
+
if not injected_params:
|
|
798
|
+
return call
|
|
799
|
+
|
|
926
800
|
if inspect.iscoroutinefunction(call):
|
|
927
801
|
|
|
928
802
|
@functools.wraps(call)
|
|
@@ -933,7 +807,7 @@ class Container:
|
|
|
933
807
|
|
|
934
808
|
self._inject_cache[call] = awrapper
|
|
935
809
|
|
|
936
|
-
return awrapper # type: ignore
|
|
810
|
+
return awrapper # type: ignore
|
|
937
811
|
|
|
938
812
|
@functools.wraps(call)
|
|
939
813
|
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
@@ -956,9 +830,9 @@ class Container:
|
|
|
956
830
|
except LookupError as exc:
|
|
957
831
|
if not self.strict:
|
|
958
832
|
self.logger.debug(
|
|
959
|
-
f"Cannot validate the `{
|
|
833
|
+
f"Cannot validate the `{type_repr(call)}` parameter "
|
|
960
834
|
f"`{parameter.name}` with an annotation of "
|
|
961
|
-
f"`{
|
|
835
|
+
f"`{type_repr(parameter.annotation)} due to being "
|
|
962
836
|
"in non-strict mode. It will be validated at the first call."
|
|
963
837
|
)
|
|
964
838
|
else:
|
|
@@ -972,215 +846,29 @@ class Container:
|
|
|
972
846
|
"""Validate an injected parameter."""
|
|
973
847
|
if parameter.annotation is inspect.Parameter.empty:
|
|
974
848
|
raise TypeError(
|
|
975
|
-
f"Missing `{
|
|
976
|
-
f"`{parameter.name}` annotation."
|
|
849
|
+
f"Missing `{type_repr(call)}` parameter `{parameter.name}` annotation."
|
|
977
850
|
)
|
|
978
851
|
|
|
979
852
|
if not self.is_registered(parameter.annotation):
|
|
980
853
|
raise LookupError(
|
|
981
|
-
f"`{
|
|
854
|
+
f"`{type_repr(call)}` has an unknown dependency parameter "
|
|
982
855
|
f"`{parameter.name}` with an annotation of "
|
|
983
|
-
f"`{
|
|
856
|
+
f"`{type_repr(parameter.annotation)}`."
|
|
984
857
|
)
|
|
985
858
|
|
|
986
859
|
############################
|
|
987
860
|
# Module Methods
|
|
988
861
|
############################
|
|
989
862
|
|
|
990
|
-
def register_module(
|
|
991
|
-
self, module: Module | type[Module] | Callable[[Container], None] | str
|
|
992
|
-
) -> None:
|
|
863
|
+
def register_module(self, module: ModuleDef) -> None:
|
|
993
864
|
"""Register a module as a callable, module type, or module instance."""
|
|
994
|
-
|
|
995
|
-
if inspect.isfunction(module):
|
|
996
|
-
module(self)
|
|
997
|
-
return
|
|
998
|
-
|
|
999
|
-
# Module path
|
|
1000
|
-
if isinstance(module, str):
|
|
1001
|
-
module = import_string(module)
|
|
1002
|
-
|
|
1003
|
-
# Class based Module or Module type
|
|
1004
|
-
if inspect.isclass(module) and issubclass(module, Module):
|
|
1005
|
-
module = module()
|
|
1006
|
-
|
|
1007
|
-
if isinstance(module, Module):
|
|
1008
|
-
module.configure(self)
|
|
1009
|
-
for provider_name, decorator_args in module.providers:
|
|
1010
|
-
obj = getattr(module, provider_name)
|
|
1011
|
-
self.provider(
|
|
1012
|
-
scope=decorator_args.scope,
|
|
1013
|
-
override=decorator_args.override,
|
|
1014
|
-
)(obj)
|
|
1015
|
-
else:
|
|
1016
|
-
raise TypeError(
|
|
1017
|
-
"The module must be a callable, a module type, or a module instance."
|
|
1018
|
-
)
|
|
865
|
+
self._modules.register(module)
|
|
1019
866
|
|
|
1020
867
|
############################
|
|
1021
868
|
# Scanner Methods
|
|
1022
869
|
############################
|
|
1023
870
|
|
|
1024
871
|
def scan(
|
|
1025
|
-
self,
|
|
1026
|
-
/,
|
|
1027
|
-
packages: ModuleType | str | Iterable[ModuleType | str],
|
|
1028
|
-
*,
|
|
1029
|
-
tags: Iterable[str] | None = None,
|
|
872
|
+
self, /, packages: PackageOrIterable, *, tags: Iterable[str] | None = None
|
|
1030
873
|
) -> None:
|
|
1031
|
-
|
|
1032
|
-
dependencies: list[ScannedDependency] = []
|
|
1033
|
-
|
|
1034
|
-
if isinstance(packages, ModuleType | str):
|
|
1035
|
-
scan_packages: Iterable[ModuleType | str] = [packages]
|
|
1036
|
-
else:
|
|
1037
|
-
scan_packages = packages
|
|
1038
|
-
|
|
1039
|
-
for package in scan_packages:
|
|
1040
|
-
dependencies.extend(self._scan_package(package, tags=tags))
|
|
1041
|
-
|
|
1042
|
-
for dependency in dependencies:
|
|
1043
|
-
decorator = self.inject()(dependency.member)
|
|
1044
|
-
setattr(dependency.module, dependency.member.__name__, decorator)
|
|
1045
|
-
|
|
1046
|
-
def _scan_package(
|
|
1047
|
-
self,
|
|
1048
|
-
package: ModuleType | str,
|
|
1049
|
-
*,
|
|
1050
|
-
tags: Iterable[str] | None = None,
|
|
1051
|
-
) -> list[ScannedDependency]:
|
|
1052
|
-
"""Scan a package or module for decorated members."""
|
|
1053
|
-
tags = tags or []
|
|
1054
|
-
if isinstance(package, str):
|
|
1055
|
-
package = importlib.import_module(package)
|
|
1056
|
-
|
|
1057
|
-
package_path = getattr(package, "__path__", None)
|
|
1058
|
-
|
|
1059
|
-
if not package_path:
|
|
1060
|
-
return self._scan_module(package, tags=tags)
|
|
1061
|
-
|
|
1062
|
-
dependencies: list[ScannedDependency] = []
|
|
1063
|
-
|
|
1064
|
-
for module_info in pkgutil.walk_packages(
|
|
1065
|
-
path=package_path, prefix=package.__name__ + "."
|
|
1066
|
-
):
|
|
1067
|
-
module = importlib.import_module(module_info.name)
|
|
1068
|
-
dependencies.extend(self._scan_module(module, tags=tags))
|
|
1069
|
-
|
|
1070
|
-
return dependencies
|
|
1071
|
-
|
|
1072
|
-
def _scan_module(
|
|
1073
|
-
self, module: ModuleType, *, tags: Iterable[str]
|
|
1074
|
-
) -> list[ScannedDependency]:
|
|
1075
|
-
"""Scan a module for decorated members."""
|
|
1076
|
-
dependencies: list[ScannedDependency] = []
|
|
1077
|
-
|
|
1078
|
-
for _, member in inspect.getmembers(module):
|
|
1079
|
-
if getattr(member, "__module__", None) != module.__name__ or not callable(
|
|
1080
|
-
member
|
|
1081
|
-
):
|
|
1082
|
-
continue
|
|
1083
|
-
|
|
1084
|
-
decorator_args: InjectableDecoratorArgs = getattr(
|
|
1085
|
-
member,
|
|
1086
|
-
"__injectable__",
|
|
1087
|
-
InjectableDecoratorArgs(wrapped=False, tags=[]),
|
|
1088
|
-
)
|
|
1089
|
-
|
|
1090
|
-
if tags and (
|
|
1091
|
-
decorator_args.tags
|
|
1092
|
-
and not set(decorator_args.tags).intersection(tags)
|
|
1093
|
-
or not decorator_args.tags
|
|
1094
|
-
):
|
|
1095
|
-
continue
|
|
1096
|
-
|
|
1097
|
-
if decorator_args.wrapped:
|
|
1098
|
-
dependencies.append(
|
|
1099
|
-
self._create_scanned_dependency(member=member, module=module)
|
|
1100
|
-
)
|
|
1101
|
-
continue
|
|
1102
|
-
|
|
1103
|
-
# Get by Marker
|
|
1104
|
-
for parameter in get_typed_parameters(member):
|
|
1105
|
-
if is_marker(parameter.default):
|
|
1106
|
-
dependencies.append(
|
|
1107
|
-
self._create_scanned_dependency(member=member, module=module)
|
|
1108
|
-
)
|
|
1109
|
-
continue
|
|
1110
|
-
|
|
1111
|
-
return dependencies
|
|
1112
|
-
|
|
1113
|
-
def _create_scanned_dependency(
|
|
1114
|
-
self, member: Any, module: ModuleType
|
|
1115
|
-
) -> ScannedDependency:
|
|
1116
|
-
"""Create a `Dependency` object from the scanned member and module."""
|
|
1117
|
-
if hasattr(member, "__wrapped__"):
|
|
1118
|
-
member = member.__wrapped__
|
|
1119
|
-
return ScannedDependency(member=member, module=module)
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
############################
|
|
1123
|
-
# Decorators
|
|
1124
|
-
############################
|
|
1125
|
-
|
|
1126
|
-
|
|
1127
|
-
def transient(target: T) -> T:
|
|
1128
|
-
"""Decorator for marking a class as transient scope."""
|
|
1129
|
-
target.__scope__ = "transient"
|
|
1130
|
-
return target
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
def request(target: T) -> T:
|
|
1134
|
-
"""Decorator for marking a class as request scope."""
|
|
1135
|
-
target.__scope__ = "request"
|
|
1136
|
-
return target
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
def singleton(target: T) -> T:
|
|
1140
|
-
"""Decorator for marking a class as singleton scope."""
|
|
1141
|
-
target.__scope__ = "singleton"
|
|
1142
|
-
return target
|
|
1143
|
-
|
|
1144
|
-
|
|
1145
|
-
def provider(
|
|
1146
|
-
*, scope: Scope, override: bool = False
|
|
1147
|
-
) -> Callable[[Callable[Concatenate[M, P], T]], Callable[Concatenate[M, P], T]]:
|
|
1148
|
-
"""Decorator for marking a function or method as a provider in a AnyDI module."""
|
|
1149
|
-
|
|
1150
|
-
def decorator(
|
|
1151
|
-
target: Callable[Concatenate[M, P], T],
|
|
1152
|
-
) -> Callable[Concatenate[M, P], T]:
|
|
1153
|
-
target.__provider__ = ProviderDecoratorArgs(scope=scope, override=override) # type: ignore
|
|
1154
|
-
return target
|
|
1155
|
-
|
|
1156
|
-
return decorator
|
|
1157
|
-
|
|
1158
|
-
|
|
1159
|
-
@overload
|
|
1160
|
-
def injectable(func: Callable[P, T]) -> Callable[P, T]: ...
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
@overload
|
|
1164
|
-
def injectable(
|
|
1165
|
-
*, tags: Iterable[str] | None = None
|
|
1166
|
-
) -> Callable[[Callable[P, T]], Callable[P, T]]: ...
|
|
1167
|
-
|
|
1168
|
-
|
|
1169
|
-
def injectable(
|
|
1170
|
-
func: Callable[P, T] | None = None,
|
|
1171
|
-
tags: Iterable[str] | None = None,
|
|
1172
|
-
) -> Callable[[Callable[P, T]], Callable[P, T]] | Callable[P, T]:
|
|
1173
|
-
"""Decorator for marking a function or method as requiring dependency injection."""
|
|
1174
|
-
|
|
1175
|
-
def decorator(inner: Callable[P, T]) -> Callable[P, T]:
|
|
1176
|
-
setattr(
|
|
1177
|
-
inner,
|
|
1178
|
-
"__injectable__",
|
|
1179
|
-
InjectableDecoratorArgs(wrapped=True, tags=tags),
|
|
1180
|
-
)
|
|
1181
|
-
return inner
|
|
1182
|
-
|
|
1183
|
-
if func is None:
|
|
1184
|
-
return decorator
|
|
1185
|
-
|
|
1186
|
-
return decorator(func)
|
|
874
|
+
self._scanner.scan(packages=packages, tags=tags)
|
anydi/_context.py
CHANGED
|
@@ -7,7 +7,7 @@ from typing import Any
|
|
|
7
7
|
|
|
8
8
|
from typing_extensions import Self
|
|
9
9
|
|
|
10
|
-
from .
|
|
10
|
+
from ._async import AsyncRLock, run_sync
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class InstanceContext:
|
|
@@ -78,7 +78,7 @@ class InstanceContext:
|
|
|
78
78
|
exc_tb: TracebackType | None,
|
|
79
79
|
) -> bool:
|
|
80
80
|
"""Exit the context asynchronously."""
|
|
81
|
-
sync_exit = await
|
|
81
|
+
sync_exit = await run_sync(self.__exit__, exc_type, exc_val, exc_tb)
|
|
82
82
|
async_exit = await self._async_stack.__aexit__(exc_type, exc_val, exc_tb)
|
|
83
83
|
return bool(sync_exit) or bool(async_exit)
|
|
84
84
|
|