anydi 0.22.1__py3-none-any.whl → 0.37.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
anydi/_container.py CHANGED
@@ -3,104 +3,115 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import contextlib
6
+ import functools
7
+ import importlib
6
8
  import inspect
9
+ import logging
10
+ import pkgutil
11
+ import threading
7
12
  import types
8
- import uuid
13
+ from collections import defaultdict
14
+ from collections.abc import AsyncIterator, Iterable, Iterator, Sequence
9
15
  from contextvars import ContextVar
10
- from functools import wraps
11
- from typing import (
12
- Any,
13
- AsyncContextManager,
14
- AsyncIterator,
15
- Awaitable,
16
- Callable,
17
- ContextManager,
18
- Dict,
19
- Iterable,
20
- Iterator,
21
- List,
22
- Mapping,
23
- Optional,
24
- Sequence,
25
- Type,
26
- TypeVar,
27
- Union,
28
- cast,
29
- overload,
16
+ from types import ModuleType
17
+ from typing import Any, Callable, TypeVar, Union, cast, overload
18
+ from weakref import WeakKeyDictionary
19
+
20
+ from typing_extensions import Concatenate, ParamSpec, Self, final
21
+
22
+ from ._context import InstanceContext
23
+ from ._provider import Provider
24
+ from ._types import (
25
+ AnyInterface,
26
+ Dependency,
27
+ InjectableDecoratorArgs,
28
+ InstanceProxy,
29
+ ProviderDecoratorArgs,
30
+ Scope,
31
+ is_event_type,
32
+ is_marker,
30
33
  )
31
-
32
- from typing_extensions import Annotated, ParamSpec, final, get_args, get_origin
33
-
34
- try:
35
- from types import NoneType
36
- except ImportError:
37
- NoneType = type(None) # type: ignore[misc]
38
-
39
-
40
- from ._context import (
41
- RequestContext,
42
- ResourceScopedContext,
43
- ScopedContext,
44
- SingletonContext,
45
- TransientContext,
34
+ from ._utils import (
35
+ AsyncRLock,
36
+ get_full_qualname,
37
+ get_typed_parameters,
38
+ import_string,
39
+ is_async_context_manager,
40
+ is_builtin_type,
41
+ is_context_manager,
42
+ run_async,
46
43
  )
47
- from ._logger import logger
48
- from ._module import Module, ModuleRegistry
49
- from ._scanner import Scanner
50
- from ._types import AnyInterface, Interface, Marker, Provider, Scope
51
- from ._utils import get_full_qualname, get_signature, is_builtin_type
52
44
 
53
45
  T = TypeVar("T", bound=Any)
46
+ M = TypeVar("M", bound="Module")
54
47
  P = ParamSpec("P")
55
48
 
56
- ALLOWED_SCOPES: Dict[Scope, List[Scope]] = {
49
+ ALLOWED_SCOPES: dict[Scope, list[Scope]] = {
57
50
  "singleton": ["singleton"],
58
51
  "request": ["request", "singleton"],
59
- "transient": ["transient", "singleton", "request"],
52
+ "transient": ["transient", "request", "singleton"],
60
53
  }
61
54
 
62
55
 
56
+ class ModuleMeta(type):
57
+ """A metaclass used for the Module base class."""
58
+
59
+ def __new__(cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any]) -> Any:
60
+ attrs["providers"] = [
61
+ (name, getattr(value, "__provider__"))
62
+ for name, value in attrs.items()
63
+ if hasattr(value, "__provider__")
64
+ ]
65
+ return super().__new__(cls, name, bases, attrs)
66
+
67
+
68
+ class Module(metaclass=ModuleMeta):
69
+ """A base class for defining AnyDI modules."""
70
+
71
+ providers: list[tuple[str, ProviderDecoratorArgs]]
72
+
73
+ def configure(self, container: Container) -> None:
74
+ """Configure the AnyDI container with providers and their dependencies."""
75
+
76
+
77
+ # noinspection PyShadowingNames
63
78
  @final
64
79
  class Container:
65
- """AnyDI is a dependency injection container.
66
-
67
- Args:
68
- modules: Optional sequence of modules to register during initialization.
69
- """
80
+ """AnyDI is a dependency injection container."""
70
81
 
71
82
  def __init__(
72
83
  self,
73
84
  *,
74
- providers: Optional[Mapping[Type[Any], Provider]] = None,
75
- modules: Optional[
76
- Sequence[Union[Module, Type[Module], Callable[[Container], None]]]
77
- ] = None,
85
+ providers: Sequence[Provider] | None = None,
86
+ modules: Sequence[Module | type[Module] | Callable[[Container], None] | str]
87
+ | None = None,
78
88
  strict: bool = False,
89
+ default_scope: Scope = "transient",
90
+ testing: bool = False,
91
+ logger: logging.Logger | None = None,
79
92
  ) -> None:
80
- """Initialize the AnyDI instance.
81
-
82
- Args:
83
- providers: Optional mapping of providers to register during initialization.
84
- modules: Optional sequence of modules to register during initialization.
85
- strict: Whether to enable strict mode. Defaults to False.
86
- """
87
- self._providers: Dict[Type[Any], Provider] = {}
88
- self._singleton_context = SingletonContext(self)
89
- self._transient_context = TransientContext(self)
90
- self._request_context_var: ContextVar[Optional[RequestContext]] = ContextVar(
93
+ self._providers: dict[type[Any], Provider] = {}
94
+ self._strict = strict
95
+ self._default_scope = default_scope
96
+ self._testing = testing
97
+ self._logger = logger or logging.getLogger(__name__)
98
+ self._resources: dict[str, list[type[Any]]] = defaultdict(list)
99
+ self._singleton_context = InstanceContext()
100
+ self._singleton_lock = threading.RLock()
101
+ self._singleton_async_lock = AsyncRLock()
102
+ self._request_context_var: ContextVar[InstanceContext | None] = ContextVar(
91
103
  "request_context", default=None
92
104
  )
93
- self._override_instances: Dict[Type[Any], Any] = {}
94
- self._strict = strict
95
-
96
- # Components
97
- self._modules = ModuleRegistry(self)
98
- self._scanner = Scanner(self)
105
+ self._override_instances: dict[type[Any], Any] = {}
106
+ self._unresolved_interfaces: set[type[Any]] = set()
107
+ self._inject_cache: WeakKeyDictionary[
108
+ Callable[..., Any], Callable[..., Any]
109
+ ] = WeakKeyDictionary()
99
110
 
100
111
  # Register providers
101
- providers = providers or {}
102
- for interface, provider in providers.items():
103
- self.register(interface, provider.obj, scope=provider.scope)
112
+ providers = providers or []
113
+ for provider in providers:
114
+ self._register_provider(provider, False)
104
115
 
105
116
  # Register modules
106
117
  modules = modules or []
@@ -109,353 +120,314 @@ class Container:
109
120
 
110
121
  @property
111
122
  def strict(self) -> bool:
112
- """Check if strict mode is enabled.
113
-
114
- Returns:
115
- True if strict mode is enabled, False otherwise.
116
- """
123
+ """Check if strict mode is enabled."""
117
124
  return self._strict
118
125
 
119
126
  @property
120
- def providers(self) -> Dict[Type[Any], Provider]:
121
- """Get the registered providers.
127
+ def default_scope(self) -> Scope:
128
+ """Get the default scope."""
129
+ return self._default_scope
122
130
 
123
- Returns:
124
- A dictionary containing the registered providers.
125
- """
126
- return self._providers
131
+ @property
132
+ def testing(self) -> bool:
133
+ """Check if testing mode is enabled."""
134
+ return self._testing
127
135
 
128
- def is_registered(self, interface: AnyInterface) -> bool:
129
- """Check if a provider is registered for the specified interface.
136
+ @property
137
+ def providers(self) -> dict[type[Any], Provider]:
138
+ """Get the registered providers."""
139
+ return self._providers
130
140
 
131
- Args:
132
- interface: The interface to check for a registered provider.
141
+ @property
142
+ def logger(self) -> logging.Logger:
143
+ """Get the logger instance."""
144
+ return self._logger
133
145
 
134
- Returns:
135
- True if a provider is registered for the interface, False otherwise.
136
- """
146
+ def is_registered(self, interface: AnyInterface) -> bool:
147
+ """Check if a provider is registered for the specified interface."""
137
148
  return interface in self._providers
138
149
 
139
150
  def register(
140
151
  self,
141
152
  interface: AnyInterface,
142
- obj: Callable[..., Any],
153
+ call: Callable[..., Any],
143
154
  *,
144
155
  scope: Scope,
145
156
  override: bool = False,
146
157
  ) -> Provider:
147
- """Register a provider for the specified interface.
148
-
149
- Args:
150
- interface: The interface for which the provider is being registered.
151
- obj: The provider function or callable object.
152
- scope: The scope of the provider.
153
- override: If True, override an existing provider for the interface
154
- if one is already registered. Defaults to False.
155
-
156
- Returns:
157
- The registered provider.
158
-
159
- Raises:
160
- LookupError: If a provider for the interface is already registered
161
- and override is False.
162
-
163
- Notes:
164
- - If the provider is a resource or an asynchronous resource, and the
165
- interface is None, an Event type will be automatically created and used
166
- as the interface.
167
- - The provider will be validated for its scope, type, and matching scopes.
168
- """
169
- provider = Provider(obj=obj, scope=scope)
170
-
171
- # Create Event type
172
- if provider.is_resource and (interface is NoneType or interface is None):
173
- interface = type(f"Event{uuid.uuid4().hex}", (), {})
158
+ """Register a provider for the specified interface."""
159
+ provider = Provider(call=call, scope=scope, interface=interface)
160
+ return self._register_provider(provider, override)
174
161
 
175
- if interface in self._providers:
162
+ def _register_provider(
163
+ self, provider: Provider, override: bool, /, **defaults: Any
164
+ ) -> Provider:
165
+ """Register a provider."""
166
+ if provider.interface in self._providers:
176
167
  if override:
177
- self._providers[interface] = provider
168
+ self._set_provider(provider)
178
169
  return provider
179
170
 
180
171
  raise LookupError(
181
- f"The provider interface `{get_full_qualname(interface)}` "
172
+ f"The provider interface `{get_full_qualname(provider.interface)}` "
182
173
  "already registered."
183
174
  )
184
175
 
185
- # Validate provider
186
- self._validate_provider_scope(provider)
187
- self._validate_provider_type(provider)
188
- self._validate_provider_match_scopes(interface, provider)
189
-
190
- self._providers[interface] = provider
176
+ self._validate_sub_providers(provider, **defaults)
177
+ self._set_provider(provider)
191
178
  return provider
192
179
 
193
180
  def unregister(self, interface: AnyInterface) -> None:
194
- """Unregister a provider by interface.
195
-
196
- Args:
197
- interface: The interface of the provider to unregister.
198
-
199
- Raises:
200
- LookupError: If the provider interface is not registered.
201
-
202
- Notes:
203
- - The method cleans up any scoped context instance associated with
204
- the provider's scope.
205
- - The method removes the provider reference from the internal dictionary
206
- of registered providers.
207
- """
181
+ """Unregister a provider by interface."""
208
182
  if not self.is_registered(interface):
209
183
  raise LookupError(
210
184
  "The provider interface "
211
185
  f"`{get_full_qualname(interface)}` not registered."
212
186
  )
213
187
 
214
- provider = self._get_or_register_provider(interface)
188
+ provider = self._get_provider(interface)
215
189
 
216
- # Cleanup scoped context instance
217
- try:
218
- scoped_context = self._get_scoped_context(provider.scope)
219
- except LookupError:
220
- pass
221
- else:
222
- if isinstance(scoped_context, ResourceScopedContext):
223
- scoped_context.delete(interface)
190
+ # Cleanup instance context
191
+ if provider.scope != "transient":
192
+ try:
193
+ context = self._get_scoped_context(provider.scope)
194
+ except LookupError:
195
+ pass
196
+ else:
197
+ del context[interface]
224
198
 
225
199
  # Cleanup provider references
226
- self._providers.pop(interface, None)
227
-
228
- def _get_or_register_provider(self, interface: AnyInterface) -> Provider:
229
- """Get or register a provider by interface.
230
-
231
- Args:
232
- interface: The interface for which to retrieve the provider.
200
+ self._delete_provider(provider)
233
201
 
234
- Returns:
235
- Provider: The provider object associated with the interface.
236
-
237
- Raises:
238
- LookupError: If the provider interface has not been registered.
239
- """
202
+ def _get_provider(self, interface: AnyInterface) -> Provider:
203
+ """Get provider by interface."""
240
204
  try:
241
205
  return self._providers[interface]
242
206
  except KeyError as exc:
207
+ raise LookupError(
208
+ f"The provider interface for `{get_full_qualname(interface)}` has "
209
+ "not been registered. Please ensure that the provider interface is "
210
+ "properly registered before attempting to use it."
211
+ ) from exc
212
+
213
+ def _get_or_register_provider(
214
+ self, interface: AnyInterface, parent_scope: Scope | None, /, **defaults: Any
215
+ ) -> Provider:
216
+ """Get or register a provider by interface."""
217
+ try:
218
+ return self._get_provider(interface)
219
+ except LookupError:
243
220
  if (
244
221
  not self.strict
245
222
  and inspect.isclass(interface)
246
223
  and not is_builtin_type(interface)
224
+ and interface is not inspect.Parameter.empty
247
225
  ):
248
226
  # Try to get defined scope
249
- scope = getattr(interface, "__scope__", None)
227
+ scope = getattr(interface, "__scope__", parent_scope)
250
228
  # Try to detect scope
251
229
  if scope is None:
252
- scope = self._detect_scope(interface)
253
- if scope is None:
254
- raise TypeError(
255
- "Unable to automatically register the provider interface for "
256
- f"`{get_full_qualname(interface)}` because the scope detection "
257
- "failed. Please resolve this issue by using "
258
- "the appropriate scope decorator."
259
- ) from exc
260
- return self.register(interface, interface, scope=scope)
261
- raise LookupError(
262
- f"The provider interface for `{get_full_qualname(interface)}` has "
263
- "not been registered. Please ensure that the provider interface is "
264
- "properly registered before attempting to use it."
265
- ) from exc
266
-
267
- def _validate_provider_scope(self, provider: Provider) -> None:
268
- """Validate the scope of a provider.
269
-
270
- Args:
271
- provider: The provider to validate.
272
-
273
- Raises:
274
- ValueError: If the scope provided is invalid.
275
- """
276
- if provider.scope not in get_args(Scope):
277
- raise ValueError(
278
- "The scope provided is invalid. Only the following scopes are "
279
- f"supported: {', '.join(get_args(Scope))}. Please use one of the "
280
- "supported scopes when registering a provider."
281
- )
282
-
283
- def _validate_provider_type(self, provider: Provider) -> None:
284
- """Validate the type of provider.
285
-
286
- Args:
287
- provider: The provider to validate.
288
-
289
- Raises:
290
- TypeError: If the provider has an invalid type.
291
- """
292
- if provider.is_function or provider.is_class:
293
- return
294
-
230
+ scope = self._detect_scope(interface, **defaults)
231
+ scope = scope or self.default_scope
232
+ provider = Provider(call=interface, scope=scope, interface=interface)
233
+ return self._register_provider(provider, False, **defaults)
234
+ raise
235
+
236
+ def _set_provider(self, provider: Provider) -> None:
237
+ """Set a provider by interface."""
238
+ self._providers[provider.interface] = provider
295
239
  if provider.is_resource:
296
- if provider.scope == "transient":
297
- raise TypeError(
298
- f"The resource provider `{provider}` is attempting to register "
299
- "with a transient scope, which is not allowed. Please update the "
300
- "provider's scope to an appropriate value before registering it."
301
- )
302
- return
303
-
304
- raise TypeError(
305
- f"The provider `{provider.obj}` is invalid because it is not a callable "
306
- "object. Only callable providers are allowed. Please update the provider "
307
- "to a callable object before attempting to register it."
308
- )
309
-
310
- def _validate_provider_match_scopes(
311
- self, interface: AnyInterface, provider: Provider
312
- ) -> None:
313
- """Validate that the provider and its dependencies have matching scopes.
240
+ self._resources[provider.scope].append(provider.interface)
314
241
 
315
- Args:
316
- interface: The interface associated with the provider.
317
- provider: The provider to validate.
242
+ def _delete_provider(self, provider: Provider) -> None:
243
+ """Delete a provider."""
244
+ if provider.interface in self._providers:
245
+ del self._providers[provider.interface]
246
+ if provider.is_resource:
247
+ self._resources[provider.scope].remove(provider.interface)
318
248
 
319
- Raises:
320
- ValueError: If the provider and its dependencies have mismatched scopes.
321
- TypeError: If a dependency is missing an annotation.
322
- """
323
- related_providers = []
249
+ def _validate_sub_providers(self, provider: Provider, /, **defaults: Any) -> None:
250
+ """Validate the sub-providers of a provider."""
324
251
 
325
- for parameter in provider.parameters.values():
326
- if parameter.annotation is inspect._empty: # noqa
252
+ for parameter in provider.parameters:
253
+ if parameter.annotation is inspect.Parameter.empty:
327
254
  raise TypeError(
328
255
  f"Missing provider `{provider}` "
329
256
  f"dependency `{parameter.name}` annotation."
330
257
  )
258
+
331
259
  try:
332
- sub_provider = self._get_or_register_provider(parameter.annotation)
260
+ sub_provider = self._get_or_register_provider(
261
+ parameter.annotation, provider.scope
262
+ )
333
263
  except LookupError:
334
- raise LookupError(
335
- f"The provider `{get_full_qualname(provider.obj)}` depends on "
336
- f"`{parameter.name}` of type "
264
+ if self._parameter_has_default(parameter, **defaults):
265
+ continue
266
+
267
+ if provider.scope not in {"singleton", "transient"}:
268
+ self._unresolved_interfaces.add(provider.interface)
269
+ continue
270
+ raise ValueError(
271
+ f"The provider `{provider}` depends on `{parameter.name}` of type "
337
272
  f"`{get_full_qualname(parameter.annotation)}`, which "
338
- "has not been registered. To resolve this, ensure that "
273
+ "has not been registered or set. To resolve this, ensure that "
339
274
  f"`{parameter.name}` is registered before attempting to use it."
340
275
  ) from None
341
- related_providers.append(sub_provider)
342
276
 
343
- for related_provider in related_providers:
344
- left_scope, right_scope = related_provider.scope, provider.scope
345
- allowed_scopes = ALLOWED_SCOPES.get(right_scope) or []
346
- if left_scope not in allowed_scopes:
277
+ # Check scope compatibility
278
+ if sub_provider.scope not in ALLOWED_SCOPES.get(provider.scope, []):
347
279
  raise ValueError(
348
- f"The provider `{provider}` with a {provider.scope} scope was "
349
- "attempted to be registered with the provider "
350
- f"`{related_provider}` with a `{related_provider.scope}` scope, "
351
- "which is not allowed. Please ensure that all providers are "
352
- "registered with matching scopes."
280
+ f"The provider `{provider}` with a `{provider.scope}` scope cannot "
281
+ f"depend on `{sub_provider}` with a `{sub_provider.scope}` scope. "
282
+ "Please ensure all providers are registered with matching scopes."
353
283
  )
354
284
 
355
- def _detect_scope(self, obj: Callable[..., Any]) -> Optional[Scope]:
356
- """Detect the scope for a provider.
285
+ def _detect_scope(self, call: Callable[..., Any], **defaults: Any) -> Scope | None:
286
+ """Detect the scope for a callable."""
287
+ scopes = set()
357
288
 
358
- Args:
359
- obj: The provider to detect the auto scope for.
360
- Returns:
361
- The auto scope, or None if the auto scope cannot be detected.
362
- """
363
- has_transient, has_request, has_singleton = False, False, False
364
- for parameter in get_signature(obj).parameters.values():
365
- sub_provider = self._get_or_register_provider(parameter.annotation)
366
- if not has_transient and sub_provider.scope == "transient":
367
- has_transient = True
368
- if not has_request and sub_provider.scope == "request":
369
- has_request = True
370
- if not has_singleton and sub_provider.scope == "singleton":
371
- has_singleton = True
372
- if has_transient:
373
- return "transient"
374
- if has_request:
289
+ for parameter in get_typed_parameters(call):
290
+ try:
291
+ sub_provider = self._get_or_register_provider(
292
+ parameter.annotation, None
293
+ )
294
+ except LookupError:
295
+ if self._parameter_has_default(parameter, **defaults):
296
+ continue
297
+ raise
298
+ scope = sub_provider.scope
299
+
300
+ if scope == "transient":
301
+ return "transient"
302
+ scopes.add(scope)
303
+
304
+ # If all scopes are found, we can return based on priority order
305
+ if {"transient", "request", "singleton"}.issubset(scopes):
306
+ break # pragma: no cover
307
+
308
+ # Determine scope based on priority
309
+ if "request" in scopes:
375
310
  return "request"
376
- if has_singleton:
311
+ if "singleton" in scopes:
377
312
  return "singleton"
313
+
378
314
  return None
379
315
 
316
+ def _parameter_has_default(
317
+ self, parameter: inspect.Parameter, /, **defaults: Any
318
+ ) -> bool:
319
+ return (defaults and parameter.name in defaults) or (
320
+ not self.strict and parameter.default is not inspect.Parameter.empty
321
+ )
322
+
380
323
  def register_module(
381
- self, module: Union[Module, Type[Module], Callable[[Container], None]]
324
+ self, module: Module | type[Module] | Callable[[Container], None] | str
382
325
  ) -> None:
383
- """Register a module as a callable, module type, or module instance.
326
+ """Register a module as a callable, module type, or module instance."""
327
+ # Callable Module
328
+ if inspect.isfunction(module):
329
+ module(self)
330
+ return
384
331
 
385
- Args:
386
- module: The module to register.
387
- """
388
- self._modules.register(module)
332
+ # Module path
333
+ if isinstance(module, str):
334
+ module = import_string(module)
335
+
336
+ # Class based Module or Module type
337
+ if inspect.isclass(module) and issubclass(module, Module):
338
+ module = module()
339
+
340
+ if isinstance(module, Module):
341
+ module.configure(self)
342
+ for provider_name, decorator_args in module.providers:
343
+ obj = getattr(module, provider_name)
344
+ self.provider(
345
+ scope=decorator_args.scope,
346
+ override=decorator_args.override,
347
+ )(obj)
348
+
349
+ def __enter__(self) -> Self:
350
+ """Enter the singleton context."""
351
+ self.start()
352
+ return self
353
+
354
+ def __exit__(
355
+ self,
356
+ exc_type: type[BaseException] | None,
357
+ exc_val: BaseException | None,
358
+ exc_tb: types.TracebackType | None,
359
+ ) -> Any:
360
+ """Exit the singleton context."""
361
+ return self._singleton_context.__exit__(exc_type, exc_val, exc_tb)
389
362
 
390
363
  def start(self) -> None:
391
364
  """Start the singleton context."""
392
- for interface, provider in self._providers.items():
393
- if provider.scope == "singleton":
394
- self.resolve(interface) # noqa
365
+ # Resolve all singleton resources
366
+ for interface in self._resources.get("singleton", []):
367
+ self.resolve(interface)
395
368
 
396
369
  def close(self) -> None:
397
370
  """Close the singleton context."""
398
371
  self._singleton_context.close()
399
372
 
400
- def request_context(self) -> ContextManager[None]:
401
- """Obtain a context manager for the request-scoped context.
373
+ @contextlib.contextmanager
374
+ def request_context(self) -> Iterator[InstanceContext]:
375
+ """Obtain a context manager for the request-scoped context."""
376
+ context = InstanceContext()
402
377
 
403
- Returns:
404
- A context manager for the request-scoped context.
405
- """
406
- return contextlib.contextmanager(self._request_context)()
378
+ token = self._request_context_var.set(context)
407
379
 
408
- def _request_context(self) -> Iterator[None]:
409
- """Internal method that manages the request-scoped context.
380
+ # Resolve all request resources
381
+ for interface in self._resources.get("request", []):
382
+ if not is_event_type(interface):
383
+ continue
384
+ self.resolve(interface)
410
385
 
411
- Yields:
412
- Yield control to the code block within the request context.
413
- """
414
- context = RequestContext(self)
415
- token = self._request_context_var.set(context)
416
386
  with context:
417
- yield
387
+ yield context
418
388
  self._request_context_var.reset(token)
419
389
 
390
+ async def __aenter__(self) -> Self:
391
+ """Enter the singleton context."""
392
+ await self.astart()
393
+ return self
394
+
395
+ async def __aexit__(
396
+ self,
397
+ exc_type: type[BaseException] | None,
398
+ exc_val: BaseException | None,
399
+ exc_tb: types.TracebackType | None,
400
+ ) -> bool:
401
+ """Exit the singleton context."""
402
+ return await self._singleton_context.__aexit__(exc_type, exc_val, exc_tb)
403
+
420
404
  async def astart(self) -> None:
421
405
  """Start the singleton context asynchronously."""
422
- for interface, provider in self._providers.items():
423
- if provider.scope == "singleton":
424
- await self.aresolve(interface) # noqa
406
+ for interface in self._resources.get("singleton", []):
407
+ await self.aresolve(interface)
425
408
 
426
409
  async def aclose(self) -> None:
427
410
  """Close the singleton context asynchronously."""
428
411
  await self._singleton_context.aclose()
429
412
 
430
- def arequest_context(self) -> AsyncContextManager[None]:
431
- """Obtain an async context manager for the request-scoped context.
413
+ @contextlib.asynccontextmanager
414
+ async def arequest_context(self) -> AsyncIterator[InstanceContext]:
415
+ """Obtain an async context manager for the request-scoped context."""
416
+ context = InstanceContext()
432
417
 
433
- Returns:
434
- An async context manager for the request-scoped context.
435
- """
436
- return contextlib.asynccontextmanager(self._arequest_context)()
418
+ token = self._request_context_var.set(context)
437
419
 
438
- async def _arequest_context(self) -> AsyncIterator[None]:
439
- """Internal method that manages the async request-scoped context.
420
+ for interface in self._resources.get("request", []):
421
+ if not is_event_type(interface):
422
+ continue
423
+ await self.aresolve(interface)
440
424
 
441
- Yields:
442
- Yield control to the code block within the request context.
443
- """
444
- context = RequestContext(self)
445
- token = self._request_context_var.set(context)
446
425
  async with context:
447
- yield
426
+ yield context
448
427
  self._request_context_var.reset(token)
449
428
 
450
- def _get_request_context(self) -> RequestContext:
451
- """Get the current request context.
452
-
453
- Returns:
454
- RequestContext: The current request context.
455
-
456
- Raises:
457
- LookupError: If the request context has not been started.
458
- """
429
+ def _get_request_context(self) -> InstanceContext:
430
+ """Get the current request context."""
459
431
  request_context = self._request_context_var.get()
460
432
  if request_context is None:
461
433
  raise LookupError(
@@ -468,123 +440,365 @@ class Container:
468
440
  def reset(self) -> None:
469
441
  """Reset resolved instances."""
470
442
  for interface, provider in self._providers.items():
471
- scoped_context = self._get_scoped_context(provider.scope)
472
- if isinstance(scoped_context, ResourceScopedContext):
473
- scoped_context.delete(interface)
443
+ if provider.scope == "transient":
444
+ continue
445
+ try:
446
+ context = self._get_scoped_context(provider.scope)
447
+ except LookupError:
448
+ continue
449
+ del context[interface]
474
450
 
475
451
  @overload
476
- def resolve(self, interface: Interface[T]) -> T: ...
452
+ def resolve(self, interface: type[T]) -> T: ...
477
453
 
478
454
  @overload
479
455
  def resolve(self, interface: T) -> T: ...
480
456
 
481
- def resolve(self, interface: Interface[T]) -> T:
482
- """Resolve an instance by interface.
457
+ def resolve(self, interface: type[T]) -> T:
458
+ """Resolve an instance by interface."""
459
+ provider = self._get_or_register_provider(interface, None)
460
+ if provider.scope == "transient":
461
+ instance = self._create_instance(provider, None)
462
+ else:
463
+ context = self._get_scoped_context(provider.scope)
464
+ if provider.scope == "singleton":
465
+ with self._singleton_lock:
466
+ instance = self._get_or_create_instance(provider, context)
467
+ else:
468
+ instance = self._get_or_create_instance(provider, context)
469
+ if self.testing:
470
+ instance = self._patch_test_resolver(provider.interface, instance)
471
+ return cast(T, instance)
483
472
 
484
- Args:
485
- interface: The interface type.
473
+ @overload
474
+ async def aresolve(self, interface: type[T]) -> T: ...
486
475
 
487
- Returns:
488
- The instance of the interface.
476
+ @overload
477
+ async def aresolve(self, interface: T) -> T: ...
489
478
 
490
- Raises:
491
- LookupError: If the provider for the interface is not registered.
492
- """
493
- if interface in self._override_instances:
494
- return cast(T, self._override_instances[interface])
479
+ async def aresolve(self, interface: type[T]) -> T:
480
+ """Resolve an instance by interface asynchronously."""
481
+ provider = self._get_or_register_provider(interface, None)
482
+ if provider.scope == "transient":
483
+ instance = await self._acreate_instance(provider, None)
484
+ else:
485
+ context = self._get_scoped_context(provider.scope)
486
+ if provider.scope == "singleton":
487
+ async with self._singleton_async_lock:
488
+ instance = await self._aget_or_create_instance(provider, context)
489
+ else:
490
+ instance = await self._aget_or_create_instance(provider, context)
491
+ if self.testing:
492
+ instance = self._patch_test_resolver(interface, instance)
493
+ return cast(T, instance)
494
+
495
+ def create(self, interface: type[T], **defaults: Any) -> T:
496
+ """Create an instance by interface."""
497
+ provider = self._get_or_register_provider(interface, None, **defaults)
498
+ if provider.scope == "transient":
499
+ instance = self._create_instance(provider, None, **defaults)
500
+ else:
501
+ context = self._get_scoped_context(provider.scope)
502
+ if provider.scope == "singleton":
503
+ with self._singleton_lock:
504
+ instance = self._create_instance(provider, context, **defaults)
505
+ else:
506
+ instance = self._create_instance(provider, context, **defaults)
507
+ return cast(T, instance)
508
+
509
+ async def acreate(self, interface: type[T], **defaults: Any) -> T:
510
+ """Create an instance by interface."""
511
+ provider = self._get_or_register_provider(interface, None, **defaults)
512
+ if provider.scope == "transient":
513
+ instance = await self._acreate_instance(provider, None, **defaults)
514
+ else:
515
+ context = self._get_scoped_context(provider.scope)
516
+ if provider.scope == "singleton":
517
+ async with self._singleton_async_lock:
518
+ instance = await self._acreate_instance(
519
+ provider, context, **defaults
520
+ )
521
+ else:
522
+ instance = await self._acreate_instance(provider, context, **defaults)
523
+ return cast(T, instance)
524
+
525
+ def _get_or_create_instance(
526
+ self, provider: Provider, context: InstanceContext
527
+ ) -> Any:
528
+ """Get an instance of a dependency from the scoped context."""
529
+ instance = context.get(provider.interface)
530
+ if instance is None:
531
+ instance = self._create_instance(provider, context)
532
+ context.set(provider.interface, instance)
533
+ return instance
534
+ return instance
535
+
536
+ async def _aget_or_create_instance(
537
+ self, provider: Provider, context: InstanceContext
538
+ ) -> Any:
539
+ """Get an async instance of a dependency from the scoped context."""
540
+ instance = context.get(provider.interface)
541
+ if instance is None:
542
+ instance = await self._acreate_instance(provider, context)
543
+ context.set(provider.interface, instance)
544
+ return instance
545
+ return instance
546
+
547
+ def _create_instance(
548
+ self, provider: Provider, context: InstanceContext | None, /, **defaults: Any
549
+ ) -> Any:
550
+ """Create an instance using the provider."""
551
+ if provider.is_async:
552
+ raise TypeError(
553
+ f"The instance for the provider `{provider}` cannot be created in "
554
+ "synchronous mode."
555
+ )
495
556
 
496
- provider = self._get_or_register_provider(interface)
497
- scoped_context = self._get_scoped_context(provider.scope)
498
- return scoped_context.get(interface, provider)
557
+ provider_kwargs = self._get_provided_kwargs(provider, context, **defaults)
558
+
559
+ if provider.is_generator:
560
+ if context is None:
561
+ raise ValueError("The context is required for generator providers.")
562
+ cm = contextlib.contextmanager(provider.call)(**provider_kwargs)
563
+ return context.enter(cm)
564
+
565
+ instance = provider.call(**provider_kwargs)
566
+ if context is not None and provider.is_class and is_context_manager(instance):
567
+ context.enter(instance)
568
+ return instance
569
+
570
+ async def _acreate_instance(
571
+ self, provider: Provider, context: InstanceContext | None, /, **defaults: Any
572
+ ) -> Any:
573
+ """Create an instance asynchronously using the provider."""
574
+ provider_kwargs = await self._aget_provided_kwargs(
575
+ provider, context, **defaults
576
+ )
499
577
 
500
- @overload
501
- async def aresolve(self, interface: Interface[T]) -> T: ...
578
+ if provider.is_coroutine:
579
+ return await provider.call(**provider_kwargs)
502
580
 
503
- @overload
504
- async def aresolve(self, interface: T) -> T: ...
581
+ if provider.is_async_generator:
582
+ if context is None:
583
+ raise ValueError(
584
+ "The async stack is required for async generator providers."
585
+ )
586
+ cm = contextlib.asynccontextmanager(provider.call)(**provider_kwargs)
587
+ return await context.aenter(cm)
588
+
589
+ if provider.is_generator:
590
+
591
+ def _create() -> Any:
592
+ if context is None:
593
+ raise ValueError("The stack is required for generator providers.")
594
+ cm = contextlib.contextmanager(provider.call)(**provider_kwargs)
595
+ return context.enter(cm)
596
+
597
+ return await run_async(_create)
598
+
599
+ instance = await run_async(provider.call, **provider_kwargs)
600
+ if (
601
+ context is not None
602
+ and provider.is_class
603
+ and is_async_context_manager(instance)
604
+ ):
605
+ await context.aenter(instance)
606
+ return instance
607
+
608
+ def _get_provided_kwargs(
609
+ self, provider: Provider, context: InstanceContext | None, /, **defaults: Any
610
+ ) -> dict[str, Any]:
611
+ """Retrieve the arguments for a provider."""
612
+ provided_kwargs = {}
613
+ for parameter in provider.parameters:
614
+ instance = self._get_provider_instance(
615
+ provider, parameter, context, **defaults
616
+ )
617
+ provided_kwargs[parameter.name] = instance
618
+ return {**defaults, **provided_kwargs}
505
619
 
506
- async def aresolve(self, interface: Interface[T]) -> T:
507
- """Resolve an instance by interface asynchronously.
620
+ def _get_provider_instance(
621
+ self,
622
+ provider: Provider,
623
+ parameter: inspect.Parameter,
624
+ context: InstanceContext | None,
625
+ /,
626
+ **defaults: Any,
627
+ ) -> Any:
628
+ """Retrieve an instance of a dependency from the scoped context."""
508
629
 
509
- Args:
510
- interface: The interface type.
630
+ # Try to get instance from defaults
631
+ if parameter.name in defaults:
632
+ return defaults[parameter.name]
511
633
 
512
- Returns:
513
- The instance of the interface.
634
+ # Try to get instance from context
635
+ elif context and parameter.annotation in context:
636
+ instance = context[parameter.annotation]
514
637
 
515
- Raises:
516
- LookupError: If the provider for the interface is not registered.
517
- """
518
- if interface in self._override_instances:
519
- return cast(T, self._override_instances[interface])
638
+ # Resolve new instance
639
+ else:
640
+ try:
641
+ instance = self._resolve_parameter(provider, parameter)
642
+ except LookupError:
643
+ if parameter.default is inspect.Parameter.empty:
644
+ raise
645
+ return parameter.default
646
+
647
+ # Wrap the instance in a proxy for testing
648
+ if self.testing:
649
+ return InstanceProxy(instance, interface=parameter.annotation)
650
+ return instance
651
+
652
+ async def _aget_provided_kwargs(
653
+ self, provider: Provider, context: InstanceContext | None, /, **defaults: Any
654
+ ) -> dict[str, Any]:
655
+ """Asynchronously retrieve the arguments for a provider."""
656
+ provided_kwargs = {}
657
+ for parameter in provider.parameters:
658
+ instance = await self._aget_provider_instance(
659
+ provider, parameter, context, **defaults
660
+ )
661
+ provided_kwargs[parameter.name] = instance
662
+ return {**defaults, **provided_kwargs}
520
663
 
521
- provider = self._get_or_register_provider(interface)
522
- scoped_context = self._get_scoped_context(provider.scope)
523
- return await scoped_context.aget(interface, provider)
664
+ async def _aget_provider_instance(
665
+ self,
666
+ provider: Provider,
667
+ parameter: inspect.Parameter,
668
+ context: InstanceContext | None,
669
+ /,
670
+ **defaults: Any,
671
+ ) -> Any:
672
+ """Asynchronously retrieve an instance of a dependency from the context."""
524
673
 
525
- def is_resolved(self, interface: AnyInterface) -> bool:
526
- """Check if an instance by interface exists.
674
+ # Try to get instance from defaults
675
+ if parameter.name in defaults:
676
+ return defaults[parameter.name]
527
677
 
528
- Args:
529
- interface: The interface type.
678
+ # Try to get instance from context
679
+ elif context and parameter.annotation in context:
680
+ instance = context[parameter.annotation]
530
681
 
531
- Returns:
532
- True if the instance exists, otherwise False.
533
- """
534
- try:
535
- provider = self._get_or_register_provider(interface)
536
- except LookupError:
537
- pass
682
+ # Resolve new instance
538
683
  else:
539
- scoped_context = self._get_scoped_context(provider.scope)
540
- if isinstance(scoped_context, ResourceScopedContext):
541
- return scoped_context.has(interface)
542
- return False
684
+ try:
685
+ instance = await self._aresolve_parameter(provider, parameter)
686
+ except LookupError:
687
+ if parameter.default is inspect.Parameter.empty:
688
+ raise
689
+ return parameter.default
690
+
691
+ # Wrap the instance in a proxy for testing
692
+ if self.testing:
693
+ return InstanceProxy(instance, interface=parameter.annotation)
694
+ return instance
695
+
696
+ def _resolve_parameter(
697
+ self, provider: Provider, parameter: inspect.Parameter
698
+ ) -> Any:
699
+ self._validate_resolvable_parameter(parameter, call=provider.call)
700
+ return self.resolve(parameter.annotation)
701
+
702
+ async def _aresolve_parameter(
703
+ self, provider: Provider, parameter: inspect.Parameter
704
+ ) -> Any:
705
+ self._validate_resolvable_parameter(parameter, call=provider.call)
706
+ return await self.aresolve(parameter.annotation)
707
+
708
+ def _validate_resolvable_parameter(
709
+ self, parameter: inspect.Parameter, call: Callable[..., Any]
710
+ ) -> None:
711
+ """Ensure that the specified interface is resolved."""
712
+ if parameter.annotation in self._unresolved_interfaces:
713
+ raise LookupError(
714
+ f"You are attempting to get the parameter `{parameter.name}` with the "
715
+ f"annotation `{get_full_qualname(parameter.annotation)}` as a "
716
+ f"dependency into `{get_full_qualname(call)}` which is not registered "
717
+ "or set in the scoped context."
718
+ )
543
719
 
544
- def release(self, interface: AnyInterface) -> None:
545
- """Release an instance by interface.
720
+ def _patch_test_resolver(self, interface: type[Any], instance: Any) -> Any:
721
+ """Patch the test resolver for the instance."""
722
+ if interface in self._override_instances:
723
+ return self._override_instances[interface]
546
724
 
547
- Args:
548
- interface: The interface type.
725
+ if not hasattr(instance, "__dict__") or hasattr(
726
+ instance, "__resolver_getter__"
727
+ ):
728
+ return instance
549
729
 
550
- Raises:
551
- LookupError: If the provider for the interface is not registered.
552
- """
553
- provider = self._get_or_register_provider(interface)
554
- scoped_context = self._get_scoped_context(provider.scope)
555
- if isinstance(scoped_context, ResourceScopedContext):
556
- scoped_context.delete(interface)
730
+ wrapped = {
731
+ name: value.interface
732
+ for name, value in instance.__dict__.items()
733
+ if isinstance(value, InstanceProxy)
734
+ }
557
735
 
558
- def _get_scoped_context(self, scope: Scope) -> ScopedContext:
559
- """Get the scoped context based on the specified scope.
736
+ def __resolver_getter__(name: str) -> Any:
737
+ if name in wrapped:
738
+ _interface = wrapped[name]
739
+ # Resolve the dependency if it's wrapped
740
+ return self.resolve(_interface)
741
+ raise LookupError
560
742
 
561
- Args:
562
- scope: The scope of the provider.
743
+ # Attach the resolver getter to the instance
744
+ instance.__resolver_getter__ = __resolver_getter__
563
745
 
564
- Returns:
565
- The scoped context, or None if the scope is not applicable.
566
- """
746
+ if not hasattr(instance.__class__, "__getattribute_patched__"):
747
+
748
+ def __getattribute__(_self: Any, name: str) -> Any:
749
+ # Skip the resolver getter
750
+ if name in {"__resolver_getter__"}:
751
+ return object.__getattribute__(_self, name)
752
+
753
+ if hasattr(_self, "__resolver_getter__"):
754
+ try:
755
+ return _self.__resolver_getter__(name)
756
+ except LookupError:
757
+ pass
758
+
759
+ # Fall back to default behavior
760
+ return object.__getattribute__(_self, name)
761
+
762
+ # Apply the patched resolver if wrapped attributes exist
763
+ instance.__class__.__getattribute__ = __getattribute__
764
+ instance.__class__.__getattribute_patched__ = True
765
+
766
+ return instance
767
+
768
+ def is_resolved(self, interface: AnyInterface) -> bool:
769
+ """Check if an instance by interface exists."""
770
+ try:
771
+ provider = self._get_provider(interface)
772
+ except LookupError:
773
+ return False
774
+ if provider.scope == "transient":
775
+ return False
776
+ context = self._get_scoped_context(provider.scope)
777
+ return interface in context
778
+
779
+ def release(self, interface: AnyInterface) -> None:
780
+ """Release an instance by interface."""
781
+ provider = self._get_provider(interface)
782
+ if provider.scope == "transient":
783
+ return None
784
+ context = self._get_scoped_context(provider.scope)
785
+ del context[interface]
786
+
787
+ def _get_scoped_context(self, scope: Scope) -> InstanceContext:
788
+ """Get the instance context for the specified scope."""
567
789
  if scope == "singleton":
568
790
  return self._singleton_context
569
- elif scope == "request":
570
- request_context = self._get_request_context()
571
- return request_context
572
- return self._transient_context
791
+ return self._get_request_context()
573
792
 
574
793
  @contextlib.contextmanager
575
794
  def override(self, interface: AnyInterface, instance: Any) -> Iterator[None]:
576
- """Override the provider for the specified interface with a specific instance.
577
-
578
- Args:
579
- interface: The interface type to override.
580
- instance: The instance to use as the override.
581
-
582
- Yields:
583
- None
584
-
585
- Raises:
586
- LookupError: If the provider for the interface is not registered.
587
795
  """
796
+ Override the provider for the specified interface with a specific instance.
797
+ """
798
+ if not self.testing:
799
+ raise RuntimeError(
800
+ "The `override` method can only be used in testing mode."
801
+ )
588
802
  if not self.is_registered(interface) and self.strict:
589
803
  raise LookupError(
590
804
  f"The provider interface `{get_full_qualname(interface)}` "
@@ -597,166 +811,74 @@ class Container:
597
811
  def provider(
598
812
  self, *, scope: Scope, override: bool = False
599
813
  ) -> Callable[[Callable[P, T]], Callable[P, T]]:
600
- """Decorator to register a provider function with the specified scope.
601
-
602
- Args:
603
- scope : The scope of the provider.
604
- override: Whether the provider should override an existing provider
605
- for the same interface. Defaults to False.
606
-
607
- Returns:
608
- The decorator function.
609
- """
814
+ """Decorator to register a provider function with the specified scope."""
610
815
 
611
- def decorator(func: Callable[P, T]) -> Callable[P, T]:
612
- interface = self._get_provider_annotation(func)
613
- self.register(interface, func, scope=scope, override=override)
614
- return func
816
+ def decorator(call: Callable[P, T]) -> Callable[P, T]:
817
+ provider = Provider(call=call, scope=scope)
818
+ self._register_provider(provider, override)
819
+ return call
615
820
 
616
821
  return decorator
617
822
 
618
823
  @overload
619
- def inject(self, obj: Callable[P, T]) -> Callable[P, T]: ...
824
+ def inject(self, func: Callable[P, T]) -> Callable[P, T]: ...
620
825
 
621
826
  @overload
622
827
  def inject(self) -> Callable[[Callable[P, T]], Callable[P, T]]: ...
623
828
 
624
829
  def inject(
625
- self, obj: Union[Callable[P, Union[T, Awaitable[T]]], None] = None
626
- ) -> Union[
627
- Callable[
628
- [Callable[P, Union[T, Awaitable[T]]]],
629
- Callable[P, Union[T, Awaitable[T]]],
630
- ],
631
- Callable[P, Union[T, Awaitable[T]]],
632
- ]:
633
- """Decorator to inject dependencies into a callable.
634
-
635
- Args:
636
- obj: The callable object to be decorated. If None, returns
637
- the decorator itself.
638
-
639
- Returns:
640
- The decorated callable object or decorator function.
641
- """
642
-
643
- def decorator(
644
- obj: Callable[P, Union[T, Awaitable[T]]],
645
- ) -> Callable[P, Union[T, Awaitable[T]]]:
646
- injected_params = self._get_injected_params(obj)
647
-
648
- if inspect.iscoroutinefunction(obj):
830
+ self, func: Callable[P, T] | None = None
831
+ ) -> Callable[[Callable[P, T]], Callable[P, T]] | Callable[P, T]:
832
+ """Decorator to inject dependencies into a callable."""
649
833
 
650
- @wraps(obj)
651
- async def awrapped(*args: P.args, **kwargs: P.kwargs) -> T:
652
- for name, annotation in injected_params.items():
653
- kwargs[name] = await self.aresolve(annotation)
654
- return cast(T, await obj(*args, **kwargs))
834
+ def decorator(call: Callable[P, T]) -> Callable[P, T]:
835
+ return self._inject(call)
655
836
 
656
- return awrapped
657
-
658
- @wraps(obj)
659
- def wrapped(*args: P.args, **kwargs: P.kwargs) -> T:
660
- for name, annotation in injected_params.items():
661
- kwargs[name] = self.resolve(annotation)
662
- return cast(T, obj(*args, **kwargs))
663
-
664
- return wrapped
665
-
666
- if obj is None:
837
+ if func is None:
667
838
  return decorator
668
- return decorator(obj)
839
+ return decorator(func)
669
840
 
670
- def run(self, obj: Callable[P, T], /, *args: P.args, **kwargs: P.kwargs) -> T:
671
- """Run the given function with injected dependencies.
841
+ def _inject(self, call: Callable[P, T]) -> Callable[P, T]:
842
+ """Inject dependencies into a callable."""
843
+ if call in self._inject_cache:
844
+ return cast(Callable[P, T], self._inject_cache[call])
672
845
 
673
- Args:
674
- obj: The callable object.
675
- args: The positional arguments to pass to the object.
676
- kwargs: The keyword arguments to pass to the object.
677
-
678
- Returns:
679
- The result of the callable object.
680
- """
681
- return self.inject(obj)(*args, **kwargs)
682
-
683
- def scan(
684
- self,
685
- /,
686
- packages: Union[
687
- Union[types.ModuleType, str],
688
- Iterable[Union[types.ModuleType, str]],
689
- ],
690
- *,
691
- tags: Optional[Iterable[str]] = None,
692
- ) -> None:
693
- """Scan packages or modules for decorated members and inject dependencies.
694
-
695
- Args:
696
- packages: A single package or module to scan,
697
- or an iterable of packages or modules to scan.
698
- tags: Optional list of tags to filter the scanned members. Only members
699
- with at least one matching tag will be scanned. Defaults to None.
700
- """
701
- self._scanner.scan(packages, tags=tags)
846
+ injected_params = self._get_injected_params(call)
702
847
 
703
- def _get_provider_annotation(self, obj: Callable[..., Any]) -> Any:
704
- """Retrieve the provider return annotation from a callable object.
848
+ if inspect.iscoroutinefunction(call):
705
849
 
706
- Args:
707
- obj: The callable object (provider).
708
-
709
- Returns:
710
- The provider return annotation.
711
-
712
- Raises:
713
- TypeError: If the provider return annotation is missing or invalid.
714
- """
715
- annotation = get_signature(obj).return_annotation
716
-
717
- if annotation is inspect._empty: # noqa
718
- raise TypeError(
719
- f"Missing `{get_full_qualname(obj)}` provider return annotation."
720
- )
850
+ @functools.wraps(call)
851
+ async def awrapper(*args: P.args, **kwargs: P.kwargs) -> T:
852
+ for name, annotation in injected_params.items():
853
+ kwargs[name] = await self.aresolve(annotation)
854
+ return cast(T, await call(*args, **kwargs))
721
855
 
722
- origin = get_origin(annotation) or annotation
723
- args = get_args(annotation)
856
+ self._inject_cache[call] = awrapper
724
857
 
725
- # Supported generic types
726
- if origin in (list, dict, tuple, Annotated):
727
- if args:
728
- return annotation
729
- else:
730
- raise TypeError(
731
- f"Cannot use `{get_full_qualname(obj)}` generic type annotation "
732
- "without actual type."
733
- )
858
+ return awrapper # type: ignore[return-value]
734
859
 
735
- try:
736
- return args[0]
737
- except IndexError:
738
- return annotation
860
+ @functools.wraps(call)
861
+ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
862
+ for name, annotation in injected_params.items():
863
+ kwargs[name] = self.resolve(annotation)
864
+ return call(*args, **kwargs)
739
865
 
740
- def _get_injected_params(self, obj: Callable[..., Any]) -> Dict[str, Any]:
741
- """Get the injected parameters of a callable object.
866
+ self._inject_cache[call] = wrapper
742
867
 
743
- Args:
744
- obj: The callable object.
868
+ return wrapper
745
869
 
746
- Returns:
747
- A dictionary containing the names and annotations
748
- of the injected parameters.
749
- """
870
+ def _get_injected_params(self, call: Callable[..., Any]) -> dict[str, Any]:
871
+ """Get the injected parameters of a callable object."""
750
872
  injected_params = {}
751
- for parameter in get_signature(obj).parameters.values():
752
- if not isinstance(parameter.default, Marker):
873
+ for parameter in get_typed_parameters(call):
874
+ if not is_marker(parameter.default):
753
875
  continue
754
876
  try:
755
- self._validate_injected_parameter(obj, parameter)
877
+ self._validate_injected_parameter(call, parameter)
756
878
  except LookupError as exc:
757
879
  if not self.strict:
758
- logger.debug(
759
- f"Cannot validate the `{get_full_qualname(obj)}` parameter "
880
+ self.logger.debug(
881
+ f"Cannot validate the `{get_full_qualname(call)}` parameter "
760
882
  f"`{parameter.name}` with an annotation of "
761
883
  f"`{get_full_qualname(parameter.annotation)} due to being "
762
884
  "in non-strict mode. It will be validated at the first call."
@@ -767,65 +889,183 @@ class Container:
767
889
  return injected_params
768
890
 
769
891
  def _validate_injected_parameter(
770
- self, obj: Callable[..., Any], parameter: inspect.Parameter
892
+ self, call: Callable[..., Any], parameter: inspect.Parameter
771
893
  ) -> None:
772
- """Validate an injected parameter.
773
-
774
- Args:
775
- obj: The callable object.
776
- parameter: The parameter to validate.
777
-
778
- Raises:
779
- TypeError: If the parameter annotation is missing or an unknown dependency.
780
- """
781
- if parameter.annotation is inspect._empty: # noqa
894
+ """Validate an injected parameter."""
895
+ if parameter.annotation is inspect.Parameter.empty:
782
896
  raise TypeError(
783
- f"Missing `{get_full_qualname(obj)}` parameter "
897
+ f"Missing `{get_full_qualname(call)}` parameter "
784
898
  f"`{parameter.name}` annotation."
785
899
  )
786
900
 
787
901
  if not self.is_registered(parameter.annotation):
788
902
  raise LookupError(
789
- f"`{get_full_qualname(obj)}` has an unknown dependency parameter "
903
+ f"`{get_full_qualname(call)}` has an unknown dependency parameter "
790
904
  f"`{parameter.name}` with an annotation of "
791
905
  f"`{get_full_qualname(parameter.annotation)}`."
792
906
  )
793
907
 
908
+ def run(self, func: Callable[P, T], /, *args: P.args, **kwargs: P.kwargs) -> T:
909
+ """Run the given function with injected dependencies."""
910
+ return self._inject(func)(*args, **kwargs)
794
911
 
795
- def transient(target: T) -> T:
796
- """Decorator for marking a class as transient scope.
912
+ def scan(
913
+ self,
914
+ /,
915
+ packages: ModuleType | str | Iterable[ModuleType | str],
916
+ *,
917
+ tags: Iterable[str] | None = None,
918
+ ) -> None:
919
+ """Scan packages or modules for decorated members and inject dependencies."""
920
+ dependencies: list[Dependency] = []
921
+
922
+ if isinstance(packages, Iterable) and not isinstance(packages, str):
923
+ scan_packages: Iterable[ModuleType | str] = packages
924
+ else:
925
+ scan_packages = cast(Iterable[Union[ModuleType, str]], [packages])
926
+
927
+ for package in scan_packages:
928
+ dependencies.extend(self._scan_package(package, tags=tags))
929
+
930
+ for dependency in dependencies:
931
+ decorator = self.inject()(dependency.member)
932
+ setattr(dependency.module, dependency.member.__name__, decorator)
933
+
934
+ def _scan_package(
935
+ self,
936
+ package: ModuleType | str,
937
+ *,
938
+ tags: Iterable[str] | None = None,
939
+ ) -> list[Dependency]:
940
+ """Scan a package or module for decorated members."""
941
+ tags = tags or []
942
+ if isinstance(package, str):
943
+ package = importlib.import_module(package)
944
+
945
+ package_path = getattr(package, "__path__", None)
946
+
947
+ if not package_path:
948
+ return self._scan_module(package, tags=tags)
949
+
950
+ dependencies: list[Dependency] = []
951
+
952
+ for module_info in pkgutil.walk_packages(
953
+ path=package_path, prefix=package.__name__ + "."
954
+ ):
955
+ module = importlib.import_module(module_info.name)
956
+ dependencies.extend(self._scan_module(module, tags=tags))
957
+
958
+ return dependencies
959
+
960
+ def _scan_module(
961
+ self, module: ModuleType, *, tags: Iterable[str]
962
+ ) -> list[Dependency]:
963
+ """Scan a module for decorated members."""
964
+ dependencies: list[Dependency] = []
965
+
966
+ for _, member in inspect.getmembers(module):
967
+ if getattr(member, "__module__", None) != module.__name__ or not callable(
968
+ member
969
+ ):
970
+ continue
971
+
972
+ decorator_args: InjectableDecoratorArgs = getattr(
973
+ member,
974
+ "__injectable__",
975
+ InjectableDecoratorArgs(wrapped=False, tags=[]),
976
+ )
977
+
978
+ if tags and (
979
+ decorator_args.tags
980
+ and not set(decorator_args.tags).intersection(tags)
981
+ or not decorator_args.tags
982
+ ):
983
+ continue
984
+
985
+ if decorator_args.wrapped:
986
+ dependencies.append(
987
+ self._create_dependency(member=member, module=module)
988
+ )
989
+ continue
990
+
991
+ # Get by Marker
992
+ for parameter in get_typed_parameters(member):
993
+ if is_marker(parameter.default):
994
+ dependencies.append(
995
+ self._create_dependency(member=member, module=module)
996
+ )
997
+ continue
998
+
999
+ return dependencies
1000
+
1001
+ def _create_dependency(self, member: Any, module: ModuleType) -> Dependency:
1002
+ """Create a `Dependency` object from the scanned member and module."""
1003
+ if hasattr(member, "__wrapped__"):
1004
+ member = member.__wrapped__
1005
+ return Dependency(member=member, module=module)
797
1006
 
798
- Args:
799
- target: The target class to be decorated.
800
1007
 
801
- Returns:
802
- The decorated target class.
803
- """
1008
+ def transient(target: T) -> T:
1009
+ """Decorator for marking a class as transient scope."""
804
1010
  setattr(target, "__scope__", "transient")
805
1011
  return target
806
1012
 
807
1013
 
808
1014
  def request(target: T) -> T:
809
- """Decorator for marking a class as request scope.
810
-
811
- Args:
812
- target: The target class to be decorated.
813
-
814
- Returns:
815
- The decorated target class.
816
- """
1015
+ """Decorator for marking a class as request scope."""
817
1016
  setattr(target, "__scope__", "request")
818
1017
  return target
819
1018
 
820
1019
 
821
1020
  def singleton(target: T) -> T:
822
- """Decorator for marking a class as singleton scope.
823
-
824
- Args:
825
- target: The target class to be decorated.
826
-
827
- Returns:
828
- The decorated target class.
829
- """
1021
+ """Decorator for marking a class as singleton scope."""
830
1022
  setattr(target, "__scope__", "singleton")
831
1023
  return target
1024
+
1025
+
1026
+ def provider(
1027
+ *, scope: Scope, override: bool = False
1028
+ ) -> Callable[[Callable[Concatenate[M, P], T]], Callable[Concatenate[M, P], T]]:
1029
+ """Decorator for marking a function or method as a provider in a AnyDI module."""
1030
+
1031
+ def decorator(
1032
+ target: Callable[Concatenate[M, P], T],
1033
+ ) -> Callable[Concatenate[M, P], T]:
1034
+ setattr(
1035
+ target,
1036
+ "__provider__",
1037
+ ProviderDecoratorArgs(scope=scope, override=override),
1038
+ )
1039
+ return target
1040
+
1041
+ return decorator
1042
+
1043
+
1044
+ @overload
1045
+ def injectable(func: Callable[P, T]) -> Callable[P, T]: ...
1046
+
1047
+
1048
+ @overload
1049
+ def injectable(
1050
+ *, tags: Iterable[str] | None = None
1051
+ ) -> Callable[[Callable[P, T]], Callable[P, T]]: ...
1052
+
1053
+
1054
+ def injectable(
1055
+ func: Callable[P, T] | None = None,
1056
+ tags: Iterable[str] | None = None,
1057
+ ) -> Callable[[Callable[P, T]], Callable[P, T]] | Callable[P, T]:
1058
+ """Decorator for marking a function or method as requiring dependency injection."""
1059
+
1060
+ def decorator(inner: Callable[P, T]) -> Callable[P, T]:
1061
+ setattr(
1062
+ inner,
1063
+ "__injectable__",
1064
+ InjectableDecoratorArgs(wrapped=True, tags=tags),
1065
+ )
1066
+ return inner
1067
+
1068
+ if func is None:
1069
+ return decorator
1070
+
1071
+ return decorator(func)