anydi 0.34.0__py3-none-any.whl → 0.34.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- anydi/__init__.py +9 -3
- anydi/_container.py +466 -82
- anydi/_context.py +29 -268
- anydi/_provider.py +32 -4
- anydi/_types.py +19 -3
- anydi/ext/django/middleware.py +4 -4
- anydi/ext/starlette/middleware.py +2 -2
- {anydi-0.34.0.dist-info → anydi-0.34.1.dist-info}/METADATA +2 -2
- {anydi-0.34.0.dist-info → anydi-0.34.1.dist-info}/RECORD +12 -14
- anydi/_module.py +0 -94
- anydi/_scanner.py +0 -171
- {anydi-0.34.0.dist-info → anydi-0.34.1.dist-info}/LICENSE +0 -0
- {anydi-0.34.0.dist-info → anydi-0.34.1.dist-info}/WHEEL +0 -0
- {anydi-0.34.0.dist-info → anydi-0.34.1.dist-info}/entry_points.txt +0 -0
anydi/_container.py
CHANGED
|
@@ -4,41 +4,76 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
import contextlib
|
|
6
6
|
import functools
|
|
7
|
+
import importlib
|
|
7
8
|
import inspect
|
|
9
|
+
import pkgutil
|
|
8
10
|
import threading
|
|
9
11
|
import types
|
|
10
12
|
from collections import defaultdict
|
|
11
13
|
from collections.abc import AsyncIterator, Iterable, Iterator, Sequence
|
|
12
14
|
from contextvars import ContextVar
|
|
13
|
-
from
|
|
15
|
+
from types import ModuleType
|
|
16
|
+
from typing import Any, Callable, TypeVar, Union, cast, overload
|
|
14
17
|
from weakref import WeakKeyDictionary
|
|
15
18
|
|
|
16
|
-
from typing_extensions import ParamSpec, Self, final
|
|
19
|
+
from typing_extensions import Concatenate, ParamSpec, Self, final
|
|
17
20
|
|
|
18
|
-
from ._context import
|
|
19
|
-
RequestContext,
|
|
20
|
-
ResourceScopedContext,
|
|
21
|
-
ScopedContext,
|
|
22
|
-
SingletonContext,
|
|
23
|
-
TransientContext,
|
|
24
|
-
)
|
|
21
|
+
from ._context import InstanceContext
|
|
25
22
|
from ._logger import logger
|
|
26
|
-
from ._module import Module, ModuleRegistry
|
|
27
23
|
from ._provider import Provider
|
|
28
|
-
from .
|
|
29
|
-
|
|
30
|
-
|
|
24
|
+
from ._types import (
|
|
25
|
+
AnyInterface,
|
|
26
|
+
Dependency,
|
|
27
|
+
DependencyWrapper,
|
|
28
|
+
InjectableDecoratorArgs,
|
|
29
|
+
ProviderDecoratorArgs,
|
|
30
|
+
Scope,
|
|
31
|
+
is_event_type,
|
|
32
|
+
is_marker,
|
|
33
|
+
)
|
|
34
|
+
from ._utils import (
|
|
35
|
+
AsyncRLock,
|
|
36
|
+
get_full_qualname,
|
|
37
|
+
get_typed_parameters,
|
|
38
|
+
import_string,
|
|
39
|
+
is_async_context_manager,
|
|
40
|
+
is_builtin_type,
|
|
41
|
+
is_context_manager,
|
|
42
|
+
run_async,
|
|
43
|
+
)
|
|
31
44
|
|
|
32
45
|
T = TypeVar("T", bound=Any)
|
|
46
|
+
M = TypeVar("M", bound="Module")
|
|
33
47
|
P = ParamSpec("P")
|
|
34
48
|
|
|
35
49
|
ALLOWED_SCOPES: dict[Scope, list[Scope]] = {
|
|
36
50
|
"singleton": ["singleton"],
|
|
37
51
|
"request": ["request", "singleton"],
|
|
38
|
-
"transient": ["transient", "
|
|
52
|
+
"transient": ["transient", "request", "singleton"],
|
|
39
53
|
}
|
|
40
54
|
|
|
41
55
|
|
|
56
|
+
class ModuleMeta(type):
|
|
57
|
+
"""A metaclass used for the Module base class."""
|
|
58
|
+
|
|
59
|
+
def __new__(cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any]) -> Any:
|
|
60
|
+
attrs["providers"] = [
|
|
61
|
+
(name, getattr(value, "__provider__"))
|
|
62
|
+
for name, value in attrs.items()
|
|
63
|
+
if hasattr(value, "__provider__")
|
|
64
|
+
]
|
|
65
|
+
return super().__new__(cls, name, bases, attrs)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class Module(metaclass=ModuleMeta):
|
|
69
|
+
"""A base class for defining AnyDI modules."""
|
|
70
|
+
|
|
71
|
+
providers: list[tuple[str, ProviderDecoratorArgs]]
|
|
72
|
+
|
|
73
|
+
def configure(self, container: Container) -> None:
|
|
74
|
+
"""Configure the AnyDI container with providers and their dependencies."""
|
|
75
|
+
|
|
76
|
+
|
|
42
77
|
@final
|
|
43
78
|
class Container:
|
|
44
79
|
"""AnyDI is a dependency injection container."""
|
|
@@ -53,12 +88,11 @@ class Container:
|
|
|
53
88
|
testing: bool = False,
|
|
54
89
|
) -> None:
|
|
55
90
|
self._providers: dict[type[Any], Provider] = {}
|
|
56
|
-
self.
|
|
57
|
-
self._singleton_context =
|
|
91
|
+
self._resources: dict[str, list[type[Any]]] = defaultdict(list)
|
|
92
|
+
self._singleton_context = InstanceContext()
|
|
58
93
|
self._singleton_lock = threading.RLock()
|
|
59
94
|
self._singleton_async_lock = AsyncRLock()
|
|
60
|
-
self.
|
|
61
|
-
self._request_context_var: ContextVar[RequestContext | None] = ContextVar(
|
|
95
|
+
self._request_context_var: ContextVar[InstanceContext | None] = ContextVar(
|
|
62
96
|
"request_context", default=None
|
|
63
97
|
)
|
|
64
98
|
self._override_instances: dict[type[Any], Any] = {}
|
|
@@ -69,10 +103,6 @@ class Container:
|
|
|
69
103
|
Callable[..., Any], Callable[..., Any]
|
|
70
104
|
] = WeakKeyDictionary()
|
|
71
105
|
|
|
72
|
-
# Components
|
|
73
|
-
self._modules = ModuleRegistry(self)
|
|
74
|
-
self._scanner = Scanner(self)
|
|
75
|
-
|
|
76
106
|
# Register providers
|
|
77
107
|
providers = providers or []
|
|
78
108
|
for provider in providers:
|
|
@@ -142,14 +172,14 @@ class Container:
|
|
|
142
172
|
|
|
143
173
|
provider = self._get_provider(interface)
|
|
144
174
|
|
|
145
|
-
# Cleanup
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
175
|
+
# Cleanup instance context
|
|
176
|
+
if provider.scope != "transient":
|
|
177
|
+
try:
|
|
178
|
+
context = self._get_scoped_context(provider.scope)
|
|
179
|
+
except LookupError:
|
|
180
|
+
pass
|
|
181
|
+
else:
|
|
182
|
+
del context[interface]
|
|
153
183
|
|
|
154
184
|
# Cleanup provider references
|
|
155
185
|
self._delete_provider(provider)
|
|
@@ -190,14 +220,14 @@ class Container:
|
|
|
190
220
|
"""Set a provider by interface."""
|
|
191
221
|
self._providers[provider.interface] = provider
|
|
192
222
|
if provider.is_resource:
|
|
193
|
-
self.
|
|
223
|
+
self._resources[provider.scope].append(provider.interface)
|
|
194
224
|
|
|
195
225
|
def _delete_provider(self, provider: Provider) -> None:
|
|
196
226
|
"""Delete a provider."""
|
|
197
227
|
if provider.interface in self._providers:
|
|
198
228
|
del self._providers[provider.interface]
|
|
199
229
|
if provider.is_resource:
|
|
200
|
-
self.
|
|
230
|
+
self._resources[provider.scope].remove(provider.interface)
|
|
201
231
|
|
|
202
232
|
def _validate_sub_providers(self, provider: Provider) -> None:
|
|
203
233
|
"""Validate the sub-providers of a provider."""
|
|
@@ -271,7 +301,27 @@ class Container:
|
|
|
271
301
|
self, module: Module | type[Module] | Callable[[Container], None] | str
|
|
272
302
|
) -> None:
|
|
273
303
|
"""Register a module as a callable, module type, or module instance."""
|
|
274
|
-
|
|
304
|
+
# Callable Module
|
|
305
|
+
if inspect.isfunction(module):
|
|
306
|
+
module(self)
|
|
307
|
+
return
|
|
308
|
+
|
|
309
|
+
# Module path
|
|
310
|
+
if isinstance(module, str):
|
|
311
|
+
module = import_string(module)
|
|
312
|
+
|
|
313
|
+
# Class based Module or Module type
|
|
314
|
+
if inspect.isclass(module) and issubclass(module, Module):
|
|
315
|
+
module = module()
|
|
316
|
+
|
|
317
|
+
if isinstance(module, Module):
|
|
318
|
+
module.configure(self)
|
|
319
|
+
for provider_name, decorator_args in module.providers:
|
|
320
|
+
obj = getattr(module, provider_name)
|
|
321
|
+
self.provider(
|
|
322
|
+
scope=decorator_args.scope,
|
|
323
|
+
override=decorator_args.override,
|
|
324
|
+
)(obj)
|
|
275
325
|
|
|
276
326
|
def __enter__(self) -> Self:
|
|
277
327
|
"""Enter the singleton context."""
|
|
@@ -283,23 +333,33 @@ class Container:
|
|
|
283
333
|
exc_type: type[BaseException] | None,
|
|
284
334
|
exc_val: BaseException | None,
|
|
285
335
|
exc_tb: types.TracebackType | None,
|
|
286
|
-
) ->
|
|
336
|
+
) -> Any:
|
|
287
337
|
"""Exit the singleton context."""
|
|
288
338
|
return self._singleton_context.__exit__(exc_type, exc_val, exc_tb)
|
|
289
339
|
|
|
290
340
|
def start(self) -> None:
|
|
291
341
|
"""Start the singleton context."""
|
|
292
|
-
|
|
342
|
+
# Resolve all singleton resources
|
|
343
|
+
for interface in self._resources.get("singleton", []):
|
|
344
|
+
self.resolve(interface)
|
|
293
345
|
|
|
294
346
|
def close(self) -> None:
|
|
295
347
|
"""Close the singleton context."""
|
|
296
348
|
self._singleton_context.close()
|
|
297
349
|
|
|
298
350
|
@contextlib.contextmanager
|
|
299
|
-
def request_context(self) -> Iterator[
|
|
351
|
+
def request_context(self) -> Iterator[InstanceContext]:
|
|
300
352
|
"""Obtain a context manager for the request-scoped context."""
|
|
301
|
-
context =
|
|
353
|
+
context = InstanceContext()
|
|
354
|
+
|
|
302
355
|
token = self._request_context_var.set(context)
|
|
356
|
+
|
|
357
|
+
# Resolve all request resources
|
|
358
|
+
for interface in self._resources.get("request", []):
|
|
359
|
+
if not is_event_type(interface):
|
|
360
|
+
continue
|
|
361
|
+
self.resolve(interface)
|
|
362
|
+
|
|
303
363
|
with context:
|
|
304
364
|
yield context
|
|
305
365
|
self._request_context_var.reset(token)
|
|
@@ -320,22 +380,30 @@ class Container:
|
|
|
320
380
|
|
|
321
381
|
async def astart(self) -> None:
|
|
322
382
|
"""Start the singleton context asynchronously."""
|
|
323
|
-
|
|
383
|
+
for interface in self._resources.get("singleton", []):
|
|
384
|
+
await self.aresolve(interface)
|
|
324
385
|
|
|
325
386
|
async def aclose(self) -> None:
|
|
326
387
|
"""Close the singleton context asynchronously."""
|
|
327
388
|
await self._singleton_context.aclose()
|
|
328
389
|
|
|
329
390
|
@contextlib.asynccontextmanager
|
|
330
|
-
async def arequest_context(self) -> AsyncIterator[
|
|
391
|
+
async def arequest_context(self) -> AsyncIterator[InstanceContext]:
|
|
331
392
|
"""Obtain an async context manager for the request-scoped context."""
|
|
332
|
-
context =
|
|
393
|
+
context = InstanceContext()
|
|
394
|
+
|
|
333
395
|
token = self._request_context_var.set(context)
|
|
396
|
+
|
|
397
|
+
for interface in self._resources.get("request", []):
|
|
398
|
+
if not is_event_type(interface):
|
|
399
|
+
continue
|
|
400
|
+
await self.aresolve(interface)
|
|
401
|
+
|
|
334
402
|
async with context:
|
|
335
403
|
yield context
|
|
336
404
|
self._request_context_var.reset(token)
|
|
337
405
|
|
|
338
|
-
def _get_request_context(self) ->
|
|
406
|
+
def _get_request_context(self) -> InstanceContext:
|
|
339
407
|
"""Get the current request context."""
|
|
340
408
|
request_context = self._request_context_var.get()
|
|
341
409
|
if request_context is None:
|
|
@@ -349,57 +417,248 @@ class Container:
|
|
|
349
417
|
def reset(self) -> None:
|
|
350
418
|
"""Reset resolved instances."""
|
|
351
419
|
for interface, provider in self._providers.items():
|
|
420
|
+
if provider.scope == "transient":
|
|
421
|
+
continue
|
|
352
422
|
try:
|
|
353
|
-
|
|
423
|
+
context = self._get_scoped_context(provider.scope)
|
|
354
424
|
except LookupError:
|
|
355
425
|
continue
|
|
356
|
-
|
|
357
|
-
scoped_context.delete(interface)
|
|
426
|
+
del context[interface]
|
|
358
427
|
|
|
359
428
|
@overload
|
|
360
|
-
def resolve(self, interface:
|
|
429
|
+
def resolve(self, interface: type[T]) -> T: ...
|
|
361
430
|
|
|
362
431
|
@overload
|
|
363
432
|
def resolve(self, interface: T) -> T: ...
|
|
364
433
|
|
|
365
|
-
def resolve(self, interface:
|
|
434
|
+
def resolve(self, interface: type[T]) -> T:
|
|
366
435
|
"""Resolve an instance by interface."""
|
|
367
436
|
if interface in self._override_instances:
|
|
368
437
|
return cast(T, self._override_instances[interface])
|
|
369
438
|
|
|
370
439
|
provider = self._get_or_register_provider(interface)
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
with self._singleton_lock:
|
|
374
|
-
instance, created = scoped_context.get_or_create(provider)
|
|
440
|
+
if provider.scope == "transient":
|
|
441
|
+
instance, created = self._create_instance(provider), True
|
|
375
442
|
else:
|
|
376
|
-
|
|
443
|
+
context = self._get_scoped_context(provider.scope)
|
|
444
|
+
if provider.scope == "singleton":
|
|
445
|
+
with self._singleton_lock:
|
|
446
|
+
instance, created = self._get_or_create_instance(
|
|
447
|
+
provider, context=context
|
|
448
|
+
)
|
|
449
|
+
else:
|
|
450
|
+
instance, created = self._get_or_create_instance(
|
|
451
|
+
provider, context=context
|
|
452
|
+
)
|
|
377
453
|
if self.testing and created:
|
|
378
454
|
self._patch_test_resolver(instance)
|
|
379
455
|
return cast(T, instance)
|
|
380
456
|
|
|
381
457
|
@overload
|
|
382
|
-
async def aresolve(self, interface:
|
|
458
|
+
async def aresolve(self, interface: type[T]) -> T: ...
|
|
383
459
|
|
|
384
460
|
@overload
|
|
385
461
|
async def aresolve(self, interface: T) -> T: ...
|
|
386
462
|
|
|
387
|
-
async def aresolve(self, interface:
|
|
463
|
+
async def aresolve(self, interface: type[T]) -> T:
|
|
388
464
|
"""Resolve an instance by interface asynchronously."""
|
|
389
465
|
if interface in self._override_instances:
|
|
390
466
|
return cast(T, self._override_instances[interface])
|
|
391
467
|
|
|
392
468
|
provider = self._get_or_register_provider(interface)
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
async with self._singleton_async_lock:
|
|
396
|
-
instance, created = await scoped_context.aget_or_create(provider)
|
|
469
|
+
if provider.scope == "transient":
|
|
470
|
+
instance, created = await self._acreate_instance(provider), True
|
|
397
471
|
else:
|
|
398
|
-
|
|
472
|
+
context = self._get_scoped_context(provider.scope)
|
|
473
|
+
if provider.scope == "singleton":
|
|
474
|
+
async with self._singleton_async_lock:
|
|
475
|
+
instance, created = await self._aget_or_create_instance(
|
|
476
|
+
provider, context=context
|
|
477
|
+
)
|
|
478
|
+
else:
|
|
479
|
+
instance, created = await self._aget_or_create_instance(
|
|
480
|
+
provider, context=context
|
|
481
|
+
)
|
|
399
482
|
if self.testing and created:
|
|
400
483
|
self._patch_test_resolver(instance)
|
|
401
484
|
return cast(T, instance)
|
|
402
485
|
|
|
486
|
+
def _get_or_create_instance(
|
|
487
|
+
self, provider: Provider, context: InstanceContext
|
|
488
|
+
) -> tuple[Any, bool]:
|
|
489
|
+
"""Get an instance of a dependency from the scoped context."""
|
|
490
|
+
instance = context.get(provider.interface)
|
|
491
|
+
if instance is None:
|
|
492
|
+
instance = self._create_instance(provider, context=context)
|
|
493
|
+
context.set(provider.interface, instance)
|
|
494
|
+
return instance, True
|
|
495
|
+
return instance, False
|
|
496
|
+
|
|
497
|
+
async def _aget_or_create_instance(
|
|
498
|
+
self, provider: Provider, context: InstanceContext
|
|
499
|
+
) -> tuple[Any, bool]:
|
|
500
|
+
"""Get an async instance of a dependency from the scoped context."""
|
|
501
|
+
instance = context.get(provider.interface)
|
|
502
|
+
if instance is None:
|
|
503
|
+
instance = await self._acreate_instance(provider, context=context)
|
|
504
|
+
context.set(provider.interface, instance)
|
|
505
|
+
return instance, True
|
|
506
|
+
return instance, False
|
|
507
|
+
|
|
508
|
+
def _create_instance(
|
|
509
|
+
self,
|
|
510
|
+
provider: Provider,
|
|
511
|
+
context: InstanceContext | None = None,
|
|
512
|
+
) -> Any:
|
|
513
|
+
"""Create an instance using the provider."""
|
|
514
|
+
if provider.is_async:
|
|
515
|
+
raise TypeError(
|
|
516
|
+
f"The instance for the provider `{provider}` cannot be created in "
|
|
517
|
+
"synchronous mode."
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
args, kwargs = self._get_provided_args(provider, context=context)
|
|
521
|
+
|
|
522
|
+
if provider.is_generator:
|
|
523
|
+
if context is None:
|
|
524
|
+
raise ValueError("The context is required for generator providers.")
|
|
525
|
+
cm = contextlib.contextmanager(provider.call)(*args, **kwargs)
|
|
526
|
+
return context.enter(cm)
|
|
527
|
+
|
|
528
|
+
instance = provider.call(*args, **kwargs)
|
|
529
|
+
if context is not None and is_context_manager(instance):
|
|
530
|
+
context.enter(instance)
|
|
531
|
+
return instance
|
|
532
|
+
|
|
533
|
+
async def _acreate_instance(
|
|
534
|
+
self,
|
|
535
|
+
provider: Provider,
|
|
536
|
+
context: InstanceContext | None = None,
|
|
537
|
+
) -> Any:
|
|
538
|
+
"""Create an instance asynchronously using the provider."""
|
|
539
|
+
args, kwargs = await self._aget_provided_args(provider, context=context)
|
|
540
|
+
|
|
541
|
+
if provider.is_coroutine:
|
|
542
|
+
instance = await provider.call(*args, **kwargs)
|
|
543
|
+
if context is not None and is_async_context_manager(instance):
|
|
544
|
+
await context.aenter(instance)
|
|
545
|
+
return instance
|
|
546
|
+
|
|
547
|
+
if provider.is_async_generator:
|
|
548
|
+
if context is None:
|
|
549
|
+
raise ValueError(
|
|
550
|
+
"The async stack is required for async generator providers."
|
|
551
|
+
)
|
|
552
|
+
cm = contextlib.asynccontextmanager(provider.call)(*args, **kwargs)
|
|
553
|
+
return await context.aenter(cm)
|
|
554
|
+
|
|
555
|
+
if provider.is_generator:
|
|
556
|
+
|
|
557
|
+
def _create() -> Any:
|
|
558
|
+
if context is None:
|
|
559
|
+
raise ValueError("The stack is required for generator providers.")
|
|
560
|
+
cm = contextlib.contextmanager(provider.call)(*args, **kwargs)
|
|
561
|
+
return context.enter(cm)
|
|
562
|
+
|
|
563
|
+
return await run_async(_create)
|
|
564
|
+
|
|
565
|
+
instance = await run_async(provider.call, *args, **kwargs)
|
|
566
|
+
if context is not None and is_async_context_manager(instance):
|
|
567
|
+
await context.aenter(instance)
|
|
568
|
+
return instance
|
|
569
|
+
|
|
570
|
+
def _get_provided_args(
|
|
571
|
+
self,
|
|
572
|
+
provider: Provider,
|
|
573
|
+
context: InstanceContext | None,
|
|
574
|
+
*args: Any,
|
|
575
|
+
**kwargs: Any,
|
|
576
|
+
) -> tuple[list[Any], dict[str, Any]]:
|
|
577
|
+
"""Retrieve the arguments for a provider."""
|
|
578
|
+
provided_args: list[Any] = []
|
|
579
|
+
provided_kwargs: dict[str, Any] = {}
|
|
580
|
+
|
|
581
|
+
for parameter in provider.parameters:
|
|
582
|
+
if parameter.annotation in self._override_instances:
|
|
583
|
+
instance = self._override_instances[parameter.annotation]
|
|
584
|
+
elif context and parameter.annotation in context:
|
|
585
|
+
instance = context[parameter.annotation]
|
|
586
|
+
else:
|
|
587
|
+
try:
|
|
588
|
+
instance = self._resolve_parameter(provider, parameter)
|
|
589
|
+
except LookupError:
|
|
590
|
+
if parameter.default is inspect.Parameter.empty:
|
|
591
|
+
raise
|
|
592
|
+
instance = parameter.default
|
|
593
|
+
else:
|
|
594
|
+
if self.testing:
|
|
595
|
+
instance = DependencyWrapper(
|
|
596
|
+
interface=parameter.annotation, instance=instance
|
|
597
|
+
)
|
|
598
|
+
if parameter.kind == parameter.POSITIONAL_ONLY:
|
|
599
|
+
provided_args.append(instance)
|
|
600
|
+
else:
|
|
601
|
+
provided_kwargs[parameter.name] = instance
|
|
602
|
+
return provided_args, provided_kwargs
|
|
603
|
+
|
|
604
|
+
async def _aget_provided_args(
|
|
605
|
+
self,
|
|
606
|
+
provider: Provider,
|
|
607
|
+
context: InstanceContext | None,
|
|
608
|
+
*args: Any,
|
|
609
|
+
**kwargs: Any,
|
|
610
|
+
) -> tuple[list[Any], dict[str, Any]]:
|
|
611
|
+
"""Asynchronously retrieve the arguments for a provider."""
|
|
612
|
+
provided_args: list[Any] = []
|
|
613
|
+
provided_kwargs: dict[str, Any] = {}
|
|
614
|
+
|
|
615
|
+
for parameter in provider.parameters:
|
|
616
|
+
if parameter.annotation in self._override_instances:
|
|
617
|
+
instance = self._override_instances[parameter.annotation]
|
|
618
|
+
elif context and parameter.annotation in context:
|
|
619
|
+
instance = context[parameter.annotation]
|
|
620
|
+
else:
|
|
621
|
+
try:
|
|
622
|
+
instance = await self._aresolve_parameter(provider, parameter)
|
|
623
|
+
except LookupError:
|
|
624
|
+
if parameter.default is inspect.Parameter.empty:
|
|
625
|
+
raise
|
|
626
|
+
instance = parameter.default
|
|
627
|
+
else:
|
|
628
|
+
if self.testing:
|
|
629
|
+
instance = DependencyWrapper(
|
|
630
|
+
interface=parameter.annotation, instance=instance
|
|
631
|
+
)
|
|
632
|
+
if parameter.kind == parameter.POSITIONAL_ONLY:
|
|
633
|
+
provided_args.append(instance)
|
|
634
|
+
else:
|
|
635
|
+
provided_kwargs[parameter.name] = instance
|
|
636
|
+
return provided_args, provided_kwargs
|
|
637
|
+
|
|
638
|
+
def _resolve_parameter(
|
|
639
|
+
self, provider: Provider, parameter: inspect.Parameter
|
|
640
|
+
) -> Any:
|
|
641
|
+
self._validate_resolvable_parameter(parameter, call=provider.call)
|
|
642
|
+
return self.resolve(parameter.annotation)
|
|
643
|
+
|
|
644
|
+
async def _aresolve_parameter(
|
|
645
|
+
self, provider: Provider, parameter: inspect.Parameter
|
|
646
|
+
) -> Any:
|
|
647
|
+
self._validate_resolvable_parameter(parameter, call=provider.call)
|
|
648
|
+
return await self.aresolve(parameter.annotation)
|
|
649
|
+
|
|
650
|
+
def _validate_resolvable_parameter(
|
|
651
|
+
self, parameter: inspect.Parameter, call: Callable[..., Any]
|
|
652
|
+
) -> None:
|
|
653
|
+
"""Ensure that the specified interface is resolved."""
|
|
654
|
+
if parameter.annotation in self._unresolved_interfaces:
|
|
655
|
+
raise LookupError(
|
|
656
|
+
f"You are attempting to get the parameter `{parameter.name}` with the "
|
|
657
|
+
f"annotation `{get_full_qualname(parameter.annotation)}` as a "
|
|
658
|
+
f"dependency into `{get_full_qualname(call)}` which is not registered "
|
|
659
|
+
"or set in the scoped context."
|
|
660
|
+
)
|
|
661
|
+
|
|
403
662
|
def _patch_test_resolver(self, instance: Any) -> None:
|
|
404
663
|
"""Patch the test resolver for the instance."""
|
|
405
664
|
if not hasattr(instance, "__dict__"):
|
|
@@ -428,28 +687,25 @@ class Container:
|
|
|
428
687
|
try:
|
|
429
688
|
provider = self._get_provider(interface)
|
|
430
689
|
except LookupError:
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
return False
|
|
690
|
+
return False
|
|
691
|
+
if provider.scope == "transient":
|
|
692
|
+
return False
|
|
693
|
+
context = self._get_scoped_context(provider.scope)
|
|
694
|
+
return interface in context
|
|
437
695
|
|
|
438
696
|
def release(self, interface: AnyInterface) -> None:
|
|
439
697
|
"""Release an instance by interface."""
|
|
440
698
|
provider = self._get_provider(interface)
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
699
|
+
if provider.scope == "transient":
|
|
700
|
+
return None
|
|
701
|
+
context = self._get_scoped_context(provider.scope)
|
|
702
|
+
del context[interface]
|
|
444
703
|
|
|
445
|
-
def _get_scoped_context(self, scope: Scope) ->
|
|
446
|
-
"""Get the
|
|
704
|
+
def _get_scoped_context(self, scope: Scope) -> InstanceContext:
|
|
705
|
+
"""Get the instance context for the specified scope."""
|
|
447
706
|
if scope == "singleton":
|
|
448
707
|
return self._singleton_context
|
|
449
|
-
|
|
450
|
-
request_context = self._get_request_context()
|
|
451
|
-
return request_context
|
|
452
|
-
return self._transient_context
|
|
708
|
+
return self._get_request_context()
|
|
453
709
|
|
|
454
710
|
@contextlib.contextmanager
|
|
455
711
|
def override(self, interface: AnyInterface, instance: Any) -> Iterator[None]:
|
|
@@ -488,19 +744,14 @@ class Container:
|
|
|
488
744
|
) -> Callable[[Callable[P, T]], Callable[P, T]] | Callable[P, T]:
|
|
489
745
|
"""Decorator to inject dependencies into a callable."""
|
|
490
746
|
|
|
491
|
-
def decorator(
|
|
492
|
-
call: Callable[P, T],
|
|
493
|
-
) -> Callable[P, T]:
|
|
747
|
+
def decorator(call: Callable[P, T]) -> Callable[P, T]:
|
|
494
748
|
return self._inject(call)
|
|
495
749
|
|
|
496
750
|
if func is None:
|
|
497
751
|
return decorator
|
|
498
752
|
return decorator(func)
|
|
499
753
|
|
|
500
|
-
def _inject(
|
|
501
|
-
self,
|
|
502
|
-
call: Callable[P, T],
|
|
503
|
-
) -> Callable[P, T]:
|
|
754
|
+
def _inject(self, call: Callable[P, T]) -> Callable[P, T]:
|
|
504
755
|
"""Inject dependencies into a callable."""
|
|
505
756
|
if call in self._inject_cache:
|
|
506
757
|
return cast(Callable[P, T], self._inject_cache[call])
|
|
@@ -574,12 +825,97 @@ class Container:
|
|
|
574
825
|
def scan(
|
|
575
826
|
self,
|
|
576
827
|
/,
|
|
577
|
-
packages:
|
|
828
|
+
packages: ModuleType | str | Iterable[ModuleType | str],
|
|
578
829
|
*,
|
|
579
830
|
tags: Iterable[str] | None = None,
|
|
580
831
|
) -> None:
|
|
581
832
|
"""Scan packages or modules for decorated members and inject dependencies."""
|
|
582
|
-
|
|
833
|
+
dependencies: list[Dependency] = []
|
|
834
|
+
|
|
835
|
+
if isinstance(packages, Iterable) and not isinstance(packages, str):
|
|
836
|
+
scan_packages: Iterable[ModuleType | str] = packages
|
|
837
|
+
else:
|
|
838
|
+
scan_packages = cast(Iterable[Union[ModuleType, str]], [packages])
|
|
839
|
+
|
|
840
|
+
for package in scan_packages:
|
|
841
|
+
dependencies.extend(self._scan_package(package, tags=tags))
|
|
842
|
+
|
|
843
|
+
for dependency in dependencies:
|
|
844
|
+
decorator = self.inject()(dependency.member)
|
|
845
|
+
setattr(dependency.module, dependency.member.__name__, decorator)
|
|
846
|
+
|
|
847
|
+
def _scan_package(
|
|
848
|
+
self,
|
|
849
|
+
package: ModuleType | str,
|
|
850
|
+
*,
|
|
851
|
+
tags: Iterable[str] | None = None,
|
|
852
|
+
) -> list[Dependency]:
|
|
853
|
+
"""Scan a package or module for decorated members."""
|
|
854
|
+
tags = tags or []
|
|
855
|
+
if isinstance(package, str):
|
|
856
|
+
package = importlib.import_module(package)
|
|
857
|
+
|
|
858
|
+
package_path = getattr(package, "__path__", None)
|
|
859
|
+
|
|
860
|
+
if not package_path:
|
|
861
|
+
return self._scan_module(package, tags=tags)
|
|
862
|
+
|
|
863
|
+
dependencies: list[Dependency] = []
|
|
864
|
+
|
|
865
|
+
for module_info in pkgutil.walk_packages(
|
|
866
|
+
path=package_path, prefix=package.__name__ + "."
|
|
867
|
+
):
|
|
868
|
+
module = importlib.import_module(module_info.name)
|
|
869
|
+
dependencies.extend(self._scan_module(module, tags=tags))
|
|
870
|
+
|
|
871
|
+
return dependencies
|
|
872
|
+
|
|
873
|
+
def _scan_module(
|
|
874
|
+
self, module: ModuleType, *, tags: Iterable[str]
|
|
875
|
+
) -> list[Dependency]:
|
|
876
|
+
"""Scan a module for decorated members."""
|
|
877
|
+
dependencies: list[Dependency] = []
|
|
878
|
+
|
|
879
|
+
for _, member in inspect.getmembers(module):
|
|
880
|
+
if getattr(member, "__module__", None) != module.__name__ or not callable(
|
|
881
|
+
member
|
|
882
|
+
):
|
|
883
|
+
continue
|
|
884
|
+
|
|
885
|
+
decorator_args: InjectableDecoratorArgs = getattr(
|
|
886
|
+
member,
|
|
887
|
+
"__injectable__",
|
|
888
|
+
InjectableDecoratorArgs(wrapped=False, tags=[]),
|
|
889
|
+
)
|
|
890
|
+
|
|
891
|
+
if tags and (
|
|
892
|
+
decorator_args.tags
|
|
893
|
+
and not set(decorator_args.tags).intersection(tags)
|
|
894
|
+
or not decorator_args.tags
|
|
895
|
+
):
|
|
896
|
+
continue
|
|
897
|
+
|
|
898
|
+
if decorator_args.wrapped:
|
|
899
|
+
dependencies.append(
|
|
900
|
+
self._create_dependency(member=member, module=module)
|
|
901
|
+
)
|
|
902
|
+
continue
|
|
903
|
+
|
|
904
|
+
# Get by Marker
|
|
905
|
+
for parameter in get_typed_parameters(member):
|
|
906
|
+
if is_marker(parameter.default):
|
|
907
|
+
dependencies.append(
|
|
908
|
+
self._create_dependency(member=member, module=module)
|
|
909
|
+
)
|
|
910
|
+
continue
|
|
911
|
+
|
|
912
|
+
return dependencies
|
|
913
|
+
|
|
914
|
+
def _create_dependency(self, member: Any, module: ModuleType) -> Dependency:
|
|
915
|
+
"""Create a `Dependency` object from the scanned member and module."""
|
|
916
|
+
if hasattr(member, "__wrapped__"):
|
|
917
|
+
member = member.__wrapped__
|
|
918
|
+
return Dependency(member=member, module=module)
|
|
583
919
|
|
|
584
920
|
|
|
585
921
|
def transient(target: T) -> T:
|
|
@@ -598,3 +934,51 @@ def singleton(target: T) -> T:
|
|
|
598
934
|
"""Decorator for marking a class as singleton scope."""
|
|
599
935
|
setattr(target, "__scope__", "singleton")
|
|
600
936
|
return target
|
|
937
|
+
|
|
938
|
+
|
|
939
|
+
def provider(
|
|
940
|
+
*, scope: Scope, override: bool = False
|
|
941
|
+
) -> Callable[[Callable[Concatenate[M, P], T]], Callable[Concatenate[M, P], T]]:
|
|
942
|
+
"""Decorator for marking a function or method as a provider in a AnyDI module."""
|
|
943
|
+
|
|
944
|
+
def decorator(
|
|
945
|
+
target: Callable[Concatenate[M, P], T],
|
|
946
|
+
) -> Callable[Concatenate[M, P], T]:
|
|
947
|
+
setattr(
|
|
948
|
+
target,
|
|
949
|
+
"__provider__",
|
|
950
|
+
ProviderDecoratorArgs(scope=scope, override=override),
|
|
951
|
+
)
|
|
952
|
+
return target
|
|
953
|
+
|
|
954
|
+
return decorator
|
|
955
|
+
|
|
956
|
+
|
|
957
|
+
@overload
|
|
958
|
+
def injectable(func: Callable[P, T]) -> Callable[P, T]: ...
|
|
959
|
+
|
|
960
|
+
|
|
961
|
+
@overload
|
|
962
|
+
def injectable(
|
|
963
|
+
*, tags: Iterable[str] | None = None
|
|
964
|
+
) -> Callable[[Callable[P, T]], Callable[P, T]]: ...
|
|
965
|
+
|
|
966
|
+
|
|
967
|
+
def injectable(
|
|
968
|
+
func: Callable[P, T] | None = None,
|
|
969
|
+
tags: Iterable[str] | None = None,
|
|
970
|
+
) -> Callable[[Callable[P, T]], Callable[P, T]] | Callable[P, T]:
|
|
971
|
+
"""Decorator for marking a function or method as requiring dependency injection."""
|
|
972
|
+
|
|
973
|
+
def decorator(inner: Callable[P, T]) -> Callable[P, T]:
|
|
974
|
+
setattr(
|
|
975
|
+
inner,
|
|
976
|
+
"__injectable__",
|
|
977
|
+
InjectableDecoratorArgs(wrapped=True, tags=tags),
|
|
978
|
+
)
|
|
979
|
+
return inner
|
|
980
|
+
|
|
981
|
+
if func is None:
|
|
982
|
+
return decorator
|
|
983
|
+
|
|
984
|
+
return decorator(func)
|