anydi 0.40.0__py3-none-any.whl → 0.42.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
anydi/_container.py CHANGED
@@ -4,110 +4,74 @@ from __future__ import annotations
4
4
 
5
5
  import contextlib
6
6
  import functools
7
- import importlib
8
7
  import inspect
9
8
  import logging
10
- import pkgutil
11
9
  import types
12
10
  import uuid
13
11
  from collections import defaultdict
14
- from collections.abc import AsyncIterator, Iterable, Iterator, Sequence
12
+ from collections.abc import AsyncIterator, Iterable, Iterator
15
13
  from contextvars import ContextVar
16
- from types import ModuleType
17
14
  from typing import Annotated, Any, Callable, TypeVar, cast, overload
18
15
 
19
- from typing_extensions import Concatenate, ParamSpec, Self, final, get_args, get_origin
16
+ from typing_extensions import ParamSpec, Self, get_args, get_origin
20
17
 
18
+ from ._async import run_sync
21
19
  from ._context import InstanceContext
22
- from ._types import (
23
- NOT_SET,
24
- AnyInterface,
25
- Event,
26
- InjectableDecoratorArgs,
27
- InstanceProxy,
20
+ from ._module import ModuleDef, ModuleRegistrar
21
+ from ._provider import (
28
22
  Provider,
29
- ProviderArgs,
30
- ProviderDecoratorArgs,
23
+ ProviderDef,
31
24
  ProviderKind,
32
- ScannedDependency,
33
- Scope,
34
- is_event_type,
35
- is_marker,
36
25
  )
37
- from ._utils import (
38
- get_full_qualname,
26
+ from ._scan import PackageOrIterable, Scanner
27
+ from ._scope import ALLOWED_SCOPES, Scope
28
+ from ._typing import (
29
+ NOT_SET,
30
+ Event,
39
31
  get_typed_annotation,
40
32
  get_typed_parameters,
41
- import_string,
42
33
  is_async_context_manager,
43
34
  is_builtin_type,
44
35
  is_context_manager,
36
+ is_event_type,
45
37
  is_iterator_type,
38
+ is_marker,
46
39
  is_none_type,
47
- run_async,
40
+ type_repr,
48
41
  )
49
42
 
50
43
  T = TypeVar("T", bound=Any)
51
- M = TypeVar("M", bound="Module")
52
44
  P = ParamSpec("P")
53
45
 
54
- ALLOWED_SCOPES: dict[Scope, list[Scope]] = {
55
- "singleton": ["singleton"],
56
- "request": ["request", "singleton"],
57
- "transient": ["transient", "request", "singleton"],
58
- }
59
-
60
-
61
- class ModuleMeta(type):
62
- """A metaclass used for the Module base class."""
63
-
64
- def __new__(cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any]) -> Any:
65
- attrs["providers"] = [
66
- (name, getattr(value, "__provider__"))
67
- for name, value in attrs.items()
68
- if hasattr(value, "__provider__")
69
- ]
70
- return super().__new__(cls, name, bases, attrs)
71
-
72
-
73
- class Module(metaclass=ModuleMeta):
74
- """A base class for defining AnyDI modules."""
75
-
76
- providers: list[tuple[str, ProviderDecoratorArgs]]
77
46
 
78
- def configure(self, container: Container) -> None:
79
- """Configure the AnyDI container with providers and their dependencies."""
80
-
81
-
82
- @final
83
47
  class Container:
84
48
  """AnyDI is a dependency injection container."""
85
49
 
86
50
  def __init__(
87
51
  self,
88
52
  *,
89
- providers: Sequence[ProviderArgs] | None = None,
90
- modules: Sequence[Module | type[Module] | Callable[[Container], None] | str]
91
- | None = None,
53
+ providers: Iterable[ProviderDef] | None = None,
54
+ modules: Iterable[ModuleDef] | None = None,
92
55
  strict: bool = False,
93
56
  default_scope: Scope = "transient",
94
- testing: bool = False,
95
57
  logger: logging.Logger | None = None,
96
58
  ) -> None:
97
59
  self._providers: dict[Any, Provider] = {}
98
60
  self._strict = strict
99
61
  self._default_scope: Scope = default_scope
100
- self._testing = testing
101
62
  self._logger = logger or logging.getLogger(__name__)
102
63
  self._resources: dict[str, list[Any]] = defaultdict(list)
103
64
  self._singleton_context = InstanceContext()
104
65
  self._request_context_var: ContextVar[InstanceContext | None] = ContextVar(
105
66
  "request_context", default=None
106
67
  )
107
- self._override_instances: dict[Any, Any] = {}
108
68
  self._unresolved_interfaces: set[Any] = set()
109
69
  self._inject_cache: dict[Callable[..., Any], Callable[..., Any]] = {}
110
70
 
71
+ # Components
72
+ self._modules = ModuleRegistrar(self)
73
+ self._scanner = Scanner(self)
74
+
111
75
  # Register providers
112
76
  providers = providers or []
113
77
  for provider in providers:
@@ -136,11 +100,6 @@ class Container:
136
100
  """Get the default scope."""
137
101
  return self._default_scope
138
102
 
139
- @property
140
- def testing(self) -> bool:
141
- """Check if testing mode is enabled."""
142
- return self._testing
143
-
144
103
  @property
145
104
  def providers(self) -> dict[type[Any], Provider]:
146
105
  """Get the registered providers."""
@@ -258,7 +217,7 @@ class Container:
258
217
 
259
218
  def register(
260
219
  self,
261
- interface: AnyInterface,
220
+ interface: Any,
262
221
  call: Callable[..., Any],
263
222
  *,
264
223
  scope: Scope,
@@ -267,16 +226,15 @@ class Container:
267
226
  """Register a provider for the specified interface."""
268
227
  return self._register_provider(call, scope, interface, override)
269
228
 
270
- def is_registered(self, interface: AnyInterface) -> bool:
229
+ def is_registered(self, interface: Any) -> bool:
271
230
  """Check if a provider is registered for the specified interface."""
272
231
  return interface in self._providers
273
232
 
274
- def unregister(self, interface: AnyInterface) -> None:
233
+ def unregister(self, interface: Any) -> None:
275
234
  """Unregister a provider by interface."""
276
235
  if not self.is_registered(interface):
277
236
  raise LookupError(
278
- "The provider interface "
279
- f"`{get_full_qualname(interface)}` not registered."
237
+ f"The provider interface `{type_repr(interface)}` not registered."
280
238
  )
281
239
 
282
240
  provider = self._get_provider(interface)
@@ -314,7 +272,7 @@ class Container:
314
272
  **defaults: Any,
315
273
  ) -> Provider:
316
274
  """Register a provider with the specified scope."""
317
- name = get_full_qualname(call)
275
+ name = type_repr(call)
318
276
  kind = ProviderKind.from_call(call)
319
277
  detected_scope = scope
320
278
 
@@ -359,8 +317,7 @@ class Container:
359
317
  # Check for existing provider
360
318
  if interface in self._providers and not override:
361
319
  raise LookupError(
362
- f"The provider interface `{get_full_qualname(interface)}` "
363
- "already registered."
320
+ f"The provider interface `{type_repr(interface)}` already registered."
364
321
  )
365
322
 
366
323
  unresolved_parameter = None
@@ -421,7 +378,7 @@ class Container:
421
378
  else:
422
379
  raise LookupError(
423
380
  f"The provider `{name}` depends on `{unresolved_parameter.name}` "
424
- f"of type `{get_full_qualname(unresolved_parameter.annotation)}`, "
381
+ f"of type `{type_repr(unresolved_parameter.annotation)}`, "
425
382
  "which has not been registered or set. To resolve this, ensure "
426
383
  f"that `{unresolved_parameter.name}` is registered before "
427
384
  f"attempting to use it."
@@ -464,19 +421,19 @@ class Container:
464
421
  "with a transient scope, which is not allowed."
465
422
  )
466
423
 
467
- def _get_provider(self, interface: AnyInterface) -> Provider:
424
+ def _get_provider(self, interface: Any) -> Provider:
468
425
  """Get provider by interface."""
469
426
  try:
470
427
  return self._providers[interface]
471
428
  except KeyError as exc:
472
429
  raise LookupError(
473
- f"The provider interface for `{get_full_qualname(interface)}` has "
430
+ f"The provider interface for `{type_repr(interface)}` has "
474
431
  "not been registered. Please ensure that the provider interface is "
475
432
  "properly registered before attempting to use it."
476
433
  ) from exc
477
434
 
478
435
  def _get_or_register_provider(
479
- self, interface: AnyInterface, parent_scope: Scope | None, /, **defaults: Any
436
+ self, interface: Any, parent_scope: Scope | None, /, **defaults: Any
480
437
  ) -> Provider:
481
438
  """Get or register a provider by interface."""
482
439
  try:
@@ -551,7 +508,7 @@ class Container:
551
508
  """Create an instance by interface asynchronously."""
552
509
  return await self._aresolve_or_create(interface, True, **defaults)
553
510
 
554
- def is_resolved(self, interface: AnyInterface) -> bool:
511
+ def is_resolved(self, interface: Any) -> bool:
555
512
  """Check if an instance by interface exists."""
556
513
  try:
557
514
  provider = self._get_provider(interface)
@@ -562,7 +519,7 @@ class Container:
562
519
  context = self._get_instance_context(provider.scope)
563
520
  return interface in context
564
521
 
565
- def release(self, interface: AnyInterface) -> None:
522
+ def release(self, interface: Any) -> None:
566
523
  """Release an instance by interface."""
567
524
  provider = self._get_provider(interface)
568
525
  if provider.scope == "transient":
@@ -597,9 +554,6 @@ class Container:
597
554
  else self._create_instance(provider, context, **defaults)
598
555
  )
599
556
 
600
- if self.testing:
601
- instance = self._patch_test_resolver(provider.interface, instance)
602
-
603
557
  return cast(T, instance)
604
558
 
605
559
  async def _aresolve_or_create(
@@ -618,9 +572,6 @@ class Container:
618
572
  else await self._acreate_instance(provider, context, **defaults)
619
573
  )
620
574
 
621
- if self.testing:
622
- instance = self._patch_test_resolver(provider.interface, instance)
623
-
624
575
  return cast(T, instance)
625
576
 
626
577
  def _get_or_create_instance(
@@ -695,9 +646,9 @@ class Container:
695
646
  cm = contextlib.contextmanager(provider.call)(**provider_kwargs)
696
647
  return context.enter(cm)
697
648
 
698
- return await run_async(_create)
649
+ return await run_sync(_create)
699
650
 
700
- instance = await run_async(provider.call, **provider_kwargs)
651
+ instance = await run_sync(provider.call, **provider_kwargs)
701
652
  if (
702
653
  context is not None
703
654
  and provider.is_class
@@ -712,7 +663,7 @@ class Container:
712
663
  """Retrieve the arguments for a provider."""
713
664
  provided_kwargs = {}
714
665
  for parameter in provider.parameters:
715
- instance = self._get_provider_instance(
666
+ instance, _ = self._get_provider_instance(
716
667
  provider, parameter, context, **defaults
717
668
  )
718
669
  provided_kwargs[parameter.name] = instance
@@ -725,12 +676,12 @@ class Container:
725
676
  context: InstanceContext | None,
726
677
  /,
727
678
  **defaults: Any,
728
- ) -> Any:
679
+ ) -> tuple[Any, bool]:
729
680
  """Retrieve an instance of a dependency from the scoped context."""
730
681
 
731
682
  # Try to get instance from defaults
732
683
  if parameter.name in defaults:
733
- return defaults[parameter.name]
684
+ return defaults[parameter.name], True
734
685
 
735
686
  # Try to get instance from context
736
687
  elif context and parameter.annotation in context:
@@ -743,12 +694,8 @@ class Container:
743
694
  except LookupError:
744
695
  if parameter.default is inspect.Parameter.empty:
745
696
  raise
746
- return parameter.default
747
-
748
- # Wrap the instance in a proxy for testing
749
- if self.testing:
750
- return InstanceProxy(instance, interface=parameter.annotation)
751
- return instance
697
+ return parameter.default, True
698
+ return instance, False
752
699
 
753
700
  async def _aget_provided_kwargs(
754
701
  self, provider: Provider, context: InstanceContext | None, /, **defaults: Any
@@ -756,7 +703,7 @@ class Container:
756
703
  """Asynchronously retrieve the arguments for a provider."""
757
704
  provided_kwargs = {}
758
705
  for parameter in provider.parameters:
759
- instance = await self._aget_provider_instance(
706
+ instance, _ = await self._aget_provider_instance(
760
707
  provider, parameter, context, **defaults
761
708
  )
762
709
  provided_kwargs[parameter.name] = instance
@@ -769,12 +716,12 @@ class Container:
769
716
  context: InstanceContext | None,
770
717
  /,
771
718
  **defaults: Any,
772
- ) -> Any:
719
+ ) -> tuple[Any, bool]:
773
720
  """Asynchronously retrieve an instance of a dependency from the context."""
774
721
 
775
722
  # Try to get instance from defaults
776
723
  if parameter.name in defaults:
777
- return defaults[parameter.name]
724
+ return defaults[parameter.name], True
778
725
 
779
726
  # Try to get instance from context
780
727
  elif context and parameter.annotation in context:
@@ -787,12 +734,8 @@ class Container:
787
734
  except LookupError:
788
735
  if parameter.default is inspect.Parameter.empty:
789
736
  raise
790
- return parameter.default
791
-
792
- # Wrap the instance in a proxy for testing
793
- if self.testing:
794
- return InstanceProxy(instance, interface=parameter.annotation)
795
- return instance
737
+ return parameter.default, True
738
+ return instance, False
796
739
 
797
740
  def _resolve_parameter(
798
741
  self, provider: Provider, parameter: inspect.Parameter
@@ -813,83 +756,11 @@ class Container:
813
756
  if parameter.annotation in self._unresolved_interfaces:
814
757
  raise LookupError(
815
758
  f"You are attempting to get the parameter `{parameter.name}` with the "
816
- f"annotation `{get_full_qualname(parameter.annotation)}` as a "
817
- f"dependency into `{get_full_qualname(provider.call)}` which is "
759
+ f"annotation `{type_repr(parameter.annotation)}` as a "
760
+ f"dependency into `{type_repr(provider.call)}` which is "
818
761
  "not registered or set in the scoped context."
819
762
  )
820
763
 
821
- @contextlib.contextmanager
822
- def override(self, interface: AnyInterface, instance: Any) -> Iterator[None]:
823
- """
824
- Override the provider for the specified interface with a specific instance.
825
- """
826
- if not self.testing:
827
- raise RuntimeError(
828
- "The `override` method can only be used in testing mode."
829
- )
830
- if not self.is_registered(interface) and self.strict:
831
- raise LookupError(
832
- f"The provider interface `{get_full_qualname(interface)}` "
833
- "not registered."
834
- )
835
- self._override_instances[interface] = instance
836
- try:
837
- yield
838
- finally:
839
- self._override_instances.pop(interface, None)
840
-
841
- ############################
842
- # Testing Methods
843
- ############################
844
-
845
- def _patch_test_resolver(self, interface: type[Any], instance: Any) -> Any:
846
- """Patch the test resolver for the instance."""
847
- if interface in self._override_instances:
848
- return self._override_instances[interface]
849
-
850
- if not hasattr(instance, "__dict__") or hasattr(
851
- instance, "__resolver_getter__"
852
- ):
853
- return instance
854
-
855
- wrapped = {
856
- name: value.interface
857
- for name, value in instance.__dict__.items()
858
- if isinstance(value, InstanceProxy)
859
- }
860
-
861
- def __resolver_getter__(name: str) -> Any:
862
- if name in wrapped:
863
- _interface = wrapped[name]
864
- # Resolve the dependency if it's wrapped
865
- return self.resolve(_interface)
866
- raise LookupError
867
-
868
- # Attach the resolver getter to the instance
869
- instance.__resolver_getter__ = __resolver_getter__
870
-
871
- if not hasattr(instance.__class__, "__getattribute_patched__"):
872
-
873
- def __getattribute__(_self: Any, name: str) -> Any:
874
- # Skip the resolver getter
875
- if name in {"__resolver_getter__", "__class__"}:
876
- return object.__getattribute__(_self, name)
877
-
878
- if hasattr(_self, "__resolver_getter__"):
879
- try:
880
- return _self.__resolver_getter__(name)
881
- except LookupError:
882
- pass
883
-
884
- # Fall back to default behavior
885
- return object.__getattribute__(_self, name)
886
-
887
- # Apply the patched resolver if wrapped attributes exist
888
- instance.__class__.__getattribute__ = __getattribute__
889
- instance.__class__.__getattribute_patched__ = True
890
-
891
- return instance
892
-
893
764
  ############################
894
765
  # Injector Methods
895
766
  ############################
@@ -923,6 +794,9 @@ class Container:
923
794
 
924
795
  injected_params = self._get_injected_params(call)
925
796
 
797
+ if not injected_params:
798
+ return call
799
+
926
800
  if inspect.iscoroutinefunction(call):
927
801
 
928
802
  @functools.wraps(call)
@@ -933,7 +807,7 @@ class Container:
933
807
 
934
808
  self._inject_cache[call] = awrapper
935
809
 
936
- return awrapper # type: ignore[return-value]
810
+ return awrapper # type: ignore
937
811
 
938
812
  @functools.wraps(call)
939
813
  def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
@@ -956,9 +830,9 @@ class Container:
956
830
  except LookupError as exc:
957
831
  if not self.strict:
958
832
  self.logger.debug(
959
- f"Cannot validate the `{get_full_qualname(call)}` parameter "
833
+ f"Cannot validate the `{type_repr(call)}` parameter "
960
834
  f"`{parameter.name}` with an annotation of "
961
- f"`{get_full_qualname(parameter.annotation)} due to being "
835
+ f"`{type_repr(parameter.annotation)} due to being "
962
836
  "in non-strict mode. It will be validated at the first call."
963
837
  )
964
838
  else:
@@ -972,215 +846,29 @@ class Container:
972
846
  """Validate an injected parameter."""
973
847
  if parameter.annotation is inspect.Parameter.empty:
974
848
  raise TypeError(
975
- f"Missing `{get_full_qualname(call)}` parameter "
976
- f"`{parameter.name}` annotation."
849
+ f"Missing `{type_repr(call)}` parameter `{parameter.name}` annotation."
977
850
  )
978
851
 
979
852
  if not self.is_registered(parameter.annotation):
980
853
  raise LookupError(
981
- f"`{get_full_qualname(call)}` has an unknown dependency parameter "
854
+ f"`{type_repr(call)}` has an unknown dependency parameter "
982
855
  f"`{parameter.name}` with an annotation of "
983
- f"`{get_full_qualname(parameter.annotation)}`."
856
+ f"`{type_repr(parameter.annotation)}`."
984
857
  )
985
858
 
986
859
  ############################
987
860
  # Module Methods
988
861
  ############################
989
862
 
990
- def register_module(
991
- self, module: Module | type[Module] | Callable[[Container], None] | str
992
- ) -> None:
863
+ def register_module(self, module: ModuleDef) -> None:
993
864
  """Register a module as a callable, module type, or module instance."""
994
- # Callable Module
995
- if inspect.isfunction(module):
996
- module(self)
997
- return
998
-
999
- # Module path
1000
- if isinstance(module, str):
1001
- module = import_string(module)
1002
-
1003
- # Class based Module or Module type
1004
- if inspect.isclass(module) and issubclass(module, Module):
1005
- module = module()
1006
-
1007
- if isinstance(module, Module):
1008
- module.configure(self)
1009
- for provider_name, decorator_args in module.providers:
1010
- obj = getattr(module, provider_name)
1011
- self.provider(
1012
- scope=decorator_args.scope,
1013
- override=decorator_args.override,
1014
- )(obj)
1015
- else:
1016
- raise TypeError(
1017
- "The module must be a callable, a module type, or a module instance."
1018
- )
865
+ self._modules.register(module)
1019
866
 
1020
867
  ############################
1021
868
  # Scanner Methods
1022
869
  ############################
1023
870
 
1024
871
  def scan(
1025
- self,
1026
- /,
1027
- packages: ModuleType | str | Iterable[ModuleType | str],
1028
- *,
1029
- tags: Iterable[str] | None = None,
872
+ self, /, packages: PackageOrIterable, *, tags: Iterable[str] | None = None
1030
873
  ) -> None:
1031
- """Scan packages or modules for decorated members and inject dependencies."""
1032
- dependencies: list[ScannedDependency] = []
1033
-
1034
- if isinstance(packages, ModuleType | str):
1035
- scan_packages: Iterable[ModuleType | str] = [packages]
1036
- else:
1037
- scan_packages = packages
1038
-
1039
- for package in scan_packages:
1040
- dependencies.extend(self._scan_package(package, tags=tags))
1041
-
1042
- for dependency in dependencies:
1043
- decorator = self.inject()(dependency.member)
1044
- setattr(dependency.module, dependency.member.__name__, decorator)
1045
-
1046
- def _scan_package(
1047
- self,
1048
- package: ModuleType | str,
1049
- *,
1050
- tags: Iterable[str] | None = None,
1051
- ) -> list[ScannedDependency]:
1052
- """Scan a package or module for decorated members."""
1053
- tags = tags or []
1054
- if isinstance(package, str):
1055
- package = importlib.import_module(package)
1056
-
1057
- package_path = getattr(package, "__path__", None)
1058
-
1059
- if not package_path:
1060
- return self._scan_module(package, tags=tags)
1061
-
1062
- dependencies: list[ScannedDependency] = []
1063
-
1064
- for module_info in pkgutil.walk_packages(
1065
- path=package_path, prefix=package.__name__ + "."
1066
- ):
1067
- module = importlib.import_module(module_info.name)
1068
- dependencies.extend(self._scan_module(module, tags=tags))
1069
-
1070
- return dependencies
1071
-
1072
- def _scan_module(
1073
- self, module: ModuleType, *, tags: Iterable[str]
1074
- ) -> list[ScannedDependency]:
1075
- """Scan a module for decorated members."""
1076
- dependencies: list[ScannedDependency] = []
1077
-
1078
- for _, member in inspect.getmembers(module):
1079
- if getattr(member, "__module__", None) != module.__name__ or not callable(
1080
- member
1081
- ):
1082
- continue
1083
-
1084
- decorator_args: InjectableDecoratorArgs = getattr(
1085
- member,
1086
- "__injectable__",
1087
- InjectableDecoratorArgs(wrapped=False, tags=[]),
1088
- )
1089
-
1090
- if tags and (
1091
- decorator_args.tags
1092
- and not set(decorator_args.tags).intersection(tags)
1093
- or not decorator_args.tags
1094
- ):
1095
- continue
1096
-
1097
- if decorator_args.wrapped:
1098
- dependencies.append(
1099
- self._create_scanned_dependency(member=member, module=module)
1100
- )
1101
- continue
1102
-
1103
- # Get by Marker
1104
- for parameter in get_typed_parameters(member):
1105
- if is_marker(parameter.default):
1106
- dependencies.append(
1107
- self._create_scanned_dependency(member=member, module=module)
1108
- )
1109
- continue
1110
-
1111
- return dependencies
1112
-
1113
- def _create_scanned_dependency(
1114
- self, member: Any, module: ModuleType
1115
- ) -> ScannedDependency:
1116
- """Create a `Dependency` object from the scanned member and module."""
1117
- if hasattr(member, "__wrapped__"):
1118
- member = member.__wrapped__
1119
- return ScannedDependency(member=member, module=module)
1120
-
1121
-
1122
- ############################
1123
- # Decorators
1124
- ############################
1125
-
1126
-
1127
- def transient(target: T) -> T:
1128
- """Decorator for marking a class as transient scope."""
1129
- target.__scope__ = "transient"
1130
- return target
1131
-
1132
-
1133
- def request(target: T) -> T:
1134
- """Decorator for marking a class as request scope."""
1135
- target.__scope__ = "request"
1136
- return target
1137
-
1138
-
1139
- def singleton(target: T) -> T:
1140
- """Decorator for marking a class as singleton scope."""
1141
- target.__scope__ = "singleton"
1142
- return target
1143
-
1144
-
1145
- def provider(
1146
- *, scope: Scope, override: bool = False
1147
- ) -> Callable[[Callable[Concatenate[M, P], T]], Callable[Concatenate[M, P], T]]:
1148
- """Decorator for marking a function or method as a provider in a AnyDI module."""
1149
-
1150
- def decorator(
1151
- target: Callable[Concatenate[M, P], T],
1152
- ) -> Callable[Concatenate[M, P], T]:
1153
- target.__provider__ = ProviderDecoratorArgs(scope=scope, override=override) # type: ignore
1154
- return target
1155
-
1156
- return decorator
1157
-
1158
-
1159
- @overload
1160
- def injectable(func: Callable[P, T]) -> Callable[P, T]: ...
1161
-
1162
-
1163
- @overload
1164
- def injectable(
1165
- *, tags: Iterable[str] | None = None
1166
- ) -> Callable[[Callable[P, T]], Callable[P, T]]: ...
1167
-
1168
-
1169
- def injectable(
1170
- func: Callable[P, T] | None = None,
1171
- tags: Iterable[str] | None = None,
1172
- ) -> Callable[[Callable[P, T]], Callable[P, T]] | Callable[P, T]:
1173
- """Decorator for marking a function or method as requiring dependency injection."""
1174
-
1175
- def decorator(inner: Callable[P, T]) -> Callable[P, T]:
1176
- setattr(
1177
- inner,
1178
- "__injectable__",
1179
- InjectableDecoratorArgs(wrapped=True, tags=tags),
1180
- )
1181
- return inner
1182
-
1183
- if func is None:
1184
- return decorator
1185
-
1186
- return decorator(func)
874
+ self._scanner.scan(packages=packages, tags=tags)
anydi/_context.py CHANGED
@@ -7,7 +7,7 @@ from typing import Any
7
7
 
8
8
  from typing_extensions import Self
9
9
 
10
- from ._utils import AsyncRLock, run_async
10
+ from ._async import AsyncRLock, run_sync
11
11
 
12
12
 
13
13
  class InstanceContext:
@@ -78,7 +78,7 @@ class InstanceContext:
78
78
  exc_tb: TracebackType | None,
79
79
  ) -> bool:
80
80
  """Exit the context asynchronously."""
81
- sync_exit = await run_async(self.__exit__, exc_type, exc_val, exc_tb)
81
+ sync_exit = await run_sync(self.__exit__, exc_type, exc_val, exc_tb)
82
82
  async_exit = await self._async_stack.__aexit__(exc_type, exc_val, exc_tb)
83
83
  return bool(sync_exit) or bool(async_exit)
84
84