anydi 0.33.1__py3-none-any.whl → 0.34.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
anydi/_container.py CHANGED
@@ -4,40 +4,76 @@ from __future__ import annotations
4
4
 
5
5
  import contextlib
6
6
  import functools
7
+ import importlib
7
8
  import inspect
9
+ import pkgutil
10
+ import threading
8
11
  import types
9
12
  from collections import defaultdict
10
13
  from collections.abc import AsyncIterator, Iterable, Iterator, Sequence
11
14
  from contextvars import ContextVar
12
- from typing import Any, Callable, TypeVar, cast, overload
15
+ from types import ModuleType
16
+ from typing import Any, Callable, TypeVar, Union, cast, overload
13
17
  from weakref import WeakKeyDictionary
14
18
 
15
- from typing_extensions import ParamSpec, Self, final
19
+ from typing_extensions import Concatenate, ParamSpec, Self, final
16
20
 
17
- from ._context import (
18
- RequestContext,
19
- ResourceScopedContext,
20
- ScopedContext,
21
- SingletonContext,
22
- TransientContext,
23
- )
21
+ from ._context import InstanceContext
24
22
  from ._logger import logger
25
- from ._module import Module, ModuleRegistry
26
23
  from ._provider import Provider
27
- from ._scanner import Scanner
28
- from ._types import AnyInterface, DependencyWrapper, Interface, Scope, is_marker
29
- from ._utils import get_full_qualname, get_typed_parameters, is_builtin_type
24
+ from ._types import (
25
+ AnyInterface,
26
+ Dependency,
27
+ DependencyWrapper,
28
+ InjectableDecoratorArgs,
29
+ ProviderDecoratorArgs,
30
+ Scope,
31
+ is_event_type,
32
+ is_marker,
33
+ )
34
+ from ._utils import (
35
+ AsyncRLock,
36
+ get_full_qualname,
37
+ get_typed_parameters,
38
+ import_string,
39
+ is_async_context_manager,
40
+ is_builtin_type,
41
+ is_context_manager,
42
+ run_async,
43
+ )
30
44
 
31
45
  T = TypeVar("T", bound=Any)
46
+ M = TypeVar("M", bound="Module")
32
47
  P = ParamSpec("P")
33
48
 
34
49
  ALLOWED_SCOPES: dict[Scope, list[Scope]] = {
35
50
  "singleton": ["singleton"],
36
51
  "request": ["request", "singleton"],
37
- "transient": ["transient", "singleton", "request"],
52
+ "transient": ["transient", "request", "singleton"],
38
53
  }
39
54
 
40
55
 
56
+ class ModuleMeta(type):
57
+ """A metaclass used for the Module base class."""
58
+
59
+ def __new__(cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any]) -> Any:
60
+ attrs["providers"] = [
61
+ (name, getattr(value, "__provider__"))
62
+ for name, value in attrs.items()
63
+ if hasattr(value, "__provider__")
64
+ ]
65
+ return super().__new__(cls, name, bases, attrs)
66
+
67
+
68
+ class Module(metaclass=ModuleMeta):
69
+ """A base class for defining AnyDI modules."""
70
+
71
+ providers: list[tuple[str, ProviderDecoratorArgs]]
72
+
73
+ def configure(self, container: Container) -> None:
74
+ """Configure the AnyDI container with providers and their dependencies."""
75
+
76
+
41
77
  @final
42
78
  class Container:
43
79
  """AnyDI is a dependency injection container."""
@@ -52,10 +88,11 @@ class Container:
52
88
  testing: bool = False,
53
89
  ) -> None:
54
90
  self._providers: dict[type[Any], Provider] = {}
55
- self._resource_cache: dict[Scope, list[type[Any]]] = defaultdict(list)
56
- self._singleton_context = SingletonContext(self)
57
- self._transient_context = TransientContext(self)
58
- self._request_context_var: ContextVar[RequestContext | None] = ContextVar(
91
+ self._resources: dict[str, list[type[Any]]] = defaultdict(list)
92
+ self._singleton_context = InstanceContext()
93
+ self._singleton_lock = threading.RLock()
94
+ self._singleton_async_lock = AsyncRLock()
95
+ self._request_context_var: ContextVar[InstanceContext | None] = ContextVar(
59
96
  "request_context", default=None
60
97
  )
61
98
  self._override_instances: dict[type[Any], Any] = {}
@@ -66,10 +103,6 @@ class Container:
66
103
  Callable[..., Any], Callable[..., Any]
67
104
  ] = WeakKeyDictionary()
68
105
 
69
- # Components
70
- self._modules = ModuleRegistry(self)
71
- self._scanner = Scanner(self)
72
-
73
106
  # Register providers
74
107
  providers = providers or []
75
108
  for provider in providers:
@@ -139,14 +172,14 @@ class Container:
139
172
 
140
173
  provider = self._get_provider(interface)
141
174
 
142
- # Cleanup scoped context instance
143
- try:
144
- scoped_context = self._get_scoped_context(provider.scope)
145
- except LookupError:
146
- pass
147
- else:
148
- if isinstance(scoped_context, ResourceScopedContext):
149
- scoped_context.delete(interface)
175
+ # Cleanup instance context
176
+ if provider.scope != "transient":
177
+ try:
178
+ context = self._get_scoped_context(provider.scope)
179
+ except LookupError:
180
+ pass
181
+ else:
182
+ del context[interface]
150
183
 
151
184
  # Cleanup provider references
152
185
  self._delete_provider(provider)
@@ -187,14 +220,14 @@ class Container:
187
220
  """Set a provider by interface."""
188
221
  self._providers[provider.interface] = provider
189
222
  if provider.is_resource:
190
- self._resource_cache[provider.scope].append(provider.interface)
223
+ self._resources[provider.scope].append(provider.interface)
191
224
 
192
225
  def _delete_provider(self, provider: Provider) -> None:
193
226
  """Delete a provider."""
194
227
  if provider.interface in self._providers:
195
228
  del self._providers[provider.interface]
196
229
  if provider.is_resource:
197
- self._resource_cache[provider.scope].remove(provider.interface)
230
+ self._resources[provider.scope].remove(provider.interface)
198
231
 
199
232
  def _validate_sub_providers(self, provider: Provider) -> None:
200
233
  """Validate the sub-providers of a provider."""
@@ -268,7 +301,27 @@ class Container:
268
301
  self, module: Module | type[Module] | Callable[[Container], None] | str
269
302
  ) -> None:
270
303
  """Register a module as a callable, module type, or module instance."""
271
- self._modules.register(module)
304
+ # Callable Module
305
+ if inspect.isfunction(module):
306
+ module(self)
307
+ return
308
+
309
+ # Module path
310
+ if isinstance(module, str):
311
+ module = import_string(module)
312
+
313
+ # Class based Module or Module type
314
+ if inspect.isclass(module) and issubclass(module, Module):
315
+ module = module()
316
+
317
+ if isinstance(module, Module):
318
+ module.configure(self)
319
+ for provider_name, decorator_args in module.providers:
320
+ obj = getattr(module, provider_name)
321
+ self.provider(
322
+ scope=decorator_args.scope,
323
+ override=decorator_args.override,
324
+ )(obj)
272
325
 
273
326
  def __enter__(self) -> Self:
274
327
  """Enter the singleton context."""
@@ -280,23 +333,33 @@ class Container:
280
333
  exc_type: type[BaseException] | None,
281
334
  exc_val: BaseException | None,
282
335
  exc_tb: types.TracebackType | None,
283
- ) -> bool:
336
+ ) -> Any:
284
337
  """Exit the singleton context."""
285
338
  return self._singleton_context.__exit__(exc_type, exc_val, exc_tb)
286
339
 
287
340
  def start(self) -> None:
288
341
  """Start the singleton context."""
289
- self._singleton_context.start()
342
+ # Resolve all singleton resources
343
+ for interface in self._resources.get("singleton", []):
344
+ self.resolve(interface)
290
345
 
291
346
  def close(self) -> None:
292
347
  """Close the singleton context."""
293
348
  self._singleton_context.close()
294
349
 
295
350
  @contextlib.contextmanager
296
- def request_context(self) -> Iterator[RequestContext]:
351
+ def request_context(self) -> Iterator[InstanceContext]:
297
352
  """Obtain a context manager for the request-scoped context."""
298
- context = RequestContext(self)
353
+ context = InstanceContext()
354
+
299
355
  token = self._request_context_var.set(context)
356
+
357
+ # Resolve all request resources
358
+ for interface in self._resources.get("request", []):
359
+ if not is_event_type(interface):
360
+ continue
361
+ self.resolve(interface)
362
+
300
363
  with context:
301
364
  yield context
302
365
  self._request_context_var.reset(token)
@@ -317,22 +380,30 @@ class Container:
317
380
 
318
381
  async def astart(self) -> None:
319
382
  """Start the singleton context asynchronously."""
320
- await self._singleton_context.astart()
383
+ for interface in self._resources.get("singleton", []):
384
+ await self.aresolve(interface)
321
385
 
322
386
  async def aclose(self) -> None:
323
387
  """Close the singleton context asynchronously."""
324
388
  await self._singleton_context.aclose()
325
389
 
326
390
  @contextlib.asynccontextmanager
327
- async def arequest_context(self) -> AsyncIterator[RequestContext]:
391
+ async def arequest_context(self) -> AsyncIterator[InstanceContext]:
328
392
  """Obtain an async context manager for the request-scoped context."""
329
- context = RequestContext(self)
393
+ context = InstanceContext()
394
+
330
395
  token = self._request_context_var.set(context)
396
+
397
+ for interface in self._resources.get("request", []):
398
+ if not is_event_type(interface):
399
+ continue
400
+ await self.aresolve(interface)
401
+
331
402
  async with context:
332
403
  yield context
333
404
  self._request_context_var.reset(token)
334
405
 
335
- def _get_request_context(self) -> RequestContext:
406
+ def _get_request_context(self) -> InstanceContext:
336
407
  """Get the current request context."""
337
408
  request_context = self._request_context_var.get()
338
409
  if request_context is None:
@@ -346,49 +417,248 @@ class Container:
346
417
  def reset(self) -> None:
347
418
  """Reset resolved instances."""
348
419
  for interface, provider in self._providers.items():
420
+ if provider.scope == "transient":
421
+ continue
349
422
  try:
350
- scoped_context = self._get_scoped_context(provider.scope)
423
+ context = self._get_scoped_context(provider.scope)
351
424
  except LookupError:
352
425
  continue
353
- if isinstance(scoped_context, ResourceScopedContext):
354
- scoped_context.delete(interface)
426
+ del context[interface]
355
427
 
356
428
  @overload
357
- def resolve(self, interface: Interface[T]) -> T: ...
429
+ def resolve(self, interface: type[T]) -> T: ...
358
430
 
359
431
  @overload
360
432
  def resolve(self, interface: T) -> T: ...
361
433
 
362
- def resolve(self, interface: Interface[T]) -> T:
434
+ def resolve(self, interface: type[T]) -> T:
363
435
  """Resolve an instance by interface."""
364
436
  if interface in self._override_instances:
365
437
  return cast(T, self._override_instances[interface])
366
438
 
367
439
  provider = self._get_or_register_provider(interface)
368
- scoped_context = self._get_scoped_context(provider.scope)
369
- instance, created = scoped_context.get_or_create(provider)
440
+ if provider.scope == "transient":
441
+ instance, created = self._create_instance(provider), True
442
+ else:
443
+ context = self._get_scoped_context(provider.scope)
444
+ if provider.scope == "singleton":
445
+ with self._singleton_lock:
446
+ instance, created = self._get_or_create_instance(
447
+ provider, context=context
448
+ )
449
+ else:
450
+ instance, created = self._get_or_create_instance(
451
+ provider, context=context
452
+ )
370
453
  if self.testing and created:
371
454
  self._patch_test_resolver(instance)
372
455
  return cast(T, instance)
373
456
 
374
457
  @overload
375
- async def aresolve(self, interface: Interface[T]) -> T: ...
458
+ async def aresolve(self, interface: type[T]) -> T: ...
376
459
 
377
460
  @overload
378
461
  async def aresolve(self, interface: T) -> T: ...
379
462
 
380
- async def aresolve(self, interface: Interface[T]) -> T:
463
+ async def aresolve(self, interface: type[T]) -> T:
381
464
  """Resolve an instance by interface asynchronously."""
382
465
  if interface in self._override_instances:
383
466
  return cast(T, self._override_instances[interface])
384
467
 
385
468
  provider = self._get_or_register_provider(interface)
386
- scoped_context = self._get_scoped_context(provider.scope)
387
- instance, created = await scoped_context.aget_or_create(provider)
469
+ if provider.scope == "transient":
470
+ instance, created = await self._acreate_instance(provider), True
471
+ else:
472
+ context = self._get_scoped_context(provider.scope)
473
+ if provider.scope == "singleton":
474
+ async with self._singleton_async_lock:
475
+ instance, created = await self._aget_or_create_instance(
476
+ provider, context=context
477
+ )
478
+ else:
479
+ instance, created = await self._aget_or_create_instance(
480
+ provider, context=context
481
+ )
388
482
  if self.testing and created:
389
483
  self._patch_test_resolver(instance)
390
484
  return cast(T, instance)
391
485
 
486
+ def _get_or_create_instance(
487
+ self, provider: Provider, context: InstanceContext
488
+ ) -> tuple[Any, bool]:
489
+ """Get an instance of a dependency from the scoped context."""
490
+ instance = context.get(provider.interface)
491
+ if instance is None:
492
+ instance = self._create_instance(provider, context=context)
493
+ context.set(provider.interface, instance)
494
+ return instance, True
495
+ return instance, False
496
+
497
+ async def _aget_or_create_instance(
498
+ self, provider: Provider, context: InstanceContext
499
+ ) -> tuple[Any, bool]:
500
+ """Get an async instance of a dependency from the scoped context."""
501
+ instance = context.get(provider.interface)
502
+ if instance is None:
503
+ instance = await self._acreate_instance(provider, context=context)
504
+ context.set(provider.interface, instance)
505
+ return instance, True
506
+ return instance, False
507
+
508
+ def _create_instance(
509
+ self,
510
+ provider: Provider,
511
+ context: InstanceContext | None = None,
512
+ ) -> Any:
513
+ """Create an instance using the provider."""
514
+ if provider.is_async:
515
+ raise TypeError(
516
+ f"The instance for the provider `{provider}` cannot be created in "
517
+ "synchronous mode."
518
+ )
519
+
520
+ args, kwargs = self._get_provided_args(provider, context=context)
521
+
522
+ if provider.is_generator:
523
+ if context is None:
524
+ raise ValueError("The context is required for generator providers.")
525
+ cm = contextlib.contextmanager(provider.call)(*args, **kwargs)
526
+ return context.enter(cm)
527
+
528
+ instance = provider.call(*args, **kwargs)
529
+ if context is not None and is_context_manager(instance):
530
+ context.enter(instance)
531
+ return instance
532
+
533
+ async def _acreate_instance(
534
+ self,
535
+ provider: Provider,
536
+ context: InstanceContext | None = None,
537
+ ) -> Any:
538
+ """Create an instance asynchronously using the provider."""
539
+ args, kwargs = await self._aget_provided_args(provider, context=context)
540
+
541
+ if provider.is_coroutine:
542
+ instance = await provider.call(*args, **kwargs)
543
+ if context is not None and is_async_context_manager(instance):
544
+ await context.aenter(instance)
545
+ return instance
546
+
547
+ if provider.is_async_generator:
548
+ if context is None:
549
+ raise ValueError(
550
+ "The async stack is required for async generator providers."
551
+ )
552
+ cm = contextlib.asynccontextmanager(provider.call)(*args, **kwargs)
553
+ return await context.aenter(cm)
554
+
555
+ if provider.is_generator:
556
+
557
+ def _create() -> Any:
558
+ if context is None:
559
+ raise ValueError("The stack is required for generator providers.")
560
+ cm = contextlib.contextmanager(provider.call)(*args, **kwargs)
561
+ return context.enter(cm)
562
+
563
+ return await run_async(_create)
564
+
565
+ instance = await run_async(provider.call, *args, **kwargs)
566
+ if context is not None and is_async_context_manager(instance):
567
+ await context.aenter(instance)
568
+ return instance
569
+
570
+ def _get_provided_args(
571
+ self,
572
+ provider: Provider,
573
+ context: InstanceContext | None,
574
+ *args: Any,
575
+ **kwargs: Any,
576
+ ) -> tuple[list[Any], dict[str, Any]]:
577
+ """Retrieve the arguments for a provider."""
578
+ provided_args: list[Any] = []
579
+ provided_kwargs: dict[str, Any] = {}
580
+
581
+ for parameter in provider.parameters:
582
+ if parameter.annotation in self._override_instances:
583
+ instance = self._override_instances[parameter.annotation]
584
+ elif context and parameter.annotation in context:
585
+ instance = context[parameter.annotation]
586
+ else:
587
+ try:
588
+ instance = self._resolve_parameter(provider, parameter)
589
+ except LookupError:
590
+ if parameter.default is inspect.Parameter.empty:
591
+ raise
592
+ instance = parameter.default
593
+ else:
594
+ if self.testing:
595
+ instance = DependencyWrapper(
596
+ interface=parameter.annotation, instance=instance
597
+ )
598
+ if parameter.kind == parameter.POSITIONAL_ONLY:
599
+ provided_args.append(instance)
600
+ else:
601
+ provided_kwargs[parameter.name] = instance
602
+ return provided_args, provided_kwargs
603
+
604
+ async def _aget_provided_args(
605
+ self,
606
+ provider: Provider,
607
+ context: InstanceContext | None,
608
+ *args: Any,
609
+ **kwargs: Any,
610
+ ) -> tuple[list[Any], dict[str, Any]]:
611
+ """Asynchronously retrieve the arguments for a provider."""
612
+ provided_args: list[Any] = []
613
+ provided_kwargs: dict[str, Any] = {}
614
+
615
+ for parameter in provider.parameters:
616
+ if parameter.annotation in self._override_instances:
617
+ instance = self._override_instances[parameter.annotation]
618
+ elif context and parameter.annotation in context:
619
+ instance = context[parameter.annotation]
620
+ else:
621
+ try:
622
+ instance = await self._aresolve_parameter(provider, parameter)
623
+ except LookupError:
624
+ if parameter.default is inspect.Parameter.empty:
625
+ raise
626
+ instance = parameter.default
627
+ else:
628
+ if self.testing:
629
+ instance = DependencyWrapper(
630
+ interface=parameter.annotation, instance=instance
631
+ )
632
+ if parameter.kind == parameter.POSITIONAL_ONLY:
633
+ provided_args.append(instance)
634
+ else:
635
+ provided_kwargs[parameter.name] = instance
636
+ return provided_args, provided_kwargs
637
+
638
+ def _resolve_parameter(
639
+ self, provider: Provider, parameter: inspect.Parameter
640
+ ) -> Any:
641
+ self._validate_resolvable_parameter(parameter, call=provider.call)
642
+ return self.resolve(parameter.annotation)
643
+
644
+ async def _aresolve_parameter(
645
+ self, provider: Provider, parameter: inspect.Parameter
646
+ ) -> Any:
647
+ self._validate_resolvable_parameter(parameter, call=provider.call)
648
+ return await self.aresolve(parameter.annotation)
649
+
650
+ def _validate_resolvable_parameter(
651
+ self, parameter: inspect.Parameter, call: Callable[..., Any]
652
+ ) -> None:
653
+ """Ensure that the specified interface is resolved."""
654
+ if parameter.annotation in self._unresolved_interfaces:
655
+ raise LookupError(
656
+ f"You are attempting to get the parameter `{parameter.name}` with the "
657
+ f"annotation `{get_full_qualname(parameter.annotation)}` as a "
658
+ f"dependency into `{get_full_qualname(call)}` which is not registered "
659
+ "or set in the scoped context."
660
+ )
661
+
392
662
  def _patch_test_resolver(self, instance: Any) -> None:
393
663
  """Patch the test resolver for the instance."""
394
664
  if not hasattr(instance, "__dict__"):
@@ -417,28 +687,25 @@ class Container:
417
687
  try:
418
688
  provider = self._get_provider(interface)
419
689
  except LookupError:
420
- pass
421
- else:
422
- scoped_context = self._get_scoped_context(provider.scope)
423
- if isinstance(scoped_context, ResourceScopedContext):
424
- return scoped_context.has(interface)
425
- return False
690
+ return False
691
+ if provider.scope == "transient":
692
+ return False
693
+ context = self._get_scoped_context(provider.scope)
694
+ return interface in context
426
695
 
427
696
  def release(self, interface: AnyInterface) -> None:
428
697
  """Release an instance by interface."""
429
698
  provider = self._get_provider(interface)
430
- scoped_context = self._get_scoped_context(provider.scope)
431
- if isinstance(scoped_context, ResourceScopedContext):
432
- scoped_context.delete(interface)
699
+ if provider.scope == "transient":
700
+ return None
701
+ context = self._get_scoped_context(provider.scope)
702
+ del context[interface]
433
703
 
434
- def _get_scoped_context(self, scope: Scope) -> ScopedContext:
435
- """Get the scoped context based on the specified scope."""
704
+ def _get_scoped_context(self, scope: Scope) -> InstanceContext:
705
+ """Get the instance context for the specified scope."""
436
706
  if scope == "singleton":
437
707
  return self._singleton_context
438
- elif scope == "request":
439
- request_context = self._get_request_context()
440
- return request_context
441
- return self._transient_context
708
+ return self._get_request_context()
442
709
 
443
710
  @contextlib.contextmanager
444
711
  def override(self, interface: AnyInterface, instance: Any) -> Iterator[None]:
@@ -477,19 +744,14 @@ class Container:
477
744
  ) -> Callable[[Callable[P, T]], Callable[P, T]] | Callable[P, T]:
478
745
  """Decorator to inject dependencies into a callable."""
479
746
 
480
- def decorator(
481
- call: Callable[P, T],
482
- ) -> Callable[P, T]:
747
+ def decorator(call: Callable[P, T]) -> Callable[P, T]:
483
748
  return self._inject(call)
484
749
 
485
750
  if func is None:
486
751
  return decorator
487
752
  return decorator(func)
488
753
 
489
- def _inject(
490
- self,
491
- call: Callable[P, T],
492
- ) -> Callable[P, T]:
754
+ def _inject(self, call: Callable[P, T]) -> Callable[P, T]:
493
755
  """Inject dependencies into a callable."""
494
756
  if call in self._inject_cache:
495
757
  return cast(Callable[P, T], self._inject_cache[call])
@@ -563,12 +825,97 @@ class Container:
563
825
  def scan(
564
826
  self,
565
827
  /,
566
- packages: types.ModuleType | str | Iterable[types.ModuleType | str],
828
+ packages: ModuleType | str | Iterable[ModuleType | str],
567
829
  *,
568
830
  tags: Iterable[str] | None = None,
569
831
  ) -> None:
570
832
  """Scan packages or modules for decorated members and inject dependencies."""
571
- self._scanner.scan(packages, tags=tags)
833
+ dependencies: list[Dependency] = []
834
+
835
+ if isinstance(packages, Iterable) and not isinstance(packages, str):
836
+ scan_packages: Iterable[ModuleType | str] = packages
837
+ else:
838
+ scan_packages = cast(Iterable[Union[ModuleType, str]], [packages])
839
+
840
+ for package in scan_packages:
841
+ dependencies.extend(self._scan_package(package, tags=tags))
842
+
843
+ for dependency in dependencies:
844
+ decorator = self.inject()(dependency.member)
845
+ setattr(dependency.module, dependency.member.__name__, decorator)
846
+
847
+ def _scan_package(
848
+ self,
849
+ package: ModuleType | str,
850
+ *,
851
+ tags: Iterable[str] | None = None,
852
+ ) -> list[Dependency]:
853
+ """Scan a package or module for decorated members."""
854
+ tags = tags or []
855
+ if isinstance(package, str):
856
+ package = importlib.import_module(package)
857
+
858
+ package_path = getattr(package, "__path__", None)
859
+
860
+ if not package_path:
861
+ return self._scan_module(package, tags=tags)
862
+
863
+ dependencies: list[Dependency] = []
864
+
865
+ for module_info in pkgutil.walk_packages(
866
+ path=package_path, prefix=package.__name__ + "."
867
+ ):
868
+ module = importlib.import_module(module_info.name)
869
+ dependencies.extend(self._scan_module(module, tags=tags))
870
+
871
+ return dependencies
872
+
873
+ def _scan_module(
874
+ self, module: ModuleType, *, tags: Iterable[str]
875
+ ) -> list[Dependency]:
876
+ """Scan a module for decorated members."""
877
+ dependencies: list[Dependency] = []
878
+
879
+ for _, member in inspect.getmembers(module):
880
+ if getattr(member, "__module__", None) != module.__name__ or not callable(
881
+ member
882
+ ):
883
+ continue
884
+
885
+ decorator_args: InjectableDecoratorArgs = getattr(
886
+ member,
887
+ "__injectable__",
888
+ InjectableDecoratorArgs(wrapped=False, tags=[]),
889
+ )
890
+
891
+ if tags and (
892
+ decorator_args.tags
893
+ and not set(decorator_args.tags).intersection(tags)
894
+ or not decorator_args.tags
895
+ ):
896
+ continue
897
+
898
+ if decorator_args.wrapped:
899
+ dependencies.append(
900
+ self._create_dependency(member=member, module=module)
901
+ )
902
+ continue
903
+
904
+ # Get by Marker
905
+ for parameter in get_typed_parameters(member):
906
+ if is_marker(parameter.default):
907
+ dependencies.append(
908
+ self._create_dependency(member=member, module=module)
909
+ )
910
+ continue
911
+
912
+ return dependencies
913
+
914
+ def _create_dependency(self, member: Any, module: ModuleType) -> Dependency:
915
+ """Create a `Dependency` object from the scanned member and module."""
916
+ if hasattr(member, "__wrapped__"):
917
+ member = member.__wrapped__
918
+ return Dependency(member=member, module=module)
572
919
 
573
920
 
574
921
  def transient(target: T) -> T:
@@ -587,3 +934,51 @@ def singleton(target: T) -> T:
587
934
  """Decorator for marking a class as singleton scope."""
588
935
  setattr(target, "__scope__", "singleton")
589
936
  return target
937
+
938
+
939
+ def provider(
940
+ *, scope: Scope, override: bool = False
941
+ ) -> Callable[[Callable[Concatenate[M, P], T]], Callable[Concatenate[M, P], T]]:
942
+ """Decorator for marking a function or method as a provider in a AnyDI module."""
943
+
944
+ def decorator(
945
+ target: Callable[Concatenate[M, P], T],
946
+ ) -> Callable[Concatenate[M, P], T]:
947
+ setattr(
948
+ target,
949
+ "__provider__",
950
+ ProviderDecoratorArgs(scope=scope, override=override),
951
+ )
952
+ return target
953
+
954
+ return decorator
955
+
956
+
957
+ @overload
958
+ def injectable(func: Callable[P, T]) -> Callable[P, T]: ...
959
+
960
+
961
+ @overload
962
+ def injectable(
963
+ *, tags: Iterable[str] | None = None
964
+ ) -> Callable[[Callable[P, T]], Callable[P, T]]: ...
965
+
966
+
967
+ def injectable(
968
+ func: Callable[P, T] | None = None,
969
+ tags: Iterable[str] | None = None,
970
+ ) -> Callable[[Callable[P, T]], Callable[P, T]] | Callable[P, T]:
971
+ """Decorator for marking a function or method as requiring dependency injection."""
972
+
973
+ def decorator(inner: Callable[P, T]) -> Callable[P, T]:
974
+ setattr(
975
+ inner,
976
+ "__injectable__",
977
+ InjectableDecoratorArgs(wrapped=True, tags=tags),
978
+ )
979
+ return inner
980
+
981
+ if func is None:
982
+ return decorator
983
+
984
+ return decorator(func)