anydi 0.34.0__py3-none-any.whl → 0.34.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
anydi/_container.py CHANGED
@@ -4,41 +4,76 @@ from __future__ import annotations
4
4
 
5
5
  import contextlib
6
6
  import functools
7
+ import importlib
7
8
  import inspect
9
+ import pkgutil
8
10
  import threading
9
11
  import types
10
12
  from collections import defaultdict
11
13
  from collections.abc import AsyncIterator, Iterable, Iterator, Sequence
12
14
  from contextvars import ContextVar
13
- from typing import Any, Callable, TypeVar, cast, overload
15
+ from types import ModuleType
16
+ from typing import Any, Callable, TypeVar, Union, cast, overload
14
17
  from weakref import WeakKeyDictionary
15
18
 
16
- from typing_extensions import ParamSpec, Self, final
19
+ from typing_extensions import Concatenate, ParamSpec, Self, final
17
20
 
18
- from ._context import (
19
- RequestContext,
20
- ResourceScopedContext,
21
- ScopedContext,
22
- SingletonContext,
23
- TransientContext,
24
- )
21
+ from ._context import InstanceContext
25
22
  from ._logger import logger
26
- from ._module import Module, ModuleRegistry
27
23
  from ._provider import Provider
28
- from ._scanner import Scanner
29
- from ._types import AnyInterface, DependencyWrapper, Interface, Scope, is_marker
30
- from ._utils import AsyncRLock, get_full_qualname, get_typed_parameters, is_builtin_type
24
+ from ._types import (
25
+ AnyInterface,
26
+ Dependency,
27
+ DependencyWrapper,
28
+ InjectableDecoratorArgs,
29
+ ProviderDecoratorArgs,
30
+ Scope,
31
+ is_event_type,
32
+ is_marker,
33
+ )
34
+ from ._utils import (
35
+ AsyncRLock,
36
+ get_full_qualname,
37
+ get_typed_parameters,
38
+ import_string,
39
+ is_async_context_manager,
40
+ is_builtin_type,
41
+ is_context_manager,
42
+ run_async,
43
+ )
31
44
 
32
45
  T = TypeVar("T", bound=Any)
46
+ M = TypeVar("M", bound="Module")
33
47
  P = ParamSpec("P")
34
48
 
35
49
  ALLOWED_SCOPES: dict[Scope, list[Scope]] = {
36
50
  "singleton": ["singleton"],
37
51
  "request": ["request", "singleton"],
38
- "transient": ["transient", "singleton", "request"],
52
+ "transient": ["transient", "request", "singleton"],
39
53
  }
40
54
 
41
55
 
56
+ class ModuleMeta(type):
57
+ """A metaclass used for the Module base class."""
58
+
59
+ def __new__(cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any]) -> Any:
60
+ attrs["providers"] = [
61
+ (name, getattr(value, "__provider__"))
62
+ for name, value in attrs.items()
63
+ if hasattr(value, "__provider__")
64
+ ]
65
+ return super().__new__(cls, name, bases, attrs)
66
+
67
+
68
+ class Module(metaclass=ModuleMeta):
69
+ """A base class for defining AnyDI modules."""
70
+
71
+ providers: list[tuple[str, ProviderDecoratorArgs]]
72
+
73
+ def configure(self, container: Container) -> None:
74
+ """Configure the AnyDI container with providers and their dependencies."""
75
+
76
+
42
77
  @final
43
78
  class Container:
44
79
  """AnyDI is a dependency injection container."""
@@ -53,12 +88,11 @@ class Container:
53
88
  testing: bool = False,
54
89
  ) -> None:
55
90
  self._providers: dict[type[Any], Provider] = {}
56
- self._resource_cache: dict[Scope, list[type[Any]]] = defaultdict(list)
57
- self._singleton_context = SingletonContext(self)
91
+ self._resources: dict[str, list[type[Any]]] = defaultdict(list)
92
+ self._singleton_context = InstanceContext()
58
93
  self._singleton_lock = threading.RLock()
59
94
  self._singleton_async_lock = AsyncRLock()
60
- self._transient_context = TransientContext(self)
61
- self._request_context_var: ContextVar[RequestContext | None] = ContextVar(
95
+ self._request_context_var: ContextVar[InstanceContext | None] = ContextVar(
62
96
  "request_context", default=None
63
97
  )
64
98
  self._override_instances: dict[type[Any], Any] = {}
@@ -69,10 +103,6 @@ class Container:
69
103
  Callable[..., Any], Callable[..., Any]
70
104
  ] = WeakKeyDictionary()
71
105
 
72
- # Components
73
- self._modules = ModuleRegistry(self)
74
- self._scanner = Scanner(self)
75
-
76
106
  # Register providers
77
107
  providers = providers or []
78
108
  for provider in providers:
@@ -142,14 +172,14 @@ class Container:
142
172
 
143
173
  provider = self._get_provider(interface)
144
174
 
145
- # Cleanup scoped context instance
146
- try:
147
- scoped_context = self._get_scoped_context(provider.scope)
148
- except LookupError:
149
- pass
150
- else:
151
- if isinstance(scoped_context, ResourceScopedContext):
152
- scoped_context.delete(interface)
175
+ # Cleanup instance context
176
+ if provider.scope != "transient":
177
+ try:
178
+ context = self._get_scoped_context(provider.scope)
179
+ except LookupError:
180
+ pass
181
+ else:
182
+ del context[interface]
153
183
 
154
184
  # Cleanup provider references
155
185
  self._delete_provider(provider)
@@ -190,14 +220,14 @@ class Container:
190
220
  """Set a provider by interface."""
191
221
  self._providers[provider.interface] = provider
192
222
  if provider.is_resource:
193
- self._resource_cache[provider.scope].append(provider.interface)
223
+ self._resources[provider.scope].append(provider.interface)
194
224
 
195
225
  def _delete_provider(self, provider: Provider) -> None:
196
226
  """Delete a provider."""
197
227
  if provider.interface in self._providers:
198
228
  del self._providers[provider.interface]
199
229
  if provider.is_resource:
200
- self._resource_cache[provider.scope].remove(provider.interface)
230
+ self._resources[provider.scope].remove(provider.interface)
201
231
 
202
232
  def _validate_sub_providers(self, provider: Provider) -> None:
203
233
  """Validate the sub-providers of a provider."""
@@ -271,7 +301,27 @@ class Container:
271
301
  self, module: Module | type[Module] | Callable[[Container], None] | str
272
302
  ) -> None:
273
303
  """Register a module as a callable, module type, or module instance."""
274
- self._modules.register(module)
304
+ # Callable Module
305
+ if inspect.isfunction(module):
306
+ module(self)
307
+ return
308
+
309
+ # Module path
310
+ if isinstance(module, str):
311
+ module = import_string(module)
312
+
313
+ # Class based Module or Module type
314
+ if inspect.isclass(module) and issubclass(module, Module):
315
+ module = module()
316
+
317
+ if isinstance(module, Module):
318
+ module.configure(self)
319
+ for provider_name, decorator_args in module.providers:
320
+ obj = getattr(module, provider_name)
321
+ self.provider(
322
+ scope=decorator_args.scope,
323
+ override=decorator_args.override,
324
+ )(obj)
275
325
 
276
326
  def __enter__(self) -> Self:
277
327
  """Enter the singleton context."""
@@ -283,23 +333,33 @@ class Container:
283
333
  exc_type: type[BaseException] | None,
284
334
  exc_val: BaseException | None,
285
335
  exc_tb: types.TracebackType | None,
286
- ) -> bool:
336
+ ) -> Any:
287
337
  """Exit the singleton context."""
288
338
  return self._singleton_context.__exit__(exc_type, exc_val, exc_tb)
289
339
 
290
340
  def start(self) -> None:
291
341
  """Start the singleton context."""
292
- self._singleton_context.start()
342
+ # Resolve all singleton resources
343
+ for interface in self._resources.get("singleton", []):
344
+ self.resolve(interface)
293
345
 
294
346
  def close(self) -> None:
295
347
  """Close the singleton context."""
296
348
  self._singleton_context.close()
297
349
 
298
350
  @contextlib.contextmanager
299
- def request_context(self) -> Iterator[RequestContext]:
351
+ def request_context(self) -> Iterator[InstanceContext]:
300
352
  """Obtain a context manager for the request-scoped context."""
301
- context = RequestContext(self)
353
+ context = InstanceContext()
354
+
302
355
  token = self._request_context_var.set(context)
356
+
357
+ # Resolve all request resources
358
+ for interface in self._resources.get("request", []):
359
+ if not is_event_type(interface):
360
+ continue
361
+ self.resolve(interface)
362
+
303
363
  with context:
304
364
  yield context
305
365
  self._request_context_var.reset(token)
@@ -320,22 +380,30 @@ class Container:
320
380
 
321
381
  async def astart(self) -> None:
322
382
  """Start the singleton context asynchronously."""
323
- await self._singleton_context.astart()
383
+ for interface in self._resources.get("singleton", []):
384
+ await self.aresolve(interface)
324
385
 
325
386
  async def aclose(self) -> None:
326
387
  """Close the singleton context asynchronously."""
327
388
  await self._singleton_context.aclose()
328
389
 
329
390
  @contextlib.asynccontextmanager
330
- async def arequest_context(self) -> AsyncIterator[RequestContext]:
391
+ async def arequest_context(self) -> AsyncIterator[InstanceContext]:
331
392
  """Obtain an async context manager for the request-scoped context."""
332
- context = RequestContext(self)
393
+ context = InstanceContext()
394
+
333
395
  token = self._request_context_var.set(context)
396
+
397
+ for interface in self._resources.get("request", []):
398
+ if not is_event_type(interface):
399
+ continue
400
+ await self.aresolve(interface)
401
+
334
402
  async with context:
335
403
  yield context
336
404
  self._request_context_var.reset(token)
337
405
 
338
- def _get_request_context(self) -> RequestContext:
406
+ def _get_request_context(self) -> InstanceContext:
339
407
  """Get the current request context."""
340
408
  request_context = self._request_context_var.get()
341
409
  if request_context is None:
@@ -349,57 +417,248 @@ class Container:
349
417
  def reset(self) -> None:
350
418
  """Reset resolved instances."""
351
419
  for interface, provider in self._providers.items():
420
+ if provider.scope == "transient":
421
+ continue
352
422
  try:
353
- scoped_context = self._get_scoped_context(provider.scope)
423
+ context = self._get_scoped_context(provider.scope)
354
424
  except LookupError:
355
425
  continue
356
- if isinstance(scoped_context, ResourceScopedContext):
357
- scoped_context.delete(interface)
426
+ del context[interface]
358
427
 
359
428
  @overload
360
- def resolve(self, interface: Interface[T]) -> T: ...
429
+ def resolve(self, interface: type[T]) -> T: ...
361
430
 
362
431
  @overload
363
432
  def resolve(self, interface: T) -> T: ...
364
433
 
365
- def resolve(self, interface: Interface[T]) -> T:
434
+ def resolve(self, interface: type[T]) -> T:
366
435
  """Resolve an instance by interface."""
367
436
  if interface in self._override_instances:
368
437
  return cast(T, self._override_instances[interface])
369
438
 
370
439
  provider = self._get_or_register_provider(interface)
371
- scoped_context = self._get_scoped_context(provider.scope)
372
- if provider.scope == "singleton":
373
- with self._singleton_lock:
374
- instance, created = scoped_context.get_or_create(provider)
440
+ if provider.scope == "transient":
441
+ instance, created = self._create_instance(provider), True
375
442
  else:
376
- instance, created = scoped_context.get_or_create(provider)
443
+ context = self._get_scoped_context(provider.scope)
444
+ if provider.scope == "singleton":
445
+ with self._singleton_lock:
446
+ instance, created = self._get_or_create_instance(
447
+ provider, context=context
448
+ )
449
+ else:
450
+ instance, created = self._get_or_create_instance(
451
+ provider, context=context
452
+ )
377
453
  if self.testing and created:
378
454
  self._patch_test_resolver(instance)
379
455
  return cast(T, instance)
380
456
 
381
457
  @overload
382
- async def aresolve(self, interface: Interface[T]) -> T: ...
458
+ async def aresolve(self, interface: type[T]) -> T: ...
383
459
 
384
460
  @overload
385
461
  async def aresolve(self, interface: T) -> T: ...
386
462
 
387
- async def aresolve(self, interface: Interface[T]) -> T:
463
+ async def aresolve(self, interface: type[T]) -> T:
388
464
  """Resolve an instance by interface asynchronously."""
389
465
  if interface in self._override_instances:
390
466
  return cast(T, self._override_instances[interface])
391
467
 
392
468
  provider = self._get_or_register_provider(interface)
393
- scoped_context = self._get_scoped_context(provider.scope)
394
- if provider.scope == "singleton":
395
- async with self._singleton_async_lock:
396
- instance, created = await scoped_context.aget_or_create(provider)
469
+ if provider.scope == "transient":
470
+ instance, created = await self._acreate_instance(provider), True
397
471
  else:
398
- instance, created = await scoped_context.aget_or_create(provider)
472
+ context = self._get_scoped_context(provider.scope)
473
+ if provider.scope == "singleton":
474
+ async with self._singleton_async_lock:
475
+ instance, created = await self._aget_or_create_instance(
476
+ provider, context=context
477
+ )
478
+ else:
479
+ instance, created = await self._aget_or_create_instance(
480
+ provider, context=context
481
+ )
399
482
  if self.testing and created:
400
483
  self._patch_test_resolver(instance)
401
484
  return cast(T, instance)
402
485
 
486
+ def _get_or_create_instance(
487
+ self, provider: Provider, context: InstanceContext
488
+ ) -> tuple[Any, bool]:
489
+ """Get an instance of a dependency from the scoped context."""
490
+ instance = context.get(provider.interface)
491
+ if instance is None:
492
+ instance = self._create_instance(provider, context=context)
493
+ context.set(provider.interface, instance)
494
+ return instance, True
495
+ return instance, False
496
+
497
+ async def _aget_or_create_instance(
498
+ self, provider: Provider, context: InstanceContext
499
+ ) -> tuple[Any, bool]:
500
+ """Get an async instance of a dependency from the scoped context."""
501
+ instance = context.get(provider.interface)
502
+ if instance is None:
503
+ instance = await self._acreate_instance(provider, context=context)
504
+ context.set(provider.interface, instance)
505
+ return instance, True
506
+ return instance, False
507
+
508
+ def _create_instance(
509
+ self,
510
+ provider: Provider,
511
+ context: InstanceContext | None = None,
512
+ ) -> Any:
513
+ """Create an instance using the provider."""
514
+ if provider.is_async:
515
+ raise TypeError(
516
+ f"The instance for the provider `{provider}` cannot be created in "
517
+ "synchronous mode."
518
+ )
519
+
520
+ args, kwargs = self._get_provided_args(provider, context=context)
521
+
522
+ if provider.is_generator:
523
+ if context is None:
524
+ raise ValueError("The context is required for generator providers.")
525
+ cm = contextlib.contextmanager(provider.call)(*args, **kwargs)
526
+ return context.enter(cm)
527
+
528
+ instance = provider.call(*args, **kwargs)
529
+ if context is not None and is_context_manager(instance):
530
+ context.enter(instance)
531
+ return instance
532
+
533
+ async def _acreate_instance(
534
+ self,
535
+ provider: Provider,
536
+ context: InstanceContext | None = None,
537
+ ) -> Any:
538
+ """Create an instance asynchronously using the provider."""
539
+ args, kwargs = await self._aget_provided_args(provider, context=context)
540
+
541
+ if provider.is_coroutine:
542
+ instance = await provider.call(*args, **kwargs)
543
+ if context is not None and is_async_context_manager(instance):
544
+ await context.aenter(instance)
545
+ return instance
546
+
547
+ if provider.is_async_generator:
548
+ if context is None:
549
+ raise ValueError(
550
+ "The async stack is required for async generator providers."
551
+ )
552
+ cm = contextlib.asynccontextmanager(provider.call)(*args, **kwargs)
553
+ return await context.aenter(cm)
554
+
555
+ if provider.is_generator:
556
+
557
+ def _create() -> Any:
558
+ if context is None:
559
+ raise ValueError("The stack is required for generator providers.")
560
+ cm = contextlib.contextmanager(provider.call)(*args, **kwargs)
561
+ return context.enter(cm)
562
+
563
+ return await run_async(_create)
564
+
565
+ instance = await run_async(provider.call, *args, **kwargs)
566
+ if context is not None and is_async_context_manager(instance):
567
+ await context.aenter(instance)
568
+ return instance
569
+
570
+ def _get_provided_args(
571
+ self,
572
+ provider: Provider,
573
+ context: InstanceContext | None,
574
+ *args: Any,
575
+ **kwargs: Any,
576
+ ) -> tuple[list[Any], dict[str, Any]]:
577
+ """Retrieve the arguments for a provider."""
578
+ provided_args: list[Any] = []
579
+ provided_kwargs: dict[str, Any] = {}
580
+
581
+ for parameter in provider.parameters:
582
+ if parameter.annotation in self._override_instances:
583
+ instance = self._override_instances[parameter.annotation]
584
+ elif context and parameter.annotation in context:
585
+ instance = context[parameter.annotation]
586
+ else:
587
+ try:
588
+ instance = self._resolve_parameter(provider, parameter)
589
+ except LookupError:
590
+ if parameter.default is inspect.Parameter.empty:
591
+ raise
592
+ instance = parameter.default
593
+ else:
594
+ if self.testing:
595
+ instance = DependencyWrapper(
596
+ interface=parameter.annotation, instance=instance
597
+ )
598
+ if parameter.kind == parameter.POSITIONAL_ONLY:
599
+ provided_args.append(instance)
600
+ else:
601
+ provided_kwargs[parameter.name] = instance
602
+ return provided_args, provided_kwargs
603
+
604
+ async def _aget_provided_args(
605
+ self,
606
+ provider: Provider,
607
+ context: InstanceContext | None,
608
+ *args: Any,
609
+ **kwargs: Any,
610
+ ) -> tuple[list[Any], dict[str, Any]]:
611
+ """Asynchronously retrieve the arguments for a provider."""
612
+ provided_args: list[Any] = []
613
+ provided_kwargs: dict[str, Any] = {}
614
+
615
+ for parameter in provider.parameters:
616
+ if parameter.annotation in self._override_instances:
617
+ instance = self._override_instances[parameter.annotation]
618
+ elif context and parameter.annotation in context:
619
+ instance = context[parameter.annotation]
620
+ else:
621
+ try:
622
+ instance = await self._aresolve_parameter(provider, parameter)
623
+ except LookupError:
624
+ if parameter.default is inspect.Parameter.empty:
625
+ raise
626
+ instance = parameter.default
627
+ else:
628
+ if self.testing:
629
+ instance = DependencyWrapper(
630
+ interface=parameter.annotation, instance=instance
631
+ )
632
+ if parameter.kind == parameter.POSITIONAL_ONLY:
633
+ provided_args.append(instance)
634
+ else:
635
+ provided_kwargs[parameter.name] = instance
636
+ return provided_args, provided_kwargs
637
+
638
+ def _resolve_parameter(
639
+ self, provider: Provider, parameter: inspect.Parameter
640
+ ) -> Any:
641
+ self._validate_resolvable_parameter(parameter, call=provider.call)
642
+ return self.resolve(parameter.annotation)
643
+
644
+ async def _aresolve_parameter(
645
+ self, provider: Provider, parameter: inspect.Parameter
646
+ ) -> Any:
647
+ self._validate_resolvable_parameter(parameter, call=provider.call)
648
+ return await self.aresolve(parameter.annotation)
649
+
650
+ def _validate_resolvable_parameter(
651
+ self, parameter: inspect.Parameter, call: Callable[..., Any]
652
+ ) -> None:
653
+ """Ensure that the specified interface is resolved."""
654
+ if parameter.annotation in self._unresolved_interfaces:
655
+ raise LookupError(
656
+ f"You are attempting to get the parameter `{parameter.name}` with the "
657
+ f"annotation `{get_full_qualname(parameter.annotation)}` as a "
658
+ f"dependency into `{get_full_qualname(call)}` which is not registered "
659
+ "or set in the scoped context."
660
+ )
661
+
403
662
  def _patch_test_resolver(self, instance: Any) -> None:
404
663
  """Patch the test resolver for the instance."""
405
664
  if not hasattr(instance, "__dict__"):
@@ -428,28 +687,25 @@ class Container:
428
687
  try:
429
688
  provider = self._get_provider(interface)
430
689
  except LookupError:
431
- pass
432
- else:
433
- scoped_context = self._get_scoped_context(provider.scope)
434
- if isinstance(scoped_context, ResourceScopedContext):
435
- return scoped_context.has(interface)
436
- return False
690
+ return False
691
+ if provider.scope == "transient":
692
+ return False
693
+ context = self._get_scoped_context(provider.scope)
694
+ return interface in context
437
695
 
438
696
  def release(self, interface: AnyInterface) -> None:
439
697
  """Release an instance by interface."""
440
698
  provider = self._get_provider(interface)
441
- scoped_context = self._get_scoped_context(provider.scope)
442
- if isinstance(scoped_context, ResourceScopedContext):
443
- scoped_context.delete(interface)
699
+ if provider.scope == "transient":
700
+ return None
701
+ context = self._get_scoped_context(provider.scope)
702
+ del context[interface]
444
703
 
445
- def _get_scoped_context(self, scope: Scope) -> ScopedContext:
446
- """Get the scoped context based on the specified scope."""
704
+ def _get_scoped_context(self, scope: Scope) -> InstanceContext:
705
+ """Get the instance context for the specified scope."""
447
706
  if scope == "singleton":
448
707
  return self._singleton_context
449
- elif scope == "request":
450
- request_context = self._get_request_context()
451
- return request_context
452
- return self._transient_context
708
+ return self._get_request_context()
453
709
 
454
710
  @contextlib.contextmanager
455
711
  def override(self, interface: AnyInterface, instance: Any) -> Iterator[None]:
@@ -488,19 +744,14 @@ class Container:
488
744
  ) -> Callable[[Callable[P, T]], Callable[P, T]] | Callable[P, T]:
489
745
  """Decorator to inject dependencies into a callable."""
490
746
 
491
- def decorator(
492
- call: Callable[P, T],
493
- ) -> Callable[P, T]:
747
+ def decorator(call: Callable[P, T]) -> Callable[P, T]:
494
748
  return self._inject(call)
495
749
 
496
750
  if func is None:
497
751
  return decorator
498
752
  return decorator(func)
499
753
 
500
- def _inject(
501
- self,
502
- call: Callable[P, T],
503
- ) -> Callable[P, T]:
754
+ def _inject(self, call: Callable[P, T]) -> Callable[P, T]:
504
755
  """Inject dependencies into a callable."""
505
756
  if call in self._inject_cache:
506
757
  return cast(Callable[P, T], self._inject_cache[call])
@@ -574,12 +825,97 @@ class Container:
574
825
  def scan(
575
826
  self,
576
827
  /,
577
- packages: types.ModuleType | str | Iterable[types.ModuleType | str],
828
+ packages: ModuleType | str | Iterable[ModuleType | str],
578
829
  *,
579
830
  tags: Iterable[str] | None = None,
580
831
  ) -> None:
581
832
  """Scan packages or modules for decorated members and inject dependencies."""
582
- self._scanner.scan(packages, tags=tags)
833
+ dependencies: list[Dependency] = []
834
+
835
+ if isinstance(packages, Iterable) and not isinstance(packages, str):
836
+ scan_packages: Iterable[ModuleType | str] = packages
837
+ else:
838
+ scan_packages = cast(Iterable[Union[ModuleType, str]], [packages])
839
+
840
+ for package in scan_packages:
841
+ dependencies.extend(self._scan_package(package, tags=tags))
842
+
843
+ for dependency in dependencies:
844
+ decorator = self.inject()(dependency.member)
845
+ setattr(dependency.module, dependency.member.__name__, decorator)
846
+
847
+ def _scan_package(
848
+ self,
849
+ package: ModuleType | str,
850
+ *,
851
+ tags: Iterable[str] | None = None,
852
+ ) -> list[Dependency]:
853
+ """Scan a package or module for decorated members."""
854
+ tags = tags or []
855
+ if isinstance(package, str):
856
+ package = importlib.import_module(package)
857
+
858
+ package_path = getattr(package, "__path__", None)
859
+
860
+ if not package_path:
861
+ return self._scan_module(package, tags=tags)
862
+
863
+ dependencies: list[Dependency] = []
864
+
865
+ for module_info in pkgutil.walk_packages(
866
+ path=package_path, prefix=package.__name__ + "."
867
+ ):
868
+ module = importlib.import_module(module_info.name)
869
+ dependencies.extend(self._scan_module(module, tags=tags))
870
+
871
+ return dependencies
872
+
873
+ def _scan_module(
874
+ self, module: ModuleType, *, tags: Iterable[str]
875
+ ) -> list[Dependency]:
876
+ """Scan a module for decorated members."""
877
+ dependencies: list[Dependency] = []
878
+
879
+ for _, member in inspect.getmembers(module):
880
+ if getattr(member, "__module__", None) != module.__name__ or not callable(
881
+ member
882
+ ):
883
+ continue
884
+
885
+ decorator_args: InjectableDecoratorArgs = getattr(
886
+ member,
887
+ "__injectable__",
888
+ InjectableDecoratorArgs(wrapped=False, tags=[]),
889
+ )
890
+
891
+ if tags and (
892
+ decorator_args.tags
893
+ and not set(decorator_args.tags).intersection(tags)
894
+ or not decorator_args.tags
895
+ ):
896
+ continue
897
+
898
+ if decorator_args.wrapped:
899
+ dependencies.append(
900
+ self._create_dependency(member=member, module=module)
901
+ )
902
+ continue
903
+
904
+ # Get by Marker
905
+ for parameter in get_typed_parameters(member):
906
+ if is_marker(parameter.default):
907
+ dependencies.append(
908
+ self._create_dependency(member=member, module=module)
909
+ )
910
+ continue
911
+
912
+ return dependencies
913
+
914
+ def _create_dependency(self, member: Any, module: ModuleType) -> Dependency:
915
+ """Create a `Dependency` object from the scanned member and module."""
916
+ if hasattr(member, "__wrapped__"):
917
+ member = member.__wrapped__
918
+ return Dependency(member=member, module=module)
583
919
 
584
920
 
585
921
  def transient(target: T) -> T:
@@ -598,3 +934,51 @@ def singleton(target: T) -> T:
598
934
  """Decorator for marking a class as singleton scope."""
599
935
  setattr(target, "__scope__", "singleton")
600
936
  return target
937
+
938
+
939
+ def provider(
940
+ *, scope: Scope, override: bool = False
941
+ ) -> Callable[[Callable[Concatenate[M, P], T]], Callable[Concatenate[M, P], T]]:
942
+ """Decorator for marking a function or method as a provider in a AnyDI module."""
943
+
944
+ def decorator(
945
+ target: Callable[Concatenate[M, P], T],
946
+ ) -> Callable[Concatenate[M, P], T]:
947
+ setattr(
948
+ target,
949
+ "__provider__",
950
+ ProviderDecoratorArgs(scope=scope, override=override),
951
+ )
952
+ return target
953
+
954
+ return decorator
955
+
956
+
957
+ @overload
958
+ def injectable(func: Callable[P, T]) -> Callable[P, T]: ...
959
+
960
+
961
+ @overload
962
+ def injectable(
963
+ *, tags: Iterable[str] | None = None
964
+ ) -> Callable[[Callable[P, T]], Callable[P, T]]: ...
965
+
966
+
967
+ def injectable(
968
+ func: Callable[P, T] | None = None,
969
+ tags: Iterable[str] | None = None,
970
+ ) -> Callable[[Callable[P, T]], Callable[P, T]] | Callable[P, T]:
971
+ """Decorator for marking a function or method as requiring dependency injection."""
972
+
973
+ def decorator(inner: Callable[P, T]) -> Callable[P, T]:
974
+ setattr(
975
+ inner,
976
+ "__injectable__",
977
+ InjectableDecoratorArgs(wrapped=True, tags=tags),
978
+ )
979
+ return inner
980
+
981
+ if func is None:
982
+ return decorator
983
+
984
+ return decorator(func)