anydi 0.54.1__py3-none-any.whl → 0.55.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- anydi/__init__.py +1 -2
- anydi/{_async.py → _async_lock.py} +3 -13
- anydi/_container.py +85 -366
- anydi/_context.py +27 -9
- anydi/_decorators.py +1 -1
- anydi/_provider.py +9 -37
- anydi/_resolver.py +571 -0
- anydi/{_scan.py → _scanner.py} +1 -1
- anydi/{_typing.py → _types.py} +22 -21
- anydi/ext/fastapi.py +1 -1
- anydi/ext/faststream.py +1 -1
- anydi/testing.py +14 -48
- {anydi-0.54.1.dist-info → anydi-0.55.1.dist-info}/METADATA +10 -16
- anydi-0.55.1.dist-info/RECORD +24 -0
- anydi/_scope.py +0 -9
- anydi-0.54.1.dist-info/RECORD +0 -24
- {anydi-0.54.1.dist-info → anydi-0.55.1.dist-info}/WHEEL +0 -0
- {anydi-0.54.1.dist-info → anydi-0.55.1.dist-info}/entry_points.txt +0 -0
anydi/__init__.py
CHANGED
|
@@ -4,8 +4,7 @@ from ._container import Container
|
|
|
4
4
|
from ._decorators import injectable, provided, provider, request, singleton, transient
|
|
5
5
|
from ._module import Module
|
|
6
6
|
from ._provider import ProviderDef as Provider
|
|
7
|
-
from .
|
|
8
|
-
from ._typing import Inject
|
|
7
|
+
from ._types import Inject, Scope
|
|
9
8
|
|
|
10
9
|
# Alias for dependency auto marker
|
|
11
10
|
# TODO: deprecate it
|
|
@@ -1,18 +1,8 @@
|
|
|
1
|
-
import functools
|
|
2
|
-
from collections.abc import Callable
|
|
3
1
|
from types import TracebackType
|
|
4
|
-
from typing import Any
|
|
2
|
+
from typing import Any
|
|
5
3
|
|
|
6
|
-
import anyio
|
|
7
|
-
from typing_extensions import
|
|
8
|
-
|
|
9
|
-
T = TypeVar("T")
|
|
10
|
-
P = ParamSpec("P")
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
async def run_sync(func: Callable[P, T], /, *args: P.args, **kwargs: P.kwargs) -> T:
|
|
14
|
-
"""Runs the given function asynchronously using the `anyio` library."""
|
|
15
|
-
return await anyio.to_thread.run_sync(functools.partial(func, *args, **kwargs))
|
|
4
|
+
import anyio
|
|
5
|
+
from typing_extensions import Self
|
|
16
6
|
|
|
17
7
|
|
|
18
8
|
class AsyncRLock:
|
anydi/_container.py
CHANGED
|
@@ -15,18 +15,16 @@ from typing import Annotated, Any, TypeVar, cast, get_args, get_origin, overload
|
|
|
15
15
|
|
|
16
16
|
from typing_extensions import ParamSpec, Self, type_repr
|
|
17
17
|
|
|
18
|
-
from ._async import run_sync
|
|
19
18
|
from ._context import InstanceContext
|
|
20
19
|
from ._decorators import is_provided
|
|
21
20
|
from ._module import ModuleDef, ModuleRegistrar
|
|
22
21
|
from ._provider import Provider, ProviderDef, ProviderKind, ProviderParameter
|
|
23
|
-
from .
|
|
24
|
-
from .
|
|
25
|
-
from .
|
|
22
|
+
from ._resolver import Resolver
|
|
23
|
+
from ._scanner import PackageOrIterable, Scanner
|
|
24
|
+
from ._types import (
|
|
26
25
|
NOT_SET,
|
|
27
26
|
Event,
|
|
28
|
-
|
|
29
|
-
is_context_manager,
|
|
27
|
+
Scope,
|
|
30
28
|
is_event_type,
|
|
31
29
|
is_inject_marker,
|
|
32
30
|
is_iterator_type,
|
|
@@ -36,6 +34,12 @@ from ._typing import (
|
|
|
36
34
|
T = TypeVar("T", bound=Any)
|
|
37
35
|
P = ParamSpec("P")
|
|
38
36
|
|
|
37
|
+
ALLOWED_SCOPES: dict[Scope, list[Scope]] = {
|
|
38
|
+
"singleton": ["singleton"],
|
|
39
|
+
"request": ["request", "singleton"],
|
|
40
|
+
"transient": ["transient", "request", "singleton"],
|
|
41
|
+
}
|
|
42
|
+
|
|
39
43
|
|
|
40
44
|
class Container:
|
|
41
45
|
"""AnyDI is a dependency injection container."""
|
|
@@ -54,10 +58,10 @@ class Container:
|
|
|
54
58
|
self._request_context_var: ContextVar[InstanceContext | None] = ContextVar(
|
|
55
59
|
"request_context", default=None
|
|
56
60
|
)
|
|
57
|
-
self._unresolved_interfaces: set[Any] = set()
|
|
58
61
|
self._inject_cache: dict[Callable[..., Any], Callable[..., Any]] = {}
|
|
59
62
|
|
|
60
63
|
# Components
|
|
64
|
+
self._resolver = Resolver(self)
|
|
61
65
|
self._modules = ModuleRegistrar(self)
|
|
62
66
|
self._scanner = Scanner(self)
|
|
63
67
|
|
|
@@ -75,9 +79,7 @@ class Container:
|
|
|
75
79
|
for module in modules:
|
|
76
80
|
self.register_module(module)
|
|
77
81
|
|
|
78
|
-
|
|
79
|
-
# Properties
|
|
80
|
-
############################
|
|
82
|
+
# == Container Properties ==
|
|
81
83
|
|
|
82
84
|
@property
|
|
83
85
|
def providers(self) -> dict[type[Any], Provider]:
|
|
@@ -89,9 +91,7 @@ class Container:
|
|
|
89
91
|
"""Get the logger instance."""
|
|
90
92
|
return self._logger
|
|
91
93
|
|
|
92
|
-
|
|
93
|
-
# Lifespan/Context Methods
|
|
94
|
-
############################
|
|
94
|
+
# == Context & Lifespan Management ==
|
|
95
95
|
|
|
96
96
|
def __enter__(self) -> Self:
|
|
97
97
|
"""Enter the singleton context."""
|
|
@@ -190,9 +190,7 @@ class Container:
|
|
|
190
190
|
return self._singleton_context
|
|
191
191
|
return self._get_request_context()
|
|
192
192
|
|
|
193
|
-
|
|
194
|
-
# Provider Methods
|
|
195
|
-
############################
|
|
193
|
+
# == Provider Registry ==
|
|
196
194
|
|
|
197
195
|
def register(
|
|
198
196
|
self,
|
|
@@ -251,22 +249,23 @@ class Container:
|
|
|
251
249
|
scope: Scope,
|
|
252
250
|
interface: Any = NOT_SET,
|
|
253
251
|
override: bool = False,
|
|
254
|
-
|
|
255
|
-
**defaults: Any,
|
|
252
|
+
defaults: dict[str, Any] | None = None,
|
|
256
253
|
) -> Provider:
|
|
257
254
|
"""Register a provider with the specified scope."""
|
|
258
255
|
name = type_repr(call)
|
|
259
256
|
kind = ProviderKind.from_call(call)
|
|
257
|
+
is_class = kind == ProviderKind.CLASS
|
|
258
|
+
is_resource = kind in (ProviderKind.GENERATOR, ProviderKind.ASYNC_GENERATOR)
|
|
260
259
|
|
|
261
260
|
# Validate scope if it provided
|
|
262
|
-
self._validate_provider_scope(scope, name,
|
|
261
|
+
self._validate_provider_scope(scope, name, is_resource)
|
|
263
262
|
|
|
264
263
|
# Get the signature
|
|
265
264
|
signature = inspect.signature(call, eval_str=True)
|
|
266
265
|
|
|
267
266
|
# Detect the interface
|
|
268
267
|
if interface is NOT_SET:
|
|
269
|
-
if
|
|
268
|
+
if is_class:
|
|
270
269
|
interface = call
|
|
271
270
|
else:
|
|
272
271
|
interface = signature.return_annotation
|
|
@@ -313,10 +312,17 @@ class Container:
|
|
|
313
312
|
f"are not allowed in the provider `{name}`."
|
|
314
313
|
)
|
|
315
314
|
|
|
315
|
+
default = (
|
|
316
|
+
parameter.default
|
|
317
|
+
if parameter.default is not inspect.Parameter.empty
|
|
318
|
+
else NOT_SET
|
|
319
|
+
)
|
|
320
|
+
has_default = default is not NOT_SET
|
|
321
|
+
|
|
316
322
|
try:
|
|
317
323
|
sub_provider = self._get_or_register_provider(parameter.annotation)
|
|
318
324
|
except LookupError as exc:
|
|
319
|
-
if
|
|
325
|
+
if parameter.name in defaults if defaults else False or has_default:
|
|
320
326
|
continue
|
|
321
327
|
unresolved_parameter = parameter
|
|
322
328
|
unresolved_exc = exc
|
|
@@ -326,17 +332,12 @@ class Container:
|
|
|
326
332
|
if sub_provider.scope not in scopes:
|
|
327
333
|
scopes[sub_provider.scope] = sub_provider
|
|
328
334
|
|
|
329
|
-
default = (
|
|
330
|
-
parameter.default
|
|
331
|
-
if parameter.default is not inspect.Parameter.empty
|
|
332
|
-
else NOT_SET
|
|
333
|
-
)
|
|
334
335
|
parameters.append(
|
|
335
336
|
ProviderParameter(
|
|
336
337
|
name=parameter.name,
|
|
337
338
|
annotation=parameter.annotation,
|
|
338
339
|
default=default,
|
|
339
|
-
has_default=
|
|
340
|
+
has_default=has_default,
|
|
340
341
|
provider=sub_provider,
|
|
341
342
|
shared_scope=sub_provider.scope == scope and scope != "transient",
|
|
342
343
|
)
|
|
@@ -345,7 +346,7 @@ class Container:
|
|
|
345
346
|
# Check for unresolved parameters
|
|
346
347
|
if unresolved_parameter:
|
|
347
348
|
if scope not in ("singleton", "transient"):
|
|
348
|
-
self.
|
|
349
|
+
self._resolver.add_unresolved(interface)
|
|
349
350
|
else:
|
|
350
351
|
raise LookupError(
|
|
351
352
|
f"The provider `{name}` depends on `{unresolved_parameter.name}` "
|
|
@@ -364,28 +365,38 @@ class Container:
|
|
|
364
365
|
"Please ensure all providers are registered with matching scopes."
|
|
365
366
|
)
|
|
366
367
|
|
|
368
|
+
is_coroutine = kind == ProviderKind.COROUTINE
|
|
369
|
+
is_generator = kind == ProviderKind.GENERATOR
|
|
370
|
+
is_async_generator = kind == ProviderKind.ASYNC_GENERATOR
|
|
371
|
+
is_async = is_coroutine or is_async_generator
|
|
372
|
+
|
|
367
373
|
provider = Provider(
|
|
368
374
|
call=call,
|
|
369
375
|
scope=scope,
|
|
370
376
|
interface=interface,
|
|
371
377
|
name=name,
|
|
372
|
-
kind=kind,
|
|
373
378
|
parameters=tuple(parameters),
|
|
379
|
+
is_class=is_class,
|
|
380
|
+
is_coroutine=is_coroutine,
|
|
381
|
+
is_generator=is_generator,
|
|
382
|
+
is_async_generator=is_async_generator,
|
|
383
|
+
is_async=is_async,
|
|
384
|
+
is_resource=is_resource,
|
|
374
385
|
)
|
|
375
386
|
|
|
376
387
|
self._set_provider(provider)
|
|
377
388
|
return provider
|
|
378
389
|
|
|
379
390
|
@staticmethod
|
|
380
|
-
def _validate_provider_scope(scope: Scope, name: str,
|
|
391
|
+
def _validate_provider_scope(scope: Scope, name: str, is_resource: bool) -> None:
|
|
381
392
|
"""Validate the provider scope."""
|
|
382
|
-
if scope not in
|
|
393
|
+
if scope not in ALLOWED_SCOPES:
|
|
383
394
|
raise ValueError(
|
|
384
395
|
f"The provider `{name}` scope is invalid. Only the following "
|
|
385
|
-
f"scopes are supported: {', '.join(
|
|
396
|
+
f"scopes are supported: {', '.join(ALLOWED_SCOPES)}. "
|
|
386
397
|
"Please use one of the supported scopes when registering a provider."
|
|
387
398
|
)
|
|
388
|
-
if scope == "transient" and
|
|
399
|
+
if scope == "transient" and is_resource:
|
|
389
400
|
raise TypeError(
|
|
390
401
|
f"The resource provider `{name}` is attempting to register "
|
|
391
402
|
"with a transient scope, which is not allowed."
|
|
@@ -402,7 +413,9 @@ class Container:
|
|
|
402
413
|
"properly registered before attempting to use it."
|
|
403
414
|
) from exc
|
|
404
415
|
|
|
405
|
-
def _get_or_register_provider(
|
|
416
|
+
def _get_or_register_provider(
|
|
417
|
+
self, interface: Any, defaults: dict[str, Any] | None = None
|
|
418
|
+
) -> Provider:
|
|
406
419
|
"""Get or register a provider by interface."""
|
|
407
420
|
try:
|
|
408
421
|
return self._providers[interface]
|
|
@@ -412,7 +425,8 @@ class Container:
|
|
|
412
425
|
interface,
|
|
413
426
|
interface.__provided__["scope"],
|
|
414
427
|
NOT_SET,
|
|
415
|
-
|
|
428
|
+
False,
|
|
429
|
+
defaults,
|
|
416
430
|
)
|
|
417
431
|
raise LookupError(
|
|
418
432
|
f"The provider interface `{type_repr(interface)}` is either not "
|
|
@@ -434,17 +448,7 @@ class Container:
|
|
|
434
448
|
if provider.is_resource:
|
|
435
449
|
self._resources[provider.scope].remove(provider.interface)
|
|
436
450
|
|
|
437
|
-
|
|
438
|
-
def _parameter_has_default(
|
|
439
|
-
parameter: inspect.Parameter, /, **defaults: Any
|
|
440
|
-
) -> bool:
|
|
441
|
-
has_default_in_kwargs = parameter.name in defaults if defaults else False
|
|
442
|
-
has_default = parameter.default is not inspect.Parameter.empty
|
|
443
|
-
return has_default_in_kwargs or has_default
|
|
444
|
-
|
|
445
|
-
############################
|
|
446
|
-
# Instance Methods
|
|
447
|
-
############################
|
|
451
|
+
# == Instance Resolution ==
|
|
448
452
|
|
|
449
453
|
@overload
|
|
450
454
|
def resolve(self, interface: type[T]) -> T: ...
|
|
@@ -453,8 +457,14 @@ class Container:
|
|
|
453
457
|
def resolve(self, interface: T) -> T: ... # type: ignore
|
|
454
458
|
|
|
455
459
|
def resolve(self, interface: type[T]) -> T:
|
|
456
|
-
"""Resolve an instance by interface."""
|
|
457
|
-
|
|
460
|
+
"""Resolve an instance by interface using compiled sync resolver."""
|
|
461
|
+
cached = self._resolver.get_cached(interface, is_async=False)
|
|
462
|
+
if cached is not None:
|
|
463
|
+
return cached.resolve(self)
|
|
464
|
+
|
|
465
|
+
provider = self._get_or_register_provider(interface)
|
|
466
|
+
compiled = self._resolver.compile(provider, is_async=False)
|
|
467
|
+
return compiled.resolve(self)
|
|
458
468
|
|
|
459
469
|
@overload
|
|
460
470
|
async def aresolve(self, interface: type[T]) -> T: ...
|
|
@@ -464,15 +474,35 @@ class Container:
|
|
|
464
474
|
|
|
465
475
|
async def aresolve(self, interface: type[T]) -> T:
|
|
466
476
|
"""Resolve an instance by interface asynchronously."""
|
|
467
|
-
|
|
477
|
+
cached = self._resolver.get_cached(interface, is_async=True)
|
|
478
|
+
if cached is not None:
|
|
479
|
+
return await cached.resolve(self)
|
|
480
|
+
|
|
481
|
+
provider = self._get_or_register_provider(interface)
|
|
482
|
+
compiled = self._resolver.compile(provider, is_async=True)
|
|
483
|
+
return await compiled.resolve(self)
|
|
468
484
|
|
|
469
485
|
def create(self, interface: type[T], /, **defaults: Any) -> T:
|
|
470
486
|
"""Create an instance by interface."""
|
|
471
|
-
|
|
487
|
+
if not defaults:
|
|
488
|
+
cached = self._resolver.get_cached(interface, is_async=False)
|
|
489
|
+
if cached is not None:
|
|
490
|
+
return cached.create(self, None)
|
|
491
|
+
|
|
492
|
+
provider = self._get_or_register_provider(interface, defaults)
|
|
493
|
+
compiled = self._resolver.compile(provider, is_async=False)
|
|
494
|
+
return compiled.create(self, defaults or None)
|
|
472
495
|
|
|
473
496
|
async def acreate(self, interface: type[T], /, **defaults: Any) -> T:
|
|
474
497
|
"""Create an instance by interface asynchronously."""
|
|
475
|
-
|
|
498
|
+
if not defaults:
|
|
499
|
+
cached = self._resolver.get_cached(interface, is_async=True)
|
|
500
|
+
if cached is not None:
|
|
501
|
+
return await cached.create(self, None)
|
|
502
|
+
|
|
503
|
+
provider = self._get_or_register_provider(interface, defaults)
|
|
504
|
+
compiled = self._resolver.compile(provider, is_async=True)
|
|
505
|
+
return await compiled.create(self, defaults or None)
|
|
476
506
|
|
|
477
507
|
def is_resolved(self, interface: Any) -> bool:
|
|
478
508
|
"""Check if an instance by interface exists."""
|
|
@@ -504,312 +534,7 @@ class Container:
|
|
|
504
534
|
continue
|
|
505
535
|
del context[interface]
|
|
506
536
|
|
|
507
|
-
|
|
508
|
-
self, interface: Any, create: bool, /, **defaults: Any
|
|
509
|
-
) -> Any:
|
|
510
|
-
"""Internal method to handle instance resolution and creation."""
|
|
511
|
-
provider = self._get_or_register_provider(interface, **defaults)
|
|
512
|
-
return self._resolve_with_provider(provider, create, **defaults)
|
|
513
|
-
|
|
514
|
-
async def _aresolve_or_create(
|
|
515
|
-
self, interface: Any, create: bool, /, **defaults: Any
|
|
516
|
-
) -> Any:
|
|
517
|
-
"""Internal method to handle instance resolution and creation asynchronously."""
|
|
518
|
-
provider = self._get_or_register_provider(interface, **defaults)
|
|
519
|
-
return await self._aresolve_with_provider(provider, create, **defaults)
|
|
520
|
-
|
|
521
|
-
def _resolve_with_provider(
|
|
522
|
-
self, provider: Provider, create: bool, /, **defaults: Any
|
|
523
|
-
) -> Any:
|
|
524
|
-
if provider.scope == "transient":
|
|
525
|
-
return self._create_instance(provider, None, **defaults)
|
|
526
|
-
|
|
527
|
-
if provider.scope == "request":
|
|
528
|
-
context = self._get_request_context()
|
|
529
|
-
if not create:
|
|
530
|
-
cached = context.get(provider.interface, NOT_SET)
|
|
531
|
-
if cached is not NOT_SET:
|
|
532
|
-
return cached
|
|
533
|
-
if not create:
|
|
534
|
-
return self._get_or_create_instance(provider, context)
|
|
535
|
-
return self._create_instance(provider, context, **defaults)
|
|
536
|
-
|
|
537
|
-
context = self._get_instance_context(provider.scope)
|
|
538
|
-
if not create:
|
|
539
|
-
cached = context.get(provider.interface, NOT_SET)
|
|
540
|
-
if cached is not NOT_SET:
|
|
541
|
-
return cached
|
|
542
|
-
with context.lock():
|
|
543
|
-
return (
|
|
544
|
-
self._get_or_create_instance(provider, context)
|
|
545
|
-
if not create
|
|
546
|
-
else self._create_instance(provider, context, **defaults)
|
|
547
|
-
)
|
|
548
|
-
|
|
549
|
-
async def _aresolve_with_provider(
|
|
550
|
-
self, provider: Provider, create: bool, /, **defaults: Any
|
|
551
|
-
) -> Any:
|
|
552
|
-
if provider.scope == "transient":
|
|
553
|
-
return await self._acreate_instance(provider, None, **defaults)
|
|
554
|
-
|
|
555
|
-
if provider.scope == "request":
|
|
556
|
-
context = self._get_request_context()
|
|
557
|
-
if not create:
|
|
558
|
-
cached = context.get(provider.interface, NOT_SET)
|
|
559
|
-
if cached is not NOT_SET:
|
|
560
|
-
return cached
|
|
561
|
-
if not create:
|
|
562
|
-
return await self._aget_or_create_instance(provider, context)
|
|
563
|
-
return await self._acreate_instance(provider, context, **defaults)
|
|
564
|
-
|
|
565
|
-
context = self._get_instance_context(provider.scope)
|
|
566
|
-
if not create:
|
|
567
|
-
cached = context.get(provider.interface, NOT_SET)
|
|
568
|
-
if cached is not NOT_SET:
|
|
569
|
-
return cached
|
|
570
|
-
async with context.alock():
|
|
571
|
-
return (
|
|
572
|
-
await self._aget_or_create_instance(provider, context)
|
|
573
|
-
if not create
|
|
574
|
-
else await self._acreate_instance(provider, context, **defaults)
|
|
575
|
-
)
|
|
576
|
-
|
|
577
|
-
def _get_or_create_instance(
|
|
578
|
-
self, provider: Provider, context: InstanceContext
|
|
579
|
-
) -> Any:
|
|
580
|
-
"""Get an instance of a dependency from the scoped context."""
|
|
581
|
-
instance = context.get(provider.interface, NOT_SET)
|
|
582
|
-
if instance is NOT_SET:
|
|
583
|
-
instance = self._create_instance(provider, context)
|
|
584
|
-
context.set(provider.interface, instance)
|
|
585
|
-
return instance
|
|
586
|
-
return instance
|
|
587
|
-
|
|
588
|
-
async def _aget_or_create_instance(
|
|
589
|
-
self, provider: Provider, context: InstanceContext
|
|
590
|
-
) -> Any:
|
|
591
|
-
"""Get an async instance of a dependency from the scoped context."""
|
|
592
|
-
instance = context.get(provider.interface, NOT_SET)
|
|
593
|
-
if instance is NOT_SET:
|
|
594
|
-
instance = await self._acreate_instance(provider, context)
|
|
595
|
-
context.set(provider.interface, instance)
|
|
596
|
-
return instance
|
|
597
|
-
return instance
|
|
598
|
-
|
|
599
|
-
def _create_instance(
|
|
600
|
-
self, provider: Provider, context: InstanceContext | None, /, **defaults: Any
|
|
601
|
-
) -> Any:
|
|
602
|
-
"""Create an instance using the provider."""
|
|
603
|
-
if provider.is_async:
|
|
604
|
-
raise TypeError(
|
|
605
|
-
f"The instance for the provider `{provider}` cannot be created in "
|
|
606
|
-
"synchronous mode."
|
|
607
|
-
)
|
|
608
|
-
|
|
609
|
-
provider_kwargs = self._get_provided_kwargs(
|
|
610
|
-
provider, context, defaults=defaults if defaults else None
|
|
611
|
-
)
|
|
612
|
-
|
|
613
|
-
if provider.is_generator:
|
|
614
|
-
if context is None:
|
|
615
|
-
raise ValueError("The context is required for generator providers.")
|
|
616
|
-
cm = contextlib.contextmanager(provider.call)(**provider_kwargs)
|
|
617
|
-
return context.enter(cm)
|
|
618
|
-
|
|
619
|
-
instance = provider.call(**provider_kwargs)
|
|
620
|
-
if context is not None and provider.is_class and is_context_manager(instance):
|
|
621
|
-
context.enter(instance)
|
|
622
|
-
return instance
|
|
623
|
-
|
|
624
|
-
async def _acreate_instance(
|
|
625
|
-
self, provider: Provider, context: InstanceContext | None, /, **defaults: Any
|
|
626
|
-
) -> Any:
|
|
627
|
-
"""Create an instance asynchronously using the provider."""
|
|
628
|
-
provider_kwargs = await self._aget_provided_kwargs(
|
|
629
|
-
provider, context, defaults=defaults if defaults else None
|
|
630
|
-
)
|
|
631
|
-
|
|
632
|
-
if provider.is_coroutine:
|
|
633
|
-
return await provider.call(**provider_kwargs)
|
|
634
|
-
|
|
635
|
-
if provider.is_async_generator:
|
|
636
|
-
if context is None:
|
|
637
|
-
raise ValueError(
|
|
638
|
-
"The async stack is required for async generator providers."
|
|
639
|
-
)
|
|
640
|
-
cm = contextlib.asynccontextmanager(provider.call)(**provider_kwargs)
|
|
641
|
-
return await context.aenter(cm)
|
|
642
|
-
|
|
643
|
-
if provider.is_generator:
|
|
644
|
-
|
|
645
|
-
def _create() -> Any:
|
|
646
|
-
if context is None:
|
|
647
|
-
raise ValueError("The stack is required for generator providers.")
|
|
648
|
-
cm = contextlib.contextmanager(provider.call)(**provider_kwargs)
|
|
649
|
-
return context.enter(cm)
|
|
650
|
-
|
|
651
|
-
return await run_sync(_create)
|
|
652
|
-
|
|
653
|
-
instance = await run_sync(provider.call, **provider_kwargs)
|
|
654
|
-
if (
|
|
655
|
-
context is not None
|
|
656
|
-
and provider.is_class
|
|
657
|
-
and is_async_context_manager(instance)
|
|
658
|
-
):
|
|
659
|
-
await context.aenter(instance)
|
|
660
|
-
return instance
|
|
661
|
-
|
|
662
|
-
def _get_provided_kwargs(
|
|
663
|
-
self,
|
|
664
|
-
provider: Provider,
|
|
665
|
-
context: InstanceContext | None,
|
|
666
|
-
/,
|
|
667
|
-
defaults: dict[str, Any] | None = None,
|
|
668
|
-
) -> dict[str, Any]:
|
|
669
|
-
"""Retrieve the arguments for a provider."""
|
|
670
|
-
if not provider.parameters:
|
|
671
|
-
return defaults if defaults else {}
|
|
672
|
-
|
|
673
|
-
provided_kwargs = dict(defaults) if defaults else {}
|
|
674
|
-
for parameter in provider.parameters:
|
|
675
|
-
provided_kwargs[parameter.name] = self._get_provider_instance(
|
|
676
|
-
provider,
|
|
677
|
-
parameter,
|
|
678
|
-
context,
|
|
679
|
-
defaults=defaults,
|
|
680
|
-
)
|
|
681
|
-
return provided_kwargs
|
|
682
|
-
|
|
683
|
-
def _get_provider_instance( # noqa: C901
|
|
684
|
-
self,
|
|
685
|
-
provider: Provider,
|
|
686
|
-
parameter: ProviderParameter,
|
|
687
|
-
context: InstanceContext | None,
|
|
688
|
-
/,
|
|
689
|
-
*,
|
|
690
|
-
defaults: dict[str, Any] | None = None,
|
|
691
|
-
) -> Any:
|
|
692
|
-
"""Retrieve an instance of a dependency from the scoped context."""
|
|
693
|
-
|
|
694
|
-
if defaults and parameter.name in defaults:
|
|
695
|
-
return defaults[parameter.name]
|
|
696
|
-
|
|
697
|
-
sub_provider = parameter.provider
|
|
698
|
-
|
|
699
|
-
if context and parameter.shared_scope and sub_provider is not None:
|
|
700
|
-
existing = context.get(sub_provider.interface, NOT_SET)
|
|
701
|
-
if existing is not NOT_SET:
|
|
702
|
-
return existing
|
|
703
|
-
|
|
704
|
-
if context:
|
|
705
|
-
cached = context.get(parameter.annotation, NOT_SET)
|
|
706
|
-
if cached is not NOT_SET:
|
|
707
|
-
return cached
|
|
708
|
-
|
|
709
|
-
sub_provider = parameter.provider
|
|
710
|
-
|
|
711
|
-
if sub_provider:
|
|
712
|
-
if sub_provider.scope == "transient":
|
|
713
|
-
return self._create_instance(sub_provider, None)
|
|
714
|
-
if sub_provider.scope == "singleton" and sub_provider is not provider:
|
|
715
|
-
return self._resolve_with_provider(sub_provider, False)
|
|
716
|
-
|
|
717
|
-
try:
|
|
718
|
-
return self._resolve_parameter(provider, parameter)
|
|
719
|
-
except LookupError:
|
|
720
|
-
if not parameter.has_default:
|
|
721
|
-
raise
|
|
722
|
-
return parameter.default
|
|
723
|
-
|
|
724
|
-
async def _aget_provided_kwargs(
|
|
725
|
-
self,
|
|
726
|
-
provider: Provider,
|
|
727
|
-
context: InstanceContext | None,
|
|
728
|
-
/,
|
|
729
|
-
defaults: dict[str, Any] | None = None,
|
|
730
|
-
) -> dict[str, Any]:
|
|
731
|
-
"""Asynchronously retrieve the arguments for a provider."""
|
|
732
|
-
if not provider.parameters:
|
|
733
|
-
return defaults if defaults else {}
|
|
734
|
-
|
|
735
|
-
provided_kwargs = dict(defaults) if defaults else {}
|
|
736
|
-
for parameter in provider.parameters:
|
|
737
|
-
provided_kwargs[parameter.name] = await self._aget_provider_instance(
|
|
738
|
-
provider,
|
|
739
|
-
parameter,
|
|
740
|
-
context,
|
|
741
|
-
defaults=defaults,
|
|
742
|
-
)
|
|
743
|
-
return provided_kwargs
|
|
744
|
-
|
|
745
|
-
async def _aget_provider_instance( # noqa: C901
|
|
746
|
-
self,
|
|
747
|
-
provider: Provider,
|
|
748
|
-
parameter: ProviderParameter,
|
|
749
|
-
context: InstanceContext | None,
|
|
750
|
-
/,
|
|
751
|
-
*,
|
|
752
|
-
defaults: dict[str, Any] | None = None,
|
|
753
|
-
) -> Any:
|
|
754
|
-
"""Asynchronously retrieve an instance of a dependency from the context."""
|
|
755
|
-
|
|
756
|
-
if defaults and parameter.name in defaults:
|
|
757
|
-
return defaults[parameter.name]
|
|
758
|
-
|
|
759
|
-
sub_provider = parameter.provider
|
|
760
|
-
|
|
761
|
-
if context and parameter.shared_scope and sub_provider is not None:
|
|
762
|
-
existing = context.get(sub_provider.interface, NOT_SET)
|
|
763
|
-
if existing is not NOT_SET:
|
|
764
|
-
return existing
|
|
765
|
-
|
|
766
|
-
if context:
|
|
767
|
-
cached = context.get(parameter.annotation, NOT_SET)
|
|
768
|
-
if cached is not NOT_SET:
|
|
769
|
-
return cached
|
|
770
|
-
|
|
771
|
-
sub_provider = parameter.provider
|
|
772
|
-
|
|
773
|
-
if sub_provider:
|
|
774
|
-
if sub_provider.scope == "transient":
|
|
775
|
-
return await self._acreate_instance(sub_provider, None)
|
|
776
|
-
if sub_provider.scope == "singleton" and sub_provider is not provider:
|
|
777
|
-
return await self._aresolve_with_provider(sub_provider, False)
|
|
778
|
-
|
|
779
|
-
try:
|
|
780
|
-
return await self._aresolve_parameter(provider, parameter)
|
|
781
|
-
except LookupError:
|
|
782
|
-
if not parameter.has_default:
|
|
783
|
-
raise
|
|
784
|
-
return parameter.default
|
|
785
|
-
|
|
786
|
-
def _resolve_parameter(
|
|
787
|
-
self, provider: Provider, parameter: ProviderParameter
|
|
788
|
-
) -> Any:
|
|
789
|
-
self._validate_resolvable_parameter(provider, parameter)
|
|
790
|
-
return self._resolve_or_create(parameter.annotation, False)
|
|
791
|
-
|
|
792
|
-
async def _aresolve_parameter(
|
|
793
|
-
self, provider: Provider, parameter: ProviderParameter
|
|
794
|
-
) -> Any:
|
|
795
|
-
self._validate_resolvable_parameter(provider, parameter)
|
|
796
|
-
return await self._aresolve_or_create(parameter.annotation, False)
|
|
797
|
-
|
|
798
|
-
def _validate_resolvable_parameter(
|
|
799
|
-
self, provider: Provider, parameter: ProviderParameter
|
|
800
|
-
) -> None:
|
|
801
|
-
"""Ensure that the specified interface is resolved."""
|
|
802
|
-
if parameter.annotation in self._unresolved_interfaces:
|
|
803
|
-
raise LookupError(
|
|
804
|
-
f"You are attempting to get the parameter `{parameter.name}` with the "
|
|
805
|
-
f"annotation `{type_repr(parameter.annotation)}` as a "
|
|
806
|
-
f"dependency into `{type_repr(provider.call)}` which is "
|
|
807
|
-
"not registered or set in the scoped context."
|
|
808
|
-
)
|
|
809
|
-
|
|
810
|
-
############################
|
|
811
|
-
# Injector Methods
|
|
812
|
-
############################
|
|
537
|
+
# == Injection Utilities ==
|
|
813
538
|
|
|
814
539
|
@overload
|
|
815
540
|
def inject(self, func: Callable[P, T]) -> Callable[P, T]: ...
|
|
@@ -933,26 +658,20 @@ class Container:
|
|
|
933
658
|
|
|
934
659
|
return interface, True
|
|
935
660
|
|
|
936
|
-
|
|
937
|
-
# Module Methods
|
|
938
|
-
############################
|
|
661
|
+
# == Module Registration ==
|
|
939
662
|
|
|
940
663
|
def register_module(self, module: ModuleDef) -> None:
|
|
941
664
|
"""Register a module as a callable, module type, or module instance."""
|
|
942
665
|
self._modules.register(module)
|
|
943
666
|
|
|
944
|
-
|
|
945
|
-
# Scanner Methods
|
|
946
|
-
############################
|
|
667
|
+
# == Package Scanning ==
|
|
947
668
|
|
|
948
669
|
def scan(
|
|
949
670
|
self, /, packages: PackageOrIterable, *, tags: Iterable[str] | None = None
|
|
950
671
|
) -> None:
|
|
951
672
|
self._scanner.scan(packages=packages, tags=tags)
|
|
952
673
|
|
|
953
|
-
|
|
954
|
-
# Testing
|
|
955
|
-
############################
|
|
674
|
+
# == Testing ==
|
|
956
675
|
|
|
957
676
|
@contextlib.contextmanager
|
|
958
677
|
def override(self, interface: Any, instance: Any) -> Iterator[None]:
|