anydi 0.38.1__py3-none-any.whl → 0.38.2rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
anydi/_container.py CHANGED
@@ -10,23 +10,27 @@ import logging
10
10
  import pkgutil
11
11
  import threading
12
12
  import types
13
+ import uuid
13
14
  from collections import defaultdict
14
15
  from collections.abc import AsyncIterator, Iterable, Iterator, Sequence
15
16
  from contextvars import ContextVar
16
17
  from types import ModuleType
17
- from typing import Any, Callable, TypeVar, Union, cast, overload
18
- from weakref import WeakKeyDictionary
18
+ from typing import Annotated, Any, Callable, TypeVar, Union, cast, overload
19
19
 
20
- from typing_extensions import Concatenate, ParamSpec, Self, final
20
+ from typing_extensions import Concatenate, ParamSpec, Self, final, get_args, get_origin
21
21
 
22
22
  from ._context import InstanceContext
23
- from ._provider import Provider
24
23
  from ._types import (
24
+ NOT_SET,
25
25
  AnyInterface,
26
- Dependency,
26
+ Event,
27
27
  InjectableDecoratorArgs,
28
28
  InstanceProxy,
29
+ Provider,
30
+ ProviderArgs,
29
31
  ProviderDecoratorArgs,
32
+ ProviderKind,
33
+ ScannedDependency,
30
34
  Scope,
31
35
  is_event_type,
32
36
  is_marker,
@@ -34,6 +38,7 @@ from ._types import (
34
38
  from ._utils import (
35
39
  AsyncRLock,
36
40
  get_full_qualname,
41
+ get_typed_annotation,
37
42
  get_typed_parameters,
38
43
  import_string,
39
44
  is_async_context_manager,
@@ -42,6 +47,12 @@ from ._utils import (
42
47
  run_async,
43
48
  )
44
49
 
50
+ try:
51
+ from types import NoneType
52
+ except ImportError:
53
+ NoneType = type(None) # type: ignore[misc]
54
+
55
+
45
56
  T = TypeVar("T", bound=Any)
46
57
  M = TypeVar("M", bound="Module")
47
58
  P = ParamSpec("P")
@@ -74,7 +85,6 @@ class Module(metaclass=ModuleMeta):
74
85
  """Configure the AnyDI container with providers and their dependencies."""
75
86
 
76
87
 
77
- # noinspection PyShadowingNames
78
88
  @final
79
89
  class Container:
80
90
  """AnyDI is a dependency injection container."""
@@ -82,7 +92,7 @@ class Container:
82
92
  def __init__(
83
93
  self,
84
94
  *,
85
- providers: Sequence[Provider] | None = None,
95
+ providers: Sequence[ProviderArgs] | None = None,
86
96
  modules: Sequence[Module | type[Module] | Callable[[Container], None] | str]
87
97
  | None = None,
88
98
  strict: bool = False,
@@ -104,20 +114,26 @@ class Container:
104
114
  )
105
115
  self._override_instances: dict[type[Any], Any] = {}
106
116
  self._unresolved_interfaces: set[type[Any]] = set()
107
- self._inject_cache: WeakKeyDictionary[
108
- Callable[..., Any], Callable[..., Any]
109
- ] = WeakKeyDictionary()
117
+ self._inject_cache: dict[Callable[..., Any], Callable[..., Any]] = {}
110
118
 
111
119
  # Register providers
112
120
  providers = providers or []
113
121
  for provider in providers:
114
- self._register_provider(provider, False)
122
+ self._register_provider(
123
+ provider.call,
124
+ provider.scope,
125
+ provider.interface,
126
+ )
115
127
 
116
128
  # Register modules
117
129
  modules = modules or []
118
130
  for module in modules:
119
131
  self.register_module(module)
120
132
 
133
+ ############################
134
+ # Properties
135
+ ############################
136
+
121
137
  @property
122
138
  def strict(self) -> bool:
123
139
  """Check if strict mode is enabled."""
@@ -143,9 +159,110 @@ class Container:
143
159
  """Get the logger instance."""
144
160
  return self._logger
145
161
 
146
- def is_registered(self, interface: AnyInterface) -> bool:
147
- """Check if a provider is registered for the specified interface."""
148
- return interface in self._providers
162
+ ############################
163
+ # Lifespan/Context Methods
164
+ ############################
165
+
166
+ def __enter__(self) -> Self:
167
+ """Enter the singleton context."""
168
+ self.start()
169
+ return self
170
+
171
+ def __exit__(
172
+ self,
173
+ exc_type: type[BaseException] | None,
174
+ exc_val: BaseException | None,
175
+ exc_tb: types.TracebackType | None,
176
+ ) -> Any:
177
+ """Exit the singleton context."""
178
+ return self._singleton_context.__exit__(exc_type, exc_val, exc_tb)
179
+
180
+ def start(self) -> None:
181
+ """Start the singleton context."""
182
+ # Resolve all singleton resources
183
+ for interface in self._resources.get("singleton", []):
184
+ self.resolve(interface)
185
+
186
+ def close(self) -> None:
187
+ """Close the singleton context."""
188
+ self._singleton_context.close()
189
+
190
+ async def __aenter__(self) -> Self:
191
+ """Enter the singleton context."""
192
+ await self.astart()
193
+ return self
194
+
195
+ async def __aexit__(
196
+ self,
197
+ exc_type: type[BaseException] | None,
198
+ exc_val: BaseException | None,
199
+ exc_tb: types.TracebackType | None,
200
+ ) -> bool:
201
+ """Exit the singleton context."""
202
+ return await self._singleton_context.__aexit__(exc_type, exc_val, exc_tb)
203
+
204
+ async def astart(self) -> None:
205
+ """Start the singleton context asynchronously."""
206
+ for interface in self._resources.get("singleton", []):
207
+ await self.aresolve(interface)
208
+
209
+ async def aclose(self) -> None:
210
+ """Close the singleton context asynchronously."""
211
+ await self._singleton_context.aclose()
212
+
213
+ @contextlib.contextmanager
214
+ def request_context(self) -> Iterator[InstanceContext]:
215
+ """Obtain a context manager for the request-scoped context."""
216
+ context = InstanceContext()
217
+
218
+ token = self._request_context_var.set(context)
219
+
220
+ # Resolve all request resources
221
+ for interface in self._resources.get("request", []):
222
+ if not is_event_type(interface):
223
+ continue
224
+ self.resolve(interface)
225
+
226
+ with context:
227
+ yield context
228
+ self._request_context_var.reset(token)
229
+
230
+ @contextlib.asynccontextmanager
231
+ async def arequest_context(self) -> AsyncIterator[InstanceContext]:
232
+ """Obtain an async context manager for the request-scoped context."""
233
+ context = InstanceContext()
234
+
235
+ token = self._request_context_var.set(context)
236
+
237
+ for interface in self._resources.get("request", []):
238
+ if not is_event_type(interface):
239
+ continue
240
+ await self.aresolve(interface)
241
+
242
+ async with context:
243
+ yield context
244
+ self._request_context_var.reset(token)
245
+
246
+ def _get_request_context(self) -> InstanceContext:
247
+ """Get the current request context."""
248
+ request_context = self._request_context_var.get()
249
+ if request_context is None:
250
+ raise LookupError(
251
+ "The request context has not been started. Please ensure that "
252
+ "the request context is properly initialized before attempting "
253
+ "to use it."
254
+ )
255
+ return request_context
256
+
257
+ def _get_scoped_context(self, scope: Scope) -> InstanceContext:
258
+ """Get the instance context for the specified scope."""
259
+ if scope == "singleton":
260
+ return self._singleton_context
261
+ return self._get_request_context()
262
+
263
+ ############################
264
+ # Provider Methods
265
+ ############################
149
266
 
150
267
  def register(
151
268
  self,
@@ -156,26 +273,11 @@ class Container:
156
273
  override: bool = False,
157
274
  ) -> Provider:
158
275
  """Register a provider for the specified interface."""
159
- provider = Provider(call=call, scope=scope, interface=interface)
160
- return self._register_provider(provider, override)
276
+ return self._register_provider(call, scope, interface, override)
161
277
 
162
- def _register_provider(
163
- self, provider: Provider, override: bool, /, **defaults: Any
164
- ) -> Provider:
165
- """Register a provider."""
166
- if provider.interface in self._providers:
167
- if override:
168
- self._set_provider(provider)
169
- return provider
170
-
171
- raise LookupError(
172
- f"The provider interface `{get_full_qualname(provider.interface)}` "
173
- "already registered."
174
- )
175
-
176
- self._validate_sub_providers(provider, **defaults)
177
- self._set_provider(provider)
178
- return provider
278
+ def is_registered(self, interface: AnyInterface) -> bool:
279
+ """Check if a provider is registered for the specified interface."""
280
+ return interface in self._providers
179
281
 
180
282
  def unregister(self, interface: AnyInterface) -> None:
181
283
  """Unregister a provider by interface."""
@@ -199,254 +301,233 @@ class Container:
199
301
  # Cleanup provider references
200
302
  self._delete_provider(provider)
201
303
 
202
- def _get_provider(self, interface: AnyInterface) -> Provider:
203
- """Get provider by interface."""
204
- try:
205
- return self._providers[interface]
206
- except KeyError as exc:
207
- raise LookupError(
208
- f"The provider interface for `{get_full_qualname(interface)}` has "
209
- "not been registered. Please ensure that the provider interface is "
210
- "properly registered before attempting to use it."
211
- ) from exc
304
+ def provider(
305
+ self, *, scope: Scope, override: bool = False
306
+ ) -> Callable[[Callable[P, T]], Callable[P, T]]:
307
+ """Decorator to register a provider function with the specified scope."""
212
308
 
213
- def _get_or_register_provider(
214
- self, interface: AnyInterface, parent_scope: Scope | None, /, **defaults: Any
309
+ def decorator(call: Callable[P, T]) -> Callable[P, T]:
310
+ self._register_provider(call, scope, NOT_SET, override)
311
+ return call
312
+
313
+ return decorator
314
+
315
+ def _register_provider( # noqa: C901
316
+ self,
317
+ call: Callable[..., Any],
318
+ scope: Scope | None,
319
+ interface: Any = NOT_SET,
320
+ override: bool = False,
321
+ /,
322
+ **defaults: Any,
215
323
  ) -> Provider:
216
- """Get or register a provider by interface."""
217
- try:
218
- return self._get_provider(interface)
219
- except LookupError:
220
- if (
221
- not self.strict
222
- and inspect.isclass(interface)
223
- and not is_builtin_type(interface)
224
- and interface is not inspect.Parameter.empty
225
- ):
226
- # Try to get defined scope
227
- scope = getattr(interface, "__scope__", parent_scope)
228
- # Try to detect scope
229
- if scope is None:
230
- scope = self._detect_scope(interface, **defaults)
231
- scope = scope or self.default_scope
232
- provider = Provider(call=interface, scope=scope, interface=interface)
233
- return self._register_provider(provider, False, **defaults)
234
- raise
324
+ """Register a provider with the specified scope."""
325
+ name = get_full_qualname(call)
326
+ kind = ProviderKind.from_call(call)
327
+ detected_scope = scope
328
+
329
+ # Validate scope if it provided
330
+ if scope:
331
+ self._validate_provider_scope(scope, name, kind)
332
+
333
+ # Get the signature
334
+ globalns = getattr(call, "__globals__", {})
335
+ module = getattr(call, "__module__", None)
336
+ signature = inspect.signature(call, globals=globalns)
337
+
338
+ # Detect the interface
339
+ if interface is NOT_SET:
340
+ if kind == ProviderKind.CLASS:
341
+ interface = call
342
+ else:
343
+ interface = signature.return_annotation
344
+ if interface is inspect.Signature.empty:
345
+ interface = None
346
+ else:
347
+ interface = get_typed_annotation(interface, globalns, module)
348
+
349
+ # If the callable is an iterator, return the actual type
350
+ iterator_types = {Iterator, AsyncIterator}
351
+ if interface in iterator_types or get_origin(interface) in iterator_types:
352
+ if args := get_args(interface):
353
+ interface = args[0]
354
+ # If the callable is a generator, return the resource type
355
+ if interface in {None, NoneType}:
356
+ interface = type(f"Event_{uuid.uuid4().hex}", (Event,), {})
357
+ else:
358
+ raise TypeError(
359
+ f"Cannot use `{name}` resource type annotation "
360
+ "without actual type argument."
361
+ )
235
362
 
236
- def _set_provider(self, provider: Provider) -> None:
237
- """Set a provider by interface."""
238
- self._providers[provider.interface] = provider
239
- if provider.is_resource:
240
- self._resources[provider.scope].append(provider.interface)
363
+ # None interface is not allowed
364
+ if interface in {None, NoneType}:
365
+ raise TypeError(f"Missing `{name}` provider return annotation.")
241
366
 
242
- def _delete_provider(self, provider: Provider) -> None:
243
- """Delete a provider."""
244
- if provider.interface in self._providers:
245
- del self._providers[provider.interface]
246
- if provider.is_resource:
247
- self._resources[provider.scope].remove(provider.interface)
367
+ # Check for existing provider
368
+ if interface in self._providers and not override:
369
+ raise LookupError(
370
+ f"The provider interface `{get_full_qualname(interface)}` "
371
+ "already registered."
372
+ )
248
373
 
249
- def _validate_sub_providers(self, provider: Provider, /, **defaults: Any) -> None:
250
- """Validate the sub-providers of a provider."""
374
+ unresolved_parameter = None
375
+ parameters = []
376
+ scopes = {}
251
377
 
252
- for parameter in provider.parameters:
378
+ for parameter in signature.parameters.values():
253
379
  if parameter.annotation is inspect.Parameter.empty:
254
380
  raise TypeError(
255
- f"Missing provider `{provider}` "
381
+ f"Missing provider `{name}` "
256
382
  f"dependency `{parameter.name}` annotation."
257
383
  )
258
-
259
- try:
260
- sub_provider = self._get_or_register_provider(
261
- parameter.annotation, provider.scope
262
- )
263
- except LookupError:
264
- if self._parameter_has_default(parameter, **defaults):
265
- continue
266
-
267
- if provider.scope not in {"singleton", "transient"}:
268
- self._unresolved_interfaces.add(provider.interface)
269
- continue
270
- raise ValueError(
271
- f"The provider `{provider}` depends on `{parameter.name}` of type "
272
- f"`{get_full_qualname(parameter.annotation)}`, which "
273
- "has not been registered or set. To resolve this, ensure that "
274
- f"`{parameter.name}` is registered before attempting to use it."
275
- ) from None
276
-
277
- # Check scope compatibility
278
- if sub_provider.scope not in ALLOWED_SCOPES.get(provider.scope, []):
279
- raise ValueError(
280
- f"The provider `{provider}` with a `{provider.scope}` scope cannot "
281
- f"depend on `{sub_provider}` with a `{sub_provider.scope}` scope. "
282
- "Please ensure all providers are registered with matching scopes."
384
+ if parameter.kind == inspect.Parameter.POSITIONAL_ONLY:
385
+ raise TypeError(
386
+ "Positional-only parameters "
387
+ f"are not allowed in the provider `{name}`."
283
388
  )
284
389
 
285
- def _detect_scope(self, call: Callable[..., Any], **defaults: Any) -> Scope | None:
286
- """Detect the scope for a callable."""
287
- scopes = set()
390
+ parameter = parameter.replace(
391
+ annotation=get_typed_annotation(parameter.annotation, globalns, module)
392
+ )
288
393
 
289
- for parameter in get_typed_parameters(call):
290
394
  try:
291
395
  sub_provider = self._get_or_register_provider(
292
- parameter.annotation, None
396
+ parameter.annotation, scope
293
397
  )
294
398
  except LookupError:
295
399
  if self._parameter_has_default(parameter, **defaults):
296
400
  continue
297
- raise
298
- scope = sub_provider.scope
299
-
300
- if scope == "transient":
301
- return "transient"
302
- scopes.add(scope)
303
-
304
- # If all scopes are found, we can return based on priority order
305
- if {"transient", "request", "singleton"}.issubset(scopes):
306
- break # pragma: no cover
307
-
308
- # Determine scope based on priority
309
- if "request" in scopes:
310
- return "request"
311
- if "singleton" in scopes:
312
- return "singleton"
313
-
314
- return None
315
-
316
- def _parameter_has_default(
317
- self, parameter: inspect.Parameter, /, **defaults: Any
318
- ) -> bool:
319
- return (defaults and parameter.name in defaults) or (
320
- not self.strict and parameter.default is not inspect.Parameter.empty
321
- )
322
-
323
- def register_module(
324
- self, module: Module | type[Module] | Callable[[Container], None] | str
325
- ) -> None:
326
- """Register a module as a callable, module type, or module instance."""
327
- # Callable Module
328
- if inspect.isfunction(module):
329
- module(self)
330
- return
331
-
332
- # Module path
333
- if isinstance(module, str):
334
- module = import_string(module)
335
-
336
- # Class based Module or Module type
337
- if inspect.isclass(module) and issubclass(module, Module):
338
- module = module()
339
-
340
- if isinstance(module, Module):
341
- module.configure(self)
342
- for provider_name, decorator_args in module.providers:
343
- obj = getattr(module, provider_name)
344
- self.provider(
345
- scope=decorator_args.scope,
346
- override=decorator_args.override,
347
- )(obj)
348
-
349
- def __enter__(self) -> Self:
350
- """Enter the singleton context."""
351
- self.start()
352
- return self
353
-
354
- def __exit__(
355
- self,
356
- exc_type: type[BaseException] | None,
357
- exc_val: BaseException | None,
358
- exc_tb: types.TracebackType | None,
359
- ) -> Any:
360
- """Exit the singleton context."""
361
- return self._singleton_context.__exit__(exc_type, exc_val, exc_tb)
362
-
363
- def start(self) -> None:
364
- """Start the singleton context."""
365
- # Resolve all singleton resources
366
- for interface in self._resources.get("singleton", []):
367
- self.resolve(interface)
368
-
369
- def close(self) -> None:
370
- """Close the singleton context."""
371
- self._singleton_context.close()
372
-
373
- @contextlib.contextmanager
374
- def request_context(self) -> Iterator[InstanceContext]:
375
- """Obtain a context manager for the request-scoped context."""
376
- context = InstanceContext()
377
-
378
- token = self._request_context_var.set(context)
379
-
380
- # Resolve all request resources
381
- for interface in self._resources.get("request", []):
382
- if not is_event_type(interface):
401
+ unresolved_parameter = parameter
383
402
  continue
384
- self.resolve(interface)
385
403
 
386
- with context:
387
- yield context
388
- self._request_context_var.reset(token)
404
+ # Store first provider for each scope
405
+ if sub_provider.scope not in scopes:
406
+ scopes[sub_provider.scope] = sub_provider
389
407
 
390
- async def __aenter__(self) -> Self:
391
- """Enter the singleton context."""
392
- await self.astart()
393
- return self
408
+ parameters.append(parameter)
394
409
 
395
- async def __aexit__(
396
- self,
397
- exc_type: type[BaseException] | None,
398
- exc_val: BaseException | None,
399
- exc_tb: types.TracebackType | None,
400
- ) -> bool:
401
- """Exit the singleton context."""
402
- return await self._singleton_context.__aexit__(exc_type, exc_val, exc_tb)
410
+ # Set detected scope
411
+ if detected_scope is None:
412
+ if "transient" in scopes:
413
+ detected_scope = "transient"
414
+ elif "request" in scopes:
415
+ detected_scope = "request"
416
+ elif "singleton" in scopes:
417
+ detected_scope = "singleton"
418
+ else:
419
+ detected_scope = self.default_scope
403
420
 
404
- async def astart(self) -> None:
405
- """Start the singleton context asynchronously."""
406
- for interface in self._resources.get("singleton", []):
407
- await self.aresolve(interface)
421
+ # Validate the provider scope after detection
422
+ if scope is None:
423
+ self._validate_provider_scope(detected_scope, name, kind)
408
424
 
409
- async def aclose(self) -> None:
410
- """Close the singleton context asynchronously."""
411
- await self._singleton_context.aclose()
425
+ # Check for unresolved parameters
426
+ if unresolved_parameter:
427
+ if detected_scope not in {"singleton", "transient"}:
428
+ self._unresolved_interfaces.add(interface)
429
+ else:
430
+ raise LookupError(
431
+ f"The provider `{name}` depends on `{unresolved_parameter.name}` "
432
+ f"of type `{get_full_qualname(unresolved_parameter.annotation)}`, "
433
+ "which has not been registered or set. To resolve this, ensure "
434
+ f"that `{unresolved_parameter.name}` is registered before "
435
+ f"attempting to use it."
436
+ ) from None
412
437
 
413
- @contextlib.asynccontextmanager
414
- async def arequest_context(self) -> AsyncIterator[InstanceContext]:
415
- """Obtain an async context manager for the request-scoped context."""
416
- context = InstanceContext()
438
+ # Check scope compatibility
439
+ for sub_provider in scopes.values():
440
+ if sub_provider.scope not in ALLOWED_SCOPES.get(detected_scope, []):
441
+ raise ValueError(
442
+ f"The provider `{name}` with a `{detected_scope}` scope cannot "
443
+ f"depend on `{sub_provider}` with a `{sub_provider.scope}` scope. "
444
+ "Please ensure all providers are registered with matching scopes."
445
+ )
417
446
 
418
- token = self._request_context_var.set(context)
447
+ provider = Provider(
448
+ call=call,
449
+ scope=detected_scope,
450
+ interface=interface,
451
+ name=name,
452
+ kind=kind,
453
+ parameters=parameters,
454
+ )
419
455
 
420
- for interface in self._resources.get("request", []):
421
- if not is_event_type(interface):
422
- continue
423
- await self.aresolve(interface)
456
+ self._set_provider(provider)
457
+ return provider
424
458
 
425
- async with context:
426
- yield context
427
- self._request_context_var.reset(token)
459
+ def _validate_provider_scope(
460
+ self, scope: Scope, name: str, kind: ProviderKind
461
+ ) -> None:
462
+ """Validate the provider scope."""
463
+ if scope not in (allowed_scopes := get_args(Scope)):
464
+ raise ValueError(
465
+ f"The provider `{name}` scope is invalid. Only the following "
466
+ f"scopes are supported: {', '.join(allowed_scopes)}. "
467
+ "Please use one of the supported scopes when registering a provider."
468
+ )
469
+ if (
470
+ kind in {ProviderKind.GENERATOR, ProviderKind.ASYNC_GENERATOR}
471
+ and scope == "transient"
472
+ ):
473
+ raise TypeError(
474
+ f"The resource provider `{name}` is attempting to register "
475
+ "with a transient scope, which is not allowed."
476
+ )
428
477
 
429
- def _get_request_context(self) -> InstanceContext:
430
- """Get the current request context."""
431
- request_context = self._request_context_var.get()
432
- if request_context is None:
478
+ def _get_provider(self, interface: AnyInterface) -> Provider:
479
+ """Get provider by interface."""
480
+ try:
481
+ return self._providers[interface]
482
+ except KeyError as exc:
433
483
  raise LookupError(
434
- "The request context has not been started. Please ensure that "
435
- "the request context is properly initialized before attempting "
436
- "to use it."
437
- )
438
- return request_context
484
+ f"The provider interface for `{get_full_qualname(interface)}` has "
485
+ "not been registered. Please ensure that the provider interface is "
486
+ "properly registered before attempting to use it."
487
+ ) from exc
439
488
 
440
- def reset(self) -> None:
441
- """Reset resolved instances."""
442
- for interface, provider in self._providers.items():
443
- if provider.scope == "transient":
444
- continue
445
- try:
446
- context = self._get_scoped_context(provider.scope)
447
- except LookupError:
448
- continue
449
- del context[interface]
489
+ def _get_or_register_provider(
490
+ self, interface: AnyInterface, parent_scope: Scope | None, /, **defaults: Any
491
+ ) -> Provider:
492
+ """Get or register a provider by interface."""
493
+ try:
494
+ return self._get_provider(interface)
495
+ except LookupError:
496
+ if self.strict or interface is inspect.Parameter.empty:
497
+ raise
498
+ if get_origin(interface) is Annotated and (args := get_args(interface)):
499
+ call = args[0]
500
+ else:
501
+ call = interface
502
+ if inspect.isclass(call) and not is_builtin_type(call):
503
+ # Try to get defined scope
504
+ scope = getattr(interface, "__scope__", parent_scope)
505
+ return self._register_provider(call, scope, interface, **defaults)
506
+ raise
507
+
508
+ def _set_provider(self, provider: Provider) -> None:
509
+ """Set a provider by interface."""
510
+ self._providers[provider.interface] = provider
511
+ if provider.is_resource:
512
+ self._resources[provider.scope].append(provider.interface)
513
+
514
+ def _delete_provider(self, provider: Provider) -> None:
515
+ """Delete a provider."""
516
+ if provider.interface in self._providers:
517
+ del self._providers[provider.interface]
518
+ if provider.is_resource:
519
+ self._resources[provider.scope].remove(provider.interface)
520
+
521
+ def _parameter_has_default(
522
+ self, parameter: inspect.Parameter, /, **defaults: Any
523
+ ) -> bool:
524
+ return (defaults and parameter.name in defaults) or (
525
+ not self.strict and parameter.default is not inspect.Parameter.empty
526
+ )
527
+
528
+ ############################
529
+ # Instance Methods
530
+ ############################
450
531
 
451
532
  @overload
452
533
  def resolve(self, interface: type[T]) -> T: ...
@@ -456,19 +537,7 @@ class Container:
456
537
 
457
538
  def resolve(self, interface: type[T]) -> T:
458
539
  """Resolve an instance by interface."""
459
- provider = self._get_or_register_provider(interface, None)
460
- if provider.scope == "transient":
461
- instance = self._create_instance(provider, None)
462
- else:
463
- context = self._get_scoped_context(provider.scope)
464
- if provider.scope == "singleton":
465
- with self._singleton_lock:
466
- instance = self._get_or_create_instance(provider, context)
467
- else:
468
- instance = self._get_or_create_instance(provider, context)
469
- if self.testing:
470
- instance = self._patch_test_resolver(provider.interface, instance)
471
- return cast(T, instance)
540
+ return self._resolve_or_create(interface, False)
472
541
 
473
542
  @overload
474
543
  async def aresolve(self, interface: type[T]) -> T: ...
@@ -478,22 +547,50 @@ class Container:
478
547
 
479
548
  async def aresolve(self, interface: type[T]) -> T:
480
549
  """Resolve an instance by interface asynchronously."""
481
- provider = self._get_or_register_provider(interface, None)
482
- if provider.scope == "transient":
483
- instance = await self._acreate_instance(provider, None)
484
- else:
485
- context = self._get_scoped_context(provider.scope)
486
- if provider.scope == "singleton":
487
- async with self._singleton_async_lock:
488
- instance = await self._aget_or_create_instance(provider, context)
489
- else:
490
- instance = await self._aget_or_create_instance(provider, context)
491
- if self.testing:
492
- instance = self._patch_test_resolver(interface, instance)
493
- return cast(T, instance)
550
+ return await self._aresolve_or_acreate(interface, False)
494
551
 
495
- def create(self, interface: type[T], **defaults: Any) -> T:
552
+ def create(self, interface: type[T], /, **defaults: Any) -> T:
496
553
  """Create an instance by interface."""
554
+ return self._resolve_or_create(interface, True, **defaults)
555
+
556
+ async def acreate(self, interface: type[T], /, **defaults: Any) -> T:
557
+ """Create an instance by interface asynchronously."""
558
+ return await self._aresolve_or_acreate(interface, True, **defaults)
559
+
560
+ def is_resolved(self, interface: AnyInterface) -> bool:
561
+ """Check if an instance by interface exists."""
562
+ try:
563
+ provider = self._get_provider(interface)
564
+ except LookupError:
565
+ return False
566
+ if provider.scope == "transient":
567
+ return False
568
+ context = self._get_scoped_context(provider.scope)
569
+ return interface in context
570
+
571
+ def release(self, interface: AnyInterface) -> None:
572
+ """Release an instance by interface."""
573
+ provider = self._get_provider(interface)
574
+ if provider.scope == "transient":
575
+ return None
576
+ context = self._get_scoped_context(provider.scope)
577
+ del context[interface]
578
+
579
+ def reset(self) -> None:
580
+ """Reset resolved instances."""
581
+ for interface, provider in self._providers.items():
582
+ if provider.scope == "transient":
583
+ continue
584
+ try:
585
+ context = self._get_scoped_context(provider.scope)
586
+ except LookupError:
587
+ continue
588
+ del context[interface]
589
+
590
+ def _resolve_or_create(
591
+ self, interface: type[T], create: bool, /, **defaults: Any
592
+ ) -> T:
593
+ """Internal method to handle instance resolution and creation."""
497
594
  provider = self._get_or_register_provider(interface, None, **defaults)
498
595
  if provider.scope == "transient":
499
596
  instance = self._create_instance(provider, None, **defaults)
@@ -501,15 +598,27 @@ class Container:
501
598
  context = self._get_scoped_context(provider.scope)
502
599
  if provider.scope == "singleton":
503
600
  with self._singleton_lock:
504
- instance = self._create_instance(provider, context, **defaults)
601
+ instance = (
602
+ self._get_or_create_instance(provider, context)
603
+ if not create
604
+ else self._create_instance(provider, context, **defaults)
605
+ )
505
606
  else:
506
- instance = self._create_instance(provider, context, **defaults)
607
+ instance = (
608
+ self._get_or_create_instance(provider, context)
609
+ if not create
610
+ else self._create_instance(provider, context, **defaults)
611
+ )
612
+
507
613
  if self.testing:
508
614
  instance = self._patch_test_resolver(provider.interface, instance)
615
+
509
616
  return cast(T, instance)
510
617
 
511
- async def acreate(self, interface: type[T], **defaults: Any) -> T:
512
- """Create an instance by interface."""
618
+ async def _aresolve_or_acreate(
619
+ self, interface: type[T], create: bool, /, **defaults: Any
620
+ ) -> T:
621
+ """Internal method to handle instance resolution and creation asynchronously."""
513
622
  provider = self._get_or_register_provider(interface, None, **defaults)
514
623
  if provider.scope == "transient":
515
624
  instance = await self._acreate_instance(provider, None, **defaults)
@@ -517,13 +626,21 @@ class Container:
517
626
  context = self._get_scoped_context(provider.scope)
518
627
  if provider.scope == "singleton":
519
628
  async with self._singleton_async_lock:
520
- instance = await self._acreate_instance(
521
- provider, context, **defaults
629
+ instance = (
630
+ await self._aget_or_create_instance(provider, context)
631
+ if not create
632
+ else await self._acreate_instance(provider, context, **defaults)
522
633
  )
523
634
  else:
524
- instance = await self._acreate_instance(provider, context, **defaults)
635
+ instance = (
636
+ await self._aget_or_create_instance(provider, context)
637
+ if not create
638
+ else await self._acreate_instance(provider, context, **defaults)
639
+ )
640
+
525
641
  if self.testing:
526
642
  instance = self._patch_test_resolver(provider.interface, instance)
643
+
527
644
  return cast(T, instance)
528
645
 
529
646
  def _get_or_create_instance(
@@ -700,26 +817,48 @@ class Container:
700
817
  def _resolve_parameter(
701
818
  self, provider: Provider, parameter: inspect.Parameter
702
819
  ) -> Any:
703
- self._validate_resolvable_parameter(parameter, call=provider.call)
820
+ self._validate_resolvable_parameter(provider, parameter)
704
821
  return self.resolve(parameter.annotation)
705
822
 
706
823
  async def _aresolve_parameter(
707
824
  self, provider: Provider, parameter: inspect.Parameter
708
825
  ) -> Any:
709
- self._validate_resolvable_parameter(parameter, call=provider.call)
826
+ self._validate_resolvable_parameter(provider, parameter)
710
827
  return await self.aresolve(parameter.annotation)
711
828
 
712
829
  def _validate_resolvable_parameter(
713
- self, parameter: inspect.Parameter, call: Callable[..., Any]
830
+ self, provider: Provider, parameter: inspect.Parameter
714
831
  ) -> None:
715
832
  """Ensure that the specified interface is resolved."""
716
833
  if parameter.annotation in self._unresolved_interfaces:
717
834
  raise LookupError(
718
835
  f"You are attempting to get the parameter `{parameter.name}` with the "
719
836
  f"annotation `{get_full_qualname(parameter.annotation)}` as a "
720
- f"dependency into `{get_full_qualname(call)}` which is not registered "
721
- "or set in the scoped context."
837
+ f"dependency into `{get_full_qualname(provider.call)}` which is "
838
+ "not registered or set in the scoped context."
839
+ )
840
+
841
+ @contextlib.contextmanager
842
+ def override(self, interface: AnyInterface, instance: Any) -> Iterator[None]:
843
+ """
844
+ Override the provider for the specified interface with a specific instance.
845
+ """
846
+ if not self.testing:
847
+ raise RuntimeError(
848
+ "The `override` method can only be used in testing mode."
849
+ )
850
+ if not self.is_registered(interface) and self.strict:
851
+ raise LookupError(
852
+ f"The provider interface `{get_full_qualname(interface)}` "
853
+ "not registered."
722
854
  )
855
+ self._override_instances[interface] = instance
856
+ yield
857
+ del self._override_instances[interface]
858
+
859
+ ############################
860
+ # Testing Methods
861
+ ############################
723
862
 
724
863
  def _patch_test_resolver(self, interface: type[Any], instance: Any) -> Any:
725
864
  """Patch the test resolver for the instance."""
@@ -769,60 +908,9 @@ class Container:
769
908
 
770
909
  return instance
771
910
 
772
- def is_resolved(self, interface: AnyInterface) -> bool:
773
- """Check if an instance by interface exists."""
774
- try:
775
- provider = self._get_provider(interface)
776
- except LookupError:
777
- return False
778
- if provider.scope == "transient":
779
- return False
780
- context = self._get_scoped_context(provider.scope)
781
- return interface in context
782
-
783
- def release(self, interface: AnyInterface) -> None:
784
- """Release an instance by interface."""
785
- provider = self._get_provider(interface)
786
- if provider.scope == "transient":
787
- return None
788
- context = self._get_scoped_context(provider.scope)
789
- del context[interface]
790
-
791
- def _get_scoped_context(self, scope: Scope) -> InstanceContext:
792
- """Get the instance context for the specified scope."""
793
- if scope == "singleton":
794
- return self._singleton_context
795
- return self._get_request_context()
796
-
797
- @contextlib.contextmanager
798
- def override(self, interface: AnyInterface, instance: Any) -> Iterator[None]:
799
- """
800
- Override the provider for the specified interface with a specific instance.
801
- """
802
- if not self.testing:
803
- raise RuntimeError(
804
- "The `override` method can only be used in testing mode."
805
- )
806
- if not self.is_registered(interface) and self.strict:
807
- raise LookupError(
808
- f"The provider interface `{get_full_qualname(interface)}` "
809
- "not registered."
810
- )
811
- self._override_instances[interface] = instance
812
- yield
813
- del self._override_instances[interface]
814
-
815
- def provider(
816
- self, *, scope: Scope, override: bool = False
817
- ) -> Callable[[Callable[P, T]], Callable[P, T]]:
818
- """Decorator to register a provider function with the specified scope."""
819
-
820
- def decorator(call: Callable[P, T]) -> Callable[P, T]:
821
- provider = Provider(call=call, scope=scope)
822
- self._register_provider(provider, override)
823
- return call
824
-
825
- return decorator
911
+ ############################
912
+ # Injector Methods
913
+ ############################
826
914
 
827
915
  @overload
828
916
  def inject(self, func: Callable[P, T]) -> Callable[P, T]: ...
@@ -842,6 +930,10 @@ class Container:
842
930
  return decorator
843
931
  return decorator(func)
844
932
 
933
+ def run(self, func: Callable[P, T], /, *args: P.args, **kwargs: P.kwargs) -> T:
934
+ """Run the given function with injected dependencies."""
935
+ return self._inject(func)(*args, **kwargs)
936
+
845
937
  def _inject(self, call: Callable[P, T]) -> Callable[P, T]:
846
938
  """Inject dependencies into a callable."""
847
939
  if call in self._inject_cache:
@@ -909,9 +1001,39 @@ class Container:
909
1001
  f"`{get_full_qualname(parameter.annotation)}`."
910
1002
  )
911
1003
 
912
- def run(self, func: Callable[P, T], /, *args: P.args, **kwargs: P.kwargs) -> T:
913
- """Run the given function with injected dependencies."""
914
- return self._inject(func)(*args, **kwargs)
1004
+ ############################
1005
+ # Module Methods
1006
+ ############################
1007
+
1008
+ def register_module(
1009
+ self, module: Module | type[Module] | Callable[[Container], None] | str
1010
+ ) -> None:
1011
+ """Register a module as a callable, module type, or module instance."""
1012
+ # Callable Module
1013
+ if inspect.isfunction(module):
1014
+ module(self)
1015
+ return
1016
+
1017
+ # Module path
1018
+ if isinstance(module, str):
1019
+ module = import_string(module)
1020
+
1021
+ # Class based Module or Module type
1022
+ if inspect.isclass(module) and issubclass(module, Module):
1023
+ module = module()
1024
+
1025
+ if isinstance(module, Module):
1026
+ module.configure(self)
1027
+ for provider_name, decorator_args in module.providers:
1028
+ obj = getattr(module, provider_name)
1029
+ self.provider(
1030
+ scope=decorator_args.scope,
1031
+ override=decorator_args.override,
1032
+ )(obj)
1033
+
1034
+ ############################
1035
+ # Scanner Methods
1036
+ ############################
915
1037
 
916
1038
  def scan(
917
1039
  self,
@@ -921,7 +1043,7 @@ class Container:
921
1043
  tags: Iterable[str] | None = None,
922
1044
  ) -> None:
923
1045
  """Scan packages or modules for decorated members and inject dependencies."""
924
- dependencies: list[Dependency] = []
1046
+ dependencies: list[ScannedDependency] = []
925
1047
 
926
1048
  if isinstance(packages, Iterable) and not isinstance(packages, str):
927
1049
  scan_packages: Iterable[ModuleType | str] = packages
@@ -940,7 +1062,7 @@ class Container:
940
1062
  package: ModuleType | str,
941
1063
  *,
942
1064
  tags: Iterable[str] | None = None,
943
- ) -> list[Dependency]:
1065
+ ) -> list[ScannedDependency]:
944
1066
  """Scan a package or module for decorated members."""
945
1067
  tags = tags or []
946
1068
  if isinstance(package, str):
@@ -951,7 +1073,7 @@ class Container:
951
1073
  if not package_path:
952
1074
  return self._scan_module(package, tags=tags)
953
1075
 
954
- dependencies: list[Dependency] = []
1076
+ dependencies: list[ScannedDependency] = []
955
1077
 
956
1078
  for module_info in pkgutil.walk_packages(
957
1079
  path=package_path, prefix=package.__name__ + "."
@@ -963,9 +1085,9 @@ class Container:
963
1085
 
964
1086
  def _scan_module(
965
1087
  self, module: ModuleType, *, tags: Iterable[str]
966
- ) -> list[Dependency]:
1088
+ ) -> list[ScannedDependency]:
967
1089
  """Scan a module for decorated members."""
968
- dependencies: list[Dependency] = []
1090
+ dependencies: list[ScannedDependency] = []
969
1091
 
970
1092
  for _, member in inspect.getmembers(module):
971
1093
  if getattr(member, "__module__", None) != module.__name__ or not callable(
@@ -988,7 +1110,7 @@ class Container:
988
1110
 
989
1111
  if decorator_args.wrapped:
990
1112
  dependencies.append(
991
- self._create_dependency(member=member, module=module)
1113
+ self._create_scanned_dependency(member=member, module=module)
992
1114
  )
993
1115
  continue
994
1116
 
@@ -996,17 +1118,24 @@ class Container:
996
1118
  for parameter in get_typed_parameters(member):
997
1119
  if is_marker(parameter.default):
998
1120
  dependencies.append(
999
- self._create_dependency(member=member, module=module)
1121
+ self._create_scanned_dependency(member=member, module=module)
1000
1122
  )
1001
1123
  continue
1002
1124
 
1003
1125
  return dependencies
1004
1126
 
1005
- def _create_dependency(self, member: Any, module: ModuleType) -> Dependency:
1127
+ def _create_scanned_dependency(
1128
+ self, member: Any, module: ModuleType
1129
+ ) -> ScannedDependency:
1006
1130
  """Create a `Dependency` object from the scanned member and module."""
1007
1131
  if hasattr(member, "__wrapped__"):
1008
1132
  member = member.__wrapped__
1009
- return Dependency(member=member, module=module)
1133
+ return ScannedDependency(member=member, module=module)
1134
+
1135
+
1136
+ ############################
1137
+ # Decorators
1138
+ ############################
1010
1139
 
1011
1140
 
1012
1141
  def transient(target: T) -> T: