anydi 0.33.1__py3-none-any.whl → 0.34.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- anydi/__init__.py +9 -3
- anydi/_container.py +471 -76
- anydi/_context.py +29 -265
- anydi/_provider.py +32 -4
- anydi/_types.py +19 -3
- anydi/_utils.py +49 -12
- anydi/ext/django/middleware.py +4 -4
- anydi/ext/starlette/middleware.py +2 -2
- {anydi-0.33.1.dist-info → anydi-0.34.1.dist-info}/METADATA +2 -3
- {anydi-0.33.1.dist-info → anydi-0.34.1.dist-info}/RECORD +13 -15
- anydi/_module.py +0 -94
- anydi/_scanner.py +0 -171
- {anydi-0.33.1.dist-info → anydi-0.34.1.dist-info}/LICENSE +0 -0
- {anydi-0.33.1.dist-info → anydi-0.34.1.dist-info}/WHEEL +0 -0
- {anydi-0.33.1.dist-info → anydi-0.34.1.dist-info}/entry_points.txt +0 -0
anydi/_container.py
CHANGED
|
@@ -4,40 +4,76 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
import contextlib
|
|
6
6
|
import functools
|
|
7
|
+
import importlib
|
|
7
8
|
import inspect
|
|
9
|
+
import pkgutil
|
|
10
|
+
import threading
|
|
8
11
|
import types
|
|
9
12
|
from collections import defaultdict
|
|
10
13
|
from collections.abc import AsyncIterator, Iterable, Iterator, Sequence
|
|
11
14
|
from contextvars import ContextVar
|
|
12
|
-
from
|
|
15
|
+
from types import ModuleType
|
|
16
|
+
from typing import Any, Callable, TypeVar, Union, cast, overload
|
|
13
17
|
from weakref import WeakKeyDictionary
|
|
14
18
|
|
|
15
|
-
from typing_extensions import ParamSpec, Self, final
|
|
19
|
+
from typing_extensions import Concatenate, ParamSpec, Self, final
|
|
16
20
|
|
|
17
|
-
from ._context import
|
|
18
|
-
RequestContext,
|
|
19
|
-
ResourceScopedContext,
|
|
20
|
-
ScopedContext,
|
|
21
|
-
SingletonContext,
|
|
22
|
-
TransientContext,
|
|
23
|
-
)
|
|
21
|
+
from ._context import InstanceContext
|
|
24
22
|
from ._logger import logger
|
|
25
|
-
from ._module import Module, ModuleRegistry
|
|
26
23
|
from ._provider import Provider
|
|
27
|
-
from .
|
|
28
|
-
|
|
29
|
-
|
|
24
|
+
from ._types import (
|
|
25
|
+
AnyInterface,
|
|
26
|
+
Dependency,
|
|
27
|
+
DependencyWrapper,
|
|
28
|
+
InjectableDecoratorArgs,
|
|
29
|
+
ProviderDecoratorArgs,
|
|
30
|
+
Scope,
|
|
31
|
+
is_event_type,
|
|
32
|
+
is_marker,
|
|
33
|
+
)
|
|
34
|
+
from ._utils import (
|
|
35
|
+
AsyncRLock,
|
|
36
|
+
get_full_qualname,
|
|
37
|
+
get_typed_parameters,
|
|
38
|
+
import_string,
|
|
39
|
+
is_async_context_manager,
|
|
40
|
+
is_builtin_type,
|
|
41
|
+
is_context_manager,
|
|
42
|
+
run_async,
|
|
43
|
+
)
|
|
30
44
|
|
|
31
45
|
T = TypeVar("T", bound=Any)
|
|
46
|
+
M = TypeVar("M", bound="Module")
|
|
32
47
|
P = ParamSpec("P")
|
|
33
48
|
|
|
34
49
|
ALLOWED_SCOPES: dict[Scope, list[Scope]] = {
|
|
35
50
|
"singleton": ["singleton"],
|
|
36
51
|
"request": ["request", "singleton"],
|
|
37
|
-
"transient": ["transient", "
|
|
52
|
+
"transient": ["transient", "request", "singleton"],
|
|
38
53
|
}
|
|
39
54
|
|
|
40
55
|
|
|
56
|
+
class ModuleMeta(type):
|
|
57
|
+
"""A metaclass used for the Module base class."""
|
|
58
|
+
|
|
59
|
+
def __new__(cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any]) -> Any:
|
|
60
|
+
attrs["providers"] = [
|
|
61
|
+
(name, getattr(value, "__provider__"))
|
|
62
|
+
for name, value in attrs.items()
|
|
63
|
+
if hasattr(value, "__provider__")
|
|
64
|
+
]
|
|
65
|
+
return super().__new__(cls, name, bases, attrs)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class Module(metaclass=ModuleMeta):
|
|
69
|
+
"""A base class for defining AnyDI modules."""
|
|
70
|
+
|
|
71
|
+
providers: list[tuple[str, ProviderDecoratorArgs]]
|
|
72
|
+
|
|
73
|
+
def configure(self, container: Container) -> None:
|
|
74
|
+
"""Configure the AnyDI container with providers and their dependencies."""
|
|
75
|
+
|
|
76
|
+
|
|
41
77
|
@final
|
|
42
78
|
class Container:
|
|
43
79
|
"""AnyDI is a dependency injection container."""
|
|
@@ -52,10 +88,11 @@ class Container:
|
|
|
52
88
|
testing: bool = False,
|
|
53
89
|
) -> None:
|
|
54
90
|
self._providers: dict[type[Any], Provider] = {}
|
|
55
|
-
self.
|
|
56
|
-
self._singleton_context =
|
|
57
|
-
self.
|
|
58
|
-
self.
|
|
91
|
+
self._resources: dict[str, list[type[Any]]] = defaultdict(list)
|
|
92
|
+
self._singleton_context = InstanceContext()
|
|
93
|
+
self._singleton_lock = threading.RLock()
|
|
94
|
+
self._singleton_async_lock = AsyncRLock()
|
|
95
|
+
self._request_context_var: ContextVar[InstanceContext | None] = ContextVar(
|
|
59
96
|
"request_context", default=None
|
|
60
97
|
)
|
|
61
98
|
self._override_instances: dict[type[Any], Any] = {}
|
|
@@ -66,10 +103,6 @@ class Container:
|
|
|
66
103
|
Callable[..., Any], Callable[..., Any]
|
|
67
104
|
] = WeakKeyDictionary()
|
|
68
105
|
|
|
69
|
-
# Components
|
|
70
|
-
self._modules = ModuleRegistry(self)
|
|
71
|
-
self._scanner = Scanner(self)
|
|
72
|
-
|
|
73
106
|
# Register providers
|
|
74
107
|
providers = providers or []
|
|
75
108
|
for provider in providers:
|
|
@@ -139,14 +172,14 @@ class Container:
|
|
|
139
172
|
|
|
140
173
|
provider = self._get_provider(interface)
|
|
141
174
|
|
|
142
|
-
# Cleanup
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
175
|
+
# Cleanup instance context
|
|
176
|
+
if provider.scope != "transient":
|
|
177
|
+
try:
|
|
178
|
+
context = self._get_scoped_context(provider.scope)
|
|
179
|
+
except LookupError:
|
|
180
|
+
pass
|
|
181
|
+
else:
|
|
182
|
+
del context[interface]
|
|
150
183
|
|
|
151
184
|
# Cleanup provider references
|
|
152
185
|
self._delete_provider(provider)
|
|
@@ -187,14 +220,14 @@ class Container:
|
|
|
187
220
|
"""Set a provider by interface."""
|
|
188
221
|
self._providers[provider.interface] = provider
|
|
189
222
|
if provider.is_resource:
|
|
190
|
-
self.
|
|
223
|
+
self._resources[provider.scope].append(provider.interface)
|
|
191
224
|
|
|
192
225
|
def _delete_provider(self, provider: Provider) -> None:
|
|
193
226
|
"""Delete a provider."""
|
|
194
227
|
if provider.interface in self._providers:
|
|
195
228
|
del self._providers[provider.interface]
|
|
196
229
|
if provider.is_resource:
|
|
197
|
-
self.
|
|
230
|
+
self._resources[provider.scope].remove(provider.interface)
|
|
198
231
|
|
|
199
232
|
def _validate_sub_providers(self, provider: Provider) -> None:
|
|
200
233
|
"""Validate the sub-providers of a provider."""
|
|
@@ -268,7 +301,27 @@ class Container:
|
|
|
268
301
|
self, module: Module | type[Module] | Callable[[Container], None] | str
|
|
269
302
|
) -> None:
|
|
270
303
|
"""Register a module as a callable, module type, or module instance."""
|
|
271
|
-
|
|
304
|
+
# Callable Module
|
|
305
|
+
if inspect.isfunction(module):
|
|
306
|
+
module(self)
|
|
307
|
+
return
|
|
308
|
+
|
|
309
|
+
# Module path
|
|
310
|
+
if isinstance(module, str):
|
|
311
|
+
module = import_string(module)
|
|
312
|
+
|
|
313
|
+
# Class based Module or Module type
|
|
314
|
+
if inspect.isclass(module) and issubclass(module, Module):
|
|
315
|
+
module = module()
|
|
316
|
+
|
|
317
|
+
if isinstance(module, Module):
|
|
318
|
+
module.configure(self)
|
|
319
|
+
for provider_name, decorator_args in module.providers:
|
|
320
|
+
obj = getattr(module, provider_name)
|
|
321
|
+
self.provider(
|
|
322
|
+
scope=decorator_args.scope,
|
|
323
|
+
override=decorator_args.override,
|
|
324
|
+
)(obj)
|
|
272
325
|
|
|
273
326
|
def __enter__(self) -> Self:
|
|
274
327
|
"""Enter the singleton context."""
|
|
@@ -280,23 +333,33 @@ class Container:
|
|
|
280
333
|
exc_type: type[BaseException] | None,
|
|
281
334
|
exc_val: BaseException | None,
|
|
282
335
|
exc_tb: types.TracebackType | None,
|
|
283
|
-
) ->
|
|
336
|
+
) -> Any:
|
|
284
337
|
"""Exit the singleton context."""
|
|
285
338
|
return self._singleton_context.__exit__(exc_type, exc_val, exc_tb)
|
|
286
339
|
|
|
287
340
|
def start(self) -> None:
|
|
288
341
|
"""Start the singleton context."""
|
|
289
|
-
|
|
342
|
+
# Resolve all singleton resources
|
|
343
|
+
for interface in self._resources.get("singleton", []):
|
|
344
|
+
self.resolve(interface)
|
|
290
345
|
|
|
291
346
|
def close(self) -> None:
|
|
292
347
|
"""Close the singleton context."""
|
|
293
348
|
self._singleton_context.close()
|
|
294
349
|
|
|
295
350
|
@contextlib.contextmanager
|
|
296
|
-
def request_context(self) -> Iterator[
|
|
351
|
+
def request_context(self) -> Iterator[InstanceContext]:
|
|
297
352
|
"""Obtain a context manager for the request-scoped context."""
|
|
298
|
-
context =
|
|
353
|
+
context = InstanceContext()
|
|
354
|
+
|
|
299
355
|
token = self._request_context_var.set(context)
|
|
356
|
+
|
|
357
|
+
# Resolve all request resources
|
|
358
|
+
for interface in self._resources.get("request", []):
|
|
359
|
+
if not is_event_type(interface):
|
|
360
|
+
continue
|
|
361
|
+
self.resolve(interface)
|
|
362
|
+
|
|
300
363
|
with context:
|
|
301
364
|
yield context
|
|
302
365
|
self._request_context_var.reset(token)
|
|
@@ -317,22 +380,30 @@ class Container:
|
|
|
317
380
|
|
|
318
381
|
async def astart(self) -> None:
|
|
319
382
|
"""Start the singleton context asynchronously."""
|
|
320
|
-
|
|
383
|
+
for interface in self._resources.get("singleton", []):
|
|
384
|
+
await self.aresolve(interface)
|
|
321
385
|
|
|
322
386
|
async def aclose(self) -> None:
|
|
323
387
|
"""Close the singleton context asynchronously."""
|
|
324
388
|
await self._singleton_context.aclose()
|
|
325
389
|
|
|
326
390
|
@contextlib.asynccontextmanager
|
|
327
|
-
async def arequest_context(self) -> AsyncIterator[
|
|
391
|
+
async def arequest_context(self) -> AsyncIterator[InstanceContext]:
|
|
328
392
|
"""Obtain an async context manager for the request-scoped context."""
|
|
329
|
-
context =
|
|
393
|
+
context = InstanceContext()
|
|
394
|
+
|
|
330
395
|
token = self._request_context_var.set(context)
|
|
396
|
+
|
|
397
|
+
for interface in self._resources.get("request", []):
|
|
398
|
+
if not is_event_type(interface):
|
|
399
|
+
continue
|
|
400
|
+
await self.aresolve(interface)
|
|
401
|
+
|
|
331
402
|
async with context:
|
|
332
403
|
yield context
|
|
333
404
|
self._request_context_var.reset(token)
|
|
334
405
|
|
|
335
|
-
def _get_request_context(self) ->
|
|
406
|
+
def _get_request_context(self) -> InstanceContext:
|
|
336
407
|
"""Get the current request context."""
|
|
337
408
|
request_context = self._request_context_var.get()
|
|
338
409
|
if request_context is None:
|
|
@@ -346,49 +417,248 @@ class Container:
|
|
|
346
417
|
def reset(self) -> None:
|
|
347
418
|
"""Reset resolved instances."""
|
|
348
419
|
for interface, provider in self._providers.items():
|
|
420
|
+
if provider.scope == "transient":
|
|
421
|
+
continue
|
|
349
422
|
try:
|
|
350
|
-
|
|
423
|
+
context = self._get_scoped_context(provider.scope)
|
|
351
424
|
except LookupError:
|
|
352
425
|
continue
|
|
353
|
-
|
|
354
|
-
scoped_context.delete(interface)
|
|
426
|
+
del context[interface]
|
|
355
427
|
|
|
356
428
|
@overload
|
|
357
|
-
def resolve(self, interface:
|
|
429
|
+
def resolve(self, interface: type[T]) -> T: ...
|
|
358
430
|
|
|
359
431
|
@overload
|
|
360
432
|
def resolve(self, interface: T) -> T: ...
|
|
361
433
|
|
|
362
|
-
def resolve(self, interface:
|
|
434
|
+
def resolve(self, interface: type[T]) -> T:
|
|
363
435
|
"""Resolve an instance by interface."""
|
|
364
436
|
if interface in self._override_instances:
|
|
365
437
|
return cast(T, self._override_instances[interface])
|
|
366
438
|
|
|
367
439
|
provider = self._get_or_register_provider(interface)
|
|
368
|
-
|
|
369
|
-
|
|
440
|
+
if provider.scope == "transient":
|
|
441
|
+
instance, created = self._create_instance(provider), True
|
|
442
|
+
else:
|
|
443
|
+
context = self._get_scoped_context(provider.scope)
|
|
444
|
+
if provider.scope == "singleton":
|
|
445
|
+
with self._singleton_lock:
|
|
446
|
+
instance, created = self._get_or_create_instance(
|
|
447
|
+
provider, context=context
|
|
448
|
+
)
|
|
449
|
+
else:
|
|
450
|
+
instance, created = self._get_or_create_instance(
|
|
451
|
+
provider, context=context
|
|
452
|
+
)
|
|
370
453
|
if self.testing and created:
|
|
371
454
|
self._patch_test_resolver(instance)
|
|
372
455
|
return cast(T, instance)
|
|
373
456
|
|
|
374
457
|
@overload
|
|
375
|
-
async def aresolve(self, interface:
|
|
458
|
+
async def aresolve(self, interface: type[T]) -> T: ...
|
|
376
459
|
|
|
377
460
|
@overload
|
|
378
461
|
async def aresolve(self, interface: T) -> T: ...
|
|
379
462
|
|
|
380
|
-
async def aresolve(self, interface:
|
|
463
|
+
async def aresolve(self, interface: type[T]) -> T:
|
|
381
464
|
"""Resolve an instance by interface asynchronously."""
|
|
382
465
|
if interface in self._override_instances:
|
|
383
466
|
return cast(T, self._override_instances[interface])
|
|
384
467
|
|
|
385
468
|
provider = self._get_or_register_provider(interface)
|
|
386
|
-
|
|
387
|
-
|
|
469
|
+
if provider.scope == "transient":
|
|
470
|
+
instance, created = await self._acreate_instance(provider), True
|
|
471
|
+
else:
|
|
472
|
+
context = self._get_scoped_context(provider.scope)
|
|
473
|
+
if provider.scope == "singleton":
|
|
474
|
+
async with self._singleton_async_lock:
|
|
475
|
+
instance, created = await self._aget_or_create_instance(
|
|
476
|
+
provider, context=context
|
|
477
|
+
)
|
|
478
|
+
else:
|
|
479
|
+
instance, created = await self._aget_or_create_instance(
|
|
480
|
+
provider, context=context
|
|
481
|
+
)
|
|
388
482
|
if self.testing and created:
|
|
389
483
|
self._patch_test_resolver(instance)
|
|
390
484
|
return cast(T, instance)
|
|
391
485
|
|
|
486
|
+
def _get_or_create_instance(
|
|
487
|
+
self, provider: Provider, context: InstanceContext
|
|
488
|
+
) -> tuple[Any, bool]:
|
|
489
|
+
"""Get an instance of a dependency from the scoped context."""
|
|
490
|
+
instance = context.get(provider.interface)
|
|
491
|
+
if instance is None:
|
|
492
|
+
instance = self._create_instance(provider, context=context)
|
|
493
|
+
context.set(provider.interface, instance)
|
|
494
|
+
return instance, True
|
|
495
|
+
return instance, False
|
|
496
|
+
|
|
497
|
+
async def _aget_or_create_instance(
|
|
498
|
+
self, provider: Provider, context: InstanceContext
|
|
499
|
+
) -> tuple[Any, bool]:
|
|
500
|
+
"""Get an async instance of a dependency from the scoped context."""
|
|
501
|
+
instance = context.get(provider.interface)
|
|
502
|
+
if instance is None:
|
|
503
|
+
instance = await self._acreate_instance(provider, context=context)
|
|
504
|
+
context.set(provider.interface, instance)
|
|
505
|
+
return instance, True
|
|
506
|
+
return instance, False
|
|
507
|
+
|
|
508
|
+
def _create_instance(
|
|
509
|
+
self,
|
|
510
|
+
provider: Provider,
|
|
511
|
+
context: InstanceContext | None = None,
|
|
512
|
+
) -> Any:
|
|
513
|
+
"""Create an instance using the provider."""
|
|
514
|
+
if provider.is_async:
|
|
515
|
+
raise TypeError(
|
|
516
|
+
f"The instance for the provider `{provider}` cannot be created in "
|
|
517
|
+
"synchronous mode."
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
args, kwargs = self._get_provided_args(provider, context=context)
|
|
521
|
+
|
|
522
|
+
if provider.is_generator:
|
|
523
|
+
if context is None:
|
|
524
|
+
raise ValueError("The context is required for generator providers.")
|
|
525
|
+
cm = contextlib.contextmanager(provider.call)(*args, **kwargs)
|
|
526
|
+
return context.enter(cm)
|
|
527
|
+
|
|
528
|
+
instance = provider.call(*args, **kwargs)
|
|
529
|
+
if context is not None and is_context_manager(instance):
|
|
530
|
+
context.enter(instance)
|
|
531
|
+
return instance
|
|
532
|
+
|
|
533
|
+
async def _acreate_instance(
|
|
534
|
+
self,
|
|
535
|
+
provider: Provider,
|
|
536
|
+
context: InstanceContext | None = None,
|
|
537
|
+
) -> Any:
|
|
538
|
+
"""Create an instance asynchronously using the provider."""
|
|
539
|
+
args, kwargs = await self._aget_provided_args(provider, context=context)
|
|
540
|
+
|
|
541
|
+
if provider.is_coroutine:
|
|
542
|
+
instance = await provider.call(*args, **kwargs)
|
|
543
|
+
if context is not None and is_async_context_manager(instance):
|
|
544
|
+
await context.aenter(instance)
|
|
545
|
+
return instance
|
|
546
|
+
|
|
547
|
+
if provider.is_async_generator:
|
|
548
|
+
if context is None:
|
|
549
|
+
raise ValueError(
|
|
550
|
+
"The async stack is required for async generator providers."
|
|
551
|
+
)
|
|
552
|
+
cm = contextlib.asynccontextmanager(provider.call)(*args, **kwargs)
|
|
553
|
+
return await context.aenter(cm)
|
|
554
|
+
|
|
555
|
+
if provider.is_generator:
|
|
556
|
+
|
|
557
|
+
def _create() -> Any:
|
|
558
|
+
if context is None:
|
|
559
|
+
raise ValueError("The stack is required for generator providers.")
|
|
560
|
+
cm = contextlib.contextmanager(provider.call)(*args, **kwargs)
|
|
561
|
+
return context.enter(cm)
|
|
562
|
+
|
|
563
|
+
return await run_async(_create)
|
|
564
|
+
|
|
565
|
+
instance = await run_async(provider.call, *args, **kwargs)
|
|
566
|
+
if context is not None and is_async_context_manager(instance):
|
|
567
|
+
await context.aenter(instance)
|
|
568
|
+
return instance
|
|
569
|
+
|
|
570
|
+
def _get_provided_args(
|
|
571
|
+
self,
|
|
572
|
+
provider: Provider,
|
|
573
|
+
context: InstanceContext | None,
|
|
574
|
+
*args: Any,
|
|
575
|
+
**kwargs: Any,
|
|
576
|
+
) -> tuple[list[Any], dict[str, Any]]:
|
|
577
|
+
"""Retrieve the arguments for a provider."""
|
|
578
|
+
provided_args: list[Any] = []
|
|
579
|
+
provided_kwargs: dict[str, Any] = {}
|
|
580
|
+
|
|
581
|
+
for parameter in provider.parameters:
|
|
582
|
+
if parameter.annotation in self._override_instances:
|
|
583
|
+
instance = self._override_instances[parameter.annotation]
|
|
584
|
+
elif context and parameter.annotation in context:
|
|
585
|
+
instance = context[parameter.annotation]
|
|
586
|
+
else:
|
|
587
|
+
try:
|
|
588
|
+
instance = self._resolve_parameter(provider, parameter)
|
|
589
|
+
except LookupError:
|
|
590
|
+
if parameter.default is inspect.Parameter.empty:
|
|
591
|
+
raise
|
|
592
|
+
instance = parameter.default
|
|
593
|
+
else:
|
|
594
|
+
if self.testing:
|
|
595
|
+
instance = DependencyWrapper(
|
|
596
|
+
interface=parameter.annotation, instance=instance
|
|
597
|
+
)
|
|
598
|
+
if parameter.kind == parameter.POSITIONAL_ONLY:
|
|
599
|
+
provided_args.append(instance)
|
|
600
|
+
else:
|
|
601
|
+
provided_kwargs[parameter.name] = instance
|
|
602
|
+
return provided_args, provided_kwargs
|
|
603
|
+
|
|
604
|
+
async def _aget_provided_args(
|
|
605
|
+
self,
|
|
606
|
+
provider: Provider,
|
|
607
|
+
context: InstanceContext | None,
|
|
608
|
+
*args: Any,
|
|
609
|
+
**kwargs: Any,
|
|
610
|
+
) -> tuple[list[Any], dict[str, Any]]:
|
|
611
|
+
"""Asynchronously retrieve the arguments for a provider."""
|
|
612
|
+
provided_args: list[Any] = []
|
|
613
|
+
provided_kwargs: dict[str, Any] = {}
|
|
614
|
+
|
|
615
|
+
for parameter in provider.parameters:
|
|
616
|
+
if parameter.annotation in self._override_instances:
|
|
617
|
+
instance = self._override_instances[parameter.annotation]
|
|
618
|
+
elif context and parameter.annotation in context:
|
|
619
|
+
instance = context[parameter.annotation]
|
|
620
|
+
else:
|
|
621
|
+
try:
|
|
622
|
+
instance = await self._aresolve_parameter(provider, parameter)
|
|
623
|
+
except LookupError:
|
|
624
|
+
if parameter.default is inspect.Parameter.empty:
|
|
625
|
+
raise
|
|
626
|
+
instance = parameter.default
|
|
627
|
+
else:
|
|
628
|
+
if self.testing:
|
|
629
|
+
instance = DependencyWrapper(
|
|
630
|
+
interface=parameter.annotation, instance=instance
|
|
631
|
+
)
|
|
632
|
+
if parameter.kind == parameter.POSITIONAL_ONLY:
|
|
633
|
+
provided_args.append(instance)
|
|
634
|
+
else:
|
|
635
|
+
provided_kwargs[parameter.name] = instance
|
|
636
|
+
return provided_args, provided_kwargs
|
|
637
|
+
|
|
638
|
+
def _resolve_parameter(
|
|
639
|
+
self, provider: Provider, parameter: inspect.Parameter
|
|
640
|
+
) -> Any:
|
|
641
|
+
self._validate_resolvable_parameter(parameter, call=provider.call)
|
|
642
|
+
return self.resolve(parameter.annotation)
|
|
643
|
+
|
|
644
|
+
async def _aresolve_parameter(
|
|
645
|
+
self, provider: Provider, parameter: inspect.Parameter
|
|
646
|
+
) -> Any:
|
|
647
|
+
self._validate_resolvable_parameter(parameter, call=provider.call)
|
|
648
|
+
return await self.aresolve(parameter.annotation)
|
|
649
|
+
|
|
650
|
+
def _validate_resolvable_parameter(
|
|
651
|
+
self, parameter: inspect.Parameter, call: Callable[..., Any]
|
|
652
|
+
) -> None:
|
|
653
|
+
"""Ensure that the specified interface is resolved."""
|
|
654
|
+
if parameter.annotation in self._unresolved_interfaces:
|
|
655
|
+
raise LookupError(
|
|
656
|
+
f"You are attempting to get the parameter `{parameter.name}` with the "
|
|
657
|
+
f"annotation `{get_full_qualname(parameter.annotation)}` as a "
|
|
658
|
+
f"dependency into `{get_full_qualname(call)}` which is not registered "
|
|
659
|
+
"or set in the scoped context."
|
|
660
|
+
)
|
|
661
|
+
|
|
392
662
|
def _patch_test_resolver(self, instance: Any) -> None:
|
|
393
663
|
"""Patch the test resolver for the instance."""
|
|
394
664
|
if not hasattr(instance, "__dict__"):
|
|
@@ -417,28 +687,25 @@ class Container:
|
|
|
417
687
|
try:
|
|
418
688
|
provider = self._get_provider(interface)
|
|
419
689
|
except LookupError:
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
return False
|
|
690
|
+
return False
|
|
691
|
+
if provider.scope == "transient":
|
|
692
|
+
return False
|
|
693
|
+
context = self._get_scoped_context(provider.scope)
|
|
694
|
+
return interface in context
|
|
426
695
|
|
|
427
696
|
def release(self, interface: AnyInterface) -> None:
|
|
428
697
|
"""Release an instance by interface."""
|
|
429
698
|
provider = self._get_provider(interface)
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
699
|
+
if provider.scope == "transient":
|
|
700
|
+
return None
|
|
701
|
+
context = self._get_scoped_context(provider.scope)
|
|
702
|
+
del context[interface]
|
|
433
703
|
|
|
434
|
-
def _get_scoped_context(self, scope: Scope) ->
|
|
435
|
-
"""Get the
|
|
704
|
+
def _get_scoped_context(self, scope: Scope) -> InstanceContext:
|
|
705
|
+
"""Get the instance context for the specified scope."""
|
|
436
706
|
if scope == "singleton":
|
|
437
707
|
return self._singleton_context
|
|
438
|
-
|
|
439
|
-
request_context = self._get_request_context()
|
|
440
|
-
return request_context
|
|
441
|
-
return self._transient_context
|
|
708
|
+
return self._get_request_context()
|
|
442
709
|
|
|
443
710
|
@contextlib.contextmanager
|
|
444
711
|
def override(self, interface: AnyInterface, instance: Any) -> Iterator[None]:
|
|
@@ -477,19 +744,14 @@ class Container:
|
|
|
477
744
|
) -> Callable[[Callable[P, T]], Callable[P, T]] | Callable[P, T]:
|
|
478
745
|
"""Decorator to inject dependencies into a callable."""
|
|
479
746
|
|
|
480
|
-
def decorator(
|
|
481
|
-
call: Callable[P, T],
|
|
482
|
-
) -> Callable[P, T]:
|
|
747
|
+
def decorator(call: Callable[P, T]) -> Callable[P, T]:
|
|
483
748
|
return self._inject(call)
|
|
484
749
|
|
|
485
750
|
if func is None:
|
|
486
751
|
return decorator
|
|
487
752
|
return decorator(func)
|
|
488
753
|
|
|
489
|
-
def _inject(
|
|
490
|
-
self,
|
|
491
|
-
call: Callable[P, T],
|
|
492
|
-
) -> Callable[P, T]:
|
|
754
|
+
def _inject(self, call: Callable[P, T]) -> Callable[P, T]:
|
|
493
755
|
"""Inject dependencies into a callable."""
|
|
494
756
|
if call in self._inject_cache:
|
|
495
757
|
return cast(Callable[P, T], self._inject_cache[call])
|
|
@@ -563,12 +825,97 @@ class Container:
|
|
|
563
825
|
def scan(
|
|
564
826
|
self,
|
|
565
827
|
/,
|
|
566
|
-
packages:
|
|
828
|
+
packages: ModuleType | str | Iterable[ModuleType | str],
|
|
567
829
|
*,
|
|
568
830
|
tags: Iterable[str] | None = None,
|
|
569
831
|
) -> None:
|
|
570
832
|
"""Scan packages or modules for decorated members and inject dependencies."""
|
|
571
|
-
|
|
833
|
+
dependencies: list[Dependency] = []
|
|
834
|
+
|
|
835
|
+
if isinstance(packages, Iterable) and not isinstance(packages, str):
|
|
836
|
+
scan_packages: Iterable[ModuleType | str] = packages
|
|
837
|
+
else:
|
|
838
|
+
scan_packages = cast(Iterable[Union[ModuleType, str]], [packages])
|
|
839
|
+
|
|
840
|
+
for package in scan_packages:
|
|
841
|
+
dependencies.extend(self._scan_package(package, tags=tags))
|
|
842
|
+
|
|
843
|
+
for dependency in dependencies:
|
|
844
|
+
decorator = self.inject()(dependency.member)
|
|
845
|
+
setattr(dependency.module, dependency.member.__name__, decorator)
|
|
846
|
+
|
|
847
|
+
def _scan_package(
|
|
848
|
+
self,
|
|
849
|
+
package: ModuleType | str,
|
|
850
|
+
*,
|
|
851
|
+
tags: Iterable[str] | None = None,
|
|
852
|
+
) -> list[Dependency]:
|
|
853
|
+
"""Scan a package or module for decorated members."""
|
|
854
|
+
tags = tags or []
|
|
855
|
+
if isinstance(package, str):
|
|
856
|
+
package = importlib.import_module(package)
|
|
857
|
+
|
|
858
|
+
package_path = getattr(package, "__path__", None)
|
|
859
|
+
|
|
860
|
+
if not package_path:
|
|
861
|
+
return self._scan_module(package, tags=tags)
|
|
862
|
+
|
|
863
|
+
dependencies: list[Dependency] = []
|
|
864
|
+
|
|
865
|
+
for module_info in pkgutil.walk_packages(
|
|
866
|
+
path=package_path, prefix=package.__name__ + "."
|
|
867
|
+
):
|
|
868
|
+
module = importlib.import_module(module_info.name)
|
|
869
|
+
dependencies.extend(self._scan_module(module, tags=tags))
|
|
870
|
+
|
|
871
|
+
return dependencies
|
|
872
|
+
|
|
873
|
+
def _scan_module(
|
|
874
|
+
self, module: ModuleType, *, tags: Iterable[str]
|
|
875
|
+
) -> list[Dependency]:
|
|
876
|
+
"""Scan a module for decorated members."""
|
|
877
|
+
dependencies: list[Dependency] = []
|
|
878
|
+
|
|
879
|
+
for _, member in inspect.getmembers(module):
|
|
880
|
+
if getattr(member, "__module__", None) != module.__name__ or not callable(
|
|
881
|
+
member
|
|
882
|
+
):
|
|
883
|
+
continue
|
|
884
|
+
|
|
885
|
+
decorator_args: InjectableDecoratorArgs = getattr(
|
|
886
|
+
member,
|
|
887
|
+
"__injectable__",
|
|
888
|
+
InjectableDecoratorArgs(wrapped=False, tags=[]),
|
|
889
|
+
)
|
|
890
|
+
|
|
891
|
+
if tags and (
|
|
892
|
+
decorator_args.tags
|
|
893
|
+
and not set(decorator_args.tags).intersection(tags)
|
|
894
|
+
or not decorator_args.tags
|
|
895
|
+
):
|
|
896
|
+
continue
|
|
897
|
+
|
|
898
|
+
if decorator_args.wrapped:
|
|
899
|
+
dependencies.append(
|
|
900
|
+
self._create_dependency(member=member, module=module)
|
|
901
|
+
)
|
|
902
|
+
continue
|
|
903
|
+
|
|
904
|
+
# Get by Marker
|
|
905
|
+
for parameter in get_typed_parameters(member):
|
|
906
|
+
if is_marker(parameter.default):
|
|
907
|
+
dependencies.append(
|
|
908
|
+
self._create_dependency(member=member, module=module)
|
|
909
|
+
)
|
|
910
|
+
continue
|
|
911
|
+
|
|
912
|
+
return dependencies
|
|
913
|
+
|
|
914
|
+
def _create_dependency(self, member: Any, module: ModuleType) -> Dependency:
|
|
915
|
+
"""Create a `Dependency` object from the scanned member and module."""
|
|
916
|
+
if hasattr(member, "__wrapped__"):
|
|
917
|
+
member = member.__wrapped__
|
|
918
|
+
return Dependency(member=member, module=module)
|
|
572
919
|
|
|
573
920
|
|
|
574
921
|
def transient(target: T) -> T:
|
|
@@ -587,3 +934,51 @@ def singleton(target: T) -> T:
|
|
|
587
934
|
"""Decorator for marking a class as singleton scope."""
|
|
588
935
|
setattr(target, "__scope__", "singleton")
|
|
589
936
|
return target
|
|
937
|
+
|
|
938
|
+
|
|
939
|
+
def provider(
|
|
940
|
+
*, scope: Scope, override: bool = False
|
|
941
|
+
) -> Callable[[Callable[Concatenate[M, P], T]], Callable[Concatenate[M, P], T]]:
|
|
942
|
+
"""Decorator for marking a function or method as a provider in a AnyDI module."""
|
|
943
|
+
|
|
944
|
+
def decorator(
|
|
945
|
+
target: Callable[Concatenate[M, P], T],
|
|
946
|
+
) -> Callable[Concatenate[M, P], T]:
|
|
947
|
+
setattr(
|
|
948
|
+
target,
|
|
949
|
+
"__provider__",
|
|
950
|
+
ProviderDecoratorArgs(scope=scope, override=override),
|
|
951
|
+
)
|
|
952
|
+
return target
|
|
953
|
+
|
|
954
|
+
return decorator
|
|
955
|
+
|
|
956
|
+
|
|
957
|
+
@overload
|
|
958
|
+
def injectable(func: Callable[P, T]) -> Callable[P, T]: ...
|
|
959
|
+
|
|
960
|
+
|
|
961
|
+
@overload
|
|
962
|
+
def injectable(
|
|
963
|
+
*, tags: Iterable[str] | None = None
|
|
964
|
+
) -> Callable[[Callable[P, T]], Callable[P, T]]: ...
|
|
965
|
+
|
|
966
|
+
|
|
967
|
+
def injectable(
|
|
968
|
+
func: Callable[P, T] | None = None,
|
|
969
|
+
tags: Iterable[str] | None = None,
|
|
970
|
+
) -> Callable[[Callable[P, T]], Callable[P, T]] | Callable[P, T]:
|
|
971
|
+
"""Decorator for marking a function or method as requiring dependency injection."""
|
|
972
|
+
|
|
973
|
+
def decorator(inner: Callable[P, T]) -> Callable[P, T]:
|
|
974
|
+
setattr(
|
|
975
|
+
inner,
|
|
976
|
+
"__injectable__",
|
|
977
|
+
InjectableDecoratorArgs(wrapped=True, tags=tags),
|
|
978
|
+
)
|
|
979
|
+
return inner
|
|
980
|
+
|
|
981
|
+
if func is None:
|
|
982
|
+
return decorator
|
|
983
|
+
|
|
984
|
+
return decorator(func)
|