anydi 0.31.0__tar.gz → 0.32.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {anydi-0.31.0 → anydi-0.32.1}/PKG-INFO +1 -1
- {anydi-0.31.0 → anydi-0.32.1}/anydi/__init__.py +2 -8
- anydi-0.32.1/anydi/_container.py +473 -0
- {anydi-0.31.0 → anydi-0.32.1}/anydi/_context.py +99 -188
- anydi-0.32.1/anydi/_injector.py +94 -0
- {anydi-0.31.0 → anydi-0.32.1}/anydi/_module.py +4 -41
- anydi-0.32.1/anydi/_provider.py +187 -0
- {anydi-0.31.0 → anydi-0.32.1}/anydi/_scanner.py +14 -68
- anydi-0.32.1/anydi/_types.py +37 -0
- anydi-0.32.1/anydi/_utils.py +105 -0
- {anydi-0.31.0 → anydi-0.32.1}/anydi/ext/_utils.py +5 -11
- {anydi-0.31.0 → anydi-0.32.1}/anydi/ext/django/_settings.py +1 -1
- {anydi-0.31.0 → anydi-0.32.1}/anydi/ext/django/_utils.py +2 -2
- {anydi-0.31.0 → anydi-0.32.1}/anydi/ext/django/apps.py +1 -1
- anydi-0.32.1/anydi/ext/django/middleware.py +28 -0
- {anydi-0.31.0 → anydi-0.32.1}/anydi/ext/fastapi.py +4 -24
- {anydi-0.31.0 → anydi-0.32.1}/anydi/ext/faststream.py +0 -7
- {anydi-0.31.0 → anydi-0.32.1}/anydi/ext/pydantic_settings.py +2 -2
- {anydi-0.31.0 → anydi-0.32.1}/anydi/ext/pytest_plugin.py +3 -2
- {anydi-0.31.0 → anydi-0.32.1}/anydi/ext/starlette/middleware.py +2 -16
- {anydi-0.31.0 → anydi-0.32.1}/pyproject.toml +9 -2
- anydi-0.31.0/anydi/_container.py +0 -866
- anydi-0.31.0/anydi/_types.py +0 -136
- anydi-0.31.0/anydi/_utils.py +0 -154
- anydi-0.31.0/anydi/ext/django/middleware.py +0 -26
- {anydi-0.31.0 → anydi-0.32.1}/LICENSE +0 -0
- {anydi-0.31.0 → anydi-0.32.1}/README.md +0 -0
- {anydi-0.31.0 → anydi-0.32.1}/anydi/_logger.py +0 -0
- {anydi-0.31.0 → anydi-0.32.1}/anydi/ext/__init__.py +0 -0
- {anydi-0.31.0 → anydi-0.32.1}/anydi/ext/django/__init__.py +0 -0
- {anydi-0.31.0 → anydi-0.32.1}/anydi/ext/django/_container.py +0 -0
- {anydi-0.31.0 → anydi-0.32.1}/anydi/ext/django/ninja/__init__.py +0 -0
- {anydi-0.31.0 → anydi-0.32.1}/anydi/ext/django/ninja/_operation.py +0 -0
- {anydi-0.31.0 → anydi-0.32.1}/anydi/ext/django/ninja/_signature.py +0 -0
- {anydi-0.31.0 → anydi-0.32.1}/anydi/ext/starlette/__init__.py +0 -0
- {anydi-0.31.0 → anydi-0.32.1}/anydi/py.typed +0 -0
|
@@ -4,14 +4,9 @@ from typing import Any, cast
|
|
|
4
4
|
|
|
5
5
|
from ._container import Container, request, singleton, transient
|
|
6
6
|
from ._module import Module, provider
|
|
7
|
+
from ._provider import Provider
|
|
7
8
|
from ._scanner import injectable
|
|
8
|
-
from ._types import Marker,
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def dep() -> Any:
|
|
12
|
-
"""A marker for dependency injection."""
|
|
13
|
-
return Marker()
|
|
14
|
-
|
|
9
|
+
from ._types import Marker, Scope
|
|
15
10
|
|
|
16
11
|
# Alias for dependency auto marker
|
|
17
12
|
auto = cast(Any, Marker())
|
|
@@ -23,7 +18,6 @@ __all__ = [
|
|
|
23
18
|
"Provider",
|
|
24
19
|
"Scope",
|
|
25
20
|
"auto",
|
|
26
|
-
"dep",
|
|
27
21
|
"injectable",
|
|
28
22
|
"provider",
|
|
29
23
|
"request",
|
|
@@ -0,0 +1,473 @@
|
|
|
1
|
+
"""AnyDI core implementation module."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import contextlib
|
|
6
|
+
import inspect
|
|
7
|
+
import types
|
|
8
|
+
from collections import defaultdict
|
|
9
|
+
from collections.abc import AsyncIterator, Awaitable, Iterable, Iterator, Sequence
|
|
10
|
+
from contextvars import ContextVar
|
|
11
|
+
from typing import Any, Callable, TypeVar, cast, overload
|
|
12
|
+
|
|
13
|
+
from typing_extensions import ParamSpec, Self, final
|
|
14
|
+
|
|
15
|
+
from ._context import (
|
|
16
|
+
RequestContext,
|
|
17
|
+
ResourceScopedContext,
|
|
18
|
+
ScopedContext,
|
|
19
|
+
SingletonContext,
|
|
20
|
+
TransientContext,
|
|
21
|
+
)
|
|
22
|
+
from ._injector import Injector
|
|
23
|
+
from ._module import Module, ModuleRegistry
|
|
24
|
+
from ._provider import Provider
|
|
25
|
+
from ._scanner import Scanner
|
|
26
|
+
from ._types import AnyInterface, Interface, Scope
|
|
27
|
+
from ._utils import get_full_qualname, get_typed_parameters, is_builtin_type
|
|
28
|
+
|
|
29
|
+
T = TypeVar("T", bound=Any)
|
|
30
|
+
P = ParamSpec("P")
|
|
31
|
+
|
|
32
|
+
ALLOWED_SCOPES: dict[Scope, list[Scope]] = {
|
|
33
|
+
"singleton": ["singleton"],
|
|
34
|
+
"request": ["request", "singleton"],
|
|
35
|
+
"transient": ["transient", "singleton", "request"],
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@final
|
|
40
|
+
class Container:
|
|
41
|
+
"""AnyDI is a dependency injection container."""
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
*,
|
|
46
|
+
providers: Sequence[Provider] | None = None,
|
|
47
|
+
modules: Sequence[Module | type[Module] | Callable[[Container], None] | str]
|
|
48
|
+
| None = None,
|
|
49
|
+
strict: bool = False,
|
|
50
|
+
) -> None:
|
|
51
|
+
self._providers: dict[type[Any], Provider] = {}
|
|
52
|
+
self._resource_cache: dict[Scope, list[type[Any]]] = defaultdict(list)
|
|
53
|
+
self._singleton_context = SingletonContext(self)
|
|
54
|
+
self._transient_context = TransientContext(self)
|
|
55
|
+
self._request_context_var: ContextVar[RequestContext | None] = ContextVar(
|
|
56
|
+
"request_context", default=None
|
|
57
|
+
)
|
|
58
|
+
self._override_instances: dict[type[Any], Any] = {}
|
|
59
|
+
self._strict = strict
|
|
60
|
+
self._unresolved_interfaces: set[type[Any]] = set()
|
|
61
|
+
|
|
62
|
+
# Components
|
|
63
|
+
self._injector = Injector(self)
|
|
64
|
+
self._modules = ModuleRegistry(self)
|
|
65
|
+
self._scanner = Scanner(self)
|
|
66
|
+
|
|
67
|
+
# Register providers
|
|
68
|
+
providers = providers or []
|
|
69
|
+
for provider in providers:
|
|
70
|
+
self._register_provider(provider)
|
|
71
|
+
|
|
72
|
+
# Register modules
|
|
73
|
+
modules = modules or []
|
|
74
|
+
for module in modules:
|
|
75
|
+
self.register_module(module)
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def strict(self) -> bool:
|
|
79
|
+
"""Check if strict mode is enabled."""
|
|
80
|
+
return self._strict
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def providers(self) -> dict[type[Any], Provider]:
|
|
84
|
+
"""Get the registered providers."""
|
|
85
|
+
return self._providers
|
|
86
|
+
|
|
87
|
+
def is_registered(self, interface: AnyInterface) -> bool:
|
|
88
|
+
"""Check if a provider is registered for the specified interface."""
|
|
89
|
+
return interface in self._providers
|
|
90
|
+
|
|
91
|
+
def register(
|
|
92
|
+
self,
|
|
93
|
+
interface: AnyInterface,
|
|
94
|
+
call: Callable[..., Any],
|
|
95
|
+
*,
|
|
96
|
+
scope: Scope,
|
|
97
|
+
override: bool = False,
|
|
98
|
+
) -> Provider:
|
|
99
|
+
"""Register a provider for the specified interface."""
|
|
100
|
+
provider = Provider(call=call, scope=scope, interface=interface)
|
|
101
|
+
return self._register_provider(provider, override=override)
|
|
102
|
+
|
|
103
|
+
def _register_provider(
|
|
104
|
+
self, provider: Provider, *, override: bool = False
|
|
105
|
+
) -> Provider:
|
|
106
|
+
"""Register a provider."""
|
|
107
|
+
if provider.interface in self._providers:
|
|
108
|
+
if override:
|
|
109
|
+
self._set_provider(provider)
|
|
110
|
+
return provider
|
|
111
|
+
|
|
112
|
+
raise LookupError(
|
|
113
|
+
f"The provider interface `{get_full_qualname(provider.interface)}` "
|
|
114
|
+
"already registered."
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
self._validate_sub_providers(provider)
|
|
118
|
+
self._set_provider(provider)
|
|
119
|
+
return provider
|
|
120
|
+
|
|
121
|
+
def unregister(self, interface: AnyInterface) -> None:
|
|
122
|
+
"""Unregister a provider by interface."""
|
|
123
|
+
if not self.is_registered(interface):
|
|
124
|
+
raise LookupError(
|
|
125
|
+
"The provider interface "
|
|
126
|
+
f"`{get_full_qualname(interface)}` not registered."
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
provider = self._get_provider(interface)
|
|
130
|
+
|
|
131
|
+
# Cleanup scoped context instance
|
|
132
|
+
try:
|
|
133
|
+
scoped_context = self._get_scoped_context(provider.scope)
|
|
134
|
+
except LookupError:
|
|
135
|
+
pass
|
|
136
|
+
else:
|
|
137
|
+
if isinstance(scoped_context, ResourceScopedContext):
|
|
138
|
+
scoped_context.delete(interface)
|
|
139
|
+
|
|
140
|
+
# Cleanup provider references
|
|
141
|
+
self._delete_provider(provider)
|
|
142
|
+
|
|
143
|
+
def _get_provider(self, interface: AnyInterface) -> Provider:
|
|
144
|
+
"""Get provider by interface."""
|
|
145
|
+
try:
|
|
146
|
+
return self._providers[interface]
|
|
147
|
+
except KeyError as exc:
|
|
148
|
+
raise LookupError(
|
|
149
|
+
f"The provider interface for `{get_full_qualname(interface)}` has "
|
|
150
|
+
"not been registered. Please ensure that the provider interface is "
|
|
151
|
+
"properly registered before attempting to use it."
|
|
152
|
+
) from exc
|
|
153
|
+
|
|
154
|
+
def _get_or_register_provider(
|
|
155
|
+
self, interface: AnyInterface, parent_scope: Scope | None = None
|
|
156
|
+
) -> Provider:
|
|
157
|
+
"""Get or register a provider by interface."""
|
|
158
|
+
try:
|
|
159
|
+
return self._get_provider(interface)
|
|
160
|
+
except LookupError:
|
|
161
|
+
if (
|
|
162
|
+
not self.strict
|
|
163
|
+
and inspect.isclass(interface)
|
|
164
|
+
and not is_builtin_type(interface)
|
|
165
|
+
and interface is not inspect.Parameter.empty
|
|
166
|
+
):
|
|
167
|
+
# Try to get defined scope
|
|
168
|
+
scope = getattr(interface, "__scope__", parent_scope)
|
|
169
|
+
# Try to detect scope
|
|
170
|
+
if scope is None:
|
|
171
|
+
scope = self._detect_scope(interface)
|
|
172
|
+
return self.register(interface, interface, scope=scope or "transient")
|
|
173
|
+
raise
|
|
174
|
+
|
|
175
|
+
def _set_provider(self, provider: Provider) -> None:
|
|
176
|
+
"""Set a provider by interface."""
|
|
177
|
+
self._providers[provider.interface] = provider
|
|
178
|
+
if provider.is_resource:
|
|
179
|
+
self._resource_cache[provider.scope].append(provider.interface)
|
|
180
|
+
|
|
181
|
+
def _delete_provider(self, provider: Provider) -> None:
|
|
182
|
+
"""Delete a provider."""
|
|
183
|
+
if provider.interface in self._providers:
|
|
184
|
+
del self._providers[provider.interface]
|
|
185
|
+
if provider.is_resource:
|
|
186
|
+
self._resource_cache[provider.scope].remove(provider.interface)
|
|
187
|
+
|
|
188
|
+
def _validate_sub_providers(self, provider: Provider) -> None:
|
|
189
|
+
"""Validate the sub-providers of a provider."""
|
|
190
|
+
|
|
191
|
+
for parameter in provider.parameters:
|
|
192
|
+
annotation = parameter.annotation
|
|
193
|
+
|
|
194
|
+
if annotation is inspect.Parameter.empty:
|
|
195
|
+
raise TypeError(
|
|
196
|
+
f"Missing provider `{provider}` "
|
|
197
|
+
f"dependency `{parameter.name}` annotation."
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
try:
|
|
201
|
+
sub_provider = self._get_or_register_provider(
|
|
202
|
+
annotation, parent_scope=provider.scope
|
|
203
|
+
)
|
|
204
|
+
except LookupError:
|
|
205
|
+
if provider.scope not in {"singleton", "transient"}:
|
|
206
|
+
self._unresolved_interfaces.add(provider.interface)
|
|
207
|
+
continue
|
|
208
|
+
raise ValueError(
|
|
209
|
+
f"The provider `{provider}` depends on `{parameter.name}` of type "
|
|
210
|
+
f"`{get_full_qualname(annotation)}`, which "
|
|
211
|
+
"has not been registered or set. To resolve this, ensure that "
|
|
212
|
+
f"`{parameter.name}` is registered before attempting to use it."
|
|
213
|
+
) from None
|
|
214
|
+
|
|
215
|
+
# Check scope compatibility
|
|
216
|
+
if sub_provider.scope not in ALLOWED_SCOPES.get(provider.scope, []):
|
|
217
|
+
raise ValueError(
|
|
218
|
+
f"The provider `{provider}` with a `{provider.scope}` scope cannot "
|
|
219
|
+
f"depend on `{sub_provider}` with a `{sub_provider.scope}` scope. "
|
|
220
|
+
"Please ensure all providers are registered with matching scopes."
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
def _detect_scope(self, call: Callable[..., Any]) -> Scope | None:
|
|
224
|
+
"""Detect the scope for a callable."""
|
|
225
|
+
scopes = set()
|
|
226
|
+
|
|
227
|
+
for parameter in get_typed_parameters(call):
|
|
228
|
+
sub_provider = self._get_or_register_provider(parameter.annotation)
|
|
229
|
+
scope = sub_provider.scope
|
|
230
|
+
|
|
231
|
+
if scope == "transient":
|
|
232
|
+
return "transient"
|
|
233
|
+
scopes.add(scope)
|
|
234
|
+
|
|
235
|
+
# If all scopes are found, we can return based on priority order
|
|
236
|
+
if {"transient", "request", "singleton"}.issubset(scopes):
|
|
237
|
+
break
|
|
238
|
+
|
|
239
|
+
# Determine scope based on priority
|
|
240
|
+
if "request" in scopes:
|
|
241
|
+
return "request"
|
|
242
|
+
if "singleton" in scopes:
|
|
243
|
+
return "singleton"
|
|
244
|
+
|
|
245
|
+
return None
|
|
246
|
+
|
|
247
|
+
def register_module(
|
|
248
|
+
self, module: Module | type[Module] | Callable[[Container], None] | str
|
|
249
|
+
) -> None:
|
|
250
|
+
"""Register a module as a callable, module type, or module instance."""
|
|
251
|
+
self._modules.register(module)
|
|
252
|
+
|
|
253
|
+
def __enter__(self) -> Self:
|
|
254
|
+
"""Enter the singleton context."""
|
|
255
|
+
self.start()
|
|
256
|
+
return self
|
|
257
|
+
|
|
258
|
+
def __exit__(
|
|
259
|
+
self,
|
|
260
|
+
exc_type: type[BaseException] | None,
|
|
261
|
+
exc_val: BaseException | None,
|
|
262
|
+
exc_tb: types.TracebackType | None,
|
|
263
|
+
) -> bool:
|
|
264
|
+
"""Exit the singleton context."""
|
|
265
|
+
return self._singleton_context.__exit__(exc_type, exc_val, exc_tb)
|
|
266
|
+
|
|
267
|
+
def start(self) -> None:
|
|
268
|
+
"""Start the singleton context."""
|
|
269
|
+
self._singleton_context.start()
|
|
270
|
+
|
|
271
|
+
def close(self) -> None:
|
|
272
|
+
"""Close the singleton context."""
|
|
273
|
+
self._singleton_context.close()
|
|
274
|
+
|
|
275
|
+
@contextlib.contextmanager
|
|
276
|
+
def request_context(self) -> Iterator[RequestContext]:
|
|
277
|
+
"""Obtain a context manager for the request-scoped context."""
|
|
278
|
+
context = RequestContext(self)
|
|
279
|
+
token = self._request_context_var.set(context)
|
|
280
|
+
with context:
|
|
281
|
+
yield context
|
|
282
|
+
self._request_context_var.reset(token)
|
|
283
|
+
|
|
284
|
+
async def __aenter__(self) -> Self:
|
|
285
|
+
"""Enter the singleton context."""
|
|
286
|
+
await self.astart()
|
|
287
|
+
return self
|
|
288
|
+
|
|
289
|
+
async def __aexit__(
|
|
290
|
+
self,
|
|
291
|
+
exc_type: type[BaseException] | None,
|
|
292
|
+
exc_val: BaseException | None,
|
|
293
|
+
exc_tb: types.TracebackType | None,
|
|
294
|
+
) -> bool:
|
|
295
|
+
"""Exit the singleton context."""
|
|
296
|
+
return await self._singleton_context.__aexit__(exc_type, exc_val, exc_tb)
|
|
297
|
+
|
|
298
|
+
async def astart(self) -> None:
|
|
299
|
+
"""Start the singleton context asynchronously."""
|
|
300
|
+
await self._singleton_context.astart()
|
|
301
|
+
|
|
302
|
+
async def aclose(self) -> None:
|
|
303
|
+
"""Close the singleton context asynchronously."""
|
|
304
|
+
await self._singleton_context.aclose()
|
|
305
|
+
|
|
306
|
+
@contextlib.asynccontextmanager
|
|
307
|
+
async def arequest_context(self) -> AsyncIterator[RequestContext]:
|
|
308
|
+
"""Obtain an async context manager for the request-scoped context."""
|
|
309
|
+
context = RequestContext(self)
|
|
310
|
+
token = self._request_context_var.set(context)
|
|
311
|
+
async with context:
|
|
312
|
+
yield context
|
|
313
|
+
self._request_context_var.reset(token)
|
|
314
|
+
|
|
315
|
+
def _get_request_context(self) -> RequestContext:
|
|
316
|
+
"""Get the current request context."""
|
|
317
|
+
request_context = self._request_context_var.get()
|
|
318
|
+
if request_context is None:
|
|
319
|
+
raise LookupError(
|
|
320
|
+
"The request context has not been started. Please ensure that "
|
|
321
|
+
"the request context is properly initialized before attempting "
|
|
322
|
+
"to use it."
|
|
323
|
+
)
|
|
324
|
+
return request_context
|
|
325
|
+
|
|
326
|
+
def reset(self) -> None:
|
|
327
|
+
"""Reset resolved instances."""
|
|
328
|
+
for interface, provider in self._providers.items():
|
|
329
|
+
try:
|
|
330
|
+
scoped_context = self._get_scoped_context(provider.scope)
|
|
331
|
+
except LookupError:
|
|
332
|
+
continue
|
|
333
|
+
if isinstance(scoped_context, ResourceScopedContext):
|
|
334
|
+
scoped_context.delete(interface)
|
|
335
|
+
|
|
336
|
+
@overload
|
|
337
|
+
def resolve(self, interface: Interface[T]) -> T: ...
|
|
338
|
+
|
|
339
|
+
@overload
|
|
340
|
+
def resolve(self, interface: T) -> T: ...
|
|
341
|
+
|
|
342
|
+
def resolve(self, interface: Interface[T]) -> T:
|
|
343
|
+
"""Resolve an instance by interface."""
|
|
344
|
+
if interface in self._override_instances:
|
|
345
|
+
return cast(T, self._override_instances[interface])
|
|
346
|
+
|
|
347
|
+
provider = self._get_or_register_provider(interface)
|
|
348
|
+
scoped_context = self._get_scoped_context(provider.scope)
|
|
349
|
+
return cast(T, scoped_context.get(provider))
|
|
350
|
+
|
|
351
|
+
@overload
|
|
352
|
+
async def aresolve(self, interface: Interface[T]) -> T: ...
|
|
353
|
+
|
|
354
|
+
@overload
|
|
355
|
+
async def aresolve(self, interface: T) -> T: ...
|
|
356
|
+
|
|
357
|
+
async def aresolve(self, interface: Interface[T]) -> T:
|
|
358
|
+
"""Resolve an instance by interface asynchronously."""
|
|
359
|
+
if interface in self._override_instances:
|
|
360
|
+
return cast(T, self._override_instances[interface])
|
|
361
|
+
|
|
362
|
+
provider = self._get_or_register_provider(interface)
|
|
363
|
+
scoped_context = self._get_scoped_context(provider.scope)
|
|
364
|
+
return cast(T, await scoped_context.aget(provider))
|
|
365
|
+
|
|
366
|
+
def is_resolved(self, interface: AnyInterface) -> bool:
|
|
367
|
+
"""Check if an instance by interface exists."""
|
|
368
|
+
try:
|
|
369
|
+
provider = self._get_provider(interface)
|
|
370
|
+
except LookupError:
|
|
371
|
+
pass
|
|
372
|
+
else:
|
|
373
|
+
scoped_context = self._get_scoped_context(provider.scope)
|
|
374
|
+
if isinstance(scoped_context, ResourceScopedContext):
|
|
375
|
+
return scoped_context.has(interface)
|
|
376
|
+
return False
|
|
377
|
+
|
|
378
|
+
def release(self, interface: AnyInterface) -> None:
|
|
379
|
+
"""Release an instance by interface."""
|
|
380
|
+
provider = self._get_provider(interface)
|
|
381
|
+
scoped_context = self._get_scoped_context(provider.scope)
|
|
382
|
+
if isinstance(scoped_context, ResourceScopedContext):
|
|
383
|
+
scoped_context.delete(interface)
|
|
384
|
+
|
|
385
|
+
def _get_scoped_context(self, scope: Scope) -> ScopedContext:
|
|
386
|
+
"""Get the scoped context based on the specified scope."""
|
|
387
|
+
if scope == "singleton":
|
|
388
|
+
return self._singleton_context
|
|
389
|
+
elif scope == "request":
|
|
390
|
+
request_context = self._get_request_context()
|
|
391
|
+
return request_context
|
|
392
|
+
return self._transient_context
|
|
393
|
+
|
|
394
|
+
@contextlib.contextmanager
|
|
395
|
+
def override(self, interface: AnyInterface, instance: Any) -> Iterator[None]:
|
|
396
|
+
"""
|
|
397
|
+
Override the provider for the specified interface with a specific instance.
|
|
398
|
+
"""
|
|
399
|
+
if not self.is_registered(interface) and self.strict:
|
|
400
|
+
raise LookupError(
|
|
401
|
+
f"The provider interface `{get_full_qualname(interface)}` "
|
|
402
|
+
"not registered."
|
|
403
|
+
)
|
|
404
|
+
self._override_instances[interface] = instance
|
|
405
|
+
yield
|
|
406
|
+
del self._override_instances[interface]
|
|
407
|
+
|
|
408
|
+
def provider(
|
|
409
|
+
self, *, scope: Scope, override: bool = False
|
|
410
|
+
) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
|
411
|
+
"""Decorator to register a provider function with the specified scope."""
|
|
412
|
+
|
|
413
|
+
def decorator(call: Callable[P, T]) -> Callable[P, T]:
|
|
414
|
+
provider = Provider(call=call, scope=scope)
|
|
415
|
+
self._register_provider(provider, override=override)
|
|
416
|
+
return call
|
|
417
|
+
|
|
418
|
+
return decorator
|
|
419
|
+
|
|
420
|
+
@overload
|
|
421
|
+
def inject(self, func: Callable[P, T]) -> Callable[P, T]: ...
|
|
422
|
+
|
|
423
|
+
@overload
|
|
424
|
+
def inject(self) -> Callable[[Callable[P, T]], Callable[P, T]]: ...
|
|
425
|
+
|
|
426
|
+
def inject(
|
|
427
|
+
self, func: Callable[P, T | Awaitable[T]] | None = None
|
|
428
|
+
) -> (
|
|
429
|
+
Callable[[Callable[P, T | Awaitable[T]]], Callable[P, T | Awaitable[T]]]
|
|
430
|
+
| Callable[P, T | Awaitable[T]]
|
|
431
|
+
):
|
|
432
|
+
"""Decorator to inject dependencies into a callable."""
|
|
433
|
+
|
|
434
|
+
def decorator(
|
|
435
|
+
inner: Callable[P, T | Awaitable[T]],
|
|
436
|
+
) -> Callable[P, T | Awaitable[T]]:
|
|
437
|
+
return self._injector.inject(inner)
|
|
438
|
+
|
|
439
|
+
if func is None:
|
|
440
|
+
return decorator
|
|
441
|
+
return decorator(func)
|
|
442
|
+
|
|
443
|
+
def run(self, func: Callable[P, T], /, *args: P.args, **kwargs: P.kwargs) -> T:
|
|
444
|
+
"""Run the given function with injected dependencies."""
|
|
445
|
+
return cast(T, self._injector.inject(func)(*args, **kwargs))
|
|
446
|
+
|
|
447
|
+
def scan(
|
|
448
|
+
self,
|
|
449
|
+
/,
|
|
450
|
+
packages: types.ModuleType | str | Iterable[types.ModuleType | str],
|
|
451
|
+
*,
|
|
452
|
+
tags: Iterable[str] | None = None,
|
|
453
|
+
) -> None:
|
|
454
|
+
"""Scan packages or modules for decorated members and inject dependencies."""
|
|
455
|
+
self._scanner.scan(packages, tags=tags)
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
def transient(target: T) -> T:
|
|
459
|
+
"""Decorator for marking a class as transient scope."""
|
|
460
|
+
setattr(target, "__scope__", "transient")
|
|
461
|
+
return target
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
def request(target: T) -> T:
|
|
465
|
+
"""Decorator for marking a class as request scope."""
|
|
466
|
+
setattr(target, "__scope__", "request")
|
|
467
|
+
return target
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
def singleton(target: T) -> T:
|
|
471
|
+
"""Decorator for marking a class as singleton scope."""
|
|
472
|
+
setattr(target, "__scope__", "singleton")
|
|
473
|
+
return target
|