anydi 0.31.0__py3-none-any.whl → 0.32.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- anydi/__init__.py +2 -8
- anydi/_container.py +122 -515
- anydi/_context.py +99 -188
- anydi/_injector.py +94 -0
- anydi/_module.py +4 -41
- anydi/_provider.py +187 -0
- anydi/_scanner.py +14 -68
- anydi/_types.py +4 -103
- anydi/_utils.py +22 -71
- anydi/ext/_utils.py +5 -11
- anydi/ext/django/_settings.py +1 -1
- anydi/ext/django/_utils.py +2 -2
- anydi/ext/django/apps.py +1 -1
- anydi/ext/django/middleware.py +11 -9
- anydi/ext/fastapi.py +4 -24
- anydi/ext/faststream.py +0 -7
- anydi/ext/pydantic_settings.py +2 -2
- anydi/ext/pytest_plugin.py +3 -2
- anydi/ext/starlette/middleware.py +2 -16
- {anydi-0.31.0.dist-info → anydi-0.32.1.dist-info}/METADATA +1 -1
- anydi-0.32.1.dist-info/RECORD +33 -0
- anydi-0.31.0.dist-info/RECORD +0 -31
- {anydi-0.31.0.dist-info → anydi-0.32.1.dist-info}/LICENSE +0 -0
- {anydi-0.31.0.dist-info → anydi-0.32.1.dist-info}/WHEEL +0 -0
- {anydi-0.31.0.dist-info → anydi-0.32.1.dist-info}/entry_points.txt +0 -0
anydi/_context.py
CHANGED
|
@@ -2,19 +2,19 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import abc
|
|
4
4
|
import contextlib
|
|
5
|
+
import inspect
|
|
5
6
|
from types import TracebackType
|
|
6
|
-
from typing import TYPE_CHECKING, Any,
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Callable, ClassVar
|
|
7
8
|
|
|
8
9
|
from typing_extensions import Self, final
|
|
9
10
|
|
|
10
|
-
from .
|
|
11
|
-
from .
|
|
11
|
+
from ._provider import CallableKind, Provider
|
|
12
|
+
from ._types import AnyInterface, Scope, is_event_type
|
|
13
|
+
from ._utils import get_full_qualname, run_async
|
|
12
14
|
|
|
13
15
|
if TYPE_CHECKING:
|
|
14
16
|
from ._container import Container
|
|
15
17
|
|
|
16
|
-
T = TypeVar("T")
|
|
17
|
-
|
|
18
18
|
|
|
19
19
|
class ScopedContext(abc.ABC):
|
|
20
20
|
"""ScopedContext base class."""
|
|
@@ -23,104 +23,95 @@ class ScopedContext(abc.ABC):
|
|
|
23
23
|
|
|
24
24
|
def __init__(self, container: Container) -> None:
|
|
25
25
|
self.container = container
|
|
26
|
+
self._instances: dict[Any, Any] = {}
|
|
26
27
|
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
Args:
|
|
32
|
-
interface: The interface of the dependency.
|
|
33
|
-
provider: The provider for the instance.
|
|
34
|
-
|
|
35
|
-
Returns:
|
|
36
|
-
An instance of the dependency.
|
|
37
|
-
"""
|
|
28
|
+
def set(self, interface: AnyInterface, instance: Any) -> None:
|
|
29
|
+
"""Set an instance of a dependency in the scoped context."""
|
|
30
|
+
self._instances[interface] = instance
|
|
38
31
|
|
|
39
32
|
@abc.abstractmethod
|
|
40
|
-
|
|
41
|
-
"""Get an
|
|
42
|
-
|
|
43
|
-
Args:
|
|
44
|
-
interface: The interface of the dependency.
|
|
45
|
-
provider: The provider for the instance.
|
|
33
|
+
def get(self, provider: Provider) -> Any:
|
|
34
|
+
"""Get an instance of a dependency from the scoped context."""
|
|
46
35
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
"""
|
|
36
|
+
@abc.abstractmethod
|
|
37
|
+
async def aget(self, provider: Provider) -> Any:
|
|
38
|
+
"""Get an async instance of a dependency from the scoped context."""
|
|
50
39
|
|
|
51
40
|
def _create_instance(self, provider: Provider) -> Any:
|
|
52
|
-
"""Create an instance using the provider.
|
|
53
|
-
|
|
54
|
-
Args:
|
|
55
|
-
provider: The provider for the instance.
|
|
56
|
-
|
|
57
|
-
Returns:
|
|
58
|
-
The created instance.
|
|
59
|
-
|
|
60
|
-
Raises:
|
|
61
|
-
TypeError: If the provider's instance is a coroutine provider
|
|
62
|
-
and synchronous mode is used.
|
|
63
|
-
"""
|
|
64
|
-
if provider.is_coroutine:
|
|
41
|
+
"""Create an instance using the provider."""
|
|
42
|
+
if provider.kind == CallableKind.COROUTINE:
|
|
65
43
|
raise TypeError(
|
|
66
44
|
f"The instance for the coroutine provider `{provider}` cannot be "
|
|
67
45
|
"created in synchronous mode."
|
|
68
46
|
)
|
|
69
|
-
args, kwargs = self.
|
|
70
|
-
return provider.
|
|
47
|
+
args, kwargs = self._get_provider_params(provider)
|
|
48
|
+
return provider.call(*args, **kwargs)
|
|
71
49
|
|
|
72
50
|
async def _acreate_instance(self, provider: Provider) -> Any:
|
|
73
|
-
"""Create an instance asynchronously using the provider.
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
51
|
+
"""Create an instance asynchronously using the provider."""
|
|
52
|
+
args, kwargs = await self._aget_provider_params(provider)
|
|
53
|
+
if provider.kind == CallableKind.COROUTINE:
|
|
54
|
+
return await provider.call(*args, **kwargs)
|
|
55
|
+
return await run_async(provider.call, *args, **kwargs)
|
|
56
|
+
|
|
57
|
+
def _resolve_parameter(
|
|
58
|
+
self, provider: Provider, parameter: inspect.Parameter
|
|
59
|
+
) -> Any:
|
|
60
|
+
self._validate_resolvable_parameter(parameter, call=provider.call)
|
|
61
|
+
return self.container.resolve(parameter.annotation)
|
|
62
|
+
|
|
63
|
+
async def _aresolve_parameter(
|
|
64
|
+
self, provider: Provider, parameter: inspect.Parameter
|
|
65
|
+
) -> Any:
|
|
66
|
+
self._validate_resolvable_parameter(parameter, call=provider.call)
|
|
67
|
+
return await self.container.aresolve(parameter.annotation)
|
|
68
|
+
|
|
69
|
+
def _validate_resolvable_parameter(
|
|
70
|
+
self, parameter: inspect.Parameter, call: Callable[..., Any]
|
|
71
|
+
) -> None:
|
|
72
|
+
"""Ensure that the specified interface is resolved."""
|
|
73
|
+
if parameter.annotation in self.container._unresolved_interfaces: # noqa
|
|
74
|
+
raise LookupError(
|
|
75
|
+
f"You are attempting to get the parameter `{parameter.name}` with the "
|
|
76
|
+
f"annotation `{get_full_qualname(parameter.annotation)}` as a "
|
|
77
|
+
f"dependency into `{get_full_qualname(call)}` which is not registered "
|
|
78
|
+
"or set in the scoped context."
|
|
79
|
+
)
|
|
89
80
|
|
|
90
|
-
def
|
|
81
|
+
def _get_provider_params(
|
|
91
82
|
self, provider: Provider
|
|
92
83
|
) -> tuple[list[Any], dict[str, Any]]:
|
|
93
|
-
"""Retrieve the arguments for a provider.
|
|
84
|
+
"""Retrieve the arguments for a provider."""
|
|
85
|
+
args: list[Any] = []
|
|
86
|
+
kwargs: dict[str, Any] = {}
|
|
94
87
|
|
|
95
|
-
Args:
|
|
96
|
-
provider: The provider object.
|
|
97
|
-
|
|
98
|
-
Returns:
|
|
99
|
-
The arguments for the provider.
|
|
100
|
-
"""
|
|
101
|
-
args, kwargs = [], {}
|
|
102
88
|
for parameter in provider.parameters:
|
|
103
|
-
|
|
89
|
+
if parameter.annotation in self.container._override_instances: # noqa
|
|
90
|
+
instance = self.container._override_instances[parameter.annotation] # noqa
|
|
91
|
+
elif parameter.annotation in self._instances:
|
|
92
|
+
instance = self._instances[parameter.annotation]
|
|
93
|
+
else:
|
|
94
|
+
instance = self._resolve_parameter(provider, parameter)
|
|
104
95
|
if parameter.kind == parameter.POSITIONAL_ONLY:
|
|
105
96
|
args.append(instance)
|
|
106
97
|
else:
|
|
107
98
|
kwargs[parameter.name] = instance
|
|
108
99
|
return args, kwargs
|
|
109
100
|
|
|
110
|
-
async def
|
|
101
|
+
async def _aget_provider_params(
|
|
111
102
|
self, provider: Provider
|
|
112
103
|
) -> tuple[list[Any], dict[str, Any]]:
|
|
113
|
-
"""Asynchronously retrieve the arguments for a provider.
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
provider: The provider object.
|
|
104
|
+
"""Asynchronously retrieve the arguments for a provider."""
|
|
105
|
+
args: list[Any] = []
|
|
106
|
+
kwargs: dict[str, Any] = {}
|
|
117
107
|
|
|
118
|
-
Returns:
|
|
119
|
-
The arguments for the provider.
|
|
120
|
-
"""
|
|
121
|
-
args, kwargs = [], {}
|
|
122
108
|
for parameter in provider.parameters:
|
|
123
|
-
|
|
109
|
+
if parameter.annotation in self.container._override_instances: # noqa
|
|
110
|
+
instance = self.container._override_instances[parameter.annotation] # noqa
|
|
111
|
+
elif parameter.annotation in self._instances:
|
|
112
|
+
instance = self._instances[parameter.annotation]
|
|
113
|
+
else:
|
|
114
|
+
instance = await self._aresolve_parameter(provider, parameter)
|
|
124
115
|
if parameter.kind == parameter.POSITIONAL_ONLY:
|
|
125
116
|
args.append(instance)
|
|
126
117
|
else:
|
|
@@ -134,25 +125,16 @@ class ResourceScopedContext(ScopedContext):
|
|
|
134
125
|
def __init__(self, container: Container) -> None:
|
|
135
126
|
"""Initialize the ScopedContext."""
|
|
136
127
|
super().__init__(container)
|
|
137
|
-
self._instances: dict[type[Any], Any] = {}
|
|
138
128
|
self._stack = contextlib.ExitStack()
|
|
139
129
|
self._async_stack = contextlib.AsyncExitStack()
|
|
140
130
|
|
|
141
|
-
def get(self,
|
|
142
|
-
"""Get an instance of a dependency from the scoped context.
|
|
143
|
-
|
|
144
|
-
Args:
|
|
145
|
-
interface: The interface of the dependency.
|
|
146
|
-
provider: The provider for the instance.
|
|
147
|
-
|
|
148
|
-
Returns:
|
|
149
|
-
An instance of the dependency.
|
|
150
|
-
"""
|
|
151
|
-
instance = self._instances.get(interface)
|
|
131
|
+
def get(self, provider: Provider) -> Any:
|
|
132
|
+
"""Get an instance of a dependency from the scoped context."""
|
|
133
|
+
instance = self._instances.get(provider.interface)
|
|
152
134
|
if instance is None:
|
|
153
|
-
if provider.
|
|
135
|
+
if provider.kind == CallableKind.GENERATOR:
|
|
154
136
|
instance = self._create_resource(provider)
|
|
155
|
-
elif provider.
|
|
137
|
+
elif provider.kind == CallableKind.ASYNC_GENERATOR:
|
|
156
138
|
raise TypeError(
|
|
157
139
|
f"The provider `{provider}` cannot be started in synchronous mode "
|
|
158
140
|
"because it is an asynchronous provider. Please start the provider "
|
|
@@ -160,39 +142,24 @@ class ResourceScopedContext(ScopedContext):
|
|
|
160
142
|
)
|
|
161
143
|
else:
|
|
162
144
|
instance = self._create_instance(provider)
|
|
163
|
-
self._instances[interface] = instance
|
|
164
|
-
return
|
|
165
|
-
|
|
166
|
-
async def aget(self, interface: Interface[T], provider: Provider) -> T:
|
|
167
|
-
"""Get an async instance of a dependency from the scoped context.
|
|
168
|
-
|
|
169
|
-
Args:
|
|
170
|
-
interface: The interface of the dependency.
|
|
171
|
-
provider: The provider for the instance.
|
|
145
|
+
self._instances[provider.interface] = instance
|
|
146
|
+
return instance
|
|
172
147
|
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
instance = self._instances.get(interface)
|
|
148
|
+
async def aget(self, provider: Provider) -> Any:
|
|
149
|
+
"""Get an async instance of a dependency from the scoped context."""
|
|
150
|
+
instance = self._instances.get(provider.interface)
|
|
177
151
|
if instance is None:
|
|
178
|
-
if provider.
|
|
152
|
+
if provider.kind == CallableKind.GENERATOR:
|
|
179
153
|
instance = await run_async(self._create_resource, provider)
|
|
180
|
-
elif provider.
|
|
154
|
+
elif provider.kind == CallableKind.ASYNC_GENERATOR:
|
|
181
155
|
instance = await self._acreate_resource(provider)
|
|
182
156
|
else:
|
|
183
157
|
instance = await self._acreate_instance(provider)
|
|
184
|
-
self._instances[interface] = instance
|
|
185
|
-
return
|
|
158
|
+
self._instances[provider.interface] = instance
|
|
159
|
+
return instance
|
|
186
160
|
|
|
187
161
|
def has(self, interface: AnyInterface) -> bool:
|
|
188
|
-
"""Check if the scoped context has an instance of the dependency.
|
|
189
|
-
|
|
190
|
-
Args:
|
|
191
|
-
interface: The interface of the dependency.
|
|
192
|
-
|
|
193
|
-
Returns:
|
|
194
|
-
Whether the scoped context has an instance of the dependency.
|
|
195
|
-
"""
|
|
162
|
+
"""Check if the scoped context has an instance of the dependency."""
|
|
196
163
|
return interface in self._instances
|
|
197
164
|
|
|
198
165
|
def _create_instance(self, provider: Provider) -> Any:
|
|
@@ -204,16 +171,9 @@ class ResourceScopedContext(ScopedContext):
|
|
|
204
171
|
return instance
|
|
205
172
|
|
|
206
173
|
def _create_resource(self, provider: Provider) -> Any:
|
|
207
|
-
"""Create a resource using the provider.
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
provider: The provider for the resource.
|
|
211
|
-
|
|
212
|
-
Returns:
|
|
213
|
-
The created resource.
|
|
214
|
-
"""
|
|
215
|
-
args, kwargs = self._get_provider_arguments(provider)
|
|
216
|
-
cm = contextlib.contextmanager(provider.obj)(*args, **kwargs)
|
|
174
|
+
"""Create a resource using the provider."""
|
|
175
|
+
args, kwargs = self._get_provider_params(provider)
|
|
176
|
+
cm = contextlib.contextmanager(provider.call)(*args, **kwargs)
|
|
217
177
|
return self._stack.enter_context(cm)
|
|
218
178
|
|
|
219
179
|
async def _acreate_instance(self, provider: Provider) -> Any:
|
|
@@ -225,32 +185,17 @@ class ResourceScopedContext(ScopedContext):
|
|
|
225
185
|
return instance
|
|
226
186
|
|
|
227
187
|
async def _acreate_resource(self, provider: Provider) -> Any:
|
|
228
|
-
"""Create a resource asynchronously using the provider.
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
provider: The provider for the resource.
|
|
232
|
-
|
|
233
|
-
Returns:
|
|
234
|
-
The created resource.
|
|
235
|
-
"""
|
|
236
|
-
args, kwargs = await self._aget_provider_arguments(provider)
|
|
237
|
-
cm = contextlib.asynccontextmanager(provider.obj)(*args, **kwargs)
|
|
188
|
+
"""Create a resource asynchronously using the provider."""
|
|
189
|
+
args, kwargs = await self._aget_provider_params(provider)
|
|
190
|
+
cm = contextlib.asynccontextmanager(provider.call)(*args, **kwargs)
|
|
238
191
|
return await self._async_stack.enter_async_context(cm)
|
|
239
192
|
|
|
240
193
|
def delete(self, interface: AnyInterface) -> None:
|
|
241
|
-
"""Delete a dependency instance from the scoped context.
|
|
242
|
-
|
|
243
|
-
Args:
|
|
244
|
-
interface: The interface of the dependency.
|
|
245
|
-
"""
|
|
194
|
+
"""Delete a dependency instance from the scoped context."""
|
|
246
195
|
self._instances.pop(interface, None)
|
|
247
196
|
|
|
248
197
|
def __enter__(self) -> Self:
|
|
249
|
-
"""Enter the context.
|
|
250
|
-
|
|
251
|
-
Returns:
|
|
252
|
-
The scoped context.
|
|
253
|
-
"""
|
|
198
|
+
"""Enter the context."""
|
|
254
199
|
self.start()
|
|
255
200
|
return self
|
|
256
201
|
|
|
@@ -260,13 +205,7 @@ class ResourceScopedContext(ScopedContext):
|
|
|
260
205
|
exc_val: BaseException | None,
|
|
261
206
|
exc_tb: TracebackType | None,
|
|
262
207
|
) -> bool:
|
|
263
|
-
"""Exit the context.
|
|
264
|
-
|
|
265
|
-
Args:
|
|
266
|
-
exc_type: The type of the exception, if any.
|
|
267
|
-
exc_val: The exception instance, if any.
|
|
268
|
-
exc_tb: The traceback, if any.
|
|
269
|
-
"""
|
|
208
|
+
"""Exit the context."""
|
|
270
209
|
return self._stack.__exit__(exc_type, exc_val, exc_tb) # type: ignore[return-value]
|
|
271
210
|
|
|
272
211
|
@abc.abstractmethod
|
|
@@ -280,11 +219,7 @@ class ResourceScopedContext(ScopedContext):
|
|
|
280
219
|
self._stack.__exit__(None, None, None)
|
|
281
220
|
|
|
282
221
|
async def __aenter__(self) -> Self:
|
|
283
|
-
"""Enter the context asynchronously.
|
|
284
|
-
|
|
285
|
-
Returns:
|
|
286
|
-
The scoped context.
|
|
287
|
-
"""
|
|
222
|
+
"""Enter the context asynchronously."""
|
|
288
223
|
await self.astart()
|
|
289
224
|
return self
|
|
290
225
|
|
|
@@ -294,13 +229,7 @@ class ResourceScopedContext(ScopedContext):
|
|
|
294
229
|
exc_val: BaseException | None,
|
|
295
230
|
exc_tb: TracebackType | None,
|
|
296
231
|
) -> bool:
|
|
297
|
-
"""Exit the context asynchronously.
|
|
298
|
-
|
|
299
|
-
Args:
|
|
300
|
-
exc_type: The type of the exception, if any.
|
|
301
|
-
exc_val: The exception instance, if any.
|
|
302
|
-
exc_tb: The traceback, if any.
|
|
303
|
-
"""
|
|
232
|
+
"""Exit the context asynchronously."""
|
|
304
233
|
return await run_async(
|
|
305
234
|
self.__exit__, exc_type, exc_val, exc_tb
|
|
306
235
|
) or await self._async_stack.__aexit__(exc_type, exc_val, exc_tb)
|
|
@@ -358,28 +287,10 @@ class TransientContext(ScopedContext):
|
|
|
358
287
|
|
|
359
288
|
scope = "transient"
|
|
360
289
|
|
|
361
|
-
def get(self,
|
|
362
|
-
"""Get an instance of a dependency from the transient context.
|
|
363
|
-
|
|
364
|
-
Args:
|
|
365
|
-
interface: The interface of the dependency.
|
|
366
|
-
provider: The provider for the instance.
|
|
367
|
-
|
|
368
|
-
Returns:
|
|
369
|
-
An instance of the dependency.
|
|
370
|
-
"""
|
|
371
|
-
instance = self._create_instance(provider)
|
|
372
|
-
return cast(T, instance)
|
|
373
|
-
|
|
374
|
-
async def aget(self, interface: Interface[T], provider: Provider) -> T:
|
|
375
|
-
"""Get an async instance of a dependency from the transient context.
|
|
376
|
-
|
|
377
|
-
Args:
|
|
378
|
-
interface: The interface of the dependency.
|
|
379
|
-
provider: The provider for the instance.
|
|
290
|
+
def get(self, provider: Provider) -> Any:
|
|
291
|
+
"""Get an instance of a dependency from the transient context."""
|
|
292
|
+
return self._create_instance(provider)
|
|
380
293
|
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
instance = await self._acreate_instance(provider)
|
|
385
|
-
return cast(T, instance)
|
|
294
|
+
async def aget(self, provider: Provider) -> Any:
|
|
295
|
+
"""Get an async instance of a dependency from the transient context."""
|
|
296
|
+
return await self._acreate_instance(provider)
|
anydi/_injector.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
from collections.abc import Awaitable
|
|
5
|
+
from functools import wraps
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, cast
|
|
7
|
+
|
|
8
|
+
from typing_extensions import ParamSpec
|
|
9
|
+
|
|
10
|
+
from ._logger import logger
|
|
11
|
+
from ._types import is_marker
|
|
12
|
+
from ._utils import get_full_qualname, get_typed_parameters
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from ._container import Container
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
T = TypeVar("T", bound=Any)
|
|
19
|
+
P = ParamSpec("P")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class Injector:
|
|
23
|
+
def __init__(self, container: Container) -> None:
|
|
24
|
+
self.container = container
|
|
25
|
+
|
|
26
|
+
def inject(
|
|
27
|
+
self,
|
|
28
|
+
call: Callable[P, T | Awaitable[T]],
|
|
29
|
+
) -> Callable[P, T | Awaitable[T]]:
|
|
30
|
+
# Check if the inner callable has already been wrapped
|
|
31
|
+
if hasattr(call, "__inject_wrapper__"):
|
|
32
|
+
return cast(Callable[P, Union[T, Awaitable[T]]], call.__inject_wrapper__)
|
|
33
|
+
|
|
34
|
+
injected_params = self._get_injected_params(call)
|
|
35
|
+
|
|
36
|
+
if inspect.iscoroutinefunction(call):
|
|
37
|
+
|
|
38
|
+
@wraps(call)
|
|
39
|
+
async def awrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
40
|
+
for name, annotation in injected_params.items():
|
|
41
|
+
kwargs[name] = await self.container.aresolve(annotation)
|
|
42
|
+
return cast(T, await call(*args, **kwargs))
|
|
43
|
+
|
|
44
|
+
call.__inject_wrapper__ = awrapper # type: ignore[attr-defined]
|
|
45
|
+
|
|
46
|
+
return awrapper
|
|
47
|
+
|
|
48
|
+
@wraps(call)
|
|
49
|
+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
50
|
+
for name, annotation in injected_params.items():
|
|
51
|
+
kwargs[name] = self.container.resolve(annotation)
|
|
52
|
+
return cast(T, call(*args, **kwargs))
|
|
53
|
+
|
|
54
|
+
call.__inject_wrapper__ = wrapper # type: ignore[attr-defined]
|
|
55
|
+
|
|
56
|
+
return wrapper
|
|
57
|
+
|
|
58
|
+
def _get_injected_params(self, call: Callable[..., Any]) -> dict[str, Any]:
|
|
59
|
+
"""Get the injected parameters of a callable object."""
|
|
60
|
+
injected_params = {}
|
|
61
|
+
for parameter in get_typed_parameters(call):
|
|
62
|
+
if not is_marker(parameter.default):
|
|
63
|
+
continue
|
|
64
|
+
try:
|
|
65
|
+
self._validate_injected_parameter(call, parameter)
|
|
66
|
+
except LookupError as exc:
|
|
67
|
+
if not self.container.strict:
|
|
68
|
+
logger.debug(
|
|
69
|
+
f"Cannot validate the `{get_full_qualname(call)}` parameter "
|
|
70
|
+
f"`{parameter.name}` with an annotation of "
|
|
71
|
+
f"`{get_full_qualname(parameter.annotation)} due to being "
|
|
72
|
+
"in non-strict mode. It will be validated at the first call."
|
|
73
|
+
)
|
|
74
|
+
else:
|
|
75
|
+
raise exc
|
|
76
|
+
injected_params[parameter.name] = parameter.annotation
|
|
77
|
+
return injected_params
|
|
78
|
+
|
|
79
|
+
def _validate_injected_parameter(
|
|
80
|
+
self, call: Callable[..., Any], parameter: inspect.Parameter
|
|
81
|
+
) -> None:
|
|
82
|
+
"""Validate an injected parameter."""
|
|
83
|
+
if parameter.annotation is inspect.Parameter.empty:
|
|
84
|
+
raise TypeError(
|
|
85
|
+
f"Missing `{get_full_qualname(call)}` parameter "
|
|
86
|
+
f"`{parameter.name}` annotation."
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
if not self.container.is_registered(parameter.annotation):
|
|
90
|
+
raise LookupError(
|
|
91
|
+
f"`{get_full_qualname(call)}` has an unknown dependency parameter "
|
|
92
|
+
f"`{parameter.name}` with an annotation of "
|
|
93
|
+
f"`{get_full_qualname(parameter.annotation)}`."
|
|
94
|
+
)
|
anydi/_module.py
CHANGED
|
@@ -19,26 +19,9 @@ P = ParamSpec("P")
|
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class ModuleMeta(type):
|
|
22
|
-
"""A metaclass used for the Module base class.
|
|
23
|
-
|
|
24
|
-
This metaclass extracts provider information from the class attributes
|
|
25
|
-
and stores it in the `providers` attribute.
|
|
26
|
-
"""
|
|
22
|
+
"""A metaclass used for the Module base class."""
|
|
27
23
|
|
|
28
24
|
def __new__(cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any]) -> Any:
|
|
29
|
-
"""Create a new instance of the ModuleMeta class.
|
|
30
|
-
|
|
31
|
-
This method extracts provider information from the class attributes and
|
|
32
|
-
stores it in the `providers` attribute.
|
|
33
|
-
|
|
34
|
-
Args:
|
|
35
|
-
name: The name of the class.
|
|
36
|
-
bases: The base classes of the class.
|
|
37
|
-
attrs: The attributes of the class.
|
|
38
|
-
|
|
39
|
-
Returns:
|
|
40
|
-
The new instance of the class.
|
|
41
|
-
"""
|
|
42
25
|
attrs["providers"] = [
|
|
43
26
|
(name, getattr(value, "__provider__"))
|
|
44
27
|
for name, value in attrs.items()
|
|
@@ -53,14 +36,7 @@ class Module(metaclass=ModuleMeta):
|
|
|
53
36
|
providers: list[tuple[str, ProviderDecoratorArgs]]
|
|
54
37
|
|
|
55
38
|
def configure(self, container: Container) -> None:
|
|
56
|
-
"""Configure the AnyDI container with providers and their dependencies.
|
|
57
|
-
|
|
58
|
-
This method can be overridden in derived classes to provide the
|
|
59
|
-
configuration logic.
|
|
60
|
-
|
|
61
|
-
Args:
|
|
62
|
-
container: The AnyDI container to be configured.
|
|
63
|
-
"""
|
|
39
|
+
"""Configure the AnyDI container with providers and their dependencies."""
|
|
64
40
|
|
|
65
41
|
|
|
66
42
|
class ModuleRegistry:
|
|
@@ -70,11 +46,7 @@ class ModuleRegistry:
|
|
|
70
46
|
def register(
|
|
71
47
|
self, module: Module | type[Module] | Callable[[Container], None] | str
|
|
72
48
|
) -> None:
|
|
73
|
-
"""Register a module as a callable, module type, or module instance.
|
|
74
|
-
|
|
75
|
-
Args:
|
|
76
|
-
module: The module to register.
|
|
77
|
-
"""
|
|
49
|
+
"""Register a module as a callable, module type, or module instance."""
|
|
78
50
|
|
|
79
51
|
# Callable Module
|
|
80
52
|
if inspect.isfunction(module):
|
|
@@ -107,16 +79,7 @@ class ProviderDecoratorArgs(NamedTuple):
|
|
|
107
79
|
def provider(
|
|
108
80
|
*, scope: Scope, override: bool = False
|
|
109
81
|
) -> Callable[[Callable[Concatenate[M, P], T]], Callable[Concatenate[M, P], T]]:
|
|
110
|
-
"""Decorator for marking a function or method as a provider in a AnyDI module.
|
|
111
|
-
|
|
112
|
-
Args:
|
|
113
|
-
scope: The scope in which the provided instance should be managed.
|
|
114
|
-
override: Whether the provider should override existing providers
|
|
115
|
-
with the same interface.
|
|
116
|
-
|
|
117
|
-
Returns:
|
|
118
|
-
A decorator that marks the target function or method as a provider.
|
|
119
|
-
"""
|
|
82
|
+
"""Decorator for marking a function or method as a provider in a AnyDI module."""
|
|
120
83
|
|
|
121
84
|
def decorator(
|
|
122
85
|
target: Callable[Concatenate[M, P], T],
|