anydi 0.31.0__py3-none-any.whl → 0.32.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
anydi/_container.py CHANGED
@@ -5,31 +5,12 @@ from __future__ import annotations
5
5
  import contextlib
6
6
  import inspect
7
7
  import types
8
- import uuid
9
8
  from collections import defaultdict
9
+ from collections.abc import AsyncIterator, Awaitable, Iterable, Iterator, Sequence
10
10
  from contextvars import ContextVar
11
- from functools import wraps
12
- from typing import (
13
- Any,
14
- AsyncIterator,
15
- Awaitable,
16
- Callable,
17
- Iterable,
18
- Iterator,
19
- Mapping,
20
- Sequence,
21
- TypeVar,
22
- cast,
23
- overload,
24
- )
25
-
26
- from typing_extensions import ParamSpec, Self, final, get_args, get_origin
27
-
28
- try:
29
- from types import NoneType
30
- except ImportError:
31
- NoneType = type(None) # type: ignore[misc]
11
+ from typing import Any, Callable, TypeVar, cast, overload
32
12
 
13
+ from typing_extensions import ParamSpec, Self, final
33
14
 
34
15
  from ._context import (
35
16
  RequestContext,
@@ -38,17 +19,12 @@ from ._context import (
38
19
  SingletonContext,
39
20
  TransientContext,
40
21
  )
41
- from ._logger import logger
22
+ from ._injector import Injector
42
23
  from ._module import Module, ModuleRegistry
24
+ from ._provider import Provider
43
25
  from ._scanner import Scanner
44
- from ._types import AnyInterface, Event, Interface, Provider, Scope, is_marker
45
- from ._utils import (
46
- get_full_qualname,
47
- get_typed_parameters,
48
- get_typed_return_annotation,
49
- has_resource_origin,
50
- is_builtin_type,
51
- )
26
+ from ._types import AnyInterface, Interface, Scope
27
+ from ._utils import get_full_qualname, get_typed_parameters, is_builtin_type
52
28
 
53
29
  T = TypeVar("T", bound=Any)
54
30
  P = ParamSpec("P")
@@ -62,27 +38,16 @@ ALLOWED_SCOPES: dict[Scope, list[Scope]] = {
62
38
 
63
39
  @final
64
40
  class Container:
65
- """AnyDI is a dependency injection container.
66
-
67
- Args:
68
- modules: Optional sequence of modules to register during initialization.
69
- """
41
+ """AnyDI is a dependency injection container."""
70
42
 
71
43
  def __init__(
72
44
  self,
73
45
  *,
74
- providers: Mapping[type[Any], Provider] | None = None,
46
+ providers: Sequence[Provider] | None = None,
75
47
  modules: Sequence[Module | type[Module] | Callable[[Container], None] | str]
76
48
  | None = None,
77
49
  strict: bool = False,
78
50
  ) -> None:
79
- """Initialize the AnyDI instance.
80
-
81
- Args:
82
- providers: Optional mapping of providers to register during initialization.
83
- modules: Optional sequence of modules to register during initialization.
84
- strict: Whether to enable strict mode. Defaults to False.
85
- """
86
51
  self._providers: dict[type[Any], Provider] = {}
87
52
  self._resource_cache: dict[Scope, list[type[Any]]] = defaultdict(list)
88
53
  self._singleton_context = SingletonContext(self)
@@ -92,15 +57,17 @@ class Container:
92
57
  )
93
58
  self._override_instances: dict[type[Any], Any] = {}
94
59
  self._strict = strict
60
+ self._unresolved_interfaces: set[type[Any]] = set()
95
61
 
96
62
  # Components
63
+ self._injector = Injector(self)
97
64
  self._modules = ModuleRegistry(self)
98
65
  self._scanner = Scanner(self)
99
66
 
100
67
  # Register providers
101
- providers = providers or {}
102
- for interface, provider in providers.items():
103
- self.register(interface, provider.obj, scope=provider.scope)
68
+ providers = providers or []
69
+ for provider in providers:
70
+ self._register_provider(provider)
104
71
 
105
72
  # Register modules
106
73
  modules = modules or []
@@ -109,102 +76,50 @@ class Container:
109
76
 
110
77
  @property
111
78
  def strict(self) -> bool:
112
- """Check if strict mode is enabled.
113
-
114
- Returns:
115
- True if strict mode is enabled, False otherwise.
116
- """
79
+ """Check if strict mode is enabled."""
117
80
  return self._strict
118
81
 
119
82
  @property
120
83
  def providers(self) -> dict[type[Any], Provider]:
121
- """Get the registered providers.
122
-
123
- Returns:
124
- A dictionary containing the registered providers.
125
- """
84
+ """Get the registered providers."""
126
85
  return self._providers
127
86
 
128
87
  def is_registered(self, interface: AnyInterface) -> bool:
129
- """Check if a provider is registered for the specified interface.
130
-
131
- Args:
132
- interface: The interface to check for a registered provider.
133
-
134
- Returns:
135
- True if a provider is registered for the interface, False otherwise.
136
- """
88
+ """Check if a provider is registered for the specified interface."""
137
89
  return interface in self._providers
138
90
 
139
91
  def register(
140
92
  self,
141
93
  interface: AnyInterface,
142
- obj: Callable[..., Any],
94
+ call: Callable[..., Any],
143
95
  *,
144
96
  scope: Scope,
145
97
  override: bool = False,
146
98
  ) -> Provider:
147
- """Register a provider for the specified interface.
148
-
149
- Args:
150
- interface: The interface for which the provider is being registered.
151
- obj: The provider function or callable object.
152
- scope: The scope of the provider.
153
- override: If True, override an existing provider for the interface
154
- if one is already registered. Defaults to False.
155
-
156
- Returns:
157
- The registered provider.
158
-
159
- Raises:
160
- LookupError: If a provider for the interface is already registered
161
- and override is False.
162
-
163
- Notes:
164
- - If the provider is a resource or an asynchronous resource, and the
165
- interface is None, an Event type will be automatically created and used
166
- as the interface.
167
- - The provider will be validated for its scope, type, and matching scopes.
168
- """
169
- provider = Provider(obj=obj, scope=scope)
170
-
171
- # Create Event type
172
- if provider.is_resource and (interface is NoneType or interface is None):
173
- interface = type(f"Event_{uuid.uuid4().hex}", (Event,), {})
99
+ """Register a provider for the specified interface."""
100
+ provider = Provider(call=call, scope=scope, interface=interface)
101
+ return self._register_provider(provider, override=override)
174
102
 
175
- if interface in self._providers:
103
+ def _register_provider(
104
+ self, provider: Provider, *, override: bool = False
105
+ ) -> Provider:
106
+ """Register a provider."""
107
+ if provider.interface in self._providers:
176
108
  if override:
177
- self._set_provider(interface, provider)
109
+ self._set_provider(provider)
178
110
  return provider
179
111
 
180
112
  raise LookupError(
181
- f"The provider interface `{get_full_qualname(interface)}` "
113
+ f"The provider interface `{get_full_qualname(provider.interface)}` "
182
114
  "already registered."
183
115
  )
184
116
 
185
- # Validate provider
186
- self._validate_provider_scope(provider)
187
- self._validate_provider_type(provider)
188
- self._validate_provider_match_scopes(interface, provider)
189
-
190
- self._set_provider(interface, provider)
117
+ self._validate_sub_providers(provider)
118
+ self._set_provider(provider)
191
119
  return provider
192
120
 
193
121
  def unregister(self, interface: AnyInterface) -> None:
194
- """Unregister a provider by interface.
195
-
196
- Args:
197
- interface: The interface of the provider to unregister.
198
-
199
- Raises:
200
- LookupError: If the provider interface is not registered.
201
-
202
- Notes:
203
- - The method cleans up any scoped context instance associated with
204
- the provider's scope.
205
- - The method removes the provider reference from the internal dictionary
206
- of registered providers.
207
- """
122
+ """Unregister a provider by interface."""
208
123
  if not self.is_registered(interface):
209
124
  raise LookupError(
210
125
  "The provider interface "
@@ -223,20 +138,10 @@ class Container:
223
138
  scoped_context.delete(interface)
224
139
 
225
140
  # Cleanup provider references
226
- self._delete_provider(interface)
141
+ self._delete_provider(provider)
227
142
 
228
143
  def _get_provider(self, interface: AnyInterface) -> Provider:
229
- """Get provider by interface.
230
-
231
- Args:
232
- interface: The interface for which to retrieve the provider.
233
-
234
- Returns:
235
- Provider: The provider object associated with the interface.
236
-
237
- Raises:
238
- LookupError: If the provider interface has not been registered.
239
- """
144
+ """Get provider by interface."""
240
145
  try:
241
146
  return self._providers[interface]
242
147
  except KeyError as exc:
@@ -246,18 +151,10 @@ class Container:
246
151
  "properly registered before attempting to use it."
247
152
  ) from exc
248
153
 
249
- def _get_or_register_provider(self, interface: AnyInterface) -> Provider:
250
- """Get or register a provider by interface.
251
-
252
- Args:
253
- interface: The interface for which to retrieve the provider.
254
-
255
- Returns:
256
- Provider: The provider object associated with the interface.
257
-
258
- Raises:
259
- LookupError: If the provider interface has not been registered.
260
- """
154
+ def _get_or_register_provider(
155
+ self, interface: AnyInterface, parent_scope: Scope | None = None
156
+ ) -> Provider:
157
+ """Get or register a provider by interface."""
261
158
  try:
262
159
  return self._get_provider(interface)
263
160
  except LookupError:
@@ -265,158 +162,92 @@ class Container:
265
162
  not self.strict
266
163
  and inspect.isclass(interface)
267
164
  and not is_builtin_type(interface)
268
- and interface is not inspect._empty # noqa
165
+ and interface is not inspect.Parameter.empty
269
166
  ):
270
167
  # Try to get defined scope
271
- scope = getattr(interface, "__scope__", None)
168
+ scope = getattr(interface, "__scope__", parent_scope)
272
169
  # Try to detect scope
273
170
  if scope is None:
274
171
  scope = self._detect_scope(interface)
275
172
  return self.register(interface, interface, scope=scope or "transient")
276
173
  raise
277
174
 
278
- def _set_provider(self, interface: AnyInterface, provider: Provider) -> None:
279
- """Set a provider by interface.
280
-
281
- Args:
282
- interface: The interface for which to set the provider.
283
- provider: The provider object to set.
284
- """
285
- self._providers[interface] = provider
175
+ def _set_provider(self, provider: Provider) -> None:
176
+ """Set a provider by interface."""
177
+ self._providers[provider.interface] = provider
286
178
  if provider.is_resource:
287
- self._resource_cache[provider.scope].append(interface)
288
-
289
- def _delete_provider(self, interface: AnyInterface) -> None:
290
- """Delete a provider by interface.
291
-
292
- Args:
293
- interface: The interface for which to delete the provider.
294
- """
295
- provider = self._providers.pop(interface, None)
296
- if provider is not None and provider.is_resource:
297
- self._resource_cache[provider.scope].remove(interface)
298
-
299
- def _validate_provider_scope(self, provider: Provider) -> None:
300
- """Validate the scope of a provider.
301
-
302
- Args:
303
- provider: The provider to validate.
304
-
305
- Raises:
306
- ValueError: If the scope provided is invalid.
307
- """
308
- if provider.scope not in get_args(Scope):
309
- raise ValueError(
310
- "The scope provided is invalid. Only the following scopes are "
311
- f"supported: {', '.join(get_args(Scope))}. Please use one of the "
312
- "supported scopes when registering a provider."
313
- )
314
-
315
- def _validate_provider_type(self, provider: Provider) -> None:
316
- """Validate the type of provider.
317
-
318
- Args:
319
- provider: The provider to validate.
320
-
321
- Raises:
322
- TypeError: If the provider has an invalid type.
323
- """
324
- if provider.is_function or provider.is_class:
325
- return
179
+ self._resource_cache[provider.scope].append(provider.interface)
326
180
 
181
+ def _delete_provider(self, provider: Provider) -> None:
182
+ """Delete a provider."""
183
+ if provider.interface in self._providers:
184
+ del self._providers[provider.interface]
327
185
  if provider.is_resource:
328
- if provider.scope == "transient":
329
- raise TypeError(
330
- f"The resource provider `{provider}` is attempting to register "
331
- "with a transient scope, which is not allowed. Please update the "
332
- "provider's scope to an appropriate value before registering it."
333
- )
334
- return
186
+ self._resource_cache[provider.scope].remove(provider.interface)
335
187
 
336
- raise TypeError(
337
- f"The provider `{provider.obj}` is invalid because it is not a callable "
338
- "object. Only callable providers are allowed. Please update the provider "
339
- "to a callable object before attempting to register it."
340
- )
341
-
342
- def _validate_provider_match_scopes(
343
- self, interface: AnyInterface, provider: Provider
344
- ) -> None:
345
- """Validate that the provider and its dependencies have matching scopes.
346
-
347
- Args:
348
- interface: The interface associated with the provider.
349
- provider: The provider to validate.
350
-
351
- Raises:
352
- ValueError: If the provider and its dependencies have mismatched scopes.
353
- TypeError: If a dependency is missing an annotation.
354
- """
355
- related_providers = []
188
+ def _validate_sub_providers(self, provider: Provider) -> None:
189
+ """Validate the sub-providers of a provider."""
356
190
 
357
191
  for parameter in provider.parameters:
358
- if parameter.annotation is inspect._empty: # noqa
192
+ annotation = parameter.annotation
193
+
194
+ if annotation is inspect.Parameter.empty:
359
195
  raise TypeError(
360
196
  f"Missing provider `{provider}` "
361
197
  f"dependency `{parameter.name}` annotation."
362
198
  )
199
+
363
200
  try:
364
- sub_provider = self._get_or_register_provider(parameter.annotation)
201
+ sub_provider = self._get_or_register_provider(
202
+ annotation, parent_scope=provider.scope
203
+ )
365
204
  except LookupError:
205
+ if provider.scope not in {"singleton", "transient"}:
206
+ self._unresolved_interfaces.add(provider.interface)
207
+ continue
366
208
  raise ValueError(
367
- f"The provider `{get_full_qualname(provider.obj)}` depends on "
368
- f"`{parameter.name}` of type "
369
- f"`{get_full_qualname(parameter.annotation)}`, which "
370
- "has not been registered. To resolve this, ensure that "
209
+ f"The provider `{provider}` depends on `{parameter.name}` of type "
210
+ f"`{get_full_qualname(annotation)}`, which "
211
+ "has not been registered or set. To resolve this, ensure that "
371
212
  f"`{parameter.name}` is registered before attempting to use it."
372
213
  ) from None
373
- related_providers.append(sub_provider)
374
214
 
375
- for related_provider in related_providers:
376
- left_scope, right_scope = related_provider.scope, provider.scope
377
- allowed_scopes = ALLOWED_SCOPES.get(right_scope) or []
378
- if left_scope not in allowed_scopes:
215
+ # Check scope compatibility
216
+ if sub_provider.scope not in ALLOWED_SCOPES.get(provider.scope, []):
379
217
  raise ValueError(
380
- f"The provider `{provider}` with a {provider.scope} scope was "
381
- "attempted to be registered with the provider "
382
- f"`{related_provider}` with a `{related_provider.scope}` scope, "
383
- "which is not allowed. Please ensure that all providers are "
384
- "registered with matching scopes."
218
+ f"The provider `{provider}` with a `{provider.scope}` scope cannot "
219
+ f"depend on `{sub_provider}` with a `{sub_provider.scope}` scope. "
220
+ "Please ensure all providers are registered with matching scopes."
385
221
  )
386
222
 
387
- def _detect_scope(self, obj: Callable[..., Any]) -> Scope | None:
388
- """Detect the scope for a provider.
223
+ def _detect_scope(self, call: Callable[..., Any]) -> Scope | None:
224
+ """Detect the scope for a callable."""
225
+ scopes = set()
389
226
 
390
- Args:
391
- obj: The provider to detect the auto scope for.
392
- Returns:
393
- The auto scope, or None if the auto scope cannot be detected.
394
- """
395
- has_transient, has_request, has_singleton = False, False, False
396
- for parameter in get_typed_parameters(obj):
227
+ for parameter in get_typed_parameters(call):
397
228
  sub_provider = self._get_or_register_provider(parameter.annotation)
398
- if not has_transient and sub_provider.scope == "transient":
399
- has_transient = True
400
- if not has_request and sub_provider.scope == "request":
401
- has_request = True
402
- if not has_singleton and sub_provider.scope == "singleton":
403
- has_singleton = True
404
- if has_transient:
405
- return "transient"
406
- if has_request:
229
+ scope = sub_provider.scope
230
+
231
+ if scope == "transient":
232
+ return "transient"
233
+ scopes.add(scope)
234
+
235
+ # If all scopes are found, we can return based on priority order
236
+ if {"transient", "request", "singleton"}.issubset(scopes):
237
+ break
238
+
239
+ # Determine scope based on priority
240
+ if "request" in scopes:
407
241
  return "request"
408
- if has_singleton:
242
+ if "singleton" in scopes:
409
243
  return "singleton"
244
+
410
245
  return None
411
246
 
412
247
  def register_module(
413
248
  self, module: Module | type[Module] | Callable[[Container], None] | str
414
249
  ) -> None:
415
- """Register a module as a callable, module type, or module instance.
416
-
417
- Args:
418
- module: The module to register.
419
- """
250
+ """Register a module as a callable, module type, or module instance."""
420
251
  self._modules.register(module)
421
252
 
422
253
  def __enter__(self) -> Self:
@@ -442,16 +273,12 @@ class Container:
442
273
  self._singleton_context.close()
443
274
 
444
275
  @contextlib.contextmanager
445
- def request_context(self) -> Iterator[None]:
446
- """Obtain a context manager for the request-scoped context.
447
-
448
- Returns:
449
- A context manager for the request-scoped context.
450
- """
276
+ def request_context(self) -> Iterator[RequestContext]:
277
+ """Obtain a context manager for the request-scoped context."""
451
278
  context = RequestContext(self)
452
279
  token = self._request_context_var.set(context)
453
280
  with context:
454
- yield
281
+ yield context
455
282
  self._request_context_var.reset(token)
456
283
 
457
284
  async def __aenter__(self) -> Self:
@@ -477,27 +304,16 @@ class Container:
477
304
  await self._singleton_context.aclose()
478
305
 
479
306
  @contextlib.asynccontextmanager
480
- async def arequest_context(self) -> AsyncIterator[None]:
481
- """Obtain an async context manager for the request-scoped context.
482
-
483
- Returns:
484
- An async context manager for the request-scoped context.
485
- """
307
+ async def arequest_context(self) -> AsyncIterator[RequestContext]:
308
+ """Obtain an async context manager for the request-scoped context."""
486
309
  context = RequestContext(self)
487
310
  token = self._request_context_var.set(context)
488
311
  async with context:
489
- yield
312
+ yield context
490
313
  self._request_context_var.reset(token)
491
314
 
492
315
  def _get_request_context(self) -> RequestContext:
493
- """Get the current request context.
494
-
495
- Returns:
496
- RequestContext: The current request context.
497
-
498
- Raises:
499
- LookupError: If the request context has not been started.
500
- """
316
+ """Get the current request context."""
501
317
  request_context = self._request_context_var.get()
502
318
  if request_context is None:
503
319
  raise LookupError(
@@ -524,23 +340,13 @@ class Container:
524
340
  def resolve(self, interface: T) -> T: ...
525
341
 
526
342
  def resolve(self, interface: Interface[T]) -> T:
527
- """Resolve an instance by interface.
528
-
529
- Args:
530
- interface: The interface type.
531
-
532
- Returns:
533
- The instance of the interface.
534
-
535
- Raises:
536
- LookupError: If the provider for the interface is not registered.
537
- """
343
+ """Resolve an instance by interface."""
538
344
  if interface in self._override_instances:
539
345
  return cast(T, self._override_instances[interface])
540
346
 
541
347
  provider = self._get_or_register_provider(interface)
542
348
  scoped_context = self._get_scoped_context(provider.scope)
543
- return scoped_context.get(interface, provider)
349
+ return cast(T, scoped_context.get(provider))
544
350
 
545
351
  @overload
546
352
  async def aresolve(self, interface: Interface[T]) -> T: ...
@@ -549,33 +355,16 @@ class Container:
549
355
  async def aresolve(self, interface: T) -> T: ...
550
356
 
551
357
  async def aresolve(self, interface: Interface[T]) -> T:
552
- """Resolve an instance by interface asynchronously.
553
-
554
- Args:
555
- interface: The interface type.
556
-
557
- Returns:
558
- The instance of the interface.
559
-
560
- Raises:
561
- LookupError: If the provider for the interface is not registered.
562
- """
358
+ """Resolve an instance by interface asynchronously."""
563
359
  if interface in self._override_instances:
564
360
  return cast(T, self._override_instances[interface])
565
361
 
566
362
  provider = self._get_or_register_provider(interface)
567
363
  scoped_context = self._get_scoped_context(provider.scope)
568
- return await scoped_context.aget(interface, provider)
364
+ return cast(T, await scoped_context.aget(provider))
569
365
 
570
366
  def is_resolved(self, interface: AnyInterface) -> bool:
571
- """Check if an instance by interface exists.
572
-
573
- Args:
574
- interface: The interface type.
575
-
576
- Returns:
577
- True if the instance exists, otherwise False.
578
- """
367
+ """Check if an instance by interface exists."""
579
368
  try:
580
369
  provider = self._get_provider(interface)
581
370
  except LookupError:
@@ -587,28 +376,14 @@ class Container:
587
376
  return False
588
377
 
589
378
  def release(self, interface: AnyInterface) -> None:
590
- """Release an instance by interface.
591
-
592
- Args:
593
- interface: The interface type.
594
-
595
- Raises:
596
- LookupError: If the provider for the interface is not registered.
597
- """
379
+ """Release an instance by interface."""
598
380
  provider = self._get_provider(interface)
599
381
  scoped_context = self._get_scoped_context(provider.scope)
600
382
  if isinstance(scoped_context, ResourceScopedContext):
601
383
  scoped_context.delete(interface)
602
384
 
603
385
  def _get_scoped_context(self, scope: Scope) -> ScopedContext:
604
- """Get the scoped context based on the specified scope.
605
-
606
- Args:
607
- scope: The scope of the provider.
608
-
609
- Returns:
610
- The scoped context, or None if the scope is not applicable.
611
- """
386
+ """Get the scoped context based on the specified scope."""
612
387
  if scope == "singleton":
613
388
  return self._singleton_context
614
389
  elif scope == "request":
@@ -618,17 +393,8 @@ class Container:
618
393
 
619
394
  @contextlib.contextmanager
620
395
  def override(self, interface: AnyInterface, instance: Any) -> Iterator[None]:
621
- """Override the provider for the specified interface with a specific instance.
622
-
623
- Args:
624
- interface: The interface type to override.
625
- instance: The instance to use as the override.
626
-
627
- Yields:
628
- None
629
-
630
- Raises:
631
- LookupError: If the provider for the interface is not registered.
396
+ """
397
+ Override the provider for the specified interface with a specific instance.
632
398
  """
633
399
  if not self.is_registered(interface) and self.strict:
634
400
  raise LookupError(
@@ -642,85 +408,41 @@ class Container:
642
408
  def provider(
643
409
  self, *, scope: Scope, override: bool = False
644
410
  ) -> Callable[[Callable[P, T]], Callable[P, T]]:
645
- """Decorator to register a provider function with the specified scope.
411
+ """Decorator to register a provider function with the specified scope."""
646
412
 
647
- Args:
648
- scope : The scope of the provider.
649
- override: Whether the provider should override an existing provider
650
- for the same interface. Defaults to False.
651
-
652
- Returns:
653
- The decorator function.
654
- """
655
-
656
- def decorator(func: Callable[P, T]) -> Callable[P, T]:
657
- interface = self._get_provider_annotation(func)
658
- self.register(interface, func, scope=scope, override=override)
659
- return func
413
+ def decorator(call: Callable[P, T]) -> Callable[P, T]:
414
+ provider = Provider(call=call, scope=scope)
415
+ self._register_provider(provider, override=override)
416
+ return call
660
417
 
661
418
  return decorator
662
419
 
663
420
  @overload
664
- def inject(self, obj: Callable[P, T]) -> Callable[P, T]: ...
421
+ def inject(self, func: Callable[P, T]) -> Callable[P, T]: ...
665
422
 
666
423
  @overload
667
424
  def inject(self) -> Callable[[Callable[P, T]], Callable[P, T]]: ...
668
425
 
669
426
  def inject(
670
- self, obj: Callable[P, T | Awaitable[T]] | None = None
427
+ self, func: Callable[P, T | Awaitable[T]] | None = None
671
428
  ) -> (
672
429
  Callable[[Callable[P, T | Awaitable[T]]], Callable[P, T | Awaitable[T]]]
673
430
  | Callable[P, T | Awaitable[T]]
674
431
  ):
675
- """Decorator to inject dependencies into a callable.
676
-
677
- Args:
678
- obj: The callable object to be decorated. If None, returns
679
- the decorator itself.
680
-
681
- Returns:
682
- The decorated callable object or decorator function.
683
- """
432
+ """Decorator to inject dependencies into a callable."""
684
433
 
685
434
  def decorator(
686
- obj: Callable[P, T | Awaitable[T]],
435
+ inner: Callable[P, T | Awaitable[T]],
687
436
  ) -> Callable[P, T | Awaitable[T]]:
688
- injected_params = self._get_injected_params(obj)
689
-
690
- if inspect.iscoroutinefunction(obj):
691
-
692
- @wraps(obj)
693
- async def awrapped(*args: P.args, **kwargs: P.kwargs) -> T:
694
- for name, annotation in injected_params.items():
695
- kwargs[name] = await self.aresolve(annotation)
696
- return cast(T, await obj(*args, **kwargs))
697
-
698
- return awrapped
437
+ return self._injector.inject(inner)
699
438
 
700
- @wraps(obj)
701
- def wrapped(*args: P.args, **kwargs: P.kwargs) -> T:
702
- for name, annotation in injected_params.items():
703
- kwargs[name] = self.resolve(annotation)
704
- return cast(T, obj(*args, **kwargs))
705
-
706
- return wrapped
707
-
708
- if obj is None:
439
+ if func is None:
709
440
  return decorator
710
- return decorator(obj)
711
-
712
- def run(self, obj: Callable[P, T], /, *args: P.args, **kwargs: P.kwargs) -> T:
713
- """Run the given function with injected dependencies.
714
-
715
- Args:
716
- obj: The callable object.
717
- args: The positional arguments to pass to the object.
718
- kwargs: The keyword arguments to pass to the object.
441
+ return decorator(func)
719
442
 
720
- Returns:
721
- The result of the callable object.
722
- """
723
- return self.inject(obj)(*args, **kwargs)
443
+ def run(self, func: Callable[P, T], /, *args: P.args, **kwargs: P.kwargs) -> T:
444
+ """Run the given function with injected dependencies."""
445
+ return cast(T, self._injector.inject(func)(*args, **kwargs))
724
446
 
725
447
  def scan(
726
448
  self,
@@ -729,138 +451,23 @@ class Container:
729
451
  *,
730
452
  tags: Iterable[str] | None = None,
731
453
  ) -> None:
732
- """Scan packages or modules for decorated members and inject dependencies.
733
-
734
- Args:
735
- packages: A single package or module to scan,
736
- or an iterable of packages or modules to scan.
737
- tags: Optional list of tags to filter the scanned members. Only members
738
- with at least one matching tag will be scanned. Defaults to None.
739
- """
454
+ """Scan packages or modules for decorated members and inject dependencies."""
740
455
  self._scanner.scan(packages, tags=tags)
741
456
 
742
- def _get_provider_annotation(self, obj: Callable[..., Any]) -> Any:
743
- """Retrieve the provider return annotation from a callable object.
744
-
745
- Args:
746
- obj: The callable object (provider).
747
-
748
- Returns:
749
- The provider return annotation.
750
-
751
- Raises:
752
- TypeError: If the provider return annotation is missing or invalid.
753
- """
754
- annotation = get_typed_return_annotation(obj)
755
-
756
- if annotation is None:
757
- raise TypeError(
758
- f"Missing `{get_full_qualname(obj)}` provider return annotation."
759
- )
760
-
761
- origin = get_origin(annotation)
762
-
763
- if has_resource_origin(origin):
764
- args = get_args(annotation)
765
- if args:
766
- return args[0]
767
- else:
768
- raise TypeError(
769
- f"Cannot use `{get_full_qualname(obj)}` resource type annotation "
770
- "without actual type."
771
- )
772
-
773
- return annotation
774
-
775
- def _get_injected_params(self, obj: Callable[..., Any]) -> dict[str, Any]:
776
- """Get the injected parameters of a callable object.
777
-
778
- Args:
779
- obj: The callable object.
780
-
781
- Returns:
782
- A dictionary containing the names and annotations
783
- of the injected parameters.
784
- """
785
- injected_params = {}
786
- for parameter in get_typed_parameters(obj):
787
- if not is_marker(parameter.default):
788
- continue
789
- try:
790
- self._validate_injected_parameter(obj, parameter)
791
- except LookupError as exc:
792
- if not self.strict:
793
- logger.debug(
794
- f"Cannot validate the `{get_full_qualname(obj)}` parameter "
795
- f"`{parameter.name}` with an annotation of "
796
- f"`{get_full_qualname(parameter.annotation)} due to being "
797
- "in non-strict mode. It will be validated at the first call."
798
- )
799
- else:
800
- raise exc
801
- injected_params[parameter.name] = parameter.annotation
802
- return injected_params
803
-
804
- def _validate_injected_parameter(
805
- self, obj: Callable[..., Any], parameter: inspect.Parameter
806
- ) -> None:
807
- """Validate an injected parameter.
808
-
809
- Args:
810
- obj: The callable object.
811
- parameter: The parameter to validate.
812
-
813
- Raises:
814
- TypeError: If the parameter annotation is missing or an unknown dependency.
815
- """
816
- if parameter.annotation is inspect._empty: # noqa
817
- raise TypeError(
818
- f"Missing `{get_full_qualname(obj)}` parameter "
819
- f"`{parameter.name}` annotation."
820
- )
821
-
822
- if not self.is_registered(parameter.annotation):
823
- raise LookupError(
824
- f"`{get_full_qualname(obj)}` has an unknown dependency parameter "
825
- f"`{parameter.name}` with an annotation of "
826
- f"`{get_full_qualname(parameter.annotation)}`."
827
- )
828
-
829
457
 
830
458
  def transient(target: T) -> T:
831
- """Decorator for marking a class as transient scope.
832
-
833
- Args:
834
- target: The target class to be decorated.
835
-
836
- Returns:
837
- The decorated target class.
838
- """
459
+ """Decorator for marking a class as transient scope."""
839
460
  setattr(target, "__scope__", "transient")
840
461
  return target
841
462
 
842
463
 
843
464
  def request(target: T) -> T:
844
- """Decorator for marking a class as request scope.
845
-
846
- Args:
847
- target: The target class to be decorated.
848
-
849
- Returns:
850
- The decorated target class.
851
- """
465
+ """Decorator for marking a class as request scope."""
852
466
  setattr(target, "__scope__", "request")
853
467
  return target
854
468
 
855
469
 
856
470
  def singleton(target: T) -> T:
857
- """Decorator for marking a class as singleton scope.
858
-
859
- Args:
860
- target: The target class to be decorated.
861
-
862
- Returns:
863
- The decorated target class.
864
- """
471
+ """Decorator for marking a class as singleton scope."""
865
472
  setattr(target, "__scope__", "singleton")
866
473
  return target