anydi 0.56.0__py3-none-any.whl → 0.58.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- anydi/__init__.py +4 -2
- anydi/_container.py +181 -157
- anydi/_injector.py +132 -0
- anydi/_resolver.py +51 -24
- anydi/_scanner.py +52 -44
- anydi/_types.py +49 -8
- anydi/ext/fastapi.py +31 -33
- anydi/ext/faststream.py +25 -31
- anydi/ext/pydantic_settings.py +2 -1
- anydi/ext/pytest_plugin.py +380 -50
- anydi/ext/starlette/middleware.py +1 -1
- {anydi-0.56.0.dist-info → anydi-0.58.0.dist-info}/METADATA +32 -13
- anydi-0.58.0.dist-info/RECORD +25 -0
- anydi-0.56.0.dist-info/RECORD +0 -24
- {anydi-0.56.0.dist-info → anydi-0.58.0.dist-info}/WHEEL +0 -0
- {anydi-0.56.0.dist-info → anydi-0.58.0.dist-info}/entry_points.txt +0 -0
anydi/__init__.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
"""AnyDI public objects and functions."""
|
|
2
2
|
|
|
3
|
-
from ._container import Container
|
|
3
|
+
from ._container import Container, import_container
|
|
4
4
|
from ._decorators import injectable, provided, provider, request, singleton, transient
|
|
5
5
|
from ._module import Module
|
|
6
6
|
from ._provider import ProviderDef as Provider
|
|
7
|
-
from ._types import Inject, Scope
|
|
7
|
+
from ._types import Inject, Provide, Scope
|
|
8
8
|
|
|
9
9
|
# Alias for dependency auto marker
|
|
10
10
|
# TODO: deprecate it
|
|
@@ -15,9 +15,11 @@ __all__ = [
|
|
|
15
15
|
"Container",
|
|
16
16
|
"Inject",
|
|
17
17
|
"Module",
|
|
18
|
+
"Provide",
|
|
18
19
|
"Provider",
|
|
19
20
|
"Scope",
|
|
20
21
|
"auto",
|
|
22
|
+
"import_container",
|
|
21
23
|
"injectable",
|
|
22
24
|
"provided",
|
|
23
25
|
"provider",
|
anydi/_container.py
CHANGED
|
@@ -3,43 +3,30 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
import contextlib
|
|
6
|
-
import
|
|
6
|
+
import importlib
|
|
7
7
|
import inspect
|
|
8
8
|
import logging
|
|
9
9
|
import types
|
|
10
10
|
import uuid
|
|
11
11
|
from collections import defaultdict
|
|
12
|
-
from collections.abc import AsyncIterator, Callable, Iterable, Iterator
|
|
12
|
+
from collections.abc import AsyncIterator, Callable, Iterable, Iterator, Sequence
|
|
13
13
|
from contextvars import ContextVar
|
|
14
|
-
from typing import
|
|
14
|
+
from typing import Any, TypeVar, get_args, get_origin, overload
|
|
15
15
|
|
|
16
16
|
from typing_extensions import ParamSpec, Self, type_repr
|
|
17
17
|
|
|
18
18
|
from ._context import InstanceContext
|
|
19
19
|
from ._decorators import is_provided
|
|
20
|
+
from ._injector import Injector
|
|
20
21
|
from ._module import ModuleDef, ModuleRegistrar
|
|
21
22
|
from ._provider import Provider, ProviderDef, ProviderKind, ProviderParameter
|
|
22
23
|
from ._resolver import Resolver
|
|
23
24
|
from ._scanner import PackageOrIterable, Scanner
|
|
24
|
-
from ._types import
|
|
25
|
-
NOT_SET,
|
|
26
|
-
Event,
|
|
27
|
-
Scope,
|
|
28
|
-
is_event_type,
|
|
29
|
-
is_inject_marker,
|
|
30
|
-
is_iterator_type,
|
|
31
|
-
is_none_type,
|
|
32
|
-
)
|
|
25
|
+
from ._types import NOT_SET, Event, Scope, is_event_type, is_iterator_type, is_none_type
|
|
33
26
|
|
|
34
27
|
T = TypeVar("T", bound=Any)
|
|
35
28
|
P = ParamSpec("P")
|
|
36
29
|
|
|
37
|
-
ALLOWED_SCOPES: dict[Scope, list[Scope]] = {
|
|
38
|
-
"singleton": ["singleton"],
|
|
39
|
-
"request": ["request", "singleton"],
|
|
40
|
-
"transient": ["transient", "request", "singleton"],
|
|
41
|
-
}
|
|
42
|
-
|
|
43
30
|
|
|
44
31
|
class Container:
|
|
45
32
|
"""AnyDI is a dependency injection container."""
|
|
@@ -53,18 +40,24 @@ class Container:
|
|
|
53
40
|
) -> None:
|
|
54
41
|
self._providers: dict[Any, Provider] = {}
|
|
55
42
|
self._logger = logger or logging.getLogger(__name__)
|
|
43
|
+
self._scopes: dict[str, Sequence[str]] = {
|
|
44
|
+
"transient": ("transient", "singleton"),
|
|
45
|
+
"singleton": ("singleton",),
|
|
46
|
+
}
|
|
47
|
+
|
|
56
48
|
self._resources: dict[str, list[Any]] = defaultdict(list)
|
|
57
49
|
self._singleton_context = InstanceContext()
|
|
58
|
-
self.
|
|
59
|
-
"request_context", default=None
|
|
60
|
-
)
|
|
61
|
-
self._inject_cache: dict[Callable[..., Any], Callable[..., Any]] = {}
|
|
50
|
+
self._scoped_context: dict[str, ContextVar[InstanceContext]] = {}
|
|
62
51
|
|
|
63
52
|
# Components
|
|
64
53
|
self._resolver = Resolver(self)
|
|
54
|
+
self._injector = Injector(self)
|
|
65
55
|
self._modules = ModuleRegistrar(self)
|
|
66
56
|
self._scanner = Scanner(self)
|
|
67
57
|
|
|
58
|
+
# Register default scopes
|
|
59
|
+
self.register_scope("request")
|
|
60
|
+
|
|
68
61
|
# Register providers
|
|
69
62
|
providers = providers or []
|
|
70
63
|
for provider in providers:
|
|
@@ -141,66 +134,142 @@ class Container:
|
|
|
141
134
|
await self._singleton_context.aclose()
|
|
142
135
|
|
|
143
136
|
@contextlib.contextmanager
|
|
144
|
-
def
|
|
137
|
+
def scoped_context(self, scope: str) -> Iterator[InstanceContext]:
|
|
145
138
|
"""Obtain a context manager for the request-scoped context."""
|
|
146
|
-
|
|
139
|
+
context_var = self._get_scoped_context_var(scope)
|
|
147
140
|
|
|
148
|
-
|
|
141
|
+
# Check if context already exists (re-entering same scope)
|
|
142
|
+
context = context_var.get(None)
|
|
143
|
+
if context is not None:
|
|
144
|
+
# Reuse existing context, don't create a new one
|
|
145
|
+
yield context
|
|
146
|
+
return
|
|
147
|
+
|
|
148
|
+
# Create new context
|
|
149
|
+
context = InstanceContext()
|
|
150
|
+
token = context_var.set(context)
|
|
149
151
|
|
|
150
152
|
# Resolve all request resources
|
|
151
|
-
for interface in self._resources.get(
|
|
153
|
+
for interface in self._resources.get(scope, []):
|
|
152
154
|
if not is_event_type(interface):
|
|
153
155
|
continue
|
|
154
156
|
self.resolve(interface)
|
|
155
157
|
|
|
156
158
|
with context:
|
|
157
159
|
yield context
|
|
158
|
-
|
|
160
|
+
context_var.reset(token)
|
|
159
161
|
|
|
160
162
|
@contextlib.asynccontextmanager
|
|
161
|
-
async def
|
|
162
|
-
"""Obtain
|
|
163
|
-
|
|
163
|
+
async def ascoped_context(self, scope: str) -> AsyncIterator[InstanceContext]:
|
|
164
|
+
"""Obtain a context manager for the specified scoped context."""
|
|
165
|
+
context_var = self._get_scoped_context_var(scope)
|
|
166
|
+
|
|
167
|
+
# Check if context already exists (re-entering same scope)
|
|
168
|
+
context = context_var.get(None)
|
|
169
|
+
if context is not None:
|
|
170
|
+
# Reuse existing context, don't create a new one
|
|
171
|
+
yield context
|
|
172
|
+
return
|
|
164
173
|
|
|
165
|
-
|
|
174
|
+
# Create new context
|
|
175
|
+
context = InstanceContext()
|
|
176
|
+
token = context_var.set(context)
|
|
166
177
|
|
|
167
|
-
|
|
178
|
+
# Resolve all request resources
|
|
179
|
+
for interface in self._resources.get(scope, []):
|
|
168
180
|
if not is_event_type(interface):
|
|
169
181
|
continue
|
|
170
182
|
await self.aresolve(interface)
|
|
171
183
|
|
|
172
184
|
async with context:
|
|
173
185
|
yield context
|
|
174
|
-
|
|
186
|
+
context_var.reset(token)
|
|
187
|
+
|
|
188
|
+
@contextlib.contextmanager
|
|
189
|
+
def request_context(self) -> Iterator[InstanceContext]:
|
|
190
|
+
"""Obtain a context manager for the request-scoped context."""
|
|
191
|
+
with self.scoped_context("request") as context:
|
|
192
|
+
yield context
|
|
193
|
+
|
|
194
|
+
@contextlib.asynccontextmanager
|
|
195
|
+
async def arequest_context(self) -> AsyncIterator[InstanceContext]:
|
|
196
|
+
"""Obtain an async context manager for the request-scoped context."""
|
|
197
|
+
async with self.ascoped_context("request") as context:
|
|
198
|
+
yield context
|
|
175
199
|
|
|
176
|
-
def
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
200
|
+
def _get_scoped_context(self, scope: str) -> InstanceContext:
|
|
201
|
+
scoped_context_var = self._get_scoped_context_var(scope)
|
|
202
|
+
try:
|
|
203
|
+
scoped_context = scoped_context_var.get()
|
|
204
|
+
except LookupError as exc:
|
|
180
205
|
raise LookupError(
|
|
181
|
-
"The
|
|
182
|
-
"the
|
|
206
|
+
f"The {scope} context has not been started. Please ensure that "
|
|
207
|
+
f"the {scope} context is properly initialized before attempting "
|
|
183
208
|
"to use it."
|
|
209
|
+
) from exc
|
|
210
|
+
return scoped_context
|
|
211
|
+
|
|
212
|
+
def _get_scoped_context_var(self, scope: str) -> ContextVar[InstanceContext]:
|
|
213
|
+
"""Get the context variable for the specified scope."""
|
|
214
|
+
# Validate that scope is registered and not reserved
|
|
215
|
+
if scope in ("transient", "singleton"):
|
|
216
|
+
raise ValueError(
|
|
217
|
+
f"Cannot get context variable for reserved scope `{scope}`."
|
|
184
218
|
)
|
|
185
|
-
|
|
219
|
+
if scope not in self._scopes:
|
|
220
|
+
raise ValueError(
|
|
221
|
+
f"Cannot get context variable for not registered scope `{scope}`. "
|
|
222
|
+
f"Please register the scope first using register_scope()."
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
if scope not in self._scoped_context:
|
|
226
|
+
self._scoped_context[scope] = ContextVar(f"{scope}_context")
|
|
227
|
+
return self._scoped_context[scope]
|
|
186
228
|
|
|
187
229
|
def _get_instance_context(self, scope: Scope) -> InstanceContext:
|
|
188
230
|
"""Get the instance context for the specified scope."""
|
|
189
231
|
if scope == "singleton":
|
|
190
232
|
return self._singleton_context
|
|
191
|
-
return self.
|
|
233
|
+
return self._get_scoped_context(scope)
|
|
234
|
+
|
|
235
|
+
# == Scopes == #
|
|
236
|
+
|
|
237
|
+
def register_scope(
|
|
238
|
+
self, scope: str, *, parents: Sequence[str] | None = None
|
|
239
|
+
) -> None:
|
|
240
|
+
"""Register a new scope with the specified parents."""
|
|
241
|
+
# Check if the scope is reserved
|
|
242
|
+
if scope in ("transient", "singleton"):
|
|
243
|
+
raise ValueError(
|
|
244
|
+
f"The scope `{scope}` is reserved and cannot be overridden."
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
# Check if the scope is already registered
|
|
248
|
+
if scope in self._scopes:
|
|
249
|
+
raise ValueError(f"The scope `{scope}` is already registered.")
|
|
250
|
+
|
|
251
|
+
# Validate parents
|
|
252
|
+
parents = parents or []
|
|
253
|
+
for parent in parents:
|
|
254
|
+
if parent not in self._scopes:
|
|
255
|
+
raise ValueError(f"The parent scope `{parent}` is not registered.")
|
|
256
|
+
|
|
257
|
+
# Register the scope
|
|
258
|
+
self._scopes[scope] = tuple({scope, "singleton"} | set(parents))
|
|
192
259
|
|
|
193
260
|
# == Provider Registry ==
|
|
194
261
|
|
|
195
262
|
def register(
|
|
196
263
|
self,
|
|
197
264
|
interface: Any,
|
|
198
|
-
call: Callable[..., Any],
|
|
265
|
+
call: Callable[..., Any] = NOT_SET,
|
|
199
266
|
*,
|
|
200
|
-
scope: Scope,
|
|
267
|
+
scope: Scope = "singleton",
|
|
201
268
|
override: bool = False,
|
|
202
269
|
) -> Provider:
|
|
203
270
|
"""Register a provider for the specified interface."""
|
|
271
|
+
if call is NOT_SET:
|
|
272
|
+
call = interface
|
|
204
273
|
return self._register_provider(call, scope, interface, override)
|
|
205
274
|
|
|
206
275
|
def is_registered(self, interface: Any) -> bool:
|
|
@@ -298,7 +367,7 @@ class Container:
|
|
|
298
367
|
unresolved_parameter = None
|
|
299
368
|
unresolved_exc: LookupError | None = None
|
|
300
369
|
parameters: list[ProviderParameter] = []
|
|
301
|
-
|
|
370
|
+
scope_provider: dict[Scope, Provider] = {}
|
|
302
371
|
|
|
303
372
|
for parameter in signature.parameters.values():
|
|
304
373
|
if parameter.annotation is inspect.Parameter.empty:
|
|
@@ -329,8 +398,8 @@ class Container:
|
|
|
329
398
|
continue
|
|
330
399
|
|
|
331
400
|
# Store first provider for each scope
|
|
332
|
-
if sub_provider.scope not in
|
|
333
|
-
|
|
401
|
+
if sub_provider.scope not in scope_provider:
|
|
402
|
+
scope_provider[sub_provider.scope] = sub_provider
|
|
334
403
|
|
|
335
404
|
parameters.append(
|
|
336
405
|
ProviderParameter(
|
|
@@ -343,6 +412,18 @@ class Container:
|
|
|
343
412
|
)
|
|
344
413
|
)
|
|
345
414
|
|
|
415
|
+
# Check scope compatibility
|
|
416
|
+
# Transient scope can use any scoped dependencies
|
|
417
|
+
if scope != "transient":
|
|
418
|
+
for sub_provider in scope_provider.values():
|
|
419
|
+
if sub_provider.scope not in self._scopes.get(scope, []):
|
|
420
|
+
raise ValueError(
|
|
421
|
+
f"The provider `{name}` with a `{scope}` scope "
|
|
422
|
+
f"cannot depend on `{sub_provider}` with a "
|
|
423
|
+
f"`{sub_provider.scope}` scope. Please ensure all "
|
|
424
|
+
"providers are registered with matching scopes."
|
|
425
|
+
)
|
|
426
|
+
|
|
346
427
|
# Check for unresolved parameters
|
|
347
428
|
if unresolved_parameter:
|
|
348
429
|
if scope not in ("singleton", "transient"):
|
|
@@ -356,15 +437,6 @@ class Container:
|
|
|
356
437
|
f"attempting to use it."
|
|
357
438
|
) from unresolved_exc
|
|
358
439
|
|
|
359
|
-
# Check scope compatibility
|
|
360
|
-
for sub_provider in scopes.values():
|
|
361
|
-
if sub_provider.scope not in ALLOWED_SCOPES.get(scope, []):
|
|
362
|
-
raise ValueError(
|
|
363
|
-
f"The provider `{name}` with a `{scope}` scope cannot "
|
|
364
|
-
f"depend on `{sub_provider}` with a `{sub_provider.scope}` scope. "
|
|
365
|
-
"Please ensure all providers are registered with matching scopes."
|
|
366
|
-
)
|
|
367
|
-
|
|
368
440
|
is_coroutine = kind == ProviderKind.COROUTINE
|
|
369
441
|
is_generator = kind == ProviderKind.GENERATOR
|
|
370
442
|
is_async_generator = kind == ProviderKind.ASYNC_GENERATOR
|
|
@@ -387,13 +459,14 @@ class Container:
|
|
|
387
459
|
self._set_provider(provider)
|
|
388
460
|
return provider
|
|
389
461
|
|
|
390
|
-
|
|
391
|
-
|
|
462
|
+
def _validate_provider_scope(
|
|
463
|
+
self, scope: Scope, name: str, is_resource: bool
|
|
464
|
+
) -> None:
|
|
392
465
|
"""Validate the provider scope."""
|
|
393
|
-
if scope not in
|
|
466
|
+
if scope not in self._scopes:
|
|
394
467
|
raise ValueError(
|
|
395
468
|
f"The provider `{name}` scope is invalid. Only the following "
|
|
396
|
-
f"scopes are supported: {', '.join(
|
|
469
|
+
f"scopes are supported: {', '.join(self._scopes.keys())}. "
|
|
397
470
|
"Please use one of the supported scopes when registering a provider."
|
|
398
471
|
)
|
|
399
472
|
if scope == "transient" and is_resource:
|
|
@@ -548,115 +621,21 @@ class Container:
|
|
|
548
621
|
"""Decorator to inject dependencies into a callable."""
|
|
549
622
|
|
|
550
623
|
def decorator(call: Callable[P, T]) -> Callable[P, T]:
|
|
551
|
-
return self.
|
|
624
|
+
return self._injector.inject(call)
|
|
552
625
|
|
|
553
626
|
if func is None:
|
|
554
627
|
return decorator
|
|
555
628
|
return decorator(func)
|
|
556
629
|
|
|
557
|
-
def run(self, func: Callable[
|
|
630
|
+
def run(self, func: Callable[..., T], /, *args: Any, **kwargs: Any) -> T:
|
|
558
631
|
"""Run the given function with injected dependencies."""
|
|
559
|
-
return self.
|
|
560
|
-
|
|
561
|
-
def _inject(self, call: Callable[P, T]) -> Callable[P, T]:
|
|
562
|
-
"""Inject dependencies into a callable."""
|
|
563
|
-
if call in self._inject_cache:
|
|
564
|
-
return cast(Callable[P, T], self._inject_cache[call])
|
|
565
|
-
|
|
566
|
-
injected_params = self._get_injected_params(call)
|
|
567
|
-
if not injected_params:
|
|
568
|
-
self._inject_cache[call] = call
|
|
569
|
-
return call
|
|
570
|
-
|
|
571
|
-
if inspect.iscoroutinefunction(call):
|
|
572
|
-
|
|
573
|
-
@functools.wraps(call)
|
|
574
|
-
async def awrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
575
|
-
for name, annotation in injected_params.items():
|
|
576
|
-
kwargs[name] = await self.aresolve(annotation)
|
|
577
|
-
return cast(T, await call(*args, **kwargs))
|
|
578
|
-
|
|
579
|
-
self._inject_cache[call] = awrapper
|
|
580
|
-
|
|
581
|
-
return awrapper # type: ignore
|
|
582
|
-
|
|
583
|
-
@functools.wraps(call)
|
|
584
|
-
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
585
|
-
for name, annotation in injected_params.items():
|
|
586
|
-
kwargs[name] = self.resolve(annotation)
|
|
587
|
-
return call(*args, **kwargs)
|
|
588
|
-
|
|
589
|
-
self._inject_cache[call] = wrapper
|
|
590
|
-
|
|
591
|
-
return wrapper
|
|
592
|
-
|
|
593
|
-
def _get_injected_params(self, call: Callable[..., Any]) -> dict[str, Any]:
|
|
594
|
-
"""Get the injected parameters of a callable object."""
|
|
595
|
-
injected_params: dict[str, Any] = {}
|
|
596
|
-
for parameter in inspect.signature(call, eval_str=True).parameters.values():
|
|
597
|
-
interface, should_inject = self.validate_injected_parameter(
|
|
598
|
-
parameter, call=call
|
|
599
|
-
)
|
|
600
|
-
if should_inject:
|
|
601
|
-
injected_params[parameter.name] = interface
|
|
602
|
-
return injected_params
|
|
603
|
-
|
|
604
|
-
@staticmethod
|
|
605
|
-
def _unwrap_injected_parameter(parameter: inspect.Parameter) -> inspect.Parameter:
|
|
606
|
-
if get_origin(parameter.annotation) is not Annotated:
|
|
607
|
-
return parameter
|
|
608
|
-
|
|
609
|
-
origin, *metadata = get_args(parameter.annotation)
|
|
610
|
-
|
|
611
|
-
if not metadata or not is_inject_marker(metadata[-1]):
|
|
612
|
-
return parameter
|
|
613
|
-
|
|
614
|
-
if is_inject_marker(parameter.default):
|
|
615
|
-
raise TypeError(
|
|
616
|
-
"Cannot specify `Inject` in `Annotated` and "
|
|
617
|
-
f"default value together for '{parameter.name}'"
|
|
618
|
-
)
|
|
619
|
-
|
|
620
|
-
if parameter.default is not inspect.Parameter.empty:
|
|
621
|
-
return parameter
|
|
622
|
-
|
|
623
|
-
marker = metadata[-1]
|
|
624
|
-
new_metadata = metadata[:-1]
|
|
625
|
-
if new_metadata:
|
|
626
|
-
if hasattr(Annotated, "__getitem__"):
|
|
627
|
-
new_annotation = Annotated.__getitem__((origin, *new_metadata)) # type: ignore
|
|
628
|
-
else:
|
|
629
|
-
new_annotation = Annotated.__class_getitem__((origin, *new_metadata)) # type: ignore
|
|
630
|
-
else:
|
|
631
|
-
new_annotation = origin
|
|
632
|
-
return parameter.replace(annotation=new_annotation, default=marker)
|
|
632
|
+
return self._injector.inject(func)(*args, **kwargs)
|
|
633
633
|
|
|
634
634
|
def validate_injected_parameter(
|
|
635
635
|
self, parameter: inspect.Parameter, *, call: Callable[..., Any]
|
|
636
636
|
) -> tuple[Any, bool]:
|
|
637
637
|
"""Validate an injected parameter."""
|
|
638
|
-
|
|
639
|
-
interface = parameter.annotation
|
|
640
|
-
|
|
641
|
-
if not is_inject_marker(parameter.default):
|
|
642
|
-
return interface, False
|
|
643
|
-
|
|
644
|
-
if interface is inspect.Parameter.empty:
|
|
645
|
-
raise TypeError(
|
|
646
|
-
f"Missing `{type_repr(call)}` parameter `{parameter.name}` annotation."
|
|
647
|
-
)
|
|
648
|
-
|
|
649
|
-
# Set inject marker interface
|
|
650
|
-
parameter.default.interface = interface
|
|
651
|
-
|
|
652
|
-
if not self.has_provider_for(interface):
|
|
653
|
-
raise LookupError(
|
|
654
|
-
f"`{type_repr(call)}` has an unknown dependency parameter "
|
|
655
|
-
f"`{parameter.name}` with an annotation of "
|
|
656
|
-
f"`{type_repr(interface)}`."
|
|
657
|
-
)
|
|
658
|
-
|
|
659
|
-
return interface, True
|
|
638
|
+
return self._injector.validate_parameter(parameter, call=call)
|
|
660
639
|
|
|
661
640
|
# == Module Registration ==
|
|
662
641
|
|
|
@@ -685,3 +664,48 @@ class Container:
|
|
|
685
664
|
yield
|
|
686
665
|
finally:
|
|
687
666
|
self._resolver.remove_override(interface)
|
|
667
|
+
|
|
668
|
+
|
|
669
|
+
def import_container(container_path: str) -> Container:
|
|
670
|
+
"""Import container from a string path."""
|
|
671
|
+
# Replace colon with dot for unified processing
|
|
672
|
+
container_path = container_path.replace(":", ".")
|
|
673
|
+
|
|
674
|
+
try:
|
|
675
|
+
module_path, attr_name = container_path.rsplit(".", 1)
|
|
676
|
+
except ValueError as exc:
|
|
677
|
+
raise ImportError(
|
|
678
|
+
f"Invalid container path '{container_path}'. "
|
|
679
|
+
"Expected format: 'module.path:attribute' or 'module.path.attribute'"
|
|
680
|
+
) from exc
|
|
681
|
+
|
|
682
|
+
try:
|
|
683
|
+
module = importlib.import_module(module_path)
|
|
684
|
+
except ImportError as exc:
|
|
685
|
+
raise ImportError(
|
|
686
|
+
f"Failed to import module '{module_path}' "
|
|
687
|
+
f"from container path '{container_path}'"
|
|
688
|
+
) from exc
|
|
689
|
+
|
|
690
|
+
try:
|
|
691
|
+
container_or_factory = getattr(module, attr_name)
|
|
692
|
+
except AttributeError as exc:
|
|
693
|
+
raise ImportError(
|
|
694
|
+
f"Module '{module_path}' has no attribute '{attr_name}'"
|
|
695
|
+
) from exc
|
|
696
|
+
|
|
697
|
+
# If it's a callable (factory), call it
|
|
698
|
+
if callable(container_or_factory) and not isinstance(
|
|
699
|
+
container_or_factory, Container
|
|
700
|
+
):
|
|
701
|
+
container = container_or_factory()
|
|
702
|
+
else:
|
|
703
|
+
container = container_or_factory
|
|
704
|
+
|
|
705
|
+
if not isinstance(container, Container):
|
|
706
|
+
raise ImportError(
|
|
707
|
+
f"Expected Container instance, got {type(container).__name__} "
|
|
708
|
+
f"from '{container_path}'"
|
|
709
|
+
)
|
|
710
|
+
|
|
711
|
+
return container
|
anydi/_injector.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
"""Dependency injection utilities."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import functools
|
|
6
|
+
import inspect
|
|
7
|
+
from collections.abc import Callable
|
|
8
|
+
from typing import (
|
|
9
|
+
TYPE_CHECKING,
|
|
10
|
+
Annotated,
|
|
11
|
+
Any,
|
|
12
|
+
TypeVar,
|
|
13
|
+
cast,
|
|
14
|
+
get_args,
|
|
15
|
+
get_origin,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
from typing_extensions import ParamSpec, type_repr
|
|
19
|
+
|
|
20
|
+
from ._types import is_provide_marker
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from ._container import Container
|
|
24
|
+
|
|
25
|
+
T = TypeVar("T", bound=Any)
|
|
26
|
+
P = ParamSpec("P")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class Injector:
|
|
30
|
+
"""Handles dependency injection for callables."""
|
|
31
|
+
|
|
32
|
+
def __init__(self, container: Container) -> None:
|
|
33
|
+
self.container = container
|
|
34
|
+
self._cache: dict[Callable[..., Any], Callable[..., Any]] = {}
|
|
35
|
+
|
|
36
|
+
def inject(self, call: Callable[P, T]) -> Callable[P, T]:
|
|
37
|
+
"""Inject dependencies into a callable."""
|
|
38
|
+
if call in self._cache:
|
|
39
|
+
return cast(Callable[P, T], self._cache[call])
|
|
40
|
+
|
|
41
|
+
injected_params = self._get_injected_params(call)
|
|
42
|
+
if not injected_params:
|
|
43
|
+
self._cache[call] = call
|
|
44
|
+
return call
|
|
45
|
+
|
|
46
|
+
if inspect.iscoroutinefunction(call):
|
|
47
|
+
|
|
48
|
+
@functools.wraps(call)
|
|
49
|
+
async def awrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
50
|
+
for name, annotation in injected_params.items():
|
|
51
|
+
kwargs[name] = await self.container.aresolve(annotation)
|
|
52
|
+
return cast(T, await call(*args, **kwargs))
|
|
53
|
+
|
|
54
|
+
self._cache[call] = awrapper
|
|
55
|
+
|
|
56
|
+
return awrapper # type: ignore
|
|
57
|
+
|
|
58
|
+
@functools.wraps(call)
|
|
59
|
+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
60
|
+
for name, annotation in injected_params.items():
|
|
61
|
+
kwargs[name] = self.container.resolve(annotation)
|
|
62
|
+
return call(*args, **kwargs)
|
|
63
|
+
|
|
64
|
+
self._cache[call] = wrapper
|
|
65
|
+
|
|
66
|
+
return wrapper
|
|
67
|
+
|
|
68
|
+
def _get_injected_params(self, call: Callable[..., Any]) -> dict[str, Any]:
|
|
69
|
+
"""Get the injected parameters of a callable object."""
|
|
70
|
+
injected_params: dict[str, Any] = {}
|
|
71
|
+
for parameter in inspect.signature(call, eval_str=True).parameters.values():
|
|
72
|
+
interface, should_inject = self.validate_parameter(parameter, call=call)
|
|
73
|
+
if should_inject:
|
|
74
|
+
injected_params[parameter.name] = interface
|
|
75
|
+
return injected_params
|
|
76
|
+
|
|
77
|
+
def validate_parameter(
|
|
78
|
+
self, parameter: inspect.Parameter, *, call: Callable[..., Any]
|
|
79
|
+
) -> tuple[Any, bool]:
|
|
80
|
+
"""Validate an injected parameter."""
|
|
81
|
+
parameter = self.unwrap_parameter(parameter)
|
|
82
|
+
interface = parameter.annotation
|
|
83
|
+
|
|
84
|
+
if not is_provide_marker(parameter.default):
|
|
85
|
+
return interface, False
|
|
86
|
+
|
|
87
|
+
if interface is inspect.Parameter.empty:
|
|
88
|
+
raise TypeError(
|
|
89
|
+
f"Missing `{type_repr(call)}` parameter `{parameter.name}` annotation."
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
# Set inject marker interface
|
|
93
|
+
parameter.default.interface = interface
|
|
94
|
+
|
|
95
|
+
if not self.container.has_provider_for(interface):
|
|
96
|
+
raise LookupError(
|
|
97
|
+
f"`{type_repr(call)}` has an unknown dependency parameter "
|
|
98
|
+
f"`{parameter.name}` with an annotation of "
|
|
99
|
+
f"`{type_repr(interface)}`."
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
return interface, True
|
|
103
|
+
|
|
104
|
+
@staticmethod
|
|
105
|
+
def unwrap_parameter(parameter: inspect.Parameter) -> inspect.Parameter:
|
|
106
|
+
if get_origin(parameter.annotation) is not Annotated:
|
|
107
|
+
return parameter
|
|
108
|
+
|
|
109
|
+
origin, *metadata = get_args(parameter.annotation)
|
|
110
|
+
|
|
111
|
+
if not metadata or not is_provide_marker(metadata[-1]):
|
|
112
|
+
return parameter
|
|
113
|
+
|
|
114
|
+
if is_provide_marker(parameter.default):
|
|
115
|
+
raise TypeError(
|
|
116
|
+
"Cannot specify `Inject` in `Annotated` and "
|
|
117
|
+
f"default value together for '{parameter.name}'"
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
if parameter.default is not inspect.Parameter.empty:
|
|
121
|
+
return parameter
|
|
122
|
+
|
|
123
|
+
marker = metadata[-1]
|
|
124
|
+
new_metadata = metadata[:-1]
|
|
125
|
+
if new_metadata:
|
|
126
|
+
if hasattr(Annotated, "__getitem__"):
|
|
127
|
+
new_annotation = Annotated.__getitem__((origin, *new_metadata)) # type: ignore
|
|
128
|
+
else:
|
|
129
|
+
new_annotation = Annotated.__class_getitem__((origin, *new_metadata)) # type: ignore
|
|
130
|
+
else:
|
|
131
|
+
new_annotation = origin
|
|
132
|
+
return parameter.replace(annotation=new_annotation, default=marker)
|