anydi 0.22.1__py3-none-any.whl → 0.37.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- anydi/__init__.py +14 -14
- anydi/_container.py +811 -571
- anydi/_context.py +39 -281
- anydi/_provider.py +232 -0
- anydi/_types.py +49 -96
- anydi/_utils.py +108 -77
- anydi/ext/_utils.py +49 -28
- anydi/ext/django/__init__.py +9 -0
- anydi/ext/django/_container.py +18 -0
- anydi/ext/django/_settings.py +39 -0
- anydi/ext/django/_utils.py +128 -0
- anydi/ext/django/apps.py +85 -0
- anydi/ext/django/middleware.py +28 -0
- anydi/ext/django/ninja/__init__.py +16 -0
- anydi/ext/django/ninja/_operation.py +75 -0
- anydi/ext/django/ninja/_signature.py +64 -0
- anydi/ext/fastapi.py +11 -27
- anydi/ext/faststream.py +58 -0
- anydi/ext/pydantic_settings.py +48 -0
- anydi/ext/pytest_plugin.py +67 -41
- anydi/ext/starlette/middleware.py +2 -16
- {anydi-0.22.1.dist-info → anydi-0.37.4.dist-info}/METADATA +71 -21
- anydi-0.37.4.dist-info/RECORD +29 -0
- {anydi-0.22.1.dist-info → anydi-0.37.4.dist-info}/WHEEL +1 -1
- anydi-0.37.4.dist-info/entry_points.txt +2 -0
- anydi/_logger.py +0 -3
- anydi/_module.py +0 -124
- anydi/_scanner.py +0 -233
- anydi-0.22.1.dist-info/RECORD +0 -20
- anydi-0.22.1.dist-info/entry_points.txt +0 -3
- {anydi-0.22.1.dist-info → anydi-0.37.4.dist-info/licenses}/LICENSE +0 -0
anydi/_container.py
CHANGED
|
@@ -3,104 +3,115 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
import contextlib
|
|
6
|
+
import functools
|
|
7
|
+
import importlib
|
|
6
8
|
import inspect
|
|
9
|
+
import logging
|
|
10
|
+
import pkgutil
|
|
11
|
+
import threading
|
|
7
12
|
import types
|
|
8
|
-
import
|
|
13
|
+
from collections import defaultdict
|
|
14
|
+
from collections.abc import AsyncIterator, Iterable, Iterator, Sequence
|
|
9
15
|
from contextvars import ContextVar
|
|
10
|
-
from
|
|
11
|
-
from typing import
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
Union,
|
|
28
|
-
cast,
|
|
29
|
-
overload,
|
|
16
|
+
from types import ModuleType
|
|
17
|
+
from typing import Any, Callable, TypeVar, Union, cast, overload
|
|
18
|
+
from weakref import WeakKeyDictionary
|
|
19
|
+
|
|
20
|
+
from typing_extensions import Concatenate, ParamSpec, Self, final
|
|
21
|
+
|
|
22
|
+
from ._context import InstanceContext
|
|
23
|
+
from ._provider import Provider
|
|
24
|
+
from ._types import (
|
|
25
|
+
AnyInterface,
|
|
26
|
+
Dependency,
|
|
27
|
+
InjectableDecoratorArgs,
|
|
28
|
+
InstanceProxy,
|
|
29
|
+
ProviderDecoratorArgs,
|
|
30
|
+
Scope,
|
|
31
|
+
is_event_type,
|
|
32
|
+
is_marker,
|
|
30
33
|
)
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
from ._context import (
|
|
41
|
-
RequestContext,
|
|
42
|
-
ResourceScopedContext,
|
|
43
|
-
ScopedContext,
|
|
44
|
-
SingletonContext,
|
|
45
|
-
TransientContext,
|
|
34
|
+
from ._utils import (
|
|
35
|
+
AsyncRLock,
|
|
36
|
+
get_full_qualname,
|
|
37
|
+
get_typed_parameters,
|
|
38
|
+
import_string,
|
|
39
|
+
is_async_context_manager,
|
|
40
|
+
is_builtin_type,
|
|
41
|
+
is_context_manager,
|
|
42
|
+
run_async,
|
|
46
43
|
)
|
|
47
|
-
from ._logger import logger
|
|
48
|
-
from ._module import Module, ModuleRegistry
|
|
49
|
-
from ._scanner import Scanner
|
|
50
|
-
from ._types import AnyInterface, Interface, Marker, Provider, Scope
|
|
51
|
-
from ._utils import get_full_qualname, get_signature, is_builtin_type
|
|
52
44
|
|
|
53
45
|
T = TypeVar("T", bound=Any)
|
|
46
|
+
M = TypeVar("M", bound="Module")
|
|
54
47
|
P = ParamSpec("P")
|
|
55
48
|
|
|
56
|
-
ALLOWED_SCOPES:
|
|
49
|
+
ALLOWED_SCOPES: dict[Scope, list[Scope]] = {
|
|
57
50
|
"singleton": ["singleton"],
|
|
58
51
|
"request": ["request", "singleton"],
|
|
59
|
-
"transient": ["transient", "
|
|
52
|
+
"transient": ["transient", "request", "singleton"],
|
|
60
53
|
}
|
|
61
54
|
|
|
62
55
|
|
|
56
|
+
class ModuleMeta(type):
|
|
57
|
+
"""A metaclass used for the Module base class."""
|
|
58
|
+
|
|
59
|
+
def __new__(cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any]) -> Any:
|
|
60
|
+
attrs["providers"] = [
|
|
61
|
+
(name, getattr(value, "__provider__"))
|
|
62
|
+
for name, value in attrs.items()
|
|
63
|
+
if hasattr(value, "__provider__")
|
|
64
|
+
]
|
|
65
|
+
return super().__new__(cls, name, bases, attrs)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class Module(metaclass=ModuleMeta):
|
|
69
|
+
"""A base class for defining AnyDI modules."""
|
|
70
|
+
|
|
71
|
+
providers: list[tuple[str, ProviderDecoratorArgs]]
|
|
72
|
+
|
|
73
|
+
def configure(self, container: Container) -> None:
|
|
74
|
+
"""Configure the AnyDI container with providers and their dependencies."""
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
# noinspection PyShadowingNames
|
|
63
78
|
@final
|
|
64
79
|
class Container:
|
|
65
|
-
"""AnyDI is a dependency injection container.
|
|
66
|
-
|
|
67
|
-
Args:
|
|
68
|
-
modules: Optional sequence of modules to register during initialization.
|
|
69
|
-
"""
|
|
80
|
+
"""AnyDI is a dependency injection container."""
|
|
70
81
|
|
|
71
82
|
def __init__(
|
|
72
83
|
self,
|
|
73
84
|
*,
|
|
74
|
-
providers:
|
|
75
|
-
modules:
|
|
76
|
-
|
|
77
|
-
] = None,
|
|
85
|
+
providers: Sequence[Provider] | None = None,
|
|
86
|
+
modules: Sequence[Module | type[Module] | Callable[[Container], None] | str]
|
|
87
|
+
| None = None,
|
|
78
88
|
strict: bool = False,
|
|
89
|
+
default_scope: Scope = "transient",
|
|
90
|
+
testing: bool = False,
|
|
91
|
+
logger: logging.Logger | None = None,
|
|
79
92
|
) -> None:
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
self.
|
|
88
|
-
self.
|
|
89
|
-
self.
|
|
90
|
-
self._request_context_var: ContextVar[Optional[RequestContext]] = ContextVar(
|
|
93
|
+
self._providers: dict[type[Any], Provider] = {}
|
|
94
|
+
self._strict = strict
|
|
95
|
+
self._default_scope = default_scope
|
|
96
|
+
self._testing = testing
|
|
97
|
+
self._logger = logger or logging.getLogger(__name__)
|
|
98
|
+
self._resources: dict[str, list[type[Any]]] = defaultdict(list)
|
|
99
|
+
self._singleton_context = InstanceContext()
|
|
100
|
+
self._singleton_lock = threading.RLock()
|
|
101
|
+
self._singleton_async_lock = AsyncRLock()
|
|
102
|
+
self._request_context_var: ContextVar[InstanceContext | None] = ContextVar(
|
|
91
103
|
"request_context", default=None
|
|
92
104
|
)
|
|
93
|
-
self._override_instances:
|
|
94
|
-
self.
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
self._scanner = Scanner(self)
|
|
105
|
+
self._override_instances: dict[type[Any], Any] = {}
|
|
106
|
+
self._unresolved_interfaces: set[type[Any]] = set()
|
|
107
|
+
self._inject_cache: WeakKeyDictionary[
|
|
108
|
+
Callable[..., Any], Callable[..., Any]
|
|
109
|
+
] = WeakKeyDictionary()
|
|
99
110
|
|
|
100
111
|
# Register providers
|
|
101
|
-
providers = providers or
|
|
102
|
-
for
|
|
103
|
-
self.
|
|
112
|
+
providers = providers or []
|
|
113
|
+
for provider in providers:
|
|
114
|
+
self._register_provider(provider, False)
|
|
104
115
|
|
|
105
116
|
# Register modules
|
|
106
117
|
modules = modules or []
|
|
@@ -109,353 +120,314 @@ class Container:
|
|
|
109
120
|
|
|
110
121
|
@property
|
|
111
122
|
def strict(self) -> bool:
|
|
112
|
-
"""Check if strict mode is enabled.
|
|
113
|
-
|
|
114
|
-
Returns:
|
|
115
|
-
True if strict mode is enabled, False otherwise.
|
|
116
|
-
"""
|
|
123
|
+
"""Check if strict mode is enabled."""
|
|
117
124
|
return self._strict
|
|
118
125
|
|
|
119
126
|
@property
|
|
120
|
-
def
|
|
121
|
-
"""Get the
|
|
127
|
+
def default_scope(self) -> Scope:
|
|
128
|
+
"""Get the default scope."""
|
|
129
|
+
return self._default_scope
|
|
122
130
|
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
"""
|
|
126
|
-
return self.
|
|
131
|
+
@property
|
|
132
|
+
def testing(self) -> bool:
|
|
133
|
+
"""Check if testing mode is enabled."""
|
|
134
|
+
return self._testing
|
|
127
135
|
|
|
128
|
-
|
|
129
|
-
|
|
136
|
+
@property
|
|
137
|
+
def providers(self) -> dict[type[Any], Provider]:
|
|
138
|
+
"""Get the registered providers."""
|
|
139
|
+
return self._providers
|
|
130
140
|
|
|
131
|
-
|
|
132
|
-
|
|
141
|
+
@property
|
|
142
|
+
def logger(self) -> logging.Logger:
|
|
143
|
+
"""Get the logger instance."""
|
|
144
|
+
return self._logger
|
|
133
145
|
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
"""
|
|
146
|
+
def is_registered(self, interface: AnyInterface) -> bool:
|
|
147
|
+
"""Check if a provider is registered for the specified interface."""
|
|
137
148
|
return interface in self._providers
|
|
138
149
|
|
|
139
150
|
def register(
|
|
140
151
|
self,
|
|
141
152
|
interface: AnyInterface,
|
|
142
|
-
|
|
153
|
+
call: Callable[..., Any],
|
|
143
154
|
*,
|
|
144
155
|
scope: Scope,
|
|
145
156
|
override: bool = False,
|
|
146
157
|
) -> Provider:
|
|
147
|
-
"""Register a provider for the specified interface.
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
interface: The interface for which the provider is being registered.
|
|
151
|
-
obj: The provider function or callable object.
|
|
152
|
-
scope: The scope of the provider.
|
|
153
|
-
override: If True, override an existing provider for the interface
|
|
154
|
-
if one is already registered. Defaults to False.
|
|
155
|
-
|
|
156
|
-
Returns:
|
|
157
|
-
The registered provider.
|
|
158
|
-
|
|
159
|
-
Raises:
|
|
160
|
-
LookupError: If a provider for the interface is already registered
|
|
161
|
-
and override is False.
|
|
162
|
-
|
|
163
|
-
Notes:
|
|
164
|
-
- If the provider is a resource or an asynchronous resource, and the
|
|
165
|
-
interface is None, an Event type will be automatically created and used
|
|
166
|
-
as the interface.
|
|
167
|
-
- The provider will be validated for its scope, type, and matching scopes.
|
|
168
|
-
"""
|
|
169
|
-
provider = Provider(obj=obj, scope=scope)
|
|
170
|
-
|
|
171
|
-
# Create Event type
|
|
172
|
-
if provider.is_resource and (interface is NoneType or interface is None):
|
|
173
|
-
interface = type(f"Event{uuid.uuid4().hex}", (), {})
|
|
158
|
+
"""Register a provider for the specified interface."""
|
|
159
|
+
provider = Provider(call=call, scope=scope, interface=interface)
|
|
160
|
+
return self._register_provider(provider, override)
|
|
174
161
|
|
|
175
|
-
|
|
162
|
+
def _register_provider(
|
|
163
|
+
self, provider: Provider, override: bool, /, **defaults: Any
|
|
164
|
+
) -> Provider:
|
|
165
|
+
"""Register a provider."""
|
|
166
|
+
if provider.interface in self._providers:
|
|
176
167
|
if override:
|
|
177
|
-
self.
|
|
168
|
+
self._set_provider(provider)
|
|
178
169
|
return provider
|
|
179
170
|
|
|
180
171
|
raise LookupError(
|
|
181
|
-
f"The provider interface `{get_full_qualname(interface)}` "
|
|
172
|
+
f"The provider interface `{get_full_qualname(provider.interface)}` "
|
|
182
173
|
"already registered."
|
|
183
174
|
)
|
|
184
175
|
|
|
185
|
-
|
|
186
|
-
self.
|
|
187
|
-
self._validate_provider_type(provider)
|
|
188
|
-
self._validate_provider_match_scopes(interface, provider)
|
|
189
|
-
|
|
190
|
-
self._providers[interface] = provider
|
|
176
|
+
self._validate_sub_providers(provider, **defaults)
|
|
177
|
+
self._set_provider(provider)
|
|
191
178
|
return provider
|
|
192
179
|
|
|
193
180
|
def unregister(self, interface: AnyInterface) -> None:
|
|
194
|
-
"""Unregister a provider by interface.
|
|
195
|
-
|
|
196
|
-
Args:
|
|
197
|
-
interface: The interface of the provider to unregister.
|
|
198
|
-
|
|
199
|
-
Raises:
|
|
200
|
-
LookupError: If the provider interface is not registered.
|
|
201
|
-
|
|
202
|
-
Notes:
|
|
203
|
-
- The method cleans up any scoped context instance associated with
|
|
204
|
-
the provider's scope.
|
|
205
|
-
- The method removes the provider reference from the internal dictionary
|
|
206
|
-
of registered providers.
|
|
207
|
-
"""
|
|
181
|
+
"""Unregister a provider by interface."""
|
|
208
182
|
if not self.is_registered(interface):
|
|
209
183
|
raise LookupError(
|
|
210
184
|
"The provider interface "
|
|
211
185
|
f"`{get_full_qualname(interface)}` not registered."
|
|
212
186
|
)
|
|
213
187
|
|
|
214
|
-
provider = self.
|
|
188
|
+
provider = self._get_provider(interface)
|
|
215
189
|
|
|
216
|
-
# Cleanup
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
190
|
+
# Cleanup instance context
|
|
191
|
+
if provider.scope != "transient":
|
|
192
|
+
try:
|
|
193
|
+
context = self._get_scoped_context(provider.scope)
|
|
194
|
+
except LookupError:
|
|
195
|
+
pass
|
|
196
|
+
else:
|
|
197
|
+
del context[interface]
|
|
224
198
|
|
|
225
199
|
# Cleanup provider references
|
|
226
|
-
self.
|
|
227
|
-
|
|
228
|
-
def _get_or_register_provider(self, interface: AnyInterface) -> Provider:
|
|
229
|
-
"""Get or register a provider by interface.
|
|
230
|
-
|
|
231
|
-
Args:
|
|
232
|
-
interface: The interface for which to retrieve the provider.
|
|
200
|
+
self._delete_provider(provider)
|
|
233
201
|
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
Raises:
|
|
238
|
-
LookupError: If the provider interface has not been registered.
|
|
239
|
-
"""
|
|
202
|
+
def _get_provider(self, interface: AnyInterface) -> Provider:
|
|
203
|
+
"""Get provider by interface."""
|
|
240
204
|
try:
|
|
241
205
|
return self._providers[interface]
|
|
242
206
|
except KeyError as exc:
|
|
207
|
+
raise LookupError(
|
|
208
|
+
f"The provider interface for `{get_full_qualname(interface)}` has "
|
|
209
|
+
"not been registered. Please ensure that the provider interface is "
|
|
210
|
+
"properly registered before attempting to use it."
|
|
211
|
+
) from exc
|
|
212
|
+
|
|
213
|
+
def _get_or_register_provider(
|
|
214
|
+
self, interface: AnyInterface, parent_scope: Scope | None, /, **defaults: Any
|
|
215
|
+
) -> Provider:
|
|
216
|
+
"""Get or register a provider by interface."""
|
|
217
|
+
try:
|
|
218
|
+
return self._get_provider(interface)
|
|
219
|
+
except LookupError:
|
|
243
220
|
if (
|
|
244
221
|
not self.strict
|
|
245
222
|
and inspect.isclass(interface)
|
|
246
223
|
and not is_builtin_type(interface)
|
|
224
|
+
and interface is not inspect.Parameter.empty
|
|
247
225
|
):
|
|
248
226
|
# Try to get defined scope
|
|
249
|
-
scope = getattr(interface, "__scope__",
|
|
227
|
+
scope = getattr(interface, "__scope__", parent_scope)
|
|
250
228
|
# Try to detect scope
|
|
251
229
|
if scope is None:
|
|
252
|
-
scope = self._detect_scope(interface)
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
raise LookupError(
|
|
262
|
-
f"The provider interface for `{get_full_qualname(interface)}` has "
|
|
263
|
-
"not been registered. Please ensure that the provider interface is "
|
|
264
|
-
"properly registered before attempting to use it."
|
|
265
|
-
) from exc
|
|
266
|
-
|
|
267
|
-
def _validate_provider_scope(self, provider: Provider) -> None:
|
|
268
|
-
"""Validate the scope of a provider.
|
|
269
|
-
|
|
270
|
-
Args:
|
|
271
|
-
provider: The provider to validate.
|
|
272
|
-
|
|
273
|
-
Raises:
|
|
274
|
-
ValueError: If the scope provided is invalid.
|
|
275
|
-
"""
|
|
276
|
-
if provider.scope not in get_args(Scope):
|
|
277
|
-
raise ValueError(
|
|
278
|
-
"The scope provided is invalid. Only the following scopes are "
|
|
279
|
-
f"supported: {', '.join(get_args(Scope))}. Please use one of the "
|
|
280
|
-
"supported scopes when registering a provider."
|
|
281
|
-
)
|
|
282
|
-
|
|
283
|
-
def _validate_provider_type(self, provider: Provider) -> None:
|
|
284
|
-
"""Validate the type of provider.
|
|
285
|
-
|
|
286
|
-
Args:
|
|
287
|
-
provider: The provider to validate.
|
|
288
|
-
|
|
289
|
-
Raises:
|
|
290
|
-
TypeError: If the provider has an invalid type.
|
|
291
|
-
"""
|
|
292
|
-
if provider.is_function or provider.is_class:
|
|
293
|
-
return
|
|
294
|
-
|
|
230
|
+
scope = self._detect_scope(interface, **defaults)
|
|
231
|
+
scope = scope or self.default_scope
|
|
232
|
+
provider = Provider(call=interface, scope=scope, interface=interface)
|
|
233
|
+
return self._register_provider(provider, False, **defaults)
|
|
234
|
+
raise
|
|
235
|
+
|
|
236
|
+
def _set_provider(self, provider: Provider) -> None:
|
|
237
|
+
"""Set a provider by interface."""
|
|
238
|
+
self._providers[provider.interface] = provider
|
|
295
239
|
if provider.is_resource:
|
|
296
|
-
|
|
297
|
-
raise TypeError(
|
|
298
|
-
f"The resource provider `{provider}` is attempting to register "
|
|
299
|
-
"with a transient scope, which is not allowed. Please update the "
|
|
300
|
-
"provider's scope to an appropriate value before registering it."
|
|
301
|
-
)
|
|
302
|
-
return
|
|
303
|
-
|
|
304
|
-
raise TypeError(
|
|
305
|
-
f"The provider `{provider.obj}` is invalid because it is not a callable "
|
|
306
|
-
"object. Only callable providers are allowed. Please update the provider "
|
|
307
|
-
"to a callable object before attempting to register it."
|
|
308
|
-
)
|
|
309
|
-
|
|
310
|
-
def _validate_provider_match_scopes(
|
|
311
|
-
self, interface: AnyInterface, provider: Provider
|
|
312
|
-
) -> None:
|
|
313
|
-
"""Validate that the provider and its dependencies have matching scopes.
|
|
240
|
+
self._resources[provider.scope].append(provider.interface)
|
|
314
241
|
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
242
|
+
def _delete_provider(self, provider: Provider) -> None:
|
|
243
|
+
"""Delete a provider."""
|
|
244
|
+
if provider.interface in self._providers:
|
|
245
|
+
del self._providers[provider.interface]
|
|
246
|
+
if provider.is_resource:
|
|
247
|
+
self._resources[provider.scope].remove(provider.interface)
|
|
318
248
|
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
TypeError: If a dependency is missing an annotation.
|
|
322
|
-
"""
|
|
323
|
-
related_providers = []
|
|
249
|
+
def _validate_sub_providers(self, provider: Provider, /, **defaults: Any) -> None:
|
|
250
|
+
"""Validate the sub-providers of a provider."""
|
|
324
251
|
|
|
325
|
-
for parameter in provider.parameters
|
|
326
|
-
if parameter.annotation is inspect.
|
|
252
|
+
for parameter in provider.parameters:
|
|
253
|
+
if parameter.annotation is inspect.Parameter.empty:
|
|
327
254
|
raise TypeError(
|
|
328
255
|
f"Missing provider `{provider}` "
|
|
329
256
|
f"dependency `{parameter.name}` annotation."
|
|
330
257
|
)
|
|
258
|
+
|
|
331
259
|
try:
|
|
332
|
-
sub_provider = self._get_or_register_provider(
|
|
260
|
+
sub_provider = self._get_or_register_provider(
|
|
261
|
+
parameter.annotation, provider.scope
|
|
262
|
+
)
|
|
333
263
|
except LookupError:
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
264
|
+
if self._parameter_has_default(parameter, **defaults):
|
|
265
|
+
continue
|
|
266
|
+
|
|
267
|
+
if provider.scope not in {"singleton", "transient"}:
|
|
268
|
+
self._unresolved_interfaces.add(provider.interface)
|
|
269
|
+
continue
|
|
270
|
+
raise ValueError(
|
|
271
|
+
f"The provider `{provider}` depends on `{parameter.name}` of type "
|
|
337
272
|
f"`{get_full_qualname(parameter.annotation)}`, which "
|
|
338
|
-
"has not been registered. To resolve this, ensure that "
|
|
273
|
+
"has not been registered or set. To resolve this, ensure that "
|
|
339
274
|
f"`{parameter.name}` is registered before attempting to use it."
|
|
340
275
|
) from None
|
|
341
|
-
related_providers.append(sub_provider)
|
|
342
276
|
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
allowed_scopes = ALLOWED_SCOPES.get(right_scope) or []
|
|
346
|
-
if left_scope not in allowed_scopes:
|
|
277
|
+
# Check scope compatibility
|
|
278
|
+
if sub_provider.scope not in ALLOWED_SCOPES.get(provider.scope, []):
|
|
347
279
|
raise ValueError(
|
|
348
|
-
f"The provider `{provider}` with a {provider.scope} scope
|
|
349
|
-
"
|
|
350
|
-
|
|
351
|
-
"which is not allowed. Please ensure that all providers are "
|
|
352
|
-
"registered with matching scopes."
|
|
280
|
+
f"The provider `{provider}` with a `{provider.scope}` scope cannot "
|
|
281
|
+
f"depend on `{sub_provider}` with a `{sub_provider.scope}` scope. "
|
|
282
|
+
"Please ensure all providers are registered with matching scopes."
|
|
353
283
|
)
|
|
354
284
|
|
|
355
|
-
def _detect_scope(self,
|
|
356
|
-
"""Detect the scope for a
|
|
285
|
+
def _detect_scope(self, call: Callable[..., Any], **defaults: Any) -> Scope | None:
|
|
286
|
+
"""Detect the scope for a callable."""
|
|
287
|
+
scopes = set()
|
|
357
288
|
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
return
|
|
374
|
-
|
|
289
|
+
for parameter in get_typed_parameters(call):
|
|
290
|
+
try:
|
|
291
|
+
sub_provider = self._get_or_register_provider(
|
|
292
|
+
parameter.annotation, None
|
|
293
|
+
)
|
|
294
|
+
except LookupError:
|
|
295
|
+
if self._parameter_has_default(parameter, **defaults):
|
|
296
|
+
continue
|
|
297
|
+
raise
|
|
298
|
+
scope = sub_provider.scope
|
|
299
|
+
|
|
300
|
+
if scope == "transient":
|
|
301
|
+
return "transient"
|
|
302
|
+
scopes.add(scope)
|
|
303
|
+
|
|
304
|
+
# If all scopes are found, we can return based on priority order
|
|
305
|
+
if {"transient", "request", "singleton"}.issubset(scopes):
|
|
306
|
+
break # pragma: no cover
|
|
307
|
+
|
|
308
|
+
# Determine scope based on priority
|
|
309
|
+
if "request" in scopes:
|
|
375
310
|
return "request"
|
|
376
|
-
if
|
|
311
|
+
if "singleton" in scopes:
|
|
377
312
|
return "singleton"
|
|
313
|
+
|
|
378
314
|
return None
|
|
379
315
|
|
|
316
|
+
def _parameter_has_default(
|
|
317
|
+
self, parameter: inspect.Parameter, /, **defaults: Any
|
|
318
|
+
) -> bool:
|
|
319
|
+
return (defaults and parameter.name in defaults) or (
|
|
320
|
+
not self.strict and parameter.default is not inspect.Parameter.empty
|
|
321
|
+
)
|
|
322
|
+
|
|
380
323
|
def register_module(
|
|
381
|
-
self, module:
|
|
324
|
+
self, module: Module | type[Module] | Callable[[Container], None] | str
|
|
382
325
|
) -> None:
|
|
383
|
-
"""Register a module as a callable, module type, or module instance.
|
|
326
|
+
"""Register a module as a callable, module type, or module instance."""
|
|
327
|
+
# Callable Module
|
|
328
|
+
if inspect.isfunction(module):
|
|
329
|
+
module(self)
|
|
330
|
+
return
|
|
384
331
|
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
332
|
+
# Module path
|
|
333
|
+
if isinstance(module, str):
|
|
334
|
+
module = import_string(module)
|
|
335
|
+
|
|
336
|
+
# Class based Module or Module type
|
|
337
|
+
if inspect.isclass(module) and issubclass(module, Module):
|
|
338
|
+
module = module()
|
|
339
|
+
|
|
340
|
+
if isinstance(module, Module):
|
|
341
|
+
module.configure(self)
|
|
342
|
+
for provider_name, decorator_args in module.providers:
|
|
343
|
+
obj = getattr(module, provider_name)
|
|
344
|
+
self.provider(
|
|
345
|
+
scope=decorator_args.scope,
|
|
346
|
+
override=decorator_args.override,
|
|
347
|
+
)(obj)
|
|
348
|
+
|
|
349
|
+
def __enter__(self) -> Self:
|
|
350
|
+
"""Enter the singleton context."""
|
|
351
|
+
self.start()
|
|
352
|
+
return self
|
|
353
|
+
|
|
354
|
+
def __exit__(
|
|
355
|
+
self,
|
|
356
|
+
exc_type: type[BaseException] | None,
|
|
357
|
+
exc_val: BaseException | None,
|
|
358
|
+
exc_tb: types.TracebackType | None,
|
|
359
|
+
) -> Any:
|
|
360
|
+
"""Exit the singleton context."""
|
|
361
|
+
return self._singleton_context.__exit__(exc_type, exc_val, exc_tb)
|
|
389
362
|
|
|
390
363
|
def start(self) -> None:
|
|
391
364
|
"""Start the singleton context."""
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
365
|
+
# Resolve all singleton resources
|
|
366
|
+
for interface in self._resources.get("singleton", []):
|
|
367
|
+
self.resolve(interface)
|
|
395
368
|
|
|
396
369
|
def close(self) -> None:
|
|
397
370
|
"""Close the singleton context."""
|
|
398
371
|
self._singleton_context.close()
|
|
399
372
|
|
|
400
|
-
|
|
401
|
-
|
|
373
|
+
@contextlib.contextmanager
|
|
374
|
+
def request_context(self) -> Iterator[InstanceContext]:
|
|
375
|
+
"""Obtain a context manager for the request-scoped context."""
|
|
376
|
+
context = InstanceContext()
|
|
402
377
|
|
|
403
|
-
|
|
404
|
-
A context manager for the request-scoped context.
|
|
405
|
-
"""
|
|
406
|
-
return contextlib.contextmanager(self._request_context)()
|
|
378
|
+
token = self._request_context_var.set(context)
|
|
407
379
|
|
|
408
|
-
|
|
409
|
-
|
|
380
|
+
# Resolve all request resources
|
|
381
|
+
for interface in self._resources.get("request", []):
|
|
382
|
+
if not is_event_type(interface):
|
|
383
|
+
continue
|
|
384
|
+
self.resolve(interface)
|
|
410
385
|
|
|
411
|
-
Yields:
|
|
412
|
-
Yield control to the code block within the request context.
|
|
413
|
-
"""
|
|
414
|
-
context = RequestContext(self)
|
|
415
|
-
token = self._request_context_var.set(context)
|
|
416
386
|
with context:
|
|
417
|
-
yield
|
|
387
|
+
yield context
|
|
418
388
|
self._request_context_var.reset(token)
|
|
419
389
|
|
|
390
|
+
async def __aenter__(self) -> Self:
|
|
391
|
+
"""Enter the singleton context."""
|
|
392
|
+
await self.astart()
|
|
393
|
+
return self
|
|
394
|
+
|
|
395
|
+
async def __aexit__(
|
|
396
|
+
self,
|
|
397
|
+
exc_type: type[BaseException] | None,
|
|
398
|
+
exc_val: BaseException | None,
|
|
399
|
+
exc_tb: types.TracebackType | None,
|
|
400
|
+
) -> bool:
|
|
401
|
+
"""Exit the singleton context."""
|
|
402
|
+
return await self._singleton_context.__aexit__(exc_type, exc_val, exc_tb)
|
|
403
|
+
|
|
420
404
|
async def astart(self) -> None:
|
|
421
405
|
"""Start the singleton context asynchronously."""
|
|
422
|
-
for interface
|
|
423
|
-
|
|
424
|
-
await self.aresolve(interface) # noqa
|
|
406
|
+
for interface in self._resources.get("singleton", []):
|
|
407
|
+
await self.aresolve(interface)
|
|
425
408
|
|
|
426
409
|
async def aclose(self) -> None:
|
|
427
410
|
"""Close the singleton context asynchronously."""
|
|
428
411
|
await self._singleton_context.aclose()
|
|
429
412
|
|
|
430
|
-
|
|
431
|
-
|
|
413
|
+
@contextlib.asynccontextmanager
|
|
414
|
+
async def arequest_context(self) -> AsyncIterator[InstanceContext]:
|
|
415
|
+
"""Obtain an async context manager for the request-scoped context."""
|
|
416
|
+
context = InstanceContext()
|
|
432
417
|
|
|
433
|
-
|
|
434
|
-
An async context manager for the request-scoped context.
|
|
435
|
-
"""
|
|
436
|
-
return contextlib.asynccontextmanager(self._arequest_context)()
|
|
418
|
+
token = self._request_context_var.set(context)
|
|
437
419
|
|
|
438
|
-
|
|
439
|
-
|
|
420
|
+
for interface in self._resources.get("request", []):
|
|
421
|
+
if not is_event_type(interface):
|
|
422
|
+
continue
|
|
423
|
+
await self.aresolve(interface)
|
|
440
424
|
|
|
441
|
-
Yields:
|
|
442
|
-
Yield control to the code block within the request context.
|
|
443
|
-
"""
|
|
444
|
-
context = RequestContext(self)
|
|
445
|
-
token = self._request_context_var.set(context)
|
|
446
425
|
async with context:
|
|
447
|
-
yield
|
|
426
|
+
yield context
|
|
448
427
|
self._request_context_var.reset(token)
|
|
449
428
|
|
|
450
|
-
def _get_request_context(self) ->
|
|
451
|
-
"""Get the current request context.
|
|
452
|
-
|
|
453
|
-
Returns:
|
|
454
|
-
RequestContext: The current request context.
|
|
455
|
-
|
|
456
|
-
Raises:
|
|
457
|
-
LookupError: If the request context has not been started.
|
|
458
|
-
"""
|
|
429
|
+
def _get_request_context(self) -> InstanceContext:
|
|
430
|
+
"""Get the current request context."""
|
|
459
431
|
request_context = self._request_context_var.get()
|
|
460
432
|
if request_context is None:
|
|
461
433
|
raise LookupError(
|
|
@@ -468,123 +440,365 @@ class Container:
|
|
|
468
440
|
def reset(self) -> None:
|
|
469
441
|
"""Reset resolved instances."""
|
|
470
442
|
for interface, provider in self._providers.items():
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
443
|
+
if provider.scope == "transient":
|
|
444
|
+
continue
|
|
445
|
+
try:
|
|
446
|
+
context = self._get_scoped_context(provider.scope)
|
|
447
|
+
except LookupError:
|
|
448
|
+
continue
|
|
449
|
+
del context[interface]
|
|
474
450
|
|
|
475
451
|
@overload
|
|
476
|
-
def resolve(self, interface:
|
|
452
|
+
def resolve(self, interface: type[T]) -> T: ...
|
|
477
453
|
|
|
478
454
|
@overload
|
|
479
455
|
def resolve(self, interface: T) -> T: ...
|
|
480
456
|
|
|
481
|
-
def resolve(self, interface:
|
|
482
|
-
"""Resolve an instance by interface.
|
|
457
|
+
def resolve(self, interface: type[T]) -> T:
|
|
458
|
+
"""Resolve an instance by interface."""
|
|
459
|
+
provider = self._get_or_register_provider(interface, None)
|
|
460
|
+
if provider.scope == "transient":
|
|
461
|
+
instance = self._create_instance(provider, None)
|
|
462
|
+
else:
|
|
463
|
+
context = self._get_scoped_context(provider.scope)
|
|
464
|
+
if provider.scope == "singleton":
|
|
465
|
+
with self._singleton_lock:
|
|
466
|
+
instance = self._get_or_create_instance(provider, context)
|
|
467
|
+
else:
|
|
468
|
+
instance = self._get_or_create_instance(provider, context)
|
|
469
|
+
if self.testing:
|
|
470
|
+
instance = self._patch_test_resolver(provider.interface, instance)
|
|
471
|
+
return cast(T, instance)
|
|
483
472
|
|
|
484
|
-
|
|
485
|
-
|
|
473
|
+
@overload
|
|
474
|
+
async def aresolve(self, interface: type[T]) -> T: ...
|
|
486
475
|
|
|
487
|
-
|
|
488
|
-
|
|
476
|
+
@overload
|
|
477
|
+
async def aresolve(self, interface: T) -> T: ...
|
|
489
478
|
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
if
|
|
494
|
-
|
|
479
|
+
async def aresolve(self, interface: type[T]) -> T:
|
|
480
|
+
"""Resolve an instance by interface asynchronously."""
|
|
481
|
+
provider = self._get_or_register_provider(interface, None)
|
|
482
|
+
if provider.scope == "transient":
|
|
483
|
+
instance = await self._acreate_instance(provider, None)
|
|
484
|
+
else:
|
|
485
|
+
context = self._get_scoped_context(provider.scope)
|
|
486
|
+
if provider.scope == "singleton":
|
|
487
|
+
async with self._singleton_async_lock:
|
|
488
|
+
instance = await self._aget_or_create_instance(provider, context)
|
|
489
|
+
else:
|
|
490
|
+
instance = await self._aget_or_create_instance(provider, context)
|
|
491
|
+
if self.testing:
|
|
492
|
+
instance = self._patch_test_resolver(interface, instance)
|
|
493
|
+
return cast(T, instance)
|
|
494
|
+
|
|
495
|
+
def create(self, interface: type[T], **defaults: Any) -> T:
|
|
496
|
+
"""Create an instance by interface."""
|
|
497
|
+
provider = self._get_or_register_provider(interface, None, **defaults)
|
|
498
|
+
if provider.scope == "transient":
|
|
499
|
+
instance = self._create_instance(provider, None, **defaults)
|
|
500
|
+
else:
|
|
501
|
+
context = self._get_scoped_context(provider.scope)
|
|
502
|
+
if provider.scope == "singleton":
|
|
503
|
+
with self._singleton_lock:
|
|
504
|
+
instance = self._create_instance(provider, context, **defaults)
|
|
505
|
+
else:
|
|
506
|
+
instance = self._create_instance(provider, context, **defaults)
|
|
507
|
+
return cast(T, instance)
|
|
508
|
+
|
|
509
|
+
async def acreate(self, interface: type[T], **defaults: Any) -> T:
|
|
510
|
+
"""Create an instance by interface."""
|
|
511
|
+
provider = self._get_or_register_provider(interface, None, **defaults)
|
|
512
|
+
if provider.scope == "transient":
|
|
513
|
+
instance = await self._acreate_instance(provider, None, **defaults)
|
|
514
|
+
else:
|
|
515
|
+
context = self._get_scoped_context(provider.scope)
|
|
516
|
+
if provider.scope == "singleton":
|
|
517
|
+
async with self._singleton_async_lock:
|
|
518
|
+
instance = await self._acreate_instance(
|
|
519
|
+
provider, context, **defaults
|
|
520
|
+
)
|
|
521
|
+
else:
|
|
522
|
+
instance = await self._acreate_instance(provider, context, **defaults)
|
|
523
|
+
return cast(T, instance)
|
|
524
|
+
|
|
525
|
+
def _get_or_create_instance(
|
|
526
|
+
self, provider: Provider, context: InstanceContext
|
|
527
|
+
) -> Any:
|
|
528
|
+
"""Get an instance of a dependency from the scoped context."""
|
|
529
|
+
instance = context.get(provider.interface)
|
|
530
|
+
if instance is None:
|
|
531
|
+
instance = self._create_instance(provider, context)
|
|
532
|
+
context.set(provider.interface, instance)
|
|
533
|
+
return instance
|
|
534
|
+
return instance
|
|
535
|
+
|
|
536
|
+
async def _aget_or_create_instance(
|
|
537
|
+
self, provider: Provider, context: InstanceContext
|
|
538
|
+
) -> Any:
|
|
539
|
+
"""Get an async instance of a dependency from the scoped context."""
|
|
540
|
+
instance = context.get(provider.interface)
|
|
541
|
+
if instance is None:
|
|
542
|
+
instance = await self._acreate_instance(provider, context)
|
|
543
|
+
context.set(provider.interface, instance)
|
|
544
|
+
return instance
|
|
545
|
+
return instance
|
|
546
|
+
|
|
547
|
+
def _create_instance(
|
|
548
|
+
self, provider: Provider, context: InstanceContext | None, /, **defaults: Any
|
|
549
|
+
) -> Any:
|
|
550
|
+
"""Create an instance using the provider."""
|
|
551
|
+
if provider.is_async:
|
|
552
|
+
raise TypeError(
|
|
553
|
+
f"The instance for the provider `{provider}` cannot be created in "
|
|
554
|
+
"synchronous mode."
|
|
555
|
+
)
|
|
495
556
|
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
557
|
+
provider_kwargs = self._get_provided_kwargs(provider, context, **defaults)
|
|
558
|
+
|
|
559
|
+
if provider.is_generator:
|
|
560
|
+
if context is None:
|
|
561
|
+
raise ValueError("The context is required for generator providers.")
|
|
562
|
+
cm = contextlib.contextmanager(provider.call)(**provider_kwargs)
|
|
563
|
+
return context.enter(cm)
|
|
564
|
+
|
|
565
|
+
instance = provider.call(**provider_kwargs)
|
|
566
|
+
if context is not None and provider.is_class and is_context_manager(instance):
|
|
567
|
+
context.enter(instance)
|
|
568
|
+
return instance
|
|
569
|
+
|
|
570
|
+
async def _acreate_instance(
|
|
571
|
+
self, provider: Provider, context: InstanceContext | None, /, **defaults: Any
|
|
572
|
+
) -> Any:
|
|
573
|
+
"""Create an instance asynchronously using the provider."""
|
|
574
|
+
provider_kwargs = await self._aget_provided_kwargs(
|
|
575
|
+
provider, context, **defaults
|
|
576
|
+
)
|
|
499
577
|
|
|
500
|
-
|
|
501
|
-
|
|
578
|
+
if provider.is_coroutine:
|
|
579
|
+
return await provider.call(**provider_kwargs)
|
|
502
580
|
|
|
503
|
-
|
|
504
|
-
|
|
581
|
+
if provider.is_async_generator:
|
|
582
|
+
if context is None:
|
|
583
|
+
raise ValueError(
|
|
584
|
+
"The async stack is required for async generator providers."
|
|
585
|
+
)
|
|
586
|
+
cm = contextlib.asynccontextmanager(provider.call)(**provider_kwargs)
|
|
587
|
+
return await context.aenter(cm)
|
|
588
|
+
|
|
589
|
+
if provider.is_generator:
|
|
590
|
+
|
|
591
|
+
def _create() -> Any:
|
|
592
|
+
if context is None:
|
|
593
|
+
raise ValueError("The stack is required for generator providers.")
|
|
594
|
+
cm = contextlib.contextmanager(provider.call)(**provider_kwargs)
|
|
595
|
+
return context.enter(cm)
|
|
596
|
+
|
|
597
|
+
return await run_async(_create)
|
|
598
|
+
|
|
599
|
+
instance = await run_async(provider.call, **provider_kwargs)
|
|
600
|
+
if (
|
|
601
|
+
context is not None
|
|
602
|
+
and provider.is_class
|
|
603
|
+
and is_async_context_manager(instance)
|
|
604
|
+
):
|
|
605
|
+
await context.aenter(instance)
|
|
606
|
+
return instance
|
|
607
|
+
|
|
608
|
+
def _get_provided_kwargs(
|
|
609
|
+
self, provider: Provider, context: InstanceContext | None, /, **defaults: Any
|
|
610
|
+
) -> dict[str, Any]:
|
|
611
|
+
"""Retrieve the arguments for a provider."""
|
|
612
|
+
provided_kwargs = {}
|
|
613
|
+
for parameter in provider.parameters:
|
|
614
|
+
instance = self._get_provider_instance(
|
|
615
|
+
provider, parameter, context, **defaults
|
|
616
|
+
)
|
|
617
|
+
provided_kwargs[parameter.name] = instance
|
|
618
|
+
return {**defaults, **provided_kwargs}
|
|
505
619
|
|
|
506
|
-
|
|
507
|
-
|
|
620
|
+
def _get_provider_instance(
|
|
621
|
+
self,
|
|
622
|
+
provider: Provider,
|
|
623
|
+
parameter: inspect.Parameter,
|
|
624
|
+
context: InstanceContext | None,
|
|
625
|
+
/,
|
|
626
|
+
**defaults: Any,
|
|
627
|
+
) -> Any:
|
|
628
|
+
"""Retrieve an instance of a dependency from the scoped context."""
|
|
508
629
|
|
|
509
|
-
|
|
510
|
-
|
|
630
|
+
# Try to get instance from defaults
|
|
631
|
+
if parameter.name in defaults:
|
|
632
|
+
return defaults[parameter.name]
|
|
511
633
|
|
|
512
|
-
|
|
513
|
-
|
|
634
|
+
# Try to get instance from context
|
|
635
|
+
elif context and parameter.annotation in context:
|
|
636
|
+
instance = context[parameter.annotation]
|
|
514
637
|
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
638
|
+
# Resolve new instance
|
|
639
|
+
else:
|
|
640
|
+
try:
|
|
641
|
+
instance = self._resolve_parameter(provider, parameter)
|
|
642
|
+
except LookupError:
|
|
643
|
+
if parameter.default is inspect.Parameter.empty:
|
|
644
|
+
raise
|
|
645
|
+
return parameter.default
|
|
646
|
+
|
|
647
|
+
# Wrap the instance in a proxy for testing
|
|
648
|
+
if self.testing:
|
|
649
|
+
return InstanceProxy(instance, interface=parameter.annotation)
|
|
650
|
+
return instance
|
|
651
|
+
|
|
652
|
+
async def _aget_provided_kwargs(
|
|
653
|
+
self, provider: Provider, context: InstanceContext | None, /, **defaults: Any
|
|
654
|
+
) -> dict[str, Any]:
|
|
655
|
+
"""Asynchronously retrieve the arguments for a provider."""
|
|
656
|
+
provided_kwargs = {}
|
|
657
|
+
for parameter in provider.parameters:
|
|
658
|
+
instance = await self._aget_provider_instance(
|
|
659
|
+
provider, parameter, context, **defaults
|
|
660
|
+
)
|
|
661
|
+
provided_kwargs[parameter.name] = instance
|
|
662
|
+
return {**defaults, **provided_kwargs}
|
|
520
663
|
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
664
|
+
async def _aget_provider_instance(
|
|
665
|
+
self,
|
|
666
|
+
provider: Provider,
|
|
667
|
+
parameter: inspect.Parameter,
|
|
668
|
+
context: InstanceContext | None,
|
|
669
|
+
/,
|
|
670
|
+
**defaults: Any,
|
|
671
|
+
) -> Any:
|
|
672
|
+
"""Asynchronously retrieve an instance of a dependency from the context."""
|
|
524
673
|
|
|
525
|
-
|
|
526
|
-
|
|
674
|
+
# Try to get instance from defaults
|
|
675
|
+
if parameter.name in defaults:
|
|
676
|
+
return defaults[parameter.name]
|
|
527
677
|
|
|
528
|
-
|
|
529
|
-
|
|
678
|
+
# Try to get instance from context
|
|
679
|
+
elif context and parameter.annotation in context:
|
|
680
|
+
instance = context[parameter.annotation]
|
|
530
681
|
|
|
531
|
-
|
|
532
|
-
True if the instance exists, otherwise False.
|
|
533
|
-
"""
|
|
534
|
-
try:
|
|
535
|
-
provider = self._get_or_register_provider(interface)
|
|
536
|
-
except LookupError:
|
|
537
|
-
pass
|
|
682
|
+
# Resolve new instance
|
|
538
683
|
else:
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
684
|
+
try:
|
|
685
|
+
instance = await self._aresolve_parameter(provider, parameter)
|
|
686
|
+
except LookupError:
|
|
687
|
+
if parameter.default is inspect.Parameter.empty:
|
|
688
|
+
raise
|
|
689
|
+
return parameter.default
|
|
690
|
+
|
|
691
|
+
# Wrap the instance in a proxy for testing
|
|
692
|
+
if self.testing:
|
|
693
|
+
return InstanceProxy(instance, interface=parameter.annotation)
|
|
694
|
+
return instance
|
|
695
|
+
|
|
696
|
+
def _resolve_parameter(
|
|
697
|
+
self, provider: Provider, parameter: inspect.Parameter
|
|
698
|
+
) -> Any:
|
|
699
|
+
self._validate_resolvable_parameter(parameter, call=provider.call)
|
|
700
|
+
return self.resolve(parameter.annotation)
|
|
701
|
+
|
|
702
|
+
async def _aresolve_parameter(
|
|
703
|
+
self, provider: Provider, parameter: inspect.Parameter
|
|
704
|
+
) -> Any:
|
|
705
|
+
self._validate_resolvable_parameter(parameter, call=provider.call)
|
|
706
|
+
return await self.aresolve(parameter.annotation)
|
|
707
|
+
|
|
708
|
+
def _validate_resolvable_parameter(
|
|
709
|
+
self, parameter: inspect.Parameter, call: Callable[..., Any]
|
|
710
|
+
) -> None:
|
|
711
|
+
"""Ensure that the specified interface is resolved."""
|
|
712
|
+
if parameter.annotation in self._unresolved_interfaces:
|
|
713
|
+
raise LookupError(
|
|
714
|
+
f"You are attempting to get the parameter `{parameter.name}` with the "
|
|
715
|
+
f"annotation `{get_full_qualname(parameter.annotation)}` as a "
|
|
716
|
+
f"dependency into `{get_full_qualname(call)}` which is not registered "
|
|
717
|
+
"or set in the scoped context."
|
|
718
|
+
)
|
|
543
719
|
|
|
544
|
-
def
|
|
545
|
-
"""
|
|
720
|
+
def _patch_test_resolver(self, interface: type[Any], instance: Any) -> Any:
|
|
721
|
+
"""Patch the test resolver for the instance."""
|
|
722
|
+
if interface in self._override_instances:
|
|
723
|
+
return self._override_instances[interface]
|
|
546
724
|
|
|
547
|
-
|
|
548
|
-
|
|
725
|
+
if not hasattr(instance, "__dict__") or hasattr(
|
|
726
|
+
instance, "__resolver_getter__"
|
|
727
|
+
):
|
|
728
|
+
return instance
|
|
549
729
|
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
if isinstance(scoped_context, ResourceScopedContext):
|
|
556
|
-
scoped_context.delete(interface)
|
|
730
|
+
wrapped = {
|
|
731
|
+
name: value.interface
|
|
732
|
+
for name, value in instance.__dict__.items()
|
|
733
|
+
if isinstance(value, InstanceProxy)
|
|
734
|
+
}
|
|
557
735
|
|
|
558
|
-
|
|
559
|
-
|
|
736
|
+
def __resolver_getter__(name: str) -> Any:
|
|
737
|
+
if name in wrapped:
|
|
738
|
+
_interface = wrapped[name]
|
|
739
|
+
# Resolve the dependency if it's wrapped
|
|
740
|
+
return self.resolve(_interface)
|
|
741
|
+
raise LookupError
|
|
560
742
|
|
|
561
|
-
|
|
562
|
-
|
|
743
|
+
# Attach the resolver getter to the instance
|
|
744
|
+
instance.__resolver_getter__ = __resolver_getter__
|
|
563
745
|
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
746
|
+
if not hasattr(instance.__class__, "__getattribute_patched__"):
|
|
747
|
+
|
|
748
|
+
def __getattribute__(_self: Any, name: str) -> Any:
|
|
749
|
+
# Skip the resolver getter
|
|
750
|
+
if name in {"__resolver_getter__"}:
|
|
751
|
+
return object.__getattribute__(_self, name)
|
|
752
|
+
|
|
753
|
+
if hasattr(_self, "__resolver_getter__"):
|
|
754
|
+
try:
|
|
755
|
+
return _self.__resolver_getter__(name)
|
|
756
|
+
except LookupError:
|
|
757
|
+
pass
|
|
758
|
+
|
|
759
|
+
# Fall back to default behavior
|
|
760
|
+
return object.__getattribute__(_self, name)
|
|
761
|
+
|
|
762
|
+
# Apply the patched resolver if wrapped attributes exist
|
|
763
|
+
instance.__class__.__getattribute__ = __getattribute__
|
|
764
|
+
instance.__class__.__getattribute_patched__ = True
|
|
765
|
+
|
|
766
|
+
return instance
|
|
767
|
+
|
|
768
|
+
def is_resolved(self, interface: AnyInterface) -> bool:
|
|
769
|
+
"""Check if an instance by interface exists."""
|
|
770
|
+
try:
|
|
771
|
+
provider = self._get_provider(interface)
|
|
772
|
+
except LookupError:
|
|
773
|
+
return False
|
|
774
|
+
if provider.scope == "transient":
|
|
775
|
+
return False
|
|
776
|
+
context = self._get_scoped_context(provider.scope)
|
|
777
|
+
return interface in context
|
|
778
|
+
|
|
779
|
+
def release(self, interface: AnyInterface) -> None:
|
|
780
|
+
"""Release an instance by interface."""
|
|
781
|
+
provider = self._get_provider(interface)
|
|
782
|
+
if provider.scope == "transient":
|
|
783
|
+
return None
|
|
784
|
+
context = self._get_scoped_context(provider.scope)
|
|
785
|
+
del context[interface]
|
|
786
|
+
|
|
787
|
+
def _get_scoped_context(self, scope: Scope) -> InstanceContext:
|
|
788
|
+
"""Get the instance context for the specified scope."""
|
|
567
789
|
if scope == "singleton":
|
|
568
790
|
return self._singleton_context
|
|
569
|
-
|
|
570
|
-
request_context = self._get_request_context()
|
|
571
|
-
return request_context
|
|
572
|
-
return self._transient_context
|
|
791
|
+
return self._get_request_context()
|
|
573
792
|
|
|
574
793
|
@contextlib.contextmanager
|
|
575
794
|
def override(self, interface: AnyInterface, instance: Any) -> Iterator[None]:
|
|
576
|
-
"""Override the provider for the specified interface with a specific instance.
|
|
577
|
-
|
|
578
|
-
Args:
|
|
579
|
-
interface: The interface type to override.
|
|
580
|
-
instance: The instance to use as the override.
|
|
581
|
-
|
|
582
|
-
Yields:
|
|
583
|
-
None
|
|
584
|
-
|
|
585
|
-
Raises:
|
|
586
|
-
LookupError: If the provider for the interface is not registered.
|
|
587
795
|
"""
|
|
796
|
+
Override the provider for the specified interface with a specific instance.
|
|
797
|
+
"""
|
|
798
|
+
if not self.testing:
|
|
799
|
+
raise RuntimeError(
|
|
800
|
+
"The `override` method can only be used in testing mode."
|
|
801
|
+
)
|
|
588
802
|
if not self.is_registered(interface) and self.strict:
|
|
589
803
|
raise LookupError(
|
|
590
804
|
f"The provider interface `{get_full_qualname(interface)}` "
|
|
@@ -597,166 +811,74 @@ class Container:
|
|
|
597
811
|
def provider(
|
|
598
812
|
self, *, scope: Scope, override: bool = False
|
|
599
813
|
) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
|
600
|
-
"""Decorator to register a provider function with the specified scope.
|
|
601
|
-
|
|
602
|
-
Args:
|
|
603
|
-
scope : The scope of the provider.
|
|
604
|
-
override: Whether the provider should override an existing provider
|
|
605
|
-
for the same interface. Defaults to False.
|
|
606
|
-
|
|
607
|
-
Returns:
|
|
608
|
-
The decorator function.
|
|
609
|
-
"""
|
|
814
|
+
"""Decorator to register a provider function with the specified scope."""
|
|
610
815
|
|
|
611
|
-
def decorator(
|
|
612
|
-
|
|
613
|
-
self.
|
|
614
|
-
return
|
|
816
|
+
def decorator(call: Callable[P, T]) -> Callable[P, T]:
|
|
817
|
+
provider = Provider(call=call, scope=scope)
|
|
818
|
+
self._register_provider(provider, override)
|
|
819
|
+
return call
|
|
615
820
|
|
|
616
821
|
return decorator
|
|
617
822
|
|
|
618
823
|
@overload
|
|
619
|
-
def inject(self,
|
|
824
|
+
def inject(self, func: Callable[P, T]) -> Callable[P, T]: ...
|
|
620
825
|
|
|
621
826
|
@overload
|
|
622
827
|
def inject(self) -> Callable[[Callable[P, T]], Callable[P, T]]: ...
|
|
623
828
|
|
|
624
829
|
def inject(
|
|
625
|
-
self,
|
|
626
|
-
) ->
|
|
627
|
-
|
|
628
|
-
[Callable[P, Union[T, Awaitable[T]]]],
|
|
629
|
-
Callable[P, Union[T, Awaitable[T]]],
|
|
630
|
-
],
|
|
631
|
-
Callable[P, Union[T, Awaitable[T]]],
|
|
632
|
-
]:
|
|
633
|
-
"""Decorator to inject dependencies into a callable.
|
|
634
|
-
|
|
635
|
-
Args:
|
|
636
|
-
obj: The callable object to be decorated. If None, returns
|
|
637
|
-
the decorator itself.
|
|
638
|
-
|
|
639
|
-
Returns:
|
|
640
|
-
The decorated callable object or decorator function.
|
|
641
|
-
"""
|
|
642
|
-
|
|
643
|
-
def decorator(
|
|
644
|
-
obj: Callable[P, Union[T, Awaitable[T]]],
|
|
645
|
-
) -> Callable[P, Union[T, Awaitable[T]]]:
|
|
646
|
-
injected_params = self._get_injected_params(obj)
|
|
647
|
-
|
|
648
|
-
if inspect.iscoroutinefunction(obj):
|
|
830
|
+
self, func: Callable[P, T] | None = None
|
|
831
|
+
) -> Callable[[Callable[P, T]], Callable[P, T]] | Callable[P, T]:
|
|
832
|
+
"""Decorator to inject dependencies into a callable."""
|
|
649
833
|
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
for name, annotation in injected_params.items():
|
|
653
|
-
kwargs[name] = await self.aresolve(annotation)
|
|
654
|
-
return cast(T, await obj(*args, **kwargs))
|
|
834
|
+
def decorator(call: Callable[P, T]) -> Callable[P, T]:
|
|
835
|
+
return self._inject(call)
|
|
655
836
|
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
@wraps(obj)
|
|
659
|
-
def wrapped(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
660
|
-
for name, annotation in injected_params.items():
|
|
661
|
-
kwargs[name] = self.resolve(annotation)
|
|
662
|
-
return cast(T, obj(*args, **kwargs))
|
|
663
|
-
|
|
664
|
-
return wrapped
|
|
665
|
-
|
|
666
|
-
if obj is None:
|
|
837
|
+
if func is None:
|
|
667
838
|
return decorator
|
|
668
|
-
return decorator(
|
|
839
|
+
return decorator(func)
|
|
669
840
|
|
|
670
|
-
def
|
|
671
|
-
"""
|
|
841
|
+
def _inject(self, call: Callable[P, T]) -> Callable[P, T]:
|
|
842
|
+
"""Inject dependencies into a callable."""
|
|
843
|
+
if call in self._inject_cache:
|
|
844
|
+
return cast(Callable[P, T], self._inject_cache[call])
|
|
672
845
|
|
|
673
|
-
|
|
674
|
-
obj: The callable object.
|
|
675
|
-
args: The positional arguments to pass to the object.
|
|
676
|
-
kwargs: The keyword arguments to pass to the object.
|
|
677
|
-
|
|
678
|
-
Returns:
|
|
679
|
-
The result of the callable object.
|
|
680
|
-
"""
|
|
681
|
-
return self.inject(obj)(*args, **kwargs)
|
|
682
|
-
|
|
683
|
-
def scan(
|
|
684
|
-
self,
|
|
685
|
-
/,
|
|
686
|
-
packages: Union[
|
|
687
|
-
Union[types.ModuleType, str],
|
|
688
|
-
Iterable[Union[types.ModuleType, str]],
|
|
689
|
-
],
|
|
690
|
-
*,
|
|
691
|
-
tags: Optional[Iterable[str]] = None,
|
|
692
|
-
) -> None:
|
|
693
|
-
"""Scan packages or modules for decorated members and inject dependencies.
|
|
694
|
-
|
|
695
|
-
Args:
|
|
696
|
-
packages: A single package or module to scan,
|
|
697
|
-
or an iterable of packages or modules to scan.
|
|
698
|
-
tags: Optional list of tags to filter the scanned members. Only members
|
|
699
|
-
with at least one matching tag will be scanned. Defaults to None.
|
|
700
|
-
"""
|
|
701
|
-
self._scanner.scan(packages, tags=tags)
|
|
846
|
+
injected_params = self._get_injected_params(call)
|
|
702
847
|
|
|
703
|
-
|
|
704
|
-
"""Retrieve the provider return annotation from a callable object.
|
|
848
|
+
if inspect.iscoroutinefunction(call):
|
|
705
849
|
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
Raises:
|
|
713
|
-
TypeError: If the provider return annotation is missing or invalid.
|
|
714
|
-
"""
|
|
715
|
-
annotation = get_signature(obj).return_annotation
|
|
716
|
-
|
|
717
|
-
if annotation is inspect._empty: # noqa
|
|
718
|
-
raise TypeError(
|
|
719
|
-
f"Missing `{get_full_qualname(obj)}` provider return annotation."
|
|
720
|
-
)
|
|
850
|
+
@functools.wraps(call)
|
|
851
|
+
async def awrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
852
|
+
for name, annotation in injected_params.items():
|
|
853
|
+
kwargs[name] = await self.aresolve(annotation)
|
|
854
|
+
return cast(T, await call(*args, **kwargs))
|
|
721
855
|
|
|
722
|
-
|
|
723
|
-
args = get_args(annotation)
|
|
856
|
+
self._inject_cache[call] = awrapper
|
|
724
857
|
|
|
725
|
-
|
|
726
|
-
if origin in (list, dict, tuple, Annotated):
|
|
727
|
-
if args:
|
|
728
|
-
return annotation
|
|
729
|
-
else:
|
|
730
|
-
raise TypeError(
|
|
731
|
-
f"Cannot use `{get_full_qualname(obj)}` generic type annotation "
|
|
732
|
-
"without actual type."
|
|
733
|
-
)
|
|
858
|
+
return awrapper # type: ignore[return-value]
|
|
734
859
|
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
860
|
+
@functools.wraps(call)
|
|
861
|
+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
862
|
+
for name, annotation in injected_params.items():
|
|
863
|
+
kwargs[name] = self.resolve(annotation)
|
|
864
|
+
return call(*args, **kwargs)
|
|
739
865
|
|
|
740
|
-
|
|
741
|
-
"""Get the injected parameters of a callable object.
|
|
866
|
+
self._inject_cache[call] = wrapper
|
|
742
867
|
|
|
743
|
-
|
|
744
|
-
obj: The callable object.
|
|
868
|
+
return wrapper
|
|
745
869
|
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
of the injected parameters.
|
|
749
|
-
"""
|
|
870
|
+
def _get_injected_params(self, call: Callable[..., Any]) -> dict[str, Any]:
|
|
871
|
+
"""Get the injected parameters of a callable object."""
|
|
750
872
|
injected_params = {}
|
|
751
|
-
for parameter in
|
|
752
|
-
if not
|
|
873
|
+
for parameter in get_typed_parameters(call):
|
|
874
|
+
if not is_marker(parameter.default):
|
|
753
875
|
continue
|
|
754
876
|
try:
|
|
755
|
-
self._validate_injected_parameter(
|
|
877
|
+
self._validate_injected_parameter(call, parameter)
|
|
756
878
|
except LookupError as exc:
|
|
757
879
|
if not self.strict:
|
|
758
|
-
logger.debug(
|
|
759
|
-
f"Cannot validate the `{get_full_qualname(
|
|
880
|
+
self.logger.debug(
|
|
881
|
+
f"Cannot validate the `{get_full_qualname(call)}` parameter "
|
|
760
882
|
f"`{parameter.name}` with an annotation of "
|
|
761
883
|
f"`{get_full_qualname(parameter.annotation)} due to being "
|
|
762
884
|
"in non-strict mode. It will be validated at the first call."
|
|
@@ -767,65 +889,183 @@ class Container:
|
|
|
767
889
|
return injected_params
|
|
768
890
|
|
|
769
891
|
def _validate_injected_parameter(
|
|
770
|
-
self,
|
|
892
|
+
self, call: Callable[..., Any], parameter: inspect.Parameter
|
|
771
893
|
) -> None:
|
|
772
|
-
"""Validate an injected parameter.
|
|
773
|
-
|
|
774
|
-
Args:
|
|
775
|
-
obj: The callable object.
|
|
776
|
-
parameter: The parameter to validate.
|
|
777
|
-
|
|
778
|
-
Raises:
|
|
779
|
-
TypeError: If the parameter annotation is missing or an unknown dependency.
|
|
780
|
-
"""
|
|
781
|
-
if parameter.annotation is inspect._empty: # noqa
|
|
894
|
+
"""Validate an injected parameter."""
|
|
895
|
+
if parameter.annotation is inspect.Parameter.empty:
|
|
782
896
|
raise TypeError(
|
|
783
|
-
f"Missing `{get_full_qualname(
|
|
897
|
+
f"Missing `{get_full_qualname(call)}` parameter "
|
|
784
898
|
f"`{parameter.name}` annotation."
|
|
785
899
|
)
|
|
786
900
|
|
|
787
901
|
if not self.is_registered(parameter.annotation):
|
|
788
902
|
raise LookupError(
|
|
789
|
-
f"`{get_full_qualname(
|
|
903
|
+
f"`{get_full_qualname(call)}` has an unknown dependency parameter "
|
|
790
904
|
f"`{parameter.name}` with an annotation of "
|
|
791
905
|
f"`{get_full_qualname(parameter.annotation)}`."
|
|
792
906
|
)
|
|
793
907
|
|
|
908
|
+
def run(self, func: Callable[P, T], /, *args: P.args, **kwargs: P.kwargs) -> T:
|
|
909
|
+
"""Run the given function with injected dependencies."""
|
|
910
|
+
return self._inject(func)(*args, **kwargs)
|
|
794
911
|
|
|
795
|
-
def
|
|
796
|
-
|
|
912
|
+
def scan(
|
|
913
|
+
self,
|
|
914
|
+
/,
|
|
915
|
+
packages: ModuleType | str | Iterable[ModuleType | str],
|
|
916
|
+
*,
|
|
917
|
+
tags: Iterable[str] | None = None,
|
|
918
|
+
) -> None:
|
|
919
|
+
"""Scan packages or modules for decorated members and inject dependencies."""
|
|
920
|
+
dependencies: list[Dependency] = []
|
|
921
|
+
|
|
922
|
+
if isinstance(packages, Iterable) and not isinstance(packages, str):
|
|
923
|
+
scan_packages: Iterable[ModuleType | str] = packages
|
|
924
|
+
else:
|
|
925
|
+
scan_packages = cast(Iterable[Union[ModuleType, str]], [packages])
|
|
926
|
+
|
|
927
|
+
for package in scan_packages:
|
|
928
|
+
dependencies.extend(self._scan_package(package, tags=tags))
|
|
929
|
+
|
|
930
|
+
for dependency in dependencies:
|
|
931
|
+
decorator = self.inject()(dependency.member)
|
|
932
|
+
setattr(dependency.module, dependency.member.__name__, decorator)
|
|
933
|
+
|
|
934
|
+
def _scan_package(
|
|
935
|
+
self,
|
|
936
|
+
package: ModuleType | str,
|
|
937
|
+
*,
|
|
938
|
+
tags: Iterable[str] | None = None,
|
|
939
|
+
) -> list[Dependency]:
|
|
940
|
+
"""Scan a package or module for decorated members."""
|
|
941
|
+
tags = tags or []
|
|
942
|
+
if isinstance(package, str):
|
|
943
|
+
package = importlib.import_module(package)
|
|
944
|
+
|
|
945
|
+
package_path = getattr(package, "__path__", None)
|
|
946
|
+
|
|
947
|
+
if not package_path:
|
|
948
|
+
return self._scan_module(package, tags=tags)
|
|
949
|
+
|
|
950
|
+
dependencies: list[Dependency] = []
|
|
951
|
+
|
|
952
|
+
for module_info in pkgutil.walk_packages(
|
|
953
|
+
path=package_path, prefix=package.__name__ + "."
|
|
954
|
+
):
|
|
955
|
+
module = importlib.import_module(module_info.name)
|
|
956
|
+
dependencies.extend(self._scan_module(module, tags=tags))
|
|
957
|
+
|
|
958
|
+
return dependencies
|
|
959
|
+
|
|
960
|
+
def _scan_module(
|
|
961
|
+
self, module: ModuleType, *, tags: Iterable[str]
|
|
962
|
+
) -> list[Dependency]:
|
|
963
|
+
"""Scan a module for decorated members."""
|
|
964
|
+
dependencies: list[Dependency] = []
|
|
965
|
+
|
|
966
|
+
for _, member in inspect.getmembers(module):
|
|
967
|
+
if getattr(member, "__module__", None) != module.__name__ or not callable(
|
|
968
|
+
member
|
|
969
|
+
):
|
|
970
|
+
continue
|
|
971
|
+
|
|
972
|
+
decorator_args: InjectableDecoratorArgs = getattr(
|
|
973
|
+
member,
|
|
974
|
+
"__injectable__",
|
|
975
|
+
InjectableDecoratorArgs(wrapped=False, tags=[]),
|
|
976
|
+
)
|
|
977
|
+
|
|
978
|
+
if tags and (
|
|
979
|
+
decorator_args.tags
|
|
980
|
+
and not set(decorator_args.tags).intersection(tags)
|
|
981
|
+
or not decorator_args.tags
|
|
982
|
+
):
|
|
983
|
+
continue
|
|
984
|
+
|
|
985
|
+
if decorator_args.wrapped:
|
|
986
|
+
dependencies.append(
|
|
987
|
+
self._create_dependency(member=member, module=module)
|
|
988
|
+
)
|
|
989
|
+
continue
|
|
990
|
+
|
|
991
|
+
# Get by Marker
|
|
992
|
+
for parameter in get_typed_parameters(member):
|
|
993
|
+
if is_marker(parameter.default):
|
|
994
|
+
dependencies.append(
|
|
995
|
+
self._create_dependency(member=member, module=module)
|
|
996
|
+
)
|
|
997
|
+
continue
|
|
998
|
+
|
|
999
|
+
return dependencies
|
|
1000
|
+
|
|
1001
|
+
def _create_dependency(self, member: Any, module: ModuleType) -> Dependency:
|
|
1002
|
+
"""Create a `Dependency` object from the scanned member and module."""
|
|
1003
|
+
if hasattr(member, "__wrapped__"):
|
|
1004
|
+
member = member.__wrapped__
|
|
1005
|
+
return Dependency(member=member, module=module)
|
|
797
1006
|
|
|
798
|
-
Args:
|
|
799
|
-
target: The target class to be decorated.
|
|
800
1007
|
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
"""
|
|
1008
|
+
def transient(target: T) -> T:
|
|
1009
|
+
"""Decorator for marking a class as transient scope."""
|
|
804
1010
|
setattr(target, "__scope__", "transient")
|
|
805
1011
|
return target
|
|
806
1012
|
|
|
807
1013
|
|
|
808
1014
|
def request(target: T) -> T:
|
|
809
|
-
"""Decorator for marking a class as request scope.
|
|
810
|
-
|
|
811
|
-
Args:
|
|
812
|
-
target: The target class to be decorated.
|
|
813
|
-
|
|
814
|
-
Returns:
|
|
815
|
-
The decorated target class.
|
|
816
|
-
"""
|
|
1015
|
+
"""Decorator for marking a class as request scope."""
|
|
817
1016
|
setattr(target, "__scope__", "request")
|
|
818
1017
|
return target
|
|
819
1018
|
|
|
820
1019
|
|
|
821
1020
|
def singleton(target: T) -> T:
|
|
822
|
-
"""Decorator for marking a class as singleton scope.
|
|
823
|
-
|
|
824
|
-
Args:
|
|
825
|
-
target: The target class to be decorated.
|
|
826
|
-
|
|
827
|
-
Returns:
|
|
828
|
-
The decorated target class.
|
|
829
|
-
"""
|
|
1021
|
+
"""Decorator for marking a class as singleton scope."""
|
|
830
1022
|
setattr(target, "__scope__", "singleton")
|
|
831
1023
|
return target
|
|
1024
|
+
|
|
1025
|
+
|
|
1026
|
+
def provider(
|
|
1027
|
+
*, scope: Scope, override: bool = False
|
|
1028
|
+
) -> Callable[[Callable[Concatenate[M, P], T]], Callable[Concatenate[M, P], T]]:
|
|
1029
|
+
"""Decorator for marking a function or method as a provider in a AnyDI module."""
|
|
1030
|
+
|
|
1031
|
+
def decorator(
|
|
1032
|
+
target: Callable[Concatenate[M, P], T],
|
|
1033
|
+
) -> Callable[Concatenate[M, P], T]:
|
|
1034
|
+
setattr(
|
|
1035
|
+
target,
|
|
1036
|
+
"__provider__",
|
|
1037
|
+
ProviderDecoratorArgs(scope=scope, override=override),
|
|
1038
|
+
)
|
|
1039
|
+
return target
|
|
1040
|
+
|
|
1041
|
+
return decorator
|
|
1042
|
+
|
|
1043
|
+
|
|
1044
|
+
@overload
|
|
1045
|
+
def injectable(func: Callable[P, T]) -> Callable[P, T]: ...
|
|
1046
|
+
|
|
1047
|
+
|
|
1048
|
+
@overload
|
|
1049
|
+
def injectable(
|
|
1050
|
+
*, tags: Iterable[str] | None = None
|
|
1051
|
+
) -> Callable[[Callable[P, T]], Callable[P, T]]: ...
|
|
1052
|
+
|
|
1053
|
+
|
|
1054
|
+
def injectable(
|
|
1055
|
+
func: Callable[P, T] | None = None,
|
|
1056
|
+
tags: Iterable[str] | None = None,
|
|
1057
|
+
) -> Callable[[Callable[P, T]], Callable[P, T]] | Callable[P, T]:
|
|
1058
|
+
"""Decorator for marking a function or method as requiring dependency injection."""
|
|
1059
|
+
|
|
1060
|
+
def decorator(inner: Callable[P, T]) -> Callable[P, T]:
|
|
1061
|
+
setattr(
|
|
1062
|
+
inner,
|
|
1063
|
+
"__injectable__",
|
|
1064
|
+
InjectableDecoratorArgs(wrapped=True, tags=tags),
|
|
1065
|
+
)
|
|
1066
|
+
return inner
|
|
1067
|
+
|
|
1068
|
+
if func is None:
|
|
1069
|
+
return decorator
|
|
1070
|
+
|
|
1071
|
+
return decorator(func)
|