anydi 0.38.0__py3-none-any.whl → 0.38.2rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
anydi/_container.py CHANGED
@@ -10,23 +10,27 @@ import logging
10
10
  import pkgutil
11
11
  import threading
12
12
  import types
13
+ import uuid
13
14
  from collections import defaultdict
14
15
  from collections.abc import AsyncIterator, Iterable, Iterator, Sequence
15
16
  from contextvars import ContextVar
16
17
  from types import ModuleType
17
- from typing import Any, Callable, TypeVar, Union, cast, overload
18
- from weakref import WeakKeyDictionary
18
+ from typing import Any, Callable, TypeVar, Union, cast, get_origin, overload
19
19
 
20
- from typing_extensions import Concatenate, ParamSpec, Self, final
20
+ from typing_extensions import Concatenate, ParamSpec, Self, final, get_args
21
21
 
22
22
  from ._context import InstanceContext
23
- from ._provider import Provider
24
23
  from ._types import (
24
+ NOT_SET,
25
25
  AnyInterface,
26
- Dependency,
26
+ Event,
27
27
  InjectableDecoratorArgs,
28
28
  InstanceProxy,
29
+ Provider,
30
+ ProviderArgs,
29
31
  ProviderDecoratorArgs,
32
+ ProviderKind,
33
+ ScannedDependency,
30
34
  Scope,
31
35
  is_event_type,
32
36
  is_marker,
@@ -34,6 +38,7 @@ from ._types import (
34
38
  from ._utils import (
35
39
  AsyncRLock,
36
40
  get_full_qualname,
41
+ get_typed_annotation,
37
42
  get_typed_parameters,
38
43
  import_string,
39
44
  is_async_context_manager,
@@ -42,6 +47,12 @@ from ._utils import (
42
47
  run_async,
43
48
  )
44
49
 
50
+ try:
51
+ from types import NoneType
52
+ except ImportError:
53
+ NoneType = type(None) # type: ignore[misc]
54
+
55
+
45
56
  T = TypeVar("T", bound=Any)
46
57
  M = TypeVar("M", bound="Module")
47
58
  P = ParamSpec("P")
@@ -74,7 +85,6 @@ class Module(metaclass=ModuleMeta):
74
85
  """Configure the AnyDI container with providers and their dependencies."""
75
86
 
76
87
 
77
- # noinspection PyShadowingNames
78
88
  @final
79
89
  class Container:
80
90
  """AnyDI is a dependency injection container."""
@@ -82,7 +92,7 @@ class Container:
82
92
  def __init__(
83
93
  self,
84
94
  *,
85
- providers: Sequence[Provider] | None = None,
95
+ providers: Sequence[ProviderArgs] | None = None,
86
96
  modules: Sequence[Module | type[Module] | Callable[[Container], None] | str]
87
97
  | None = None,
88
98
  strict: bool = False,
@@ -104,20 +114,27 @@ class Container:
104
114
  )
105
115
  self._override_instances: dict[type[Any], Any] = {}
106
116
  self._unresolved_interfaces: set[type[Any]] = set()
107
- self._inject_cache: WeakKeyDictionary[
108
- Callable[..., Any], Callable[..., Any]
109
- ] = WeakKeyDictionary()
117
+ self._inject_cache: dict[Callable[..., Any], Callable[..., Any]] = {}
110
118
 
111
119
  # Register providers
112
120
  providers = providers or []
113
121
  for provider in providers:
114
- self._register_provider(provider, False)
122
+ _provider = self._create_provider(
123
+ call=provider.call,
124
+ scope=provider.scope,
125
+ interface=provider.interface,
126
+ )
127
+ self._register_provider(_provider, False)
115
128
 
116
129
  # Register modules
117
130
  modules = modules or []
118
131
  for module in modules:
119
132
  self.register_module(module)
120
133
 
134
+ ############################
135
+ # Properties
136
+ ############################
137
+
121
138
  @property
122
139
  def strict(self) -> bool:
123
140
  """Check if strict mode is enabled."""
@@ -143,9 +160,110 @@ class Container:
143
160
  """Get the logger instance."""
144
161
  return self._logger
145
162
 
146
- def is_registered(self, interface: AnyInterface) -> bool:
147
- """Check if a provider is registered for the specified interface."""
148
- return interface in self._providers
163
+ ############################
164
+ # Lifespan/Context Methods
165
+ ############################
166
+
167
+ def __enter__(self) -> Self:
168
+ """Enter the singleton context."""
169
+ self.start()
170
+ return self
171
+
172
+ def __exit__(
173
+ self,
174
+ exc_type: type[BaseException] | None,
175
+ exc_val: BaseException | None,
176
+ exc_tb: types.TracebackType | None,
177
+ ) -> Any:
178
+ """Exit the singleton context."""
179
+ return self._singleton_context.__exit__(exc_type, exc_val, exc_tb)
180
+
181
+ def start(self) -> None:
182
+ """Start the singleton context."""
183
+ # Resolve all singleton resources
184
+ for interface in self._resources.get("singleton", []):
185
+ self.resolve(interface)
186
+
187
+ def close(self) -> None:
188
+ """Close the singleton context."""
189
+ self._singleton_context.close()
190
+
191
+ async def __aenter__(self) -> Self:
192
+ """Enter the singleton context."""
193
+ await self.astart()
194
+ return self
195
+
196
+ async def __aexit__(
197
+ self,
198
+ exc_type: type[BaseException] | None,
199
+ exc_val: BaseException | None,
200
+ exc_tb: types.TracebackType | None,
201
+ ) -> bool:
202
+ """Exit the singleton context."""
203
+ return await self._singleton_context.__aexit__(exc_type, exc_val, exc_tb)
204
+
205
+ async def astart(self) -> None:
206
+ """Start the singleton context asynchronously."""
207
+ for interface in self._resources.get("singleton", []):
208
+ await self.aresolve(interface)
209
+
210
+ async def aclose(self) -> None:
211
+ """Close the singleton context asynchronously."""
212
+ await self._singleton_context.aclose()
213
+
214
+ @contextlib.contextmanager
215
+ def request_context(self) -> Iterator[InstanceContext]:
216
+ """Obtain a context manager for the request-scoped context."""
217
+ context = InstanceContext()
218
+
219
+ token = self._request_context_var.set(context)
220
+
221
+ # Resolve all request resources
222
+ for interface in self._resources.get("request", []):
223
+ if not is_event_type(interface):
224
+ continue
225
+ self.resolve(interface)
226
+
227
+ with context:
228
+ yield context
229
+ self._request_context_var.reset(token)
230
+
231
+ @contextlib.asynccontextmanager
232
+ async def arequest_context(self) -> AsyncIterator[InstanceContext]:
233
+ """Obtain an async context manager for the request-scoped context."""
234
+ context = InstanceContext()
235
+
236
+ token = self._request_context_var.set(context)
237
+
238
+ for interface in self._resources.get("request", []):
239
+ if not is_event_type(interface):
240
+ continue
241
+ await self.aresolve(interface)
242
+
243
+ async with context:
244
+ yield context
245
+ self._request_context_var.reset(token)
246
+
247
+ def _get_request_context(self) -> InstanceContext:
248
+ """Get the current request context."""
249
+ request_context = self._request_context_var.get()
250
+ if request_context is None:
251
+ raise LookupError(
252
+ "The request context has not been started. Please ensure that "
253
+ "the request context is properly initialized before attempting "
254
+ "to use it."
255
+ )
256
+ return request_context
257
+
258
+ def _get_scoped_context(self, scope: Scope) -> InstanceContext:
259
+ """Get the instance context for the specified scope."""
260
+ if scope == "singleton":
261
+ return self._singleton_context
262
+ return self._get_request_context()
263
+
264
+ ############################
265
+ # Provider Methods
266
+ ############################
149
267
 
150
268
  def register(
151
269
  self,
@@ -156,26 +274,12 @@ class Container:
156
274
  override: bool = False,
157
275
  ) -> Provider:
158
276
  """Register a provider for the specified interface."""
159
- provider = Provider(call=call, scope=scope, interface=interface)
277
+ provider = self._create_provider(call=call, scope=scope, interface=interface)
160
278
  return self._register_provider(provider, override)
161
279
 
162
- def _register_provider(
163
- self, provider: Provider, override: bool, /, **defaults: Any
164
- ) -> Provider:
165
- """Register a provider."""
166
- if provider.interface in self._providers:
167
- if override:
168
- self._set_provider(provider)
169
- return provider
170
-
171
- raise LookupError(
172
- f"The provider interface `{get_full_qualname(provider.interface)}` "
173
- "already registered."
174
- )
175
-
176
- self._validate_sub_providers(provider, **defaults)
177
- self._set_provider(provider)
178
- return provider
280
+ def is_registered(self, interface: AnyInterface) -> bool:
281
+ """Check if a provider is registered for the specified interface."""
282
+ return interface in self._providers
179
283
 
180
284
  def unregister(self, interface: AnyInterface) -> None:
181
285
  """Unregister a provider by interface."""
@@ -199,6 +303,120 @@ class Container:
199
303
  # Cleanup provider references
200
304
  self._delete_provider(provider)
201
305
 
306
+ def provider(
307
+ self, *, scope: Scope, override: bool = False
308
+ ) -> Callable[[Callable[P, T]], Callable[P, T]]:
309
+ """Decorator to register a provider function with the specified scope."""
310
+
311
+ def decorator(call: Callable[P, T]) -> Callable[P, T]:
312
+ provider = self._create_provider(call=call, scope=scope)
313
+ self._register_provider(provider, override)
314
+ return call
315
+
316
+ return decorator
317
+
318
+ def _create_provider( # noqa: C901
319
+ self, call: Callable[..., Any], *, scope: Scope, interface: Any = NOT_SET
320
+ ) -> Provider:
321
+ name = get_full_qualname(call)
322
+
323
+ # Detect the kind of callable provider
324
+ kind = ProviderKind.from_call(call)
325
+
326
+ # Validate the scope of the provider
327
+ if scope not in get_args(Scope):
328
+ raise ValueError(
329
+ f"The provider `{name}` scope is invalid. Only the following "
330
+ f"scopes are supported: {', '.join(get_args(Scope))}. "
331
+ "Please use one of the supported scopes when registering a provider."
332
+ )
333
+
334
+ # Validate the scope of the provider
335
+ if (
336
+ kind in {ProviderKind.GENERATOR, ProviderKind.ASYNC_GENERATOR}
337
+ and scope == "transient"
338
+ ):
339
+ raise TypeError(
340
+ f"The resource provider `{name}` is attempting to register "
341
+ "with a transient scope, which is not allowed."
342
+ )
343
+
344
+ # Get the signature
345
+ globalns = getattr(call, "__globals__", {})
346
+ signature = inspect.signature(call, globals=globalns)
347
+
348
+ # Detect the interface
349
+ if kind == ProviderKind.CLASS:
350
+ interface = call
351
+ else:
352
+ if interface is NOT_SET:
353
+ interface = signature.return_annotation
354
+ if interface is inspect.Signature.empty:
355
+ interface = None
356
+ else:
357
+ interface = get_typed_annotation(interface, globalns)
358
+
359
+ # If the callable is an iterator, return the actual type
360
+ iterator_types = {Iterator, AsyncIterator}
361
+ if interface in iterator_types or get_origin(interface) in iterator_types:
362
+ if args := get_args(interface):
363
+ interface = args[0]
364
+ # If the callable is a generator, return the resource type
365
+ if interface in {None, NoneType}:
366
+ interface = type(f"Event_{uuid.uuid4().hex}", (Event,), {})
367
+ else:
368
+ raise TypeError(
369
+ f"Cannot use `{name}` resource type annotation "
370
+ "without actual type argument."
371
+ )
372
+
373
+ # None interface is not allowed
374
+ if interface in {None, NoneType}:
375
+ raise TypeError(f"Missing `{name}` provider return annotation.")
376
+
377
+ # Detect the parameters
378
+ parameters = []
379
+ for parameter in signature.parameters.values():
380
+ if parameter.annotation is inspect.Parameter.empty:
381
+ raise TypeError(
382
+ f"Missing provider `{name}` "
383
+ f"dependency `{parameter.name}` annotation."
384
+ )
385
+ if parameter.kind == inspect.Parameter.POSITIONAL_ONLY:
386
+ raise TypeError(
387
+ "Positional-only parameters "
388
+ f"are not allowed in the provider `{name}`."
389
+ )
390
+ annotation = get_typed_annotation(parameter.annotation, globalns)
391
+ parameters.append(parameter.replace(annotation=annotation))
392
+
393
+ return Provider(
394
+ call=call,
395
+ scope=scope,
396
+ interface=interface,
397
+ name=name,
398
+ kind=kind,
399
+ parameters=parameters,
400
+ )
401
+
402
+ def _register_provider(
403
+ self, provider: Provider, override: bool, /, **defaults: Any
404
+ ) -> Provider:
405
+ """Register a provider."""
406
+ if provider.interface in self._providers:
407
+ if override:
408
+ self._set_provider(provider)
409
+ return provider
410
+
411
+ raise LookupError(
412
+ f"The provider interface `{get_full_qualname(provider.interface)}` "
413
+ "already registered."
414
+ )
415
+
416
+ self._validate_sub_providers(provider, **defaults)
417
+ self._set_provider(provider)
418
+ return provider
419
+
202
420
  def _get_provider(self, interface: AnyInterface) -> Provider:
203
421
  """Get provider by interface."""
204
422
  try:
@@ -227,9 +445,11 @@ class Container:
227
445
  scope = getattr(interface, "__scope__", parent_scope)
228
446
  # Try to detect scope
229
447
  if scope is None:
230
- scope = self._detect_scope(interface, **defaults)
448
+ scope = self._detect_provider_scope(interface, **defaults)
231
449
  scope = scope or self.default_scope
232
- provider = Provider(call=interface, scope=scope, interface=interface)
450
+ provider = self._create_provider(
451
+ call=interface, scope=scope, interface=interface
452
+ )
233
453
  return self._register_provider(provider, False, **defaults)
234
454
  raise
235
455
 
@@ -250,12 +470,6 @@ class Container:
250
470
  """Validate the sub-providers of a provider."""
251
471
 
252
472
  for parameter in provider.parameters:
253
- if parameter.annotation is inspect.Parameter.empty:
254
- raise TypeError(
255
- f"Missing provider `{provider}` "
256
- f"dependency `{parameter.name}` annotation."
257
- )
258
-
259
473
  try:
260
474
  sub_provider = self._get_or_register_provider(
261
475
  parameter.annotation, provider.scope
@@ -282,7 +496,9 @@ class Container:
282
496
  "Please ensure all providers are registered with matching scopes."
283
497
  )
284
498
 
285
- def _detect_scope(self, call: Callable[..., Any], **defaults: Any) -> Scope | None:
499
+ def _detect_provider_scope(
500
+ self, call: Callable[..., Any], /, **defaults: Any
501
+ ) -> Scope | None:
286
502
  """Detect the scope for a callable."""
287
503
  scopes = set()
288
504
 
@@ -320,122 +536,56 @@ class Container:
320
536
  not self.strict and parameter.default is not inspect.Parameter.empty
321
537
  )
322
538
 
323
- def register_module(
324
- self, module: Module | type[Module] | Callable[[Container], None] | str
325
- ) -> None:
326
- """Register a module as a callable, module type, or module instance."""
327
- # Callable Module
328
- if inspect.isfunction(module):
329
- module(self)
330
- return
331
-
332
- # Module path
333
- if isinstance(module, str):
334
- module = import_string(module)
335
-
336
- # Class based Module or Module type
337
- if inspect.isclass(module) and issubclass(module, Module):
338
- module = module()
339
-
340
- if isinstance(module, Module):
341
- module.configure(self)
342
- for provider_name, decorator_args in module.providers:
343
- obj = getattr(module, provider_name)
344
- self.provider(
345
- scope=decorator_args.scope,
346
- override=decorator_args.override,
347
- )(obj)
348
-
349
- def __enter__(self) -> Self:
350
- """Enter the singleton context."""
351
- self.start()
352
- return self
353
-
354
- def __exit__(
355
- self,
356
- exc_type: type[BaseException] | None,
357
- exc_val: BaseException | None,
358
- exc_tb: types.TracebackType | None,
359
- ) -> Any:
360
- """Exit the singleton context."""
361
- return self._singleton_context.__exit__(exc_type, exc_val, exc_tb)
362
-
363
- def start(self) -> None:
364
- """Start the singleton context."""
365
- # Resolve all singleton resources
366
- for interface in self._resources.get("singleton", []):
367
- self.resolve(interface)
368
-
369
- def close(self) -> None:
370
- """Close the singleton context."""
371
- self._singleton_context.close()
372
-
373
- @contextlib.contextmanager
374
- def request_context(self) -> Iterator[InstanceContext]:
375
- """Obtain a context manager for the request-scoped context."""
376
- context = InstanceContext()
377
-
378
- token = self._request_context_var.set(context)
539
+ ############################
540
+ # Instance Methods
541
+ ############################
379
542
 
380
- # Resolve all request resources
381
- for interface in self._resources.get("request", []):
382
- if not is_event_type(interface):
383
- continue
384
- self.resolve(interface)
385
-
386
- with context:
387
- yield context
388
- self._request_context_var.reset(token)
543
+ @overload
544
+ def resolve(self, interface: type[T]) -> T: ...
389
545
 
390
- async def __aenter__(self) -> Self:
391
- """Enter the singleton context."""
392
- await self.astart()
393
- return self
546
+ @overload
547
+ def resolve(self, interface: T) -> T: ...
394
548
 
395
- async def __aexit__(
396
- self,
397
- exc_type: type[BaseException] | None,
398
- exc_val: BaseException | None,
399
- exc_tb: types.TracebackType | None,
400
- ) -> bool:
401
- """Exit the singleton context."""
402
- return await self._singleton_context.__aexit__(exc_type, exc_val, exc_tb)
549
+ def resolve(self, interface: type[T]) -> T:
550
+ """Resolve an instance by interface."""
551
+ return self._resolve_or_create(interface, False)
403
552
 
404
- async def astart(self) -> None:
405
- """Start the singleton context asynchronously."""
406
- for interface in self._resources.get("singleton", []):
407
- await self.aresolve(interface)
553
+ @overload
554
+ async def aresolve(self, interface: type[T]) -> T: ...
408
555
 
409
- async def aclose(self) -> None:
410
- """Close the singleton context asynchronously."""
411
- await self._singleton_context.aclose()
556
+ @overload
557
+ async def aresolve(self, interface: T) -> T: ...
412
558
 
413
- @contextlib.asynccontextmanager
414
- async def arequest_context(self) -> AsyncIterator[InstanceContext]:
415
- """Obtain an async context manager for the request-scoped context."""
416
- context = InstanceContext()
559
+ async def aresolve(self, interface: type[T]) -> T:
560
+ """Resolve an instance by interface asynchronously."""
561
+ return await self._aresolve_or_acreate(interface, False)
417
562
 
418
- token = self._request_context_var.set(context)
563
+ def create(self, interface: type[T], /, **defaults: Any) -> T:
564
+ """Create an instance by interface."""
565
+ return self._resolve_or_create(interface, True, **defaults)
419
566
 
420
- for interface in self._resources.get("request", []):
421
- if not is_event_type(interface):
422
- continue
423
- await self.aresolve(interface)
567
+ async def acreate(self, interface: type[T], /, **defaults: Any) -> T:
568
+ """Create an instance by interface asynchronously."""
569
+ return await self._aresolve_or_acreate(interface, True, **defaults)
424
570
 
425
- async with context:
426
- yield context
427
- self._request_context_var.reset(token)
571
+ def is_resolved(self, interface: AnyInterface) -> bool:
572
+ """Check if an instance by interface exists."""
573
+ try:
574
+ provider = self._get_provider(interface)
575
+ except LookupError:
576
+ return False
577
+ if provider.scope == "transient":
578
+ return False
579
+ context = self._get_scoped_context(provider.scope)
580
+ return interface in context
428
581
 
429
- def _get_request_context(self) -> InstanceContext:
430
- """Get the current request context."""
431
- request_context = self._request_context_var.get()
432
- if request_context is None:
433
- raise LookupError(
434
- "The request context has not been started. Please ensure that "
435
- "the request context is properly initialized before attempting "
436
- "to use it."
437
- )
438
- return request_context
582
+ def release(self, interface: AnyInterface) -> None:
583
+ """Release an instance by interface."""
584
+ provider = self._get_provider(interface)
585
+ if provider.scope == "transient":
586
+ return None
587
+ context = self._get_scoped_context(provider.scope)
588
+ del context[interface]
439
589
 
440
590
  def reset(self) -> None:
441
591
  """Reset resolved instances."""
@@ -448,66 +598,38 @@ class Container:
448
598
  continue
449
599
  del context[interface]
450
600
 
451
- @overload
452
- def resolve(self, interface: type[T]) -> T: ...
453
-
454
- @overload
455
- def resolve(self, interface: T) -> T: ...
456
-
457
- def resolve(self, interface: type[T]) -> T:
458
- """Resolve an instance by interface."""
459
- provider = self._get_or_register_provider(interface, None)
601
+ def _resolve_or_create(
602
+ self, interface: type[T], create: bool, /, **defaults: Any
603
+ ) -> T:
604
+ """Internal method to handle instance resolution and creation."""
605
+ provider = self._get_or_register_provider(interface, None, **defaults)
460
606
  if provider.scope == "transient":
461
- instance = self._create_instance(provider, None)
607
+ instance = self._create_instance(provider, None, **defaults)
462
608
  else:
463
609
  context = self._get_scoped_context(provider.scope)
464
610
  if provider.scope == "singleton":
465
611
  with self._singleton_lock:
466
- instance = self._get_or_create_instance(provider, context)
612
+ instance = (
613
+ self._get_or_create_instance(provider, context)
614
+ if not create
615
+ else self._create_instance(provider, context, **defaults)
616
+ )
467
617
  else:
468
- instance = self._get_or_create_instance(provider, context)
469
- if self.testing:
470
- instance = self._patch_test_resolver(provider.interface, instance)
471
- return cast(T, instance)
472
-
473
- @overload
474
- async def aresolve(self, interface: type[T]) -> T: ...
475
-
476
- @overload
477
- async def aresolve(self, interface: T) -> T: ...
618
+ instance = (
619
+ self._get_or_create_instance(provider, context)
620
+ if not create
621
+ else self._create_instance(provider, context, **defaults)
622
+ )
478
623
 
479
- async def aresolve(self, interface: type[T]) -> T:
480
- """Resolve an instance by interface asynchronously."""
481
- provider = self._get_or_register_provider(interface, None)
482
- if provider.scope == "transient":
483
- instance = await self._acreate_instance(provider, None)
484
- else:
485
- context = self._get_scoped_context(provider.scope)
486
- if provider.scope == "singleton":
487
- async with self._singleton_async_lock:
488
- instance = await self._aget_or_create_instance(provider, context)
489
- else:
490
- instance = await self._aget_or_create_instance(provider, context)
491
624
  if self.testing:
492
- instance = self._patch_test_resolver(interface, instance)
493
- return cast(T, instance)
625
+ instance = self._patch_test_resolver(provider.interface, instance)
494
626
 
495
- def create(self, interface: type[T], **defaults: Any) -> T:
496
- """Create an instance by interface."""
497
- provider = self._get_or_register_provider(interface, None, **defaults)
498
- if provider.scope == "transient":
499
- instance = self._create_instance(provider, None, **defaults)
500
- else:
501
- context = self._get_scoped_context(provider.scope)
502
- if provider.scope == "singleton":
503
- with self._singleton_lock:
504
- instance = self._create_instance(provider, context, **defaults)
505
- else:
506
- instance = self._create_instance(provider, context, **defaults)
507
627
  return cast(T, instance)
508
628
 
509
- async def acreate(self, interface: type[T], **defaults: Any) -> T:
510
- """Create an instance by interface."""
629
+ async def _aresolve_or_acreate(
630
+ self, interface: type[T], create: bool, /, **defaults: Any
631
+ ) -> T:
632
+ """Internal method to handle instance resolution and creation asynchronously."""
511
633
  provider = self._get_or_register_provider(interface, None, **defaults)
512
634
  if provider.scope == "transient":
513
635
  instance = await self._acreate_instance(provider, None, **defaults)
@@ -515,11 +637,21 @@ class Container:
515
637
  context = self._get_scoped_context(provider.scope)
516
638
  if provider.scope == "singleton":
517
639
  async with self._singleton_async_lock:
518
- instance = await self._acreate_instance(
519
- provider, context, **defaults
640
+ instance = (
641
+ await self._aget_or_create_instance(provider, context)
642
+ if not create
643
+ else await self._acreate_instance(provider, context, **defaults)
520
644
  )
521
645
  else:
522
- instance = await self._acreate_instance(provider, context, **defaults)
646
+ instance = (
647
+ await self._aget_or_create_instance(provider, context)
648
+ if not create
649
+ else await self._acreate_instance(provider, context, **defaults)
650
+ )
651
+
652
+ if self.testing:
653
+ instance = self._patch_test_resolver(provider.interface, instance)
654
+
523
655
  return cast(T, instance)
524
656
 
525
657
  def _get_or_create_instance(
@@ -696,26 +828,48 @@ class Container:
696
828
  def _resolve_parameter(
697
829
  self, provider: Provider, parameter: inspect.Parameter
698
830
  ) -> Any:
699
- self._validate_resolvable_parameter(parameter, call=provider.call)
831
+ self._validate_resolvable_parameter(provider, parameter)
700
832
  return self.resolve(parameter.annotation)
701
833
 
702
834
  async def _aresolve_parameter(
703
835
  self, provider: Provider, parameter: inspect.Parameter
704
836
  ) -> Any:
705
- self._validate_resolvable_parameter(parameter, call=provider.call)
837
+ self._validate_resolvable_parameter(provider, parameter)
706
838
  return await self.aresolve(parameter.annotation)
707
839
 
708
840
  def _validate_resolvable_parameter(
709
- self, parameter: inspect.Parameter, call: Callable[..., Any]
841
+ self, provider: Provider, parameter: inspect.Parameter
710
842
  ) -> None:
711
843
  """Ensure that the specified interface is resolved."""
712
844
  if parameter.annotation in self._unresolved_interfaces:
713
845
  raise LookupError(
714
846
  f"You are attempting to get the parameter `{parameter.name}` with the "
715
847
  f"annotation `{get_full_qualname(parameter.annotation)}` as a "
716
- f"dependency into `{get_full_qualname(call)}` which is not registered "
717
- "or set in the scoped context."
848
+ f"dependency into `{get_full_qualname(provider.call)}` which is "
849
+ "not registered or set in the scoped context."
850
+ )
851
+
852
+ @contextlib.contextmanager
853
+ def override(self, interface: AnyInterface, instance: Any) -> Iterator[None]:
854
+ """
855
+ Override the provider for the specified interface with a specific instance.
856
+ """
857
+ if not self.testing:
858
+ raise RuntimeError(
859
+ "The `override` method can only be used in testing mode."
718
860
  )
861
+ if not self.is_registered(interface) and self.strict:
862
+ raise LookupError(
863
+ f"The provider interface `{get_full_qualname(interface)}` "
864
+ "not registered."
865
+ )
866
+ self._override_instances[interface] = instance
867
+ yield
868
+ del self._override_instances[interface]
869
+
870
+ ############################
871
+ # Testing Methods
872
+ ############################
719
873
 
720
874
  def _patch_test_resolver(self, interface: type[Any], instance: Any) -> Any:
721
875
  """Patch the test resolver for the instance."""
@@ -765,60 +919,9 @@ class Container:
765
919
 
766
920
  return instance
767
921
 
768
- def is_resolved(self, interface: AnyInterface) -> bool:
769
- """Check if an instance by interface exists."""
770
- try:
771
- provider = self._get_provider(interface)
772
- except LookupError:
773
- return False
774
- if provider.scope == "transient":
775
- return False
776
- context = self._get_scoped_context(provider.scope)
777
- return interface in context
778
-
779
- def release(self, interface: AnyInterface) -> None:
780
- """Release an instance by interface."""
781
- provider = self._get_provider(interface)
782
- if provider.scope == "transient":
783
- return None
784
- context = self._get_scoped_context(provider.scope)
785
- del context[interface]
786
-
787
- def _get_scoped_context(self, scope: Scope) -> InstanceContext:
788
- """Get the instance context for the specified scope."""
789
- if scope == "singleton":
790
- return self._singleton_context
791
- return self._get_request_context()
792
-
793
- @contextlib.contextmanager
794
- def override(self, interface: AnyInterface, instance: Any) -> Iterator[None]:
795
- """
796
- Override the provider for the specified interface with a specific instance.
797
- """
798
- if not self.testing:
799
- raise RuntimeError(
800
- "The `override` method can only be used in testing mode."
801
- )
802
- if not self.is_registered(interface) and self.strict:
803
- raise LookupError(
804
- f"The provider interface `{get_full_qualname(interface)}` "
805
- "not registered."
806
- )
807
- self._override_instances[interface] = instance
808
- yield
809
- del self._override_instances[interface]
810
-
811
- def provider(
812
- self, *, scope: Scope, override: bool = False
813
- ) -> Callable[[Callable[P, T]], Callable[P, T]]:
814
- """Decorator to register a provider function with the specified scope."""
815
-
816
- def decorator(call: Callable[P, T]) -> Callable[P, T]:
817
- provider = Provider(call=call, scope=scope)
818
- self._register_provider(provider, override)
819
- return call
820
-
821
- return decorator
922
+ ############################
923
+ # Injector Methods
924
+ ############################
822
925
 
823
926
  @overload
824
927
  def inject(self, func: Callable[P, T]) -> Callable[P, T]: ...
@@ -838,6 +941,10 @@ class Container:
838
941
  return decorator
839
942
  return decorator(func)
840
943
 
944
+ def run(self, func: Callable[P, T], /, *args: P.args, **kwargs: P.kwargs) -> T:
945
+ """Run the given function with injected dependencies."""
946
+ return self._inject(func)(*args, **kwargs)
947
+
841
948
  def _inject(self, call: Callable[P, T]) -> Callable[P, T]:
842
949
  """Inject dependencies into a callable."""
843
950
  if call in self._inject_cache:
@@ -905,9 +1012,39 @@ class Container:
905
1012
  f"`{get_full_qualname(parameter.annotation)}`."
906
1013
  )
907
1014
 
908
- def run(self, func: Callable[P, T], /, *args: P.args, **kwargs: P.kwargs) -> T:
909
- """Run the given function with injected dependencies."""
910
- return self._inject(func)(*args, **kwargs)
1015
+ ############################
1016
+ # Module Methods
1017
+ ############################
1018
+
1019
+ def register_module(
1020
+ self, module: Module | type[Module] | Callable[[Container], None] | str
1021
+ ) -> None:
1022
+ """Register a module as a callable, module type, or module instance."""
1023
+ # Callable Module
1024
+ if inspect.isfunction(module):
1025
+ module(self)
1026
+ return
1027
+
1028
+ # Module path
1029
+ if isinstance(module, str):
1030
+ module = import_string(module)
1031
+
1032
+ # Class based Module or Module type
1033
+ if inspect.isclass(module) and issubclass(module, Module):
1034
+ module = module()
1035
+
1036
+ if isinstance(module, Module):
1037
+ module.configure(self)
1038
+ for provider_name, decorator_args in module.providers:
1039
+ obj = getattr(module, provider_name)
1040
+ self.provider(
1041
+ scope=decorator_args.scope,
1042
+ override=decorator_args.override,
1043
+ )(obj)
1044
+
1045
+ ############################
1046
+ # Scanner Methods
1047
+ ############################
911
1048
 
912
1049
  def scan(
913
1050
  self,
@@ -917,7 +1054,7 @@ class Container:
917
1054
  tags: Iterable[str] | None = None,
918
1055
  ) -> None:
919
1056
  """Scan packages or modules for decorated members and inject dependencies."""
920
- dependencies: list[Dependency] = []
1057
+ dependencies: list[ScannedDependency] = []
921
1058
 
922
1059
  if isinstance(packages, Iterable) and not isinstance(packages, str):
923
1060
  scan_packages: Iterable[ModuleType | str] = packages
@@ -936,7 +1073,7 @@ class Container:
936
1073
  package: ModuleType | str,
937
1074
  *,
938
1075
  tags: Iterable[str] | None = None,
939
- ) -> list[Dependency]:
1076
+ ) -> list[ScannedDependency]:
940
1077
  """Scan a package or module for decorated members."""
941
1078
  tags = tags or []
942
1079
  if isinstance(package, str):
@@ -947,7 +1084,7 @@ class Container:
947
1084
  if not package_path:
948
1085
  return self._scan_module(package, tags=tags)
949
1086
 
950
- dependencies: list[Dependency] = []
1087
+ dependencies: list[ScannedDependency] = []
951
1088
 
952
1089
  for module_info in pkgutil.walk_packages(
953
1090
  path=package_path, prefix=package.__name__ + "."
@@ -959,9 +1096,9 @@ class Container:
959
1096
 
960
1097
  def _scan_module(
961
1098
  self, module: ModuleType, *, tags: Iterable[str]
962
- ) -> list[Dependency]:
1099
+ ) -> list[ScannedDependency]:
963
1100
  """Scan a module for decorated members."""
964
- dependencies: list[Dependency] = []
1101
+ dependencies: list[ScannedDependency] = []
965
1102
 
966
1103
  for _, member in inspect.getmembers(module):
967
1104
  if getattr(member, "__module__", None) != module.__name__ or not callable(
@@ -984,7 +1121,7 @@ class Container:
984
1121
 
985
1122
  if decorator_args.wrapped:
986
1123
  dependencies.append(
987
- self._create_dependency(member=member, module=module)
1124
+ self._create_scanned_dependency(member=member, module=module)
988
1125
  )
989
1126
  continue
990
1127
 
@@ -992,17 +1129,24 @@ class Container:
992
1129
  for parameter in get_typed_parameters(member):
993
1130
  if is_marker(parameter.default):
994
1131
  dependencies.append(
995
- self._create_dependency(member=member, module=module)
1132
+ self._create_scanned_dependency(member=member, module=module)
996
1133
  )
997
1134
  continue
998
1135
 
999
1136
  return dependencies
1000
1137
 
1001
- def _create_dependency(self, member: Any, module: ModuleType) -> Dependency:
1138
+ def _create_scanned_dependency(
1139
+ self, member: Any, module: ModuleType
1140
+ ) -> ScannedDependency:
1002
1141
  """Create a `Dependency` object from the scanned member and module."""
1003
1142
  if hasattr(member, "__wrapped__"):
1004
1143
  member = member.__wrapped__
1005
- return Dependency(member=member, module=module)
1144
+ return ScannedDependency(member=member, module=module)
1145
+
1146
+
1147
+ ############################
1148
+ # Decorators
1149
+ ############################
1006
1150
 
1007
1151
 
1008
1152
  def transient(target: T) -> T: