anydi 0.38.0__py3-none-any.whl → 0.38.2rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- anydi/__init__.py +1 -2
- anydi/_container.py +414 -270
- anydi/_types.py +80 -2
- anydi/ext/pytest_plugin.py +6 -21
- {anydi-0.38.0.dist-info → anydi-0.38.2rc0.dist-info}/METADATA +1 -1
- {anydi-0.38.0.dist-info → anydi-0.38.2rc0.dist-info}/RECORD +9 -10
- anydi/_provider.py +0 -232
- {anydi-0.38.0.dist-info → anydi-0.38.2rc0.dist-info}/WHEEL +0 -0
- {anydi-0.38.0.dist-info → anydi-0.38.2rc0.dist-info}/entry_points.txt +0 -0
- {anydi-0.38.0.dist-info → anydi-0.38.2rc0.dist-info}/licenses/LICENSE +0 -0
anydi/_container.py
CHANGED
|
@@ -10,23 +10,27 @@ import logging
|
|
|
10
10
|
import pkgutil
|
|
11
11
|
import threading
|
|
12
12
|
import types
|
|
13
|
+
import uuid
|
|
13
14
|
from collections import defaultdict
|
|
14
15
|
from collections.abc import AsyncIterator, Iterable, Iterator, Sequence
|
|
15
16
|
from contextvars import ContextVar
|
|
16
17
|
from types import ModuleType
|
|
17
|
-
from typing import Any, Callable, TypeVar, Union, cast, overload
|
|
18
|
-
from weakref import WeakKeyDictionary
|
|
18
|
+
from typing import Any, Callable, TypeVar, Union, cast, get_origin, overload
|
|
19
19
|
|
|
20
|
-
from typing_extensions import Concatenate, ParamSpec, Self, final
|
|
20
|
+
from typing_extensions import Concatenate, ParamSpec, Self, final, get_args
|
|
21
21
|
|
|
22
22
|
from ._context import InstanceContext
|
|
23
|
-
from ._provider import Provider
|
|
24
23
|
from ._types import (
|
|
24
|
+
NOT_SET,
|
|
25
25
|
AnyInterface,
|
|
26
|
-
|
|
26
|
+
Event,
|
|
27
27
|
InjectableDecoratorArgs,
|
|
28
28
|
InstanceProxy,
|
|
29
|
+
Provider,
|
|
30
|
+
ProviderArgs,
|
|
29
31
|
ProviderDecoratorArgs,
|
|
32
|
+
ProviderKind,
|
|
33
|
+
ScannedDependency,
|
|
30
34
|
Scope,
|
|
31
35
|
is_event_type,
|
|
32
36
|
is_marker,
|
|
@@ -34,6 +38,7 @@ from ._types import (
|
|
|
34
38
|
from ._utils import (
|
|
35
39
|
AsyncRLock,
|
|
36
40
|
get_full_qualname,
|
|
41
|
+
get_typed_annotation,
|
|
37
42
|
get_typed_parameters,
|
|
38
43
|
import_string,
|
|
39
44
|
is_async_context_manager,
|
|
@@ -42,6 +47,12 @@ from ._utils import (
|
|
|
42
47
|
run_async,
|
|
43
48
|
)
|
|
44
49
|
|
|
50
|
+
try:
|
|
51
|
+
from types import NoneType
|
|
52
|
+
except ImportError:
|
|
53
|
+
NoneType = type(None) # type: ignore[misc]
|
|
54
|
+
|
|
55
|
+
|
|
45
56
|
T = TypeVar("T", bound=Any)
|
|
46
57
|
M = TypeVar("M", bound="Module")
|
|
47
58
|
P = ParamSpec("P")
|
|
@@ -74,7 +85,6 @@ class Module(metaclass=ModuleMeta):
|
|
|
74
85
|
"""Configure the AnyDI container with providers and their dependencies."""
|
|
75
86
|
|
|
76
87
|
|
|
77
|
-
# noinspection PyShadowingNames
|
|
78
88
|
@final
|
|
79
89
|
class Container:
|
|
80
90
|
"""AnyDI is a dependency injection container."""
|
|
@@ -82,7 +92,7 @@ class Container:
|
|
|
82
92
|
def __init__(
|
|
83
93
|
self,
|
|
84
94
|
*,
|
|
85
|
-
providers: Sequence[
|
|
95
|
+
providers: Sequence[ProviderArgs] | None = None,
|
|
86
96
|
modules: Sequence[Module | type[Module] | Callable[[Container], None] | str]
|
|
87
97
|
| None = None,
|
|
88
98
|
strict: bool = False,
|
|
@@ -104,20 +114,27 @@ class Container:
|
|
|
104
114
|
)
|
|
105
115
|
self._override_instances: dict[type[Any], Any] = {}
|
|
106
116
|
self._unresolved_interfaces: set[type[Any]] = set()
|
|
107
|
-
self._inject_cache:
|
|
108
|
-
Callable[..., Any], Callable[..., Any]
|
|
109
|
-
] = WeakKeyDictionary()
|
|
117
|
+
self._inject_cache: dict[Callable[..., Any], Callable[..., Any]] = {}
|
|
110
118
|
|
|
111
119
|
# Register providers
|
|
112
120
|
providers = providers or []
|
|
113
121
|
for provider in providers:
|
|
114
|
-
self.
|
|
122
|
+
_provider = self._create_provider(
|
|
123
|
+
call=provider.call,
|
|
124
|
+
scope=provider.scope,
|
|
125
|
+
interface=provider.interface,
|
|
126
|
+
)
|
|
127
|
+
self._register_provider(_provider, False)
|
|
115
128
|
|
|
116
129
|
# Register modules
|
|
117
130
|
modules = modules or []
|
|
118
131
|
for module in modules:
|
|
119
132
|
self.register_module(module)
|
|
120
133
|
|
|
134
|
+
############################
|
|
135
|
+
# Properties
|
|
136
|
+
############################
|
|
137
|
+
|
|
121
138
|
@property
|
|
122
139
|
def strict(self) -> bool:
|
|
123
140
|
"""Check if strict mode is enabled."""
|
|
@@ -143,9 +160,110 @@ class Container:
|
|
|
143
160
|
"""Get the logger instance."""
|
|
144
161
|
return self._logger
|
|
145
162
|
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
163
|
+
############################
|
|
164
|
+
# Lifespan/Context Methods
|
|
165
|
+
############################
|
|
166
|
+
|
|
167
|
+
def __enter__(self) -> Self:
|
|
168
|
+
"""Enter the singleton context."""
|
|
169
|
+
self.start()
|
|
170
|
+
return self
|
|
171
|
+
|
|
172
|
+
def __exit__(
|
|
173
|
+
self,
|
|
174
|
+
exc_type: type[BaseException] | None,
|
|
175
|
+
exc_val: BaseException | None,
|
|
176
|
+
exc_tb: types.TracebackType | None,
|
|
177
|
+
) -> Any:
|
|
178
|
+
"""Exit the singleton context."""
|
|
179
|
+
return self._singleton_context.__exit__(exc_type, exc_val, exc_tb)
|
|
180
|
+
|
|
181
|
+
def start(self) -> None:
|
|
182
|
+
"""Start the singleton context."""
|
|
183
|
+
# Resolve all singleton resources
|
|
184
|
+
for interface in self._resources.get("singleton", []):
|
|
185
|
+
self.resolve(interface)
|
|
186
|
+
|
|
187
|
+
def close(self) -> None:
|
|
188
|
+
"""Close the singleton context."""
|
|
189
|
+
self._singleton_context.close()
|
|
190
|
+
|
|
191
|
+
async def __aenter__(self) -> Self:
|
|
192
|
+
"""Enter the singleton context."""
|
|
193
|
+
await self.astart()
|
|
194
|
+
return self
|
|
195
|
+
|
|
196
|
+
async def __aexit__(
|
|
197
|
+
self,
|
|
198
|
+
exc_type: type[BaseException] | None,
|
|
199
|
+
exc_val: BaseException | None,
|
|
200
|
+
exc_tb: types.TracebackType | None,
|
|
201
|
+
) -> bool:
|
|
202
|
+
"""Exit the singleton context."""
|
|
203
|
+
return await self._singleton_context.__aexit__(exc_type, exc_val, exc_tb)
|
|
204
|
+
|
|
205
|
+
async def astart(self) -> None:
|
|
206
|
+
"""Start the singleton context asynchronously."""
|
|
207
|
+
for interface in self._resources.get("singleton", []):
|
|
208
|
+
await self.aresolve(interface)
|
|
209
|
+
|
|
210
|
+
async def aclose(self) -> None:
|
|
211
|
+
"""Close the singleton context asynchronously."""
|
|
212
|
+
await self._singleton_context.aclose()
|
|
213
|
+
|
|
214
|
+
@contextlib.contextmanager
|
|
215
|
+
def request_context(self) -> Iterator[InstanceContext]:
|
|
216
|
+
"""Obtain a context manager for the request-scoped context."""
|
|
217
|
+
context = InstanceContext()
|
|
218
|
+
|
|
219
|
+
token = self._request_context_var.set(context)
|
|
220
|
+
|
|
221
|
+
# Resolve all request resources
|
|
222
|
+
for interface in self._resources.get("request", []):
|
|
223
|
+
if not is_event_type(interface):
|
|
224
|
+
continue
|
|
225
|
+
self.resolve(interface)
|
|
226
|
+
|
|
227
|
+
with context:
|
|
228
|
+
yield context
|
|
229
|
+
self._request_context_var.reset(token)
|
|
230
|
+
|
|
231
|
+
@contextlib.asynccontextmanager
|
|
232
|
+
async def arequest_context(self) -> AsyncIterator[InstanceContext]:
|
|
233
|
+
"""Obtain an async context manager for the request-scoped context."""
|
|
234
|
+
context = InstanceContext()
|
|
235
|
+
|
|
236
|
+
token = self._request_context_var.set(context)
|
|
237
|
+
|
|
238
|
+
for interface in self._resources.get("request", []):
|
|
239
|
+
if not is_event_type(interface):
|
|
240
|
+
continue
|
|
241
|
+
await self.aresolve(interface)
|
|
242
|
+
|
|
243
|
+
async with context:
|
|
244
|
+
yield context
|
|
245
|
+
self._request_context_var.reset(token)
|
|
246
|
+
|
|
247
|
+
def _get_request_context(self) -> InstanceContext:
|
|
248
|
+
"""Get the current request context."""
|
|
249
|
+
request_context = self._request_context_var.get()
|
|
250
|
+
if request_context is None:
|
|
251
|
+
raise LookupError(
|
|
252
|
+
"The request context has not been started. Please ensure that "
|
|
253
|
+
"the request context is properly initialized before attempting "
|
|
254
|
+
"to use it."
|
|
255
|
+
)
|
|
256
|
+
return request_context
|
|
257
|
+
|
|
258
|
+
def _get_scoped_context(self, scope: Scope) -> InstanceContext:
|
|
259
|
+
"""Get the instance context for the specified scope."""
|
|
260
|
+
if scope == "singleton":
|
|
261
|
+
return self._singleton_context
|
|
262
|
+
return self._get_request_context()
|
|
263
|
+
|
|
264
|
+
############################
|
|
265
|
+
# Provider Methods
|
|
266
|
+
############################
|
|
149
267
|
|
|
150
268
|
def register(
|
|
151
269
|
self,
|
|
@@ -156,26 +274,12 @@ class Container:
|
|
|
156
274
|
override: bool = False,
|
|
157
275
|
) -> Provider:
|
|
158
276
|
"""Register a provider for the specified interface."""
|
|
159
|
-
provider =
|
|
277
|
+
provider = self._create_provider(call=call, scope=scope, interface=interface)
|
|
160
278
|
return self._register_provider(provider, override)
|
|
161
279
|
|
|
162
|
-
def
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
"""Register a provider."""
|
|
166
|
-
if provider.interface in self._providers:
|
|
167
|
-
if override:
|
|
168
|
-
self._set_provider(provider)
|
|
169
|
-
return provider
|
|
170
|
-
|
|
171
|
-
raise LookupError(
|
|
172
|
-
f"The provider interface `{get_full_qualname(provider.interface)}` "
|
|
173
|
-
"already registered."
|
|
174
|
-
)
|
|
175
|
-
|
|
176
|
-
self._validate_sub_providers(provider, **defaults)
|
|
177
|
-
self._set_provider(provider)
|
|
178
|
-
return provider
|
|
280
|
+
def is_registered(self, interface: AnyInterface) -> bool:
|
|
281
|
+
"""Check if a provider is registered for the specified interface."""
|
|
282
|
+
return interface in self._providers
|
|
179
283
|
|
|
180
284
|
def unregister(self, interface: AnyInterface) -> None:
|
|
181
285
|
"""Unregister a provider by interface."""
|
|
@@ -199,6 +303,120 @@ class Container:
|
|
|
199
303
|
# Cleanup provider references
|
|
200
304
|
self._delete_provider(provider)
|
|
201
305
|
|
|
306
|
+
def provider(
|
|
307
|
+
self, *, scope: Scope, override: bool = False
|
|
308
|
+
) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
|
309
|
+
"""Decorator to register a provider function with the specified scope."""
|
|
310
|
+
|
|
311
|
+
def decorator(call: Callable[P, T]) -> Callable[P, T]:
|
|
312
|
+
provider = self._create_provider(call=call, scope=scope)
|
|
313
|
+
self._register_provider(provider, override)
|
|
314
|
+
return call
|
|
315
|
+
|
|
316
|
+
return decorator
|
|
317
|
+
|
|
318
|
+
def _create_provider( # noqa: C901
|
|
319
|
+
self, call: Callable[..., Any], *, scope: Scope, interface: Any = NOT_SET
|
|
320
|
+
) -> Provider:
|
|
321
|
+
name = get_full_qualname(call)
|
|
322
|
+
|
|
323
|
+
# Detect the kind of callable provider
|
|
324
|
+
kind = ProviderKind.from_call(call)
|
|
325
|
+
|
|
326
|
+
# Validate the scope of the provider
|
|
327
|
+
if scope not in get_args(Scope):
|
|
328
|
+
raise ValueError(
|
|
329
|
+
f"The provider `{name}` scope is invalid. Only the following "
|
|
330
|
+
f"scopes are supported: {', '.join(get_args(Scope))}. "
|
|
331
|
+
"Please use one of the supported scopes when registering a provider."
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
# Validate the scope of the provider
|
|
335
|
+
if (
|
|
336
|
+
kind in {ProviderKind.GENERATOR, ProviderKind.ASYNC_GENERATOR}
|
|
337
|
+
and scope == "transient"
|
|
338
|
+
):
|
|
339
|
+
raise TypeError(
|
|
340
|
+
f"The resource provider `{name}` is attempting to register "
|
|
341
|
+
"with a transient scope, which is not allowed."
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
# Get the signature
|
|
345
|
+
globalns = getattr(call, "__globals__", {})
|
|
346
|
+
signature = inspect.signature(call, globals=globalns)
|
|
347
|
+
|
|
348
|
+
# Detect the interface
|
|
349
|
+
if kind == ProviderKind.CLASS:
|
|
350
|
+
interface = call
|
|
351
|
+
else:
|
|
352
|
+
if interface is NOT_SET:
|
|
353
|
+
interface = signature.return_annotation
|
|
354
|
+
if interface is inspect.Signature.empty:
|
|
355
|
+
interface = None
|
|
356
|
+
else:
|
|
357
|
+
interface = get_typed_annotation(interface, globalns)
|
|
358
|
+
|
|
359
|
+
# If the callable is an iterator, return the actual type
|
|
360
|
+
iterator_types = {Iterator, AsyncIterator}
|
|
361
|
+
if interface in iterator_types or get_origin(interface) in iterator_types:
|
|
362
|
+
if args := get_args(interface):
|
|
363
|
+
interface = args[0]
|
|
364
|
+
# If the callable is a generator, return the resource type
|
|
365
|
+
if interface in {None, NoneType}:
|
|
366
|
+
interface = type(f"Event_{uuid.uuid4().hex}", (Event,), {})
|
|
367
|
+
else:
|
|
368
|
+
raise TypeError(
|
|
369
|
+
f"Cannot use `{name}` resource type annotation "
|
|
370
|
+
"without actual type argument."
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
# None interface is not allowed
|
|
374
|
+
if interface in {None, NoneType}:
|
|
375
|
+
raise TypeError(f"Missing `{name}` provider return annotation.")
|
|
376
|
+
|
|
377
|
+
# Detect the parameters
|
|
378
|
+
parameters = []
|
|
379
|
+
for parameter in signature.parameters.values():
|
|
380
|
+
if parameter.annotation is inspect.Parameter.empty:
|
|
381
|
+
raise TypeError(
|
|
382
|
+
f"Missing provider `{name}` "
|
|
383
|
+
f"dependency `{parameter.name}` annotation."
|
|
384
|
+
)
|
|
385
|
+
if parameter.kind == inspect.Parameter.POSITIONAL_ONLY:
|
|
386
|
+
raise TypeError(
|
|
387
|
+
"Positional-only parameters "
|
|
388
|
+
f"are not allowed in the provider `{name}`."
|
|
389
|
+
)
|
|
390
|
+
annotation = get_typed_annotation(parameter.annotation, globalns)
|
|
391
|
+
parameters.append(parameter.replace(annotation=annotation))
|
|
392
|
+
|
|
393
|
+
return Provider(
|
|
394
|
+
call=call,
|
|
395
|
+
scope=scope,
|
|
396
|
+
interface=interface,
|
|
397
|
+
name=name,
|
|
398
|
+
kind=kind,
|
|
399
|
+
parameters=parameters,
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
def _register_provider(
|
|
403
|
+
self, provider: Provider, override: bool, /, **defaults: Any
|
|
404
|
+
) -> Provider:
|
|
405
|
+
"""Register a provider."""
|
|
406
|
+
if provider.interface in self._providers:
|
|
407
|
+
if override:
|
|
408
|
+
self._set_provider(provider)
|
|
409
|
+
return provider
|
|
410
|
+
|
|
411
|
+
raise LookupError(
|
|
412
|
+
f"The provider interface `{get_full_qualname(provider.interface)}` "
|
|
413
|
+
"already registered."
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
self._validate_sub_providers(provider, **defaults)
|
|
417
|
+
self._set_provider(provider)
|
|
418
|
+
return provider
|
|
419
|
+
|
|
202
420
|
def _get_provider(self, interface: AnyInterface) -> Provider:
|
|
203
421
|
"""Get provider by interface."""
|
|
204
422
|
try:
|
|
@@ -227,9 +445,11 @@ class Container:
|
|
|
227
445
|
scope = getattr(interface, "__scope__", parent_scope)
|
|
228
446
|
# Try to detect scope
|
|
229
447
|
if scope is None:
|
|
230
|
-
scope = self.
|
|
448
|
+
scope = self._detect_provider_scope(interface, **defaults)
|
|
231
449
|
scope = scope or self.default_scope
|
|
232
|
-
provider =
|
|
450
|
+
provider = self._create_provider(
|
|
451
|
+
call=interface, scope=scope, interface=interface
|
|
452
|
+
)
|
|
233
453
|
return self._register_provider(provider, False, **defaults)
|
|
234
454
|
raise
|
|
235
455
|
|
|
@@ -250,12 +470,6 @@ class Container:
|
|
|
250
470
|
"""Validate the sub-providers of a provider."""
|
|
251
471
|
|
|
252
472
|
for parameter in provider.parameters:
|
|
253
|
-
if parameter.annotation is inspect.Parameter.empty:
|
|
254
|
-
raise TypeError(
|
|
255
|
-
f"Missing provider `{provider}` "
|
|
256
|
-
f"dependency `{parameter.name}` annotation."
|
|
257
|
-
)
|
|
258
|
-
|
|
259
473
|
try:
|
|
260
474
|
sub_provider = self._get_or_register_provider(
|
|
261
475
|
parameter.annotation, provider.scope
|
|
@@ -282,7 +496,9 @@ class Container:
|
|
|
282
496
|
"Please ensure all providers are registered with matching scopes."
|
|
283
497
|
)
|
|
284
498
|
|
|
285
|
-
def
|
|
499
|
+
def _detect_provider_scope(
|
|
500
|
+
self, call: Callable[..., Any], /, **defaults: Any
|
|
501
|
+
) -> Scope | None:
|
|
286
502
|
"""Detect the scope for a callable."""
|
|
287
503
|
scopes = set()
|
|
288
504
|
|
|
@@ -320,122 +536,56 @@ class Container:
|
|
|
320
536
|
not self.strict and parameter.default is not inspect.Parameter.empty
|
|
321
537
|
)
|
|
322
538
|
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
"""Register a module as a callable, module type, or module instance."""
|
|
327
|
-
# Callable Module
|
|
328
|
-
if inspect.isfunction(module):
|
|
329
|
-
module(self)
|
|
330
|
-
return
|
|
331
|
-
|
|
332
|
-
# Module path
|
|
333
|
-
if isinstance(module, str):
|
|
334
|
-
module = import_string(module)
|
|
335
|
-
|
|
336
|
-
# Class based Module or Module type
|
|
337
|
-
if inspect.isclass(module) and issubclass(module, Module):
|
|
338
|
-
module = module()
|
|
339
|
-
|
|
340
|
-
if isinstance(module, Module):
|
|
341
|
-
module.configure(self)
|
|
342
|
-
for provider_name, decorator_args in module.providers:
|
|
343
|
-
obj = getattr(module, provider_name)
|
|
344
|
-
self.provider(
|
|
345
|
-
scope=decorator_args.scope,
|
|
346
|
-
override=decorator_args.override,
|
|
347
|
-
)(obj)
|
|
348
|
-
|
|
349
|
-
def __enter__(self) -> Self:
|
|
350
|
-
"""Enter the singleton context."""
|
|
351
|
-
self.start()
|
|
352
|
-
return self
|
|
353
|
-
|
|
354
|
-
def __exit__(
|
|
355
|
-
self,
|
|
356
|
-
exc_type: type[BaseException] | None,
|
|
357
|
-
exc_val: BaseException | None,
|
|
358
|
-
exc_tb: types.TracebackType | None,
|
|
359
|
-
) -> Any:
|
|
360
|
-
"""Exit the singleton context."""
|
|
361
|
-
return self._singleton_context.__exit__(exc_type, exc_val, exc_tb)
|
|
362
|
-
|
|
363
|
-
def start(self) -> None:
|
|
364
|
-
"""Start the singleton context."""
|
|
365
|
-
# Resolve all singleton resources
|
|
366
|
-
for interface in self._resources.get("singleton", []):
|
|
367
|
-
self.resolve(interface)
|
|
368
|
-
|
|
369
|
-
def close(self) -> None:
|
|
370
|
-
"""Close the singleton context."""
|
|
371
|
-
self._singleton_context.close()
|
|
372
|
-
|
|
373
|
-
@contextlib.contextmanager
|
|
374
|
-
def request_context(self) -> Iterator[InstanceContext]:
|
|
375
|
-
"""Obtain a context manager for the request-scoped context."""
|
|
376
|
-
context = InstanceContext()
|
|
377
|
-
|
|
378
|
-
token = self._request_context_var.set(context)
|
|
539
|
+
############################
|
|
540
|
+
# Instance Methods
|
|
541
|
+
############################
|
|
379
542
|
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
if not is_event_type(interface):
|
|
383
|
-
continue
|
|
384
|
-
self.resolve(interface)
|
|
385
|
-
|
|
386
|
-
with context:
|
|
387
|
-
yield context
|
|
388
|
-
self._request_context_var.reset(token)
|
|
543
|
+
@overload
|
|
544
|
+
def resolve(self, interface: type[T]) -> T: ...
|
|
389
545
|
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
await self.astart()
|
|
393
|
-
return self
|
|
546
|
+
@overload
|
|
547
|
+
def resolve(self, interface: T) -> T: ...
|
|
394
548
|
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
exc_val: BaseException | None,
|
|
399
|
-
exc_tb: types.TracebackType | None,
|
|
400
|
-
) -> bool:
|
|
401
|
-
"""Exit the singleton context."""
|
|
402
|
-
return await self._singleton_context.__aexit__(exc_type, exc_val, exc_tb)
|
|
549
|
+
def resolve(self, interface: type[T]) -> T:
|
|
550
|
+
"""Resolve an instance by interface."""
|
|
551
|
+
return self._resolve_or_create(interface, False)
|
|
403
552
|
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
for interface in self._resources.get("singleton", []):
|
|
407
|
-
await self.aresolve(interface)
|
|
553
|
+
@overload
|
|
554
|
+
async def aresolve(self, interface: type[T]) -> T: ...
|
|
408
555
|
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
await self._singleton_context.aclose()
|
|
556
|
+
@overload
|
|
557
|
+
async def aresolve(self, interface: T) -> T: ...
|
|
412
558
|
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
context = InstanceContext()
|
|
559
|
+
async def aresolve(self, interface: type[T]) -> T:
|
|
560
|
+
"""Resolve an instance by interface asynchronously."""
|
|
561
|
+
return await self._aresolve_or_acreate(interface, False)
|
|
417
562
|
|
|
418
|
-
|
|
563
|
+
def create(self, interface: type[T], /, **defaults: Any) -> T:
|
|
564
|
+
"""Create an instance by interface."""
|
|
565
|
+
return self._resolve_or_create(interface, True, **defaults)
|
|
419
566
|
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
await self.aresolve(interface)
|
|
567
|
+
async def acreate(self, interface: type[T], /, **defaults: Any) -> T:
|
|
568
|
+
"""Create an instance by interface asynchronously."""
|
|
569
|
+
return await self._aresolve_or_acreate(interface, True, **defaults)
|
|
424
570
|
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
571
|
+
def is_resolved(self, interface: AnyInterface) -> bool:
|
|
572
|
+
"""Check if an instance by interface exists."""
|
|
573
|
+
try:
|
|
574
|
+
provider = self._get_provider(interface)
|
|
575
|
+
except LookupError:
|
|
576
|
+
return False
|
|
577
|
+
if provider.scope == "transient":
|
|
578
|
+
return False
|
|
579
|
+
context = self._get_scoped_context(provider.scope)
|
|
580
|
+
return interface in context
|
|
428
581
|
|
|
429
|
-
def
|
|
430
|
-
"""
|
|
431
|
-
|
|
432
|
-
if
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
"to use it."
|
|
437
|
-
)
|
|
438
|
-
return request_context
|
|
582
|
+
def release(self, interface: AnyInterface) -> None:
|
|
583
|
+
"""Release an instance by interface."""
|
|
584
|
+
provider = self._get_provider(interface)
|
|
585
|
+
if provider.scope == "transient":
|
|
586
|
+
return None
|
|
587
|
+
context = self._get_scoped_context(provider.scope)
|
|
588
|
+
del context[interface]
|
|
439
589
|
|
|
440
590
|
def reset(self) -> None:
|
|
441
591
|
"""Reset resolved instances."""
|
|
@@ -448,66 +598,38 @@ class Container:
|
|
|
448
598
|
continue
|
|
449
599
|
del context[interface]
|
|
450
600
|
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
def resolve(self, interface: type[T]) -> T:
|
|
458
|
-
"""Resolve an instance by interface."""
|
|
459
|
-
provider = self._get_or_register_provider(interface, None)
|
|
601
|
+
def _resolve_or_create(
|
|
602
|
+
self, interface: type[T], create: bool, /, **defaults: Any
|
|
603
|
+
) -> T:
|
|
604
|
+
"""Internal method to handle instance resolution and creation."""
|
|
605
|
+
provider = self._get_or_register_provider(interface, None, **defaults)
|
|
460
606
|
if provider.scope == "transient":
|
|
461
|
-
instance = self._create_instance(provider, None)
|
|
607
|
+
instance = self._create_instance(provider, None, **defaults)
|
|
462
608
|
else:
|
|
463
609
|
context = self._get_scoped_context(provider.scope)
|
|
464
610
|
if provider.scope == "singleton":
|
|
465
611
|
with self._singleton_lock:
|
|
466
|
-
instance =
|
|
612
|
+
instance = (
|
|
613
|
+
self._get_or_create_instance(provider, context)
|
|
614
|
+
if not create
|
|
615
|
+
else self._create_instance(provider, context, **defaults)
|
|
616
|
+
)
|
|
467
617
|
else:
|
|
468
|
-
instance =
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
@overload
|
|
474
|
-
async def aresolve(self, interface: type[T]) -> T: ...
|
|
475
|
-
|
|
476
|
-
@overload
|
|
477
|
-
async def aresolve(self, interface: T) -> T: ...
|
|
618
|
+
instance = (
|
|
619
|
+
self._get_or_create_instance(provider, context)
|
|
620
|
+
if not create
|
|
621
|
+
else self._create_instance(provider, context, **defaults)
|
|
622
|
+
)
|
|
478
623
|
|
|
479
|
-
async def aresolve(self, interface: type[T]) -> T:
|
|
480
|
-
"""Resolve an instance by interface asynchronously."""
|
|
481
|
-
provider = self._get_or_register_provider(interface, None)
|
|
482
|
-
if provider.scope == "transient":
|
|
483
|
-
instance = await self._acreate_instance(provider, None)
|
|
484
|
-
else:
|
|
485
|
-
context = self._get_scoped_context(provider.scope)
|
|
486
|
-
if provider.scope == "singleton":
|
|
487
|
-
async with self._singleton_async_lock:
|
|
488
|
-
instance = await self._aget_or_create_instance(provider, context)
|
|
489
|
-
else:
|
|
490
|
-
instance = await self._aget_or_create_instance(provider, context)
|
|
491
624
|
if self.testing:
|
|
492
|
-
instance = self._patch_test_resolver(interface, instance)
|
|
493
|
-
return cast(T, instance)
|
|
625
|
+
instance = self._patch_test_resolver(provider.interface, instance)
|
|
494
626
|
|
|
495
|
-
def create(self, interface: type[T], **defaults: Any) -> T:
|
|
496
|
-
"""Create an instance by interface."""
|
|
497
|
-
provider = self._get_or_register_provider(interface, None, **defaults)
|
|
498
|
-
if provider.scope == "transient":
|
|
499
|
-
instance = self._create_instance(provider, None, **defaults)
|
|
500
|
-
else:
|
|
501
|
-
context = self._get_scoped_context(provider.scope)
|
|
502
|
-
if provider.scope == "singleton":
|
|
503
|
-
with self._singleton_lock:
|
|
504
|
-
instance = self._create_instance(provider, context, **defaults)
|
|
505
|
-
else:
|
|
506
|
-
instance = self._create_instance(provider, context, **defaults)
|
|
507
627
|
return cast(T, instance)
|
|
508
628
|
|
|
509
|
-
async def
|
|
510
|
-
|
|
629
|
+
async def _aresolve_or_acreate(
|
|
630
|
+
self, interface: type[T], create: bool, /, **defaults: Any
|
|
631
|
+
) -> T:
|
|
632
|
+
"""Internal method to handle instance resolution and creation asynchronously."""
|
|
511
633
|
provider = self._get_or_register_provider(interface, None, **defaults)
|
|
512
634
|
if provider.scope == "transient":
|
|
513
635
|
instance = await self._acreate_instance(provider, None, **defaults)
|
|
@@ -515,11 +637,21 @@ class Container:
|
|
|
515
637
|
context = self._get_scoped_context(provider.scope)
|
|
516
638
|
if provider.scope == "singleton":
|
|
517
639
|
async with self._singleton_async_lock:
|
|
518
|
-
instance =
|
|
519
|
-
provider, context
|
|
640
|
+
instance = (
|
|
641
|
+
await self._aget_or_create_instance(provider, context)
|
|
642
|
+
if not create
|
|
643
|
+
else await self._acreate_instance(provider, context, **defaults)
|
|
520
644
|
)
|
|
521
645
|
else:
|
|
522
|
-
instance =
|
|
646
|
+
instance = (
|
|
647
|
+
await self._aget_or_create_instance(provider, context)
|
|
648
|
+
if not create
|
|
649
|
+
else await self._acreate_instance(provider, context, **defaults)
|
|
650
|
+
)
|
|
651
|
+
|
|
652
|
+
if self.testing:
|
|
653
|
+
instance = self._patch_test_resolver(provider.interface, instance)
|
|
654
|
+
|
|
523
655
|
return cast(T, instance)
|
|
524
656
|
|
|
525
657
|
def _get_or_create_instance(
|
|
@@ -696,26 +828,48 @@ class Container:
|
|
|
696
828
|
def _resolve_parameter(
|
|
697
829
|
self, provider: Provider, parameter: inspect.Parameter
|
|
698
830
|
) -> Any:
|
|
699
|
-
self._validate_resolvable_parameter(
|
|
831
|
+
self._validate_resolvable_parameter(provider, parameter)
|
|
700
832
|
return self.resolve(parameter.annotation)
|
|
701
833
|
|
|
702
834
|
async def _aresolve_parameter(
|
|
703
835
|
self, provider: Provider, parameter: inspect.Parameter
|
|
704
836
|
) -> Any:
|
|
705
|
-
self._validate_resolvable_parameter(
|
|
837
|
+
self._validate_resolvable_parameter(provider, parameter)
|
|
706
838
|
return await self.aresolve(parameter.annotation)
|
|
707
839
|
|
|
708
840
|
def _validate_resolvable_parameter(
|
|
709
|
-
self, parameter: inspect.Parameter
|
|
841
|
+
self, provider: Provider, parameter: inspect.Parameter
|
|
710
842
|
) -> None:
|
|
711
843
|
"""Ensure that the specified interface is resolved."""
|
|
712
844
|
if parameter.annotation in self._unresolved_interfaces:
|
|
713
845
|
raise LookupError(
|
|
714
846
|
f"You are attempting to get the parameter `{parameter.name}` with the "
|
|
715
847
|
f"annotation `{get_full_qualname(parameter.annotation)}` as a "
|
|
716
|
-
f"dependency into `{get_full_qualname(call)}` which is
|
|
717
|
-
"or set in the scoped context."
|
|
848
|
+
f"dependency into `{get_full_qualname(provider.call)}` which is "
|
|
849
|
+
"not registered or set in the scoped context."
|
|
850
|
+
)
|
|
851
|
+
|
|
852
|
+
@contextlib.contextmanager
|
|
853
|
+
def override(self, interface: AnyInterface, instance: Any) -> Iterator[None]:
|
|
854
|
+
"""
|
|
855
|
+
Override the provider for the specified interface with a specific instance.
|
|
856
|
+
"""
|
|
857
|
+
if not self.testing:
|
|
858
|
+
raise RuntimeError(
|
|
859
|
+
"The `override` method can only be used in testing mode."
|
|
718
860
|
)
|
|
861
|
+
if not self.is_registered(interface) and self.strict:
|
|
862
|
+
raise LookupError(
|
|
863
|
+
f"The provider interface `{get_full_qualname(interface)}` "
|
|
864
|
+
"not registered."
|
|
865
|
+
)
|
|
866
|
+
self._override_instances[interface] = instance
|
|
867
|
+
yield
|
|
868
|
+
del self._override_instances[interface]
|
|
869
|
+
|
|
870
|
+
############################
|
|
871
|
+
# Testing Methods
|
|
872
|
+
############################
|
|
719
873
|
|
|
720
874
|
def _patch_test_resolver(self, interface: type[Any], instance: Any) -> Any:
|
|
721
875
|
"""Patch the test resolver for the instance."""
|
|
@@ -765,60 +919,9 @@ class Container:
|
|
|
765
919
|
|
|
766
920
|
return instance
|
|
767
921
|
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
provider = self._get_provider(interface)
|
|
772
|
-
except LookupError:
|
|
773
|
-
return False
|
|
774
|
-
if provider.scope == "transient":
|
|
775
|
-
return False
|
|
776
|
-
context = self._get_scoped_context(provider.scope)
|
|
777
|
-
return interface in context
|
|
778
|
-
|
|
779
|
-
def release(self, interface: AnyInterface) -> None:
|
|
780
|
-
"""Release an instance by interface."""
|
|
781
|
-
provider = self._get_provider(interface)
|
|
782
|
-
if provider.scope == "transient":
|
|
783
|
-
return None
|
|
784
|
-
context = self._get_scoped_context(provider.scope)
|
|
785
|
-
del context[interface]
|
|
786
|
-
|
|
787
|
-
def _get_scoped_context(self, scope: Scope) -> InstanceContext:
|
|
788
|
-
"""Get the instance context for the specified scope."""
|
|
789
|
-
if scope == "singleton":
|
|
790
|
-
return self._singleton_context
|
|
791
|
-
return self._get_request_context()
|
|
792
|
-
|
|
793
|
-
@contextlib.contextmanager
|
|
794
|
-
def override(self, interface: AnyInterface, instance: Any) -> Iterator[None]:
|
|
795
|
-
"""
|
|
796
|
-
Override the provider for the specified interface with a specific instance.
|
|
797
|
-
"""
|
|
798
|
-
if not self.testing:
|
|
799
|
-
raise RuntimeError(
|
|
800
|
-
"The `override` method can only be used in testing mode."
|
|
801
|
-
)
|
|
802
|
-
if not self.is_registered(interface) and self.strict:
|
|
803
|
-
raise LookupError(
|
|
804
|
-
f"The provider interface `{get_full_qualname(interface)}` "
|
|
805
|
-
"not registered."
|
|
806
|
-
)
|
|
807
|
-
self._override_instances[interface] = instance
|
|
808
|
-
yield
|
|
809
|
-
del self._override_instances[interface]
|
|
810
|
-
|
|
811
|
-
def provider(
|
|
812
|
-
self, *, scope: Scope, override: bool = False
|
|
813
|
-
) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
|
814
|
-
"""Decorator to register a provider function with the specified scope."""
|
|
815
|
-
|
|
816
|
-
def decorator(call: Callable[P, T]) -> Callable[P, T]:
|
|
817
|
-
provider = Provider(call=call, scope=scope)
|
|
818
|
-
self._register_provider(provider, override)
|
|
819
|
-
return call
|
|
820
|
-
|
|
821
|
-
return decorator
|
|
922
|
+
############################
|
|
923
|
+
# Injector Methods
|
|
924
|
+
############################
|
|
822
925
|
|
|
823
926
|
@overload
|
|
824
927
|
def inject(self, func: Callable[P, T]) -> Callable[P, T]: ...
|
|
@@ -838,6 +941,10 @@ class Container:
|
|
|
838
941
|
return decorator
|
|
839
942
|
return decorator(func)
|
|
840
943
|
|
|
944
|
+
def run(self, func: Callable[P, T], /, *args: P.args, **kwargs: P.kwargs) -> T:
|
|
945
|
+
"""Run the given function with injected dependencies."""
|
|
946
|
+
return self._inject(func)(*args, **kwargs)
|
|
947
|
+
|
|
841
948
|
def _inject(self, call: Callable[P, T]) -> Callable[P, T]:
|
|
842
949
|
"""Inject dependencies into a callable."""
|
|
843
950
|
if call in self._inject_cache:
|
|
@@ -905,9 +1012,39 @@ class Container:
|
|
|
905
1012
|
f"`{get_full_qualname(parameter.annotation)}`."
|
|
906
1013
|
)
|
|
907
1014
|
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
1015
|
+
############################
|
|
1016
|
+
# Module Methods
|
|
1017
|
+
############################
|
|
1018
|
+
|
|
1019
|
+
def register_module(
|
|
1020
|
+
self, module: Module | type[Module] | Callable[[Container], None] | str
|
|
1021
|
+
) -> None:
|
|
1022
|
+
"""Register a module as a callable, module type, or module instance."""
|
|
1023
|
+
# Callable Module
|
|
1024
|
+
if inspect.isfunction(module):
|
|
1025
|
+
module(self)
|
|
1026
|
+
return
|
|
1027
|
+
|
|
1028
|
+
# Module path
|
|
1029
|
+
if isinstance(module, str):
|
|
1030
|
+
module = import_string(module)
|
|
1031
|
+
|
|
1032
|
+
# Class based Module or Module type
|
|
1033
|
+
if inspect.isclass(module) and issubclass(module, Module):
|
|
1034
|
+
module = module()
|
|
1035
|
+
|
|
1036
|
+
if isinstance(module, Module):
|
|
1037
|
+
module.configure(self)
|
|
1038
|
+
for provider_name, decorator_args in module.providers:
|
|
1039
|
+
obj = getattr(module, provider_name)
|
|
1040
|
+
self.provider(
|
|
1041
|
+
scope=decorator_args.scope,
|
|
1042
|
+
override=decorator_args.override,
|
|
1043
|
+
)(obj)
|
|
1044
|
+
|
|
1045
|
+
############################
|
|
1046
|
+
# Scanner Methods
|
|
1047
|
+
############################
|
|
911
1048
|
|
|
912
1049
|
def scan(
|
|
913
1050
|
self,
|
|
@@ -917,7 +1054,7 @@ class Container:
|
|
|
917
1054
|
tags: Iterable[str] | None = None,
|
|
918
1055
|
) -> None:
|
|
919
1056
|
"""Scan packages or modules for decorated members and inject dependencies."""
|
|
920
|
-
dependencies: list[
|
|
1057
|
+
dependencies: list[ScannedDependency] = []
|
|
921
1058
|
|
|
922
1059
|
if isinstance(packages, Iterable) and not isinstance(packages, str):
|
|
923
1060
|
scan_packages: Iterable[ModuleType | str] = packages
|
|
@@ -936,7 +1073,7 @@ class Container:
|
|
|
936
1073
|
package: ModuleType | str,
|
|
937
1074
|
*,
|
|
938
1075
|
tags: Iterable[str] | None = None,
|
|
939
|
-
) -> list[
|
|
1076
|
+
) -> list[ScannedDependency]:
|
|
940
1077
|
"""Scan a package or module for decorated members."""
|
|
941
1078
|
tags = tags or []
|
|
942
1079
|
if isinstance(package, str):
|
|
@@ -947,7 +1084,7 @@ class Container:
|
|
|
947
1084
|
if not package_path:
|
|
948
1085
|
return self._scan_module(package, tags=tags)
|
|
949
1086
|
|
|
950
|
-
dependencies: list[
|
|
1087
|
+
dependencies: list[ScannedDependency] = []
|
|
951
1088
|
|
|
952
1089
|
for module_info in pkgutil.walk_packages(
|
|
953
1090
|
path=package_path, prefix=package.__name__ + "."
|
|
@@ -959,9 +1096,9 @@ class Container:
|
|
|
959
1096
|
|
|
960
1097
|
def _scan_module(
|
|
961
1098
|
self, module: ModuleType, *, tags: Iterable[str]
|
|
962
|
-
) -> list[
|
|
1099
|
+
) -> list[ScannedDependency]:
|
|
963
1100
|
"""Scan a module for decorated members."""
|
|
964
|
-
dependencies: list[
|
|
1101
|
+
dependencies: list[ScannedDependency] = []
|
|
965
1102
|
|
|
966
1103
|
for _, member in inspect.getmembers(module):
|
|
967
1104
|
if getattr(member, "__module__", None) != module.__name__ or not callable(
|
|
@@ -984,7 +1121,7 @@ class Container:
|
|
|
984
1121
|
|
|
985
1122
|
if decorator_args.wrapped:
|
|
986
1123
|
dependencies.append(
|
|
987
|
-
self.
|
|
1124
|
+
self._create_scanned_dependency(member=member, module=module)
|
|
988
1125
|
)
|
|
989
1126
|
continue
|
|
990
1127
|
|
|
@@ -992,17 +1129,24 @@ class Container:
|
|
|
992
1129
|
for parameter in get_typed_parameters(member):
|
|
993
1130
|
if is_marker(parameter.default):
|
|
994
1131
|
dependencies.append(
|
|
995
|
-
self.
|
|
1132
|
+
self._create_scanned_dependency(member=member, module=module)
|
|
996
1133
|
)
|
|
997
1134
|
continue
|
|
998
1135
|
|
|
999
1136
|
return dependencies
|
|
1000
1137
|
|
|
1001
|
-
def
|
|
1138
|
+
def _create_scanned_dependency(
|
|
1139
|
+
self, member: Any, module: ModuleType
|
|
1140
|
+
) -> ScannedDependency:
|
|
1002
1141
|
"""Create a `Dependency` object from the scanned member and module."""
|
|
1003
1142
|
if hasattr(member, "__wrapped__"):
|
|
1004
1143
|
member = member.__wrapped__
|
|
1005
|
-
return
|
|
1144
|
+
return ScannedDependency(member=member, module=module)
|
|
1145
|
+
|
|
1146
|
+
|
|
1147
|
+
############################
|
|
1148
|
+
# Decorators
|
|
1149
|
+
############################
|
|
1006
1150
|
|
|
1007
1151
|
|
|
1008
1152
|
def transient(target: T) -> T:
|