anydi 0.41.0__py3-none-any.whl → 0.43.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- anydi/__init__.py +8 -13
- anydi/_async.py +50 -0
- anydi/_container.py +65 -380
- anydi/_context.py +2 -2
- anydi/_decorators.py +122 -0
- anydi/_module.py +77 -0
- anydi/_provider.py +81 -0
- anydi/_scan.py +110 -0
- anydi/_scope.py +9 -0
- anydi/{_utils.py → _typing.py} +49 -69
- anydi/ext/_utils.py +4 -4
- anydi/ext/django/_utils.py +2 -42
- anydi/ext/django/apps.py +0 -3
- anydi/ext/django/ninja/__init__.py +3 -3
- anydi/ext/django/ninja/_operation.py +1 -1
- anydi/ext/django/ninja/_signature.py +1 -1
- anydi/ext/fastapi.py +2 -2
- anydi/ext/faststream.py +2 -2
- anydi/ext/pytest_plugin.py +1 -1
- anydi/ext/starlette/middleware.py +1 -1
- anydi/testing.py +172 -0
- {anydi-0.41.0.dist-info → anydi-0.43.0.dist-info}/METADATA +2 -2
- anydi-0.43.0.dist-info/RECORD +34 -0
- anydi/_types.py +0 -145
- anydi-0.41.0.dist-info/RECORD +0 -28
- {anydi-0.41.0.dist-info → anydi-0.43.0.dist-info}/WHEEL +0 -0
- {anydi-0.41.0.dist-info → anydi-0.43.0.dist-info}/entry_points.txt +0 -0
- {anydi-0.41.0.dist-info → anydi-0.43.0.dist-info}/licenses/LICENSE +0 -0
anydi/_container.py
CHANGED
|
@@ -4,110 +4,75 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
import contextlib
|
|
6
6
|
import functools
|
|
7
|
-
import importlib
|
|
8
7
|
import inspect
|
|
9
8
|
import logging
|
|
10
|
-
import pkgutil
|
|
11
9
|
import types
|
|
12
10
|
import uuid
|
|
13
11
|
from collections import defaultdict
|
|
14
|
-
from collections.abc import AsyncIterator, Iterable, Iterator
|
|
12
|
+
from collections.abc import AsyncIterator, Iterable, Iterator
|
|
15
13
|
from contextvars import ContextVar
|
|
16
|
-
from
|
|
17
|
-
from typing import Annotated, Any, Callable, TypeVar, cast, overload
|
|
14
|
+
from typing import Any, Callable, TypeVar, cast, overload
|
|
18
15
|
|
|
19
|
-
from typing_extensions import
|
|
16
|
+
from typing_extensions import ParamSpec, Self, get_args, get_origin
|
|
20
17
|
|
|
18
|
+
from ._async import run_sync
|
|
21
19
|
from ._context import InstanceContext
|
|
22
|
-
from .
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
Event,
|
|
26
|
-
InjectableDecoratorArgs,
|
|
27
|
-
InstanceProxy,
|
|
20
|
+
from ._decorators import is_provided
|
|
21
|
+
from ._module import ModuleDef, ModuleRegistrar
|
|
22
|
+
from ._provider import (
|
|
28
23
|
Provider,
|
|
29
|
-
|
|
30
|
-
ProviderDecoratorArgs,
|
|
24
|
+
ProviderDef,
|
|
31
25
|
ProviderKind,
|
|
32
|
-
ScannedDependency,
|
|
33
|
-
Scope,
|
|
34
|
-
is_event_type,
|
|
35
|
-
is_marker,
|
|
36
26
|
)
|
|
37
|
-
from .
|
|
38
|
-
|
|
27
|
+
from ._scan import PackageOrIterable, Scanner
|
|
28
|
+
from ._scope import ALLOWED_SCOPES, Scope
|
|
29
|
+
from ._typing import (
|
|
30
|
+
NOT_SET,
|
|
31
|
+
Event,
|
|
39
32
|
get_typed_annotation,
|
|
40
33
|
get_typed_parameters,
|
|
41
|
-
import_string,
|
|
42
34
|
is_async_context_manager,
|
|
43
35
|
is_builtin_type,
|
|
44
36
|
is_context_manager,
|
|
37
|
+
is_event_type,
|
|
45
38
|
is_iterator_type,
|
|
39
|
+
is_marker,
|
|
46
40
|
is_none_type,
|
|
47
|
-
|
|
41
|
+
type_repr,
|
|
48
42
|
)
|
|
49
43
|
|
|
50
44
|
T = TypeVar("T", bound=Any)
|
|
51
|
-
M = TypeVar("M", bound="Module")
|
|
52
45
|
P = ParamSpec("P")
|
|
53
46
|
|
|
54
|
-
ALLOWED_SCOPES: dict[Scope, list[Scope]] = {
|
|
55
|
-
"singleton": ["singleton"],
|
|
56
|
-
"request": ["request", "singleton"],
|
|
57
|
-
"transient": ["transient", "request", "singleton"],
|
|
58
|
-
}
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
class ModuleMeta(type):
|
|
62
|
-
"""A metaclass used for the Module base class."""
|
|
63
|
-
|
|
64
|
-
def __new__(cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any]) -> Any:
|
|
65
|
-
attrs["providers"] = [
|
|
66
|
-
(name, getattr(value, "__provider__"))
|
|
67
|
-
for name, value in attrs.items()
|
|
68
|
-
if hasattr(value, "__provider__")
|
|
69
|
-
]
|
|
70
|
-
return super().__new__(cls, name, bases, attrs)
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
class Module(metaclass=ModuleMeta):
|
|
74
|
-
"""A base class for defining AnyDI modules."""
|
|
75
|
-
|
|
76
|
-
providers: list[tuple[str, ProviderDecoratorArgs]]
|
|
77
47
|
|
|
78
|
-
def configure(self, container: Container) -> None:
|
|
79
|
-
"""Configure the AnyDI container with providers and their dependencies."""
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
@final
|
|
83
48
|
class Container:
|
|
84
49
|
"""AnyDI is a dependency injection container."""
|
|
85
50
|
|
|
86
51
|
def __init__(
|
|
87
52
|
self,
|
|
88
53
|
*,
|
|
89
|
-
providers:
|
|
90
|
-
modules:
|
|
91
|
-
| None = None,
|
|
54
|
+
providers: Iterable[ProviderDef] | None = None,
|
|
55
|
+
modules: Iterable[ModuleDef] | None = None,
|
|
92
56
|
strict: bool = False,
|
|
93
57
|
default_scope: Scope = "transient",
|
|
94
|
-
testing: bool = False,
|
|
95
58
|
logger: logging.Logger | None = None,
|
|
96
59
|
) -> None:
|
|
97
60
|
self._providers: dict[Any, Provider] = {}
|
|
98
61
|
self._strict = strict
|
|
99
62
|
self._default_scope: Scope = default_scope
|
|
100
|
-
self._testing = testing
|
|
101
63
|
self._logger = logger or logging.getLogger(__name__)
|
|
102
64
|
self._resources: dict[str, list[Any]] = defaultdict(list)
|
|
103
65
|
self._singleton_context = InstanceContext()
|
|
104
66
|
self._request_context_var: ContextVar[InstanceContext | None] = ContextVar(
|
|
105
67
|
"request_context", default=None
|
|
106
68
|
)
|
|
107
|
-
self._override_instances: dict[Any, Any] = {}
|
|
108
69
|
self._unresolved_interfaces: set[Any] = set()
|
|
109
70
|
self._inject_cache: dict[Callable[..., Any], Callable[..., Any]] = {}
|
|
110
71
|
|
|
72
|
+
# Components
|
|
73
|
+
self._modules = ModuleRegistrar(self)
|
|
74
|
+
self._scanner = Scanner(self)
|
|
75
|
+
|
|
111
76
|
# Register providers
|
|
112
77
|
providers = providers or []
|
|
113
78
|
for provider in providers:
|
|
@@ -136,11 +101,6 @@ class Container:
|
|
|
136
101
|
"""Get the default scope."""
|
|
137
102
|
return self._default_scope
|
|
138
103
|
|
|
139
|
-
@property
|
|
140
|
-
def testing(self) -> bool:
|
|
141
|
-
"""Check if testing mode is enabled."""
|
|
142
|
-
return self._testing
|
|
143
|
-
|
|
144
104
|
@property
|
|
145
105
|
def providers(self) -> dict[type[Any], Provider]:
|
|
146
106
|
"""Get the registered providers."""
|
|
@@ -258,7 +218,7 @@ class Container:
|
|
|
258
218
|
|
|
259
219
|
def register(
|
|
260
220
|
self,
|
|
261
|
-
interface:
|
|
221
|
+
interface: Any,
|
|
262
222
|
call: Callable[..., Any],
|
|
263
223
|
*,
|
|
264
224
|
scope: Scope,
|
|
@@ -267,16 +227,15 @@ class Container:
|
|
|
267
227
|
"""Register a provider for the specified interface."""
|
|
268
228
|
return self._register_provider(call, scope, interface, override)
|
|
269
229
|
|
|
270
|
-
def is_registered(self, interface:
|
|
230
|
+
def is_registered(self, interface: Any) -> bool:
|
|
271
231
|
"""Check if a provider is registered for the specified interface."""
|
|
272
232
|
return interface in self._providers
|
|
273
233
|
|
|
274
|
-
def unregister(self, interface:
|
|
234
|
+
def unregister(self, interface: Any) -> None:
|
|
275
235
|
"""Unregister a provider by interface."""
|
|
276
236
|
if not self.is_registered(interface):
|
|
277
237
|
raise LookupError(
|
|
278
|
-
"The provider interface "
|
|
279
|
-
f"`{get_full_qualname(interface)}` not registered."
|
|
238
|
+
f"The provider interface `{type_repr(interface)}` not registered."
|
|
280
239
|
)
|
|
281
240
|
|
|
282
241
|
provider = self._get_provider(interface)
|
|
@@ -314,7 +273,7 @@ class Container:
|
|
|
314
273
|
**defaults: Any,
|
|
315
274
|
) -> Provider:
|
|
316
275
|
"""Register a provider with the specified scope."""
|
|
317
|
-
name =
|
|
276
|
+
name = type_repr(call)
|
|
318
277
|
kind = ProviderKind.from_call(call)
|
|
319
278
|
detected_scope = scope
|
|
320
279
|
|
|
@@ -359,8 +318,7 @@ class Container:
|
|
|
359
318
|
# Check for existing provider
|
|
360
319
|
if interface in self._providers and not override:
|
|
361
320
|
raise LookupError(
|
|
362
|
-
f"The provider interface `{
|
|
363
|
-
"already registered."
|
|
321
|
+
f"The provider interface `{type_repr(interface)}` already registered."
|
|
364
322
|
)
|
|
365
323
|
|
|
366
324
|
unresolved_parameter = None
|
|
@@ -421,7 +379,7 @@ class Container:
|
|
|
421
379
|
else:
|
|
422
380
|
raise LookupError(
|
|
423
381
|
f"The provider `{name}` depends on `{unresolved_parameter.name}` "
|
|
424
|
-
f"of type `{
|
|
382
|
+
f"of type `{type_repr(unresolved_parameter.annotation)}`, "
|
|
425
383
|
"which has not been registered or set. To resolve this, ensure "
|
|
426
384
|
f"that `{unresolved_parameter.name}` is registered before "
|
|
427
385
|
f"attempting to use it."
|
|
@@ -464,19 +422,19 @@ class Container:
|
|
|
464
422
|
"with a transient scope, which is not allowed."
|
|
465
423
|
)
|
|
466
424
|
|
|
467
|
-
def _get_provider(self, interface:
|
|
425
|
+
def _get_provider(self, interface: Any) -> Provider:
|
|
468
426
|
"""Get provider by interface."""
|
|
469
427
|
try:
|
|
470
428
|
return self._providers[interface]
|
|
471
429
|
except KeyError as exc:
|
|
472
430
|
raise LookupError(
|
|
473
|
-
f"The provider interface for `{
|
|
431
|
+
f"The provider interface for `{type_repr(interface)}` has "
|
|
474
432
|
"not been registered. Please ensure that the provider interface is "
|
|
475
433
|
"properly registered before attempting to use it."
|
|
476
434
|
) from exc
|
|
477
435
|
|
|
478
436
|
def _get_or_register_provider(
|
|
479
|
-
self, interface:
|
|
437
|
+
self, interface: Any, parent_scope: Scope | None, /, **defaults: Any
|
|
480
438
|
) -> Provider:
|
|
481
439
|
"""Get or register a provider by interface."""
|
|
482
440
|
try:
|
|
@@ -484,17 +442,13 @@ class Container:
|
|
|
484
442
|
except LookupError:
|
|
485
443
|
if self.strict or interface is inspect.Parameter.empty:
|
|
486
444
|
raise
|
|
487
|
-
if
|
|
488
|
-
call = args[0]
|
|
489
|
-
else:
|
|
490
|
-
call = interface
|
|
491
|
-
if inspect.isclass(call) and not is_builtin_type(call):
|
|
445
|
+
if inspect.isclass(interface) and not is_builtin_type(interface):
|
|
492
446
|
# Try to get defined scope
|
|
493
|
-
if
|
|
494
|
-
scope =
|
|
447
|
+
if is_provided(interface):
|
|
448
|
+
scope = interface.__provided__["scope"]
|
|
495
449
|
else:
|
|
496
450
|
scope = parent_scope
|
|
497
|
-
return self._register_provider(
|
|
451
|
+
return self._register_provider(interface, scope, interface, **defaults)
|
|
498
452
|
raise
|
|
499
453
|
|
|
500
454
|
def _set_provider(self, provider: Provider) -> None:
|
|
@@ -551,7 +505,7 @@ class Container:
|
|
|
551
505
|
"""Create an instance by interface asynchronously."""
|
|
552
506
|
return await self._aresolve_or_create(interface, True, **defaults)
|
|
553
507
|
|
|
554
|
-
def is_resolved(self, interface:
|
|
508
|
+
def is_resolved(self, interface: Any) -> bool:
|
|
555
509
|
"""Check if an instance by interface exists."""
|
|
556
510
|
try:
|
|
557
511
|
provider = self._get_provider(interface)
|
|
@@ -562,7 +516,7 @@ class Container:
|
|
|
562
516
|
context = self._get_instance_context(provider.scope)
|
|
563
517
|
return interface in context
|
|
564
518
|
|
|
565
|
-
def release(self, interface:
|
|
519
|
+
def release(self, interface: Any) -> None:
|
|
566
520
|
"""Release an instance by interface."""
|
|
567
521
|
provider = self._get_provider(interface)
|
|
568
522
|
if provider.scope == "transient":
|
|
@@ -597,9 +551,6 @@ class Container:
|
|
|
597
551
|
else self._create_instance(provider, context, **defaults)
|
|
598
552
|
)
|
|
599
553
|
|
|
600
|
-
if self.testing:
|
|
601
|
-
instance = self._patch_test_resolver(provider.interface, instance)
|
|
602
|
-
|
|
603
554
|
return cast(T, instance)
|
|
604
555
|
|
|
605
556
|
async def _aresolve_or_create(
|
|
@@ -618,9 +569,6 @@ class Container:
|
|
|
618
569
|
else await self._acreate_instance(provider, context, **defaults)
|
|
619
570
|
)
|
|
620
571
|
|
|
621
|
-
if self.testing:
|
|
622
|
-
instance = self._patch_test_resolver(provider.interface, instance)
|
|
623
|
-
|
|
624
572
|
return cast(T, instance)
|
|
625
573
|
|
|
626
574
|
def _get_or_create_instance(
|
|
@@ -695,9 +643,9 @@ class Container:
|
|
|
695
643
|
cm = contextlib.contextmanager(provider.call)(**provider_kwargs)
|
|
696
644
|
return context.enter(cm)
|
|
697
645
|
|
|
698
|
-
return await
|
|
646
|
+
return await run_sync(_create)
|
|
699
647
|
|
|
700
|
-
instance = await
|
|
648
|
+
instance = await run_sync(provider.call, **provider_kwargs)
|
|
701
649
|
if (
|
|
702
650
|
context is not None
|
|
703
651
|
and provider.is_class
|
|
@@ -712,7 +660,7 @@ class Container:
|
|
|
712
660
|
"""Retrieve the arguments for a provider."""
|
|
713
661
|
provided_kwargs = {}
|
|
714
662
|
for parameter in provider.parameters:
|
|
715
|
-
instance = self._get_provider_instance(
|
|
663
|
+
instance, _ = self._get_provider_instance(
|
|
716
664
|
provider, parameter, context, **defaults
|
|
717
665
|
)
|
|
718
666
|
provided_kwargs[parameter.name] = instance
|
|
@@ -725,12 +673,12 @@ class Container:
|
|
|
725
673
|
context: InstanceContext | None,
|
|
726
674
|
/,
|
|
727
675
|
**defaults: Any,
|
|
728
|
-
) -> Any:
|
|
676
|
+
) -> tuple[Any, bool]:
|
|
729
677
|
"""Retrieve an instance of a dependency from the scoped context."""
|
|
730
678
|
|
|
731
679
|
# Try to get instance from defaults
|
|
732
680
|
if parameter.name in defaults:
|
|
733
|
-
return defaults[parameter.name]
|
|
681
|
+
return defaults[parameter.name], True
|
|
734
682
|
|
|
735
683
|
# Try to get instance from context
|
|
736
684
|
elif context and parameter.annotation in context:
|
|
@@ -743,12 +691,8 @@ class Container:
|
|
|
743
691
|
except LookupError:
|
|
744
692
|
if parameter.default is inspect.Parameter.empty:
|
|
745
693
|
raise
|
|
746
|
-
return parameter.default
|
|
747
|
-
|
|
748
|
-
# Wrap the instance in a proxy for testing
|
|
749
|
-
if self.testing:
|
|
750
|
-
return InstanceProxy(instance, interface=parameter.annotation)
|
|
751
|
-
return instance
|
|
694
|
+
return parameter.default, True
|
|
695
|
+
return instance, False
|
|
752
696
|
|
|
753
697
|
async def _aget_provided_kwargs(
|
|
754
698
|
self, provider: Provider, context: InstanceContext | None, /, **defaults: Any
|
|
@@ -756,7 +700,7 @@ class Container:
|
|
|
756
700
|
"""Asynchronously retrieve the arguments for a provider."""
|
|
757
701
|
provided_kwargs = {}
|
|
758
702
|
for parameter in provider.parameters:
|
|
759
|
-
instance = await self._aget_provider_instance(
|
|
703
|
+
instance, _ = await self._aget_provider_instance(
|
|
760
704
|
provider, parameter, context, **defaults
|
|
761
705
|
)
|
|
762
706
|
provided_kwargs[parameter.name] = instance
|
|
@@ -769,12 +713,12 @@ class Container:
|
|
|
769
713
|
context: InstanceContext | None,
|
|
770
714
|
/,
|
|
771
715
|
**defaults: Any,
|
|
772
|
-
) -> Any:
|
|
716
|
+
) -> tuple[Any, bool]:
|
|
773
717
|
"""Asynchronously retrieve an instance of a dependency from the context."""
|
|
774
718
|
|
|
775
719
|
# Try to get instance from defaults
|
|
776
720
|
if parameter.name in defaults:
|
|
777
|
-
return defaults[parameter.name]
|
|
721
|
+
return defaults[parameter.name], True
|
|
778
722
|
|
|
779
723
|
# Try to get instance from context
|
|
780
724
|
elif context and parameter.annotation in context:
|
|
@@ -787,12 +731,8 @@ class Container:
|
|
|
787
731
|
except LookupError:
|
|
788
732
|
if parameter.default is inspect.Parameter.empty:
|
|
789
733
|
raise
|
|
790
|
-
return parameter.default
|
|
791
|
-
|
|
792
|
-
# Wrap the instance in a proxy for testing
|
|
793
|
-
if self.testing:
|
|
794
|
-
return InstanceProxy(instance, interface=parameter.annotation)
|
|
795
|
-
return instance
|
|
734
|
+
return parameter.default, True
|
|
735
|
+
return instance, False
|
|
796
736
|
|
|
797
737
|
def _resolve_parameter(
|
|
798
738
|
self, provider: Provider, parameter: inspect.Parameter
|
|
@@ -813,83 +753,11 @@ class Container:
|
|
|
813
753
|
if parameter.annotation in self._unresolved_interfaces:
|
|
814
754
|
raise LookupError(
|
|
815
755
|
f"You are attempting to get the parameter `{parameter.name}` with the "
|
|
816
|
-
f"annotation `{
|
|
817
|
-
f"dependency into `{
|
|
756
|
+
f"annotation `{type_repr(parameter.annotation)}` as a "
|
|
757
|
+
f"dependency into `{type_repr(provider.call)}` which is "
|
|
818
758
|
"not registered or set in the scoped context."
|
|
819
759
|
)
|
|
820
760
|
|
|
821
|
-
@contextlib.contextmanager
|
|
822
|
-
def override(self, interface: AnyInterface, instance: Any) -> Iterator[None]:
|
|
823
|
-
"""
|
|
824
|
-
Override the provider for the specified interface with a specific instance.
|
|
825
|
-
"""
|
|
826
|
-
if not self.testing:
|
|
827
|
-
raise RuntimeError(
|
|
828
|
-
"The `override` method can only be used in testing mode."
|
|
829
|
-
)
|
|
830
|
-
if not self.is_registered(interface) and self.strict:
|
|
831
|
-
raise LookupError(
|
|
832
|
-
f"The provider interface `{get_full_qualname(interface)}` "
|
|
833
|
-
"not registered."
|
|
834
|
-
)
|
|
835
|
-
self._override_instances[interface] = instance
|
|
836
|
-
try:
|
|
837
|
-
yield
|
|
838
|
-
finally:
|
|
839
|
-
self._override_instances.pop(interface, None)
|
|
840
|
-
|
|
841
|
-
############################
|
|
842
|
-
# Testing Methods
|
|
843
|
-
############################
|
|
844
|
-
|
|
845
|
-
def _patch_test_resolver(self, interface: type[Any], instance: Any) -> Any:
|
|
846
|
-
"""Patch the test resolver for the instance."""
|
|
847
|
-
if interface in self._override_instances:
|
|
848
|
-
return self._override_instances[interface]
|
|
849
|
-
|
|
850
|
-
if not hasattr(instance, "__dict__") or hasattr(
|
|
851
|
-
instance, "__resolver_getter__"
|
|
852
|
-
):
|
|
853
|
-
return instance
|
|
854
|
-
|
|
855
|
-
wrapped = {
|
|
856
|
-
name: value.interface
|
|
857
|
-
for name, value in instance.__dict__.items()
|
|
858
|
-
if isinstance(value, InstanceProxy)
|
|
859
|
-
}
|
|
860
|
-
|
|
861
|
-
def __resolver_getter__(name: str) -> Any:
|
|
862
|
-
if name in wrapped:
|
|
863
|
-
_interface = wrapped[name]
|
|
864
|
-
# Resolve the dependency if it's wrapped
|
|
865
|
-
return self.resolve(_interface)
|
|
866
|
-
raise LookupError
|
|
867
|
-
|
|
868
|
-
# Attach the resolver getter to the instance
|
|
869
|
-
instance.__resolver_getter__ = __resolver_getter__
|
|
870
|
-
|
|
871
|
-
if not hasattr(instance.__class__, "__getattribute_patched__"):
|
|
872
|
-
|
|
873
|
-
def __getattribute__(_self: Any, name: str) -> Any:
|
|
874
|
-
# Skip the resolver getter
|
|
875
|
-
if name in {"__resolver_getter__", "__class__"}:
|
|
876
|
-
return object.__getattribute__(_self, name)
|
|
877
|
-
|
|
878
|
-
if hasattr(_self, "__resolver_getter__"):
|
|
879
|
-
try:
|
|
880
|
-
return _self.__resolver_getter__(name)
|
|
881
|
-
except LookupError:
|
|
882
|
-
pass
|
|
883
|
-
|
|
884
|
-
# Fall back to default behavior
|
|
885
|
-
return object.__getattribute__(_self, name)
|
|
886
|
-
|
|
887
|
-
# Apply the patched resolver if wrapped attributes exist
|
|
888
|
-
instance.__class__.__getattribute__ = __getattribute__
|
|
889
|
-
instance.__class__.__getattribute_patched__ = True
|
|
890
|
-
|
|
891
|
-
return instance
|
|
892
|
-
|
|
893
761
|
############################
|
|
894
762
|
# Injector Methods
|
|
895
763
|
############################
|
|
@@ -922,6 +790,9 @@ class Container:
|
|
|
922
790
|
return cast(Callable[P, T], self._inject_cache[call])
|
|
923
791
|
|
|
924
792
|
injected_params = self._get_injected_params(call)
|
|
793
|
+
if not injected_params:
|
|
794
|
+
self._inject_cache[call] = call
|
|
795
|
+
return call
|
|
925
796
|
|
|
926
797
|
if inspect.iscoroutinefunction(call):
|
|
927
798
|
|
|
@@ -933,7 +804,7 @@ class Container:
|
|
|
933
804
|
|
|
934
805
|
self._inject_cache[call] = awrapper
|
|
935
806
|
|
|
936
|
-
return awrapper # type: ignore
|
|
807
|
+
return awrapper # type: ignore
|
|
937
808
|
|
|
938
809
|
@functools.wraps(call)
|
|
939
810
|
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
@@ -956,9 +827,9 @@ class Container:
|
|
|
956
827
|
except LookupError as exc:
|
|
957
828
|
if not self.strict:
|
|
958
829
|
self.logger.debug(
|
|
959
|
-
f"Cannot validate the `{
|
|
830
|
+
f"Cannot validate the `{type_repr(call)}` parameter "
|
|
960
831
|
f"`{parameter.name}` with an annotation of "
|
|
961
|
-
f"`{
|
|
832
|
+
f"`{type_repr(parameter.annotation)} due to being "
|
|
962
833
|
"in non-strict mode. It will be validated at the first call."
|
|
963
834
|
)
|
|
964
835
|
else:
|
|
@@ -972,215 +843,29 @@ class Container:
|
|
|
972
843
|
"""Validate an injected parameter."""
|
|
973
844
|
if parameter.annotation is inspect.Parameter.empty:
|
|
974
845
|
raise TypeError(
|
|
975
|
-
f"Missing `{
|
|
976
|
-
f"`{parameter.name}` annotation."
|
|
846
|
+
f"Missing `{type_repr(call)}` parameter `{parameter.name}` annotation."
|
|
977
847
|
)
|
|
978
848
|
|
|
979
849
|
if not self.is_registered(parameter.annotation):
|
|
980
850
|
raise LookupError(
|
|
981
|
-
f"`{
|
|
851
|
+
f"`{type_repr(call)}` has an unknown dependency parameter "
|
|
982
852
|
f"`{parameter.name}` with an annotation of "
|
|
983
|
-
f"`{
|
|
853
|
+
f"`{type_repr(parameter.annotation)}`."
|
|
984
854
|
)
|
|
985
855
|
|
|
986
856
|
############################
|
|
987
857
|
# Module Methods
|
|
988
858
|
############################
|
|
989
859
|
|
|
990
|
-
def register_module(
|
|
991
|
-
self, module: Module | type[Module] | Callable[[Container], None] | str
|
|
992
|
-
) -> None:
|
|
860
|
+
def register_module(self, module: ModuleDef) -> None:
|
|
993
861
|
"""Register a module as a callable, module type, or module instance."""
|
|
994
|
-
|
|
995
|
-
if inspect.isfunction(module):
|
|
996
|
-
module(self)
|
|
997
|
-
return
|
|
998
|
-
|
|
999
|
-
# Module path
|
|
1000
|
-
if isinstance(module, str):
|
|
1001
|
-
module = import_string(module)
|
|
1002
|
-
|
|
1003
|
-
# Class based Module or Module type
|
|
1004
|
-
if inspect.isclass(module) and issubclass(module, Module):
|
|
1005
|
-
module = module()
|
|
1006
|
-
|
|
1007
|
-
if isinstance(module, Module):
|
|
1008
|
-
module.configure(self)
|
|
1009
|
-
for provider_name, decorator_args in module.providers:
|
|
1010
|
-
obj = getattr(module, provider_name)
|
|
1011
|
-
self.provider(
|
|
1012
|
-
scope=decorator_args.scope,
|
|
1013
|
-
override=decorator_args.override,
|
|
1014
|
-
)(obj)
|
|
1015
|
-
else:
|
|
1016
|
-
raise TypeError(
|
|
1017
|
-
"The module must be a callable, a module type, or a module instance."
|
|
1018
|
-
)
|
|
862
|
+
self._modules.register(module)
|
|
1019
863
|
|
|
1020
864
|
############################
|
|
1021
865
|
# Scanner Methods
|
|
1022
866
|
############################
|
|
1023
867
|
|
|
1024
868
|
def scan(
|
|
1025
|
-
self,
|
|
1026
|
-
/,
|
|
1027
|
-
packages: ModuleType | str | Iterable[ModuleType | str],
|
|
1028
|
-
*,
|
|
1029
|
-
tags: Iterable[str] | None = None,
|
|
869
|
+
self, /, packages: PackageOrIterable, *, tags: Iterable[str] | None = None
|
|
1030
870
|
) -> None:
|
|
1031
|
-
|
|
1032
|
-
dependencies: list[ScannedDependency] = []
|
|
1033
|
-
|
|
1034
|
-
if isinstance(packages, ModuleType | str):
|
|
1035
|
-
scan_packages: Iterable[ModuleType | str] = [packages]
|
|
1036
|
-
else:
|
|
1037
|
-
scan_packages = packages
|
|
1038
|
-
|
|
1039
|
-
for package in scan_packages:
|
|
1040
|
-
dependencies.extend(self._scan_package(package, tags=tags))
|
|
1041
|
-
|
|
1042
|
-
for dependency in dependencies:
|
|
1043
|
-
decorator = self.inject()(dependency.member)
|
|
1044
|
-
setattr(dependency.module, dependency.member.__name__, decorator)
|
|
1045
|
-
|
|
1046
|
-
def _scan_package(
|
|
1047
|
-
self,
|
|
1048
|
-
package: ModuleType | str,
|
|
1049
|
-
*,
|
|
1050
|
-
tags: Iterable[str] | None = None,
|
|
1051
|
-
) -> list[ScannedDependency]:
|
|
1052
|
-
"""Scan a package or module for decorated members."""
|
|
1053
|
-
tags = tags or []
|
|
1054
|
-
if isinstance(package, str):
|
|
1055
|
-
package = importlib.import_module(package)
|
|
1056
|
-
|
|
1057
|
-
package_path = getattr(package, "__path__", None)
|
|
1058
|
-
|
|
1059
|
-
if not package_path:
|
|
1060
|
-
return self._scan_module(package, tags=tags)
|
|
1061
|
-
|
|
1062
|
-
dependencies: list[ScannedDependency] = []
|
|
1063
|
-
|
|
1064
|
-
for module_info in pkgutil.walk_packages(
|
|
1065
|
-
path=package_path, prefix=package.__name__ + "."
|
|
1066
|
-
):
|
|
1067
|
-
module = importlib.import_module(module_info.name)
|
|
1068
|
-
dependencies.extend(self._scan_module(module, tags=tags))
|
|
1069
|
-
|
|
1070
|
-
return dependencies
|
|
1071
|
-
|
|
1072
|
-
def _scan_module(
|
|
1073
|
-
self, module: ModuleType, *, tags: Iterable[str]
|
|
1074
|
-
) -> list[ScannedDependency]:
|
|
1075
|
-
"""Scan a module for decorated members."""
|
|
1076
|
-
dependencies: list[ScannedDependency] = []
|
|
1077
|
-
|
|
1078
|
-
for _, member in inspect.getmembers(module):
|
|
1079
|
-
if getattr(member, "__module__", None) != module.__name__ or not callable(
|
|
1080
|
-
member
|
|
1081
|
-
):
|
|
1082
|
-
continue
|
|
1083
|
-
|
|
1084
|
-
decorator_args: InjectableDecoratorArgs = getattr(
|
|
1085
|
-
member,
|
|
1086
|
-
"__injectable__",
|
|
1087
|
-
InjectableDecoratorArgs(wrapped=False, tags=[]),
|
|
1088
|
-
)
|
|
1089
|
-
|
|
1090
|
-
if tags and (
|
|
1091
|
-
decorator_args.tags
|
|
1092
|
-
and not set(decorator_args.tags).intersection(tags)
|
|
1093
|
-
or not decorator_args.tags
|
|
1094
|
-
):
|
|
1095
|
-
continue
|
|
1096
|
-
|
|
1097
|
-
if decorator_args.wrapped:
|
|
1098
|
-
dependencies.append(
|
|
1099
|
-
self._create_scanned_dependency(member=member, module=module)
|
|
1100
|
-
)
|
|
1101
|
-
continue
|
|
1102
|
-
|
|
1103
|
-
# Get by Marker
|
|
1104
|
-
for parameter in get_typed_parameters(member):
|
|
1105
|
-
if is_marker(parameter.default):
|
|
1106
|
-
dependencies.append(
|
|
1107
|
-
self._create_scanned_dependency(member=member, module=module)
|
|
1108
|
-
)
|
|
1109
|
-
continue
|
|
1110
|
-
|
|
1111
|
-
return dependencies
|
|
1112
|
-
|
|
1113
|
-
def _create_scanned_dependency(
|
|
1114
|
-
self, member: Any, module: ModuleType
|
|
1115
|
-
) -> ScannedDependency:
|
|
1116
|
-
"""Create a `Dependency` object from the scanned member and module."""
|
|
1117
|
-
if hasattr(member, "__wrapped__"):
|
|
1118
|
-
member = member.__wrapped__
|
|
1119
|
-
return ScannedDependency(member=member, module=module)
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
############################
|
|
1123
|
-
# Decorators
|
|
1124
|
-
############################
|
|
1125
|
-
|
|
1126
|
-
|
|
1127
|
-
def transient(target: T) -> T:
|
|
1128
|
-
"""Decorator for marking a class as transient scope."""
|
|
1129
|
-
target.__scope__ = "transient"
|
|
1130
|
-
return target
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
def request(target: T) -> T:
|
|
1134
|
-
"""Decorator for marking a class as request scope."""
|
|
1135
|
-
target.__scope__ = "request"
|
|
1136
|
-
return target
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
def singleton(target: T) -> T:
|
|
1140
|
-
"""Decorator for marking a class as singleton scope."""
|
|
1141
|
-
target.__scope__ = "singleton"
|
|
1142
|
-
return target
|
|
1143
|
-
|
|
1144
|
-
|
|
1145
|
-
def provider(
|
|
1146
|
-
*, scope: Scope, override: bool = False
|
|
1147
|
-
) -> Callable[[Callable[Concatenate[M, P], T]], Callable[Concatenate[M, P], T]]:
|
|
1148
|
-
"""Decorator for marking a function or method as a provider in a AnyDI module."""
|
|
1149
|
-
|
|
1150
|
-
def decorator(
|
|
1151
|
-
target: Callable[Concatenate[M, P], T],
|
|
1152
|
-
) -> Callable[Concatenate[M, P], T]:
|
|
1153
|
-
target.__provider__ = ProviderDecoratorArgs(scope=scope, override=override) # type: ignore
|
|
1154
|
-
return target
|
|
1155
|
-
|
|
1156
|
-
return decorator
|
|
1157
|
-
|
|
1158
|
-
|
|
1159
|
-
@overload
|
|
1160
|
-
def injectable(func: Callable[P, T]) -> Callable[P, T]: ...
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
@overload
|
|
1164
|
-
def injectable(
|
|
1165
|
-
*, tags: Iterable[str] | None = None
|
|
1166
|
-
) -> Callable[[Callable[P, T]], Callable[P, T]]: ...
|
|
1167
|
-
|
|
1168
|
-
|
|
1169
|
-
def injectable(
|
|
1170
|
-
func: Callable[P, T] | None = None,
|
|
1171
|
-
tags: Iterable[str] | None = None,
|
|
1172
|
-
) -> Callable[[Callable[P, T]], Callable[P, T]] | Callable[P, T]:
|
|
1173
|
-
"""Decorator for marking a function or method as requiring dependency injection."""
|
|
1174
|
-
|
|
1175
|
-
def decorator(inner: Callable[P, T]) -> Callable[P, T]:
|
|
1176
|
-
setattr(
|
|
1177
|
-
inner,
|
|
1178
|
-
"__injectable__",
|
|
1179
|
-
InjectableDecoratorArgs(wrapped=True, tags=tags),
|
|
1180
|
-
)
|
|
1181
|
-
return inner
|
|
1182
|
-
|
|
1183
|
-
if func is None:
|
|
1184
|
-
return decorator
|
|
1185
|
-
|
|
1186
|
-
return decorator(func)
|
|
871
|
+
self._scanner.scan(packages=packages, tags=tags)
|