anydi 0.70.2__tar.gz → 0.72.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {anydi-0.70.2 → anydi-0.72.0}/PKG-INFO +1 -1
  2. {anydi-0.70.2 → anydi-0.72.0}/anydi/_container.py +20 -3
  3. {anydi-0.70.2 → anydi-0.72.0}/anydi/_decorators.py +13 -0
  4. {anydi-0.70.2 → anydi-0.72.0}/anydi/_graph.py +4 -5
  5. {anydi-0.70.2 → anydi-0.72.0}/anydi/_resolver.py +23 -18
  6. anydi-0.72.0/anydi/_scanner.py +298 -0
  7. {anydi-0.70.2 → anydi-0.72.0}/anydi/ext/typer.py +3 -1
  8. {anydi-0.70.2 → anydi-0.72.0}/pyproject.toml +1 -1
  9. anydi-0.70.2/anydi/_scanner.py +0 -160
  10. {anydi-0.70.2 → anydi-0.72.0}/README.md +0 -0
  11. {anydi-0.70.2 → anydi-0.72.0}/anydi/__init__.py +0 -0
  12. {anydi-0.70.2 → anydi-0.72.0}/anydi/_async_lock.py +0 -0
  13. {anydi-0.70.2 → anydi-0.72.0}/anydi/_cli.py +0 -0
  14. {anydi-0.70.2 → anydi-0.72.0}/anydi/_context.py +0 -0
  15. {anydi-0.70.2 → anydi-0.72.0}/anydi/_injector.py +0 -0
  16. {anydi-0.70.2 → anydi-0.72.0}/anydi/_marker.py +0 -0
  17. {anydi-0.70.2 → anydi-0.72.0}/anydi/_module.py +0 -0
  18. {anydi-0.70.2 → anydi-0.72.0}/anydi/_provider.py +0 -0
  19. {anydi-0.70.2 → anydi-0.72.0}/anydi/_types.py +0 -0
  20. {anydi-0.70.2 → anydi-0.72.0}/anydi/ext/__init__.py +0 -0
  21. {anydi-0.70.2 → anydi-0.72.0}/anydi/ext/django/__init__.py +0 -0
  22. {anydi-0.70.2 → anydi-0.72.0}/anydi/ext/fastapi.py +0 -0
  23. {anydi-0.70.2 → anydi-0.72.0}/anydi/ext/faststream.py +0 -0
  24. {anydi-0.70.2 → anydi-0.72.0}/anydi/ext/pydantic_settings.py +0 -0
  25. {anydi-0.70.2 → anydi-0.72.0}/anydi/ext/pytest_plugin.py +0 -0
  26. {anydi-0.70.2 → anydi-0.72.0}/anydi/ext/starlette/__init__.py +0 -0
  27. {anydi-0.70.2 → anydi-0.72.0}/anydi/ext/starlette/middleware.py +0 -0
  28. {anydi-0.70.2 → anydi-0.72.0}/anydi/py.typed +0 -0
  29. {anydi-0.70.2 → anydi-0.72.0}/anydi/testing.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: anydi
3
- Version: 0.70.2
3
+ Version: 0.72.0
4
4
  Summary: Dependency Injection library
5
5
  Keywords: dependency injection,dependencies,di,async,asyncio,application
6
6
  Author: Anton Ruhlov
@@ -58,6 +58,7 @@ class Container:
58
58
 
59
59
  self._resources: dict[str, list[Any]] = defaultdict(list)
60
60
  self._aliases: dict[Any, Any] = {} # alias_type → canonical_type
61
+ self._aliases_by_canonical: dict[Any, set[Any]] = defaultdict(set)
61
62
  self._singleton_context = InstanceContext()
62
63
  self._scoped_context: dict[str, ContextVar[InstanceContext]] = {}
63
64
 
@@ -405,12 +406,22 @@ class Container:
405
406
  f"Alias `{type_repr(alias_type)}` is already registered "
406
407
  f"for `{type_repr(self._aliases[alias_type])}`."
407
408
  )
409
+ if canonical_type not in self._providers:
410
+ raise ValueError(
411
+ f"Cannot create alias: provider for `{type_repr(canonical_type)}` "
412
+ "is not registered. Register the provider first."
413
+ )
408
414
  self._aliases[alias_type] = canonical_type
415
+ self._aliases_by_canonical[canonical_type].add(alias_type)
409
416
 
410
417
  def _resolve_alias(self, dependency_type: Any) -> Any:
411
418
  """Resolve an alias to its canonical type."""
412
419
  return self._aliases.get(dependency_type, dependency_type)
413
420
 
421
+ def get_aliases_for(self, canonical_type: Any, /) -> set[Any]:
422
+ """Get all aliases that point to the given canonical type."""
423
+ return self._aliases_by_canonical.get(canonical_type, set())
424
+
414
425
  def is_registered(self, dependency_type: Any, /) -> bool:
415
426
  """Check if a provider is registered for the specified dependency type."""
416
427
  canonical = self._resolve_alias(dependency_type)
@@ -722,20 +733,23 @@ class Container:
722
733
  # Check if it's a @provided class
723
734
  if inspect.isclass(dependency_type) and is_provided(dependency_type):
724
735
  provided_scope = dependency_type.__provided__["scope"]
736
+ from_context = dependency_type.__provided__.get(
737
+ "from_context", False
738
+ )
725
739
 
726
740
  # Auto-register @provided class
727
741
  dep_provider = self._register_provider(
728
742
  dependency_type,
729
743
  dependency_type,
730
744
  provided_scope,
731
- False,
745
+ from_context,
732
746
  False,
733
747
  None,
734
748
  )
735
749
  # Register aliases if specified
736
750
  aliases = to_list(dependency_type.__provided__.get("alias"))
737
751
  for alias_type in aliases:
738
- self._aliases[alias_type] = dependency_type
752
+ self.alias(alias_type, dependency_type)
739
753
  # Recursively ensure the @provided class is resolved
740
754
  dep_provider = self._ensure_provider_resolved(
741
755
  dep_provider, resolving
@@ -833,6 +847,9 @@ class Container:
833
847
  del self._providers[provider.dependency_type]
834
848
  if provider.is_resource:
835
849
  self._resources[provider.scope].remove(provider.dependency_type)
850
+ # Remove aliases pointing to this provider
851
+ for alias in self._aliases_by_canonical.pop(provider.dependency_type, set()):
852
+ self._aliases.pop(alias, None)
836
853
 
837
854
  # == Instance Resolution ==
838
855
 
@@ -1039,7 +1056,7 @@ class Container:
1039
1056
  # Register aliases if specified
1040
1057
  provided_meta = param_dependency_type.__provided__
1041
1058
  for alias_type in to_list(provided_meta.get("alias")):
1042
- self._aliases[alias_type] = param_dependency_type
1059
+ self.alias(alias_type, param_dependency_type)
1043
1060
  elif param.has_default:
1044
1061
  # Has default, can be missing
1045
1062
  resolved_params.append(param)
@@ -34,12 +34,22 @@ class ProvidedMetadata(TypedDict):
34
34
  from_context: NotRequired[bool]
35
35
 
36
36
 
37
+ def _check_already_provided(cls: type) -> None:
38
+ """Check if class already has __provided__ defined directly on it."""
39
+ if "__provided__" in cls.__dict__:
40
+ raise TypeError(
41
+ f"Class `{cls.__name__}` already has `__provided__` defined. "
42
+ "Remove the duplicate scope decorator or manual `__provided__` attribute."
43
+ )
44
+
45
+
37
46
  def provided(
38
47
  *, scope: Scope, alias: Any = NOT_SET, from_context: bool = False
39
48
  ) -> Callable[[ClassT], ClassT]:
40
49
  """Decorator for marking a class as provided by AnyDI with a specific scope."""
41
50
 
42
51
  def decorator(cls: ClassT) -> ClassT:
52
+ _check_already_provided(cls)
43
53
  metadata: ProvidedMetadata = {"scope": scope}
44
54
  if alias is not NOT_SET:
45
55
  metadata["alias"] = alias
@@ -67,6 +77,7 @@ def singleton(
67
77
  """Decorator for marking a class as a singleton dependency."""
68
78
 
69
79
  def decorator(c: ClassT) -> ClassT:
80
+ _check_already_provided(c)
70
81
  metadata: ProvidedMetadata = {"scope": "singleton"}
71
82
  if alias is not NOT_SET:
72
83
  metadata["alias"] = alias
@@ -95,6 +106,7 @@ def transient(
95
106
  """Decorator for marking a class as a transient dependency."""
96
107
 
97
108
  def decorator(c: ClassT) -> ClassT:
109
+ _check_already_provided(c)
98
110
  metadata: ProvidedMetadata = {"scope": "transient"}
99
111
  if alias is not NOT_SET:
100
112
  metadata["alias"] = alias
@@ -127,6 +139,7 @@ def request(
127
139
  """Decorator for marking a class as a request-scoped dependency."""
128
140
 
129
141
  def decorator(c: ClassT) -> ClassT:
142
+ _check_already_provided(c)
130
143
  metadata: ProvidedMetadata = {"scope": "request"}
131
144
  if alias is not NOT_SET:
132
145
  metadata["alias"] = alias
@@ -21,11 +21,10 @@ class Graph:
21
21
 
22
22
  def _get_aliases_for(self, dependency_type: Any) -> list[str]:
23
23
  """Get list of alias names that point to a dependency type."""
24
- aliases: list[str] = []
25
- for alias, canonical in self._container.aliases.items():
26
- if canonical == dependency_type:
27
- aliases.append(type_repr(alias).rsplit(".", 1)[-1])
28
- return aliases
24
+ return [
25
+ type_repr(alias).rsplit(".", 1)[-1]
26
+ for alias in self._container.get_aliases_for(dependency_type)
27
+ ]
29
28
 
30
29
  def draw(
31
30
  self,
@@ -53,22 +53,20 @@ class Resolver:
53
53
  def add_override(self, dependency_type: Any, instance: Any) -> None:
54
54
  """Add an override for a type, its canonical type, and all aliases."""
55
55
  self._overrides[dependency_type] = instance
56
- canonical = self._container.aliases.get(dependency_type)
57
- if canonical is not None:
58
- self._overrides[canonical] = instance
59
- for alias, canon in self._container.aliases.items():
60
- if canon == dependency_type:
61
- self._overrides[alias] = instance
56
+ canonical_type = self._container.aliases.get(dependency_type)
57
+ if canonical_type is not None:
58
+ self._overrides[canonical_type] = instance
59
+ for alias in self._container.get_aliases_for(dependency_type):
60
+ self._overrides[alias] = instance
62
61
 
63
62
  def remove_override(self, dependency_type: Any) -> None:
64
63
  """Remove an override for a type, its canonical type, and all aliases."""
65
64
  self._overrides.pop(dependency_type, None)
66
- canonical = self._container.aliases.get(dependency_type)
67
- if canonical is not None:
68
- self._overrides.pop(canonical, None)
69
- for alias, canon in self._container.aliases.items():
70
- if canon == dependency_type:
71
- self._overrides.pop(alias, None)
65
+ canonical_type = self._container.aliases.get(dependency_type)
66
+ if canonical_type is not None:
67
+ self._overrides.pop(canonical_type, None)
68
+ for alias in self._container.get_aliases_for(dependency_type):
69
+ self._overrides.pop(alias, None)
72
70
 
73
71
  def clear_caches(self) -> None:
74
72
  """Clear all cached resolvers."""
@@ -103,7 +101,11 @@ class Resolver:
103
101
  for param in provider.parameters:
104
102
  if param.provider is not None:
105
103
  # Look up the current provider to handle overrides
106
- current_provider = self._container.providers.get(param.dependency_type)
104
+ # Resolve alias to canonical type if needed
105
+ canonical_type = self._container.aliases.get(
106
+ param.dependency_type, param.dependency_type
107
+ )
108
+ current_provider = self._container.providers.get(canonical_type)
107
109
  if current_provider is not None:
108
110
  self.compile(current_provider, is_async=is_async)
109
111
  else:
@@ -117,10 +119,9 @@ class Resolver:
117
119
  # Store the compiled functions in the cache
118
120
  cache[provider.dependency_type] = compiled
119
121
 
120
- # Also store under all aliases that point to this type
121
- for alias, canonical in self._container.aliases.items():
122
- if canonical == provider.dependency_type:
123
- cache[alias] = compiled
122
+ # Also store under aliases that point to this canonical type
123
+ for alias in self._container.get_aliases_for(provider.dependency_type):
124
+ cache[alias] = compiled
124
125
 
125
126
  return compiled
126
127
 
@@ -198,7 +199,11 @@ class Resolver:
198
199
 
199
200
  if param.provider is not None:
200
201
  # Look up the current provider from the container to handle overrides
201
- current_provider = self._container.providers.get(param.dependency_type)
202
+ # Resolve alias to canonical type if needed
203
+ canonical_type = self._container.aliases.get(
204
+ param.dependency_type, param.dependency_type
205
+ )
206
+ current_provider = self._container.providers.get(canonical_type)
202
207
  if current_provider is not None:
203
208
  compiled = cache.get(current_provider.dependency_type)
204
209
  else:
@@ -0,0 +1,298 @@
1
+ from __future__ import annotations
2
+
3
+ import importlib
4
+ import inspect
5
+ import pkgutil
6
+ from collections.abc import Iterable, Iterator
7
+ from dataclasses import dataclass
8
+ from types import ModuleType
9
+ from typing import TYPE_CHECKING, Any
10
+
11
+ from ._decorators import Provided, is_injectable, is_provided
12
+ from ._types import to_list
13
+
14
+ if TYPE_CHECKING:
15
+ from ._container import Container
16
+
17
+ Package = ModuleType | str
18
+ PackageOrIterable = Package | Iterable[Package]
19
+
20
+
21
+ @dataclass(kw_only=True)
22
+ class ScannedDependency:
23
+ member: Any
24
+ module: ModuleType
25
+
26
+ def __post_init__(self) -> None:
27
+ # Unwrap decorated functions if necessary
28
+ if hasattr(self.member, "__wrapped__"):
29
+ self.member = self.member.__wrapped__
30
+
31
+
32
+ class Scanner:
33
+ _scanning_packages: set[str] = set()
34
+
35
+ def __init__(self, container: Container) -> None:
36
+ self._container = container
37
+ self._importing_modules: set[str] = set()
38
+
39
+ def scan(
40
+ self,
41
+ /,
42
+ packages: PackageOrIterable,
43
+ *,
44
+ tags: Iterable[str] | None = None,
45
+ ignore: PackageOrIterable | None = None,
46
+ ) -> None:
47
+ """Scan packages or modules for decorated members and inject dependencies.
48
+
49
+ Supports relative package paths (like Python's relative imports):
50
+ - "." scans the caller's package
51
+ - ".submodule" scans a submodule of the caller's package
52
+ - ".." scans the parent package
53
+ - "..sibling" scans a sibling package
54
+ """
55
+ if isinstance(packages, (ModuleType, str)):
56
+ packages = [packages]
57
+
58
+ # Resolve relative package paths
59
+ caller_package = self._get_caller_package(packages, ignore)
60
+ packages = self._resolve_relative_packages(packages, caller_package)
61
+ ignore = self._resolve_relative_packages(ignore, caller_package)
62
+
63
+ pkg_names = {p if isinstance(p, str) else p.__name__ for p in packages}
64
+ overlap = pkg_names & Scanner._scanning_packages
65
+ if overlap:
66
+ raise RuntimeError(
67
+ f"Circular import detected: scan() called recursively!\n\n"
68
+ f"Already scanning packages: {', '.join(sorted(overlap))}\n\n"
69
+ "This happens when a scanned module triggers container creation "
70
+ "(e.g., via lazy proxy).\n\n"
71
+ "Solutions:\n"
72
+ "- Add the problematic module to scan() ignore list\n"
73
+ "- Move container imports inside functions (lazy import)\n"
74
+ "- Avoid lazy container initialization in scanned modules"
75
+ )
76
+
77
+ Scanner._scanning_packages.update(pkg_names)
78
+ try:
79
+ self._do_scan(packages, tags=tags, ignore=ignore)
80
+ finally:
81
+ Scanner._scanning_packages -= pkg_names
82
+
83
+ def _do_scan( # noqa: C901
84
+ self,
85
+ packages: PackageOrIterable,
86
+ *,
87
+ tags: Iterable[str] | None = None,
88
+ ignore: PackageOrIterable | None = None,
89
+ ) -> None:
90
+ """Internal scan implementation."""
91
+ if isinstance(packages, (ModuleType, str)):
92
+ packages = [packages]
93
+
94
+ tags_set: set[str] = set(tags) if tags else set()
95
+ ignore_prefixes = self._normalize_ignore(ignore)
96
+ provided_classes: list[type[Provided]] = []
97
+ injectable_dependencies: list[ScannedDependency] = []
98
+
99
+ # Single pass: collect both @provided classes and @injectable functions
100
+ for module in self._iter_modules(packages, ignore_prefixes=ignore_prefixes):
101
+ module_name = module.__name__
102
+ for name, member in vars(module).items():
103
+ if name.startswith("_"):
104
+ continue
105
+ if getattr(member, "__module__", None) != module_name:
106
+ continue
107
+
108
+ if inspect.isclass(member) and is_provided(member):
109
+ provided_classes.append(member)
110
+ elif callable(member) and is_injectable(member):
111
+ member_tags = set(member.__injectable__["tags"] or ())
112
+ if not tags_set or (tags_set & member_tags):
113
+ injectable_dependencies.append(
114
+ ScannedDependency(member=member, module=module)
115
+ )
116
+
117
+ # First: register @provided classes
118
+ for cls in provided_classes:
119
+ if not self._container.is_registered(cls):
120
+ scope = cls.__provided__["scope"]
121
+ from_context = cls.__provided__.get("from_context", False)
122
+ self._container.register(
123
+ cls, cls, scope=scope, from_context=from_context
124
+ )
125
+ # Create aliases if specified (alias → cls)
126
+ for alias_type in to_list(cls.__provided__.get("alias")):
127
+ self._container.alias(alias_type, cls)
128
+
129
+ # Second: inject @injectable functions
130
+ for dependency in injectable_dependencies:
131
+ decorated = self._container.inject()(dependency.member)
132
+ setattr(dependency.module, dependency.member.__name__, decorated)
133
+
134
+ def _has_relative_packages(self, *package_lists: PackageOrIterable | None) -> bool:
135
+ """Check if any package list contains relative paths."""
136
+ for packages in package_lists:
137
+ if packages is None:
138
+ continue
139
+ if isinstance(packages, str):
140
+ if packages.startswith("."):
141
+ return True
142
+ elif isinstance(packages, ModuleType):
143
+ continue
144
+ else:
145
+ for p in packages:
146
+ if isinstance(p, str) and p.startswith("."):
147
+ return True
148
+ return False
149
+
150
+ def _get_caller_package(
151
+ self,
152
+ packages: Iterable[Package],
153
+ ignore: PackageOrIterable | None,
154
+ ) -> str | None:
155
+ """Get the package name of the module that called scan()."""
156
+ if not self._has_relative_packages(packages, ignore):
157
+ return None
158
+
159
+ frame = inspect.currentframe()
160
+ try:
161
+ while frame is not None:
162
+ frame = frame.f_back
163
+ if frame is None:
164
+ break
165
+ module_name = frame.f_globals.get("__name__")
166
+ if module_name and not module_name.startswith("anydi"):
167
+ # Return package portion (remove module name if present)
168
+ if "." in module_name:
169
+ return module_name.rsplit(".", 1)[0]
170
+ return module_name
171
+ finally:
172
+ del frame
173
+
174
+ raise ValueError(
175
+ "Cannot use relative package paths: unable to determine caller package. "
176
+ "Use absolute package names instead."
177
+ )
178
+
179
+ def _resolve_relative_name(self, relative_name: str, base_package: str) -> str:
180
+ """Resolve a relative package name to absolute."""
181
+ num_dots = len(relative_name) - len(relative_name.lstrip("."))
182
+ remainder = relative_name[num_dots:]
183
+
184
+ package_parts = base_package.split(".")
185
+
186
+ # Navigate up for parent references (..)
187
+ if num_dots > 1:
188
+ levels_up = num_dots - 1
189
+ if levels_up >= len(package_parts):
190
+ raise ValueError(
191
+ f"Cannot resolve '{relative_name}': "
192
+ f"too many parent levels for base package '{base_package}'"
193
+ )
194
+ package_parts = package_parts[:-levels_up]
195
+
196
+ if remainder:
197
+ return ".".join(package_parts) + "." + remainder
198
+ return ".".join(package_parts)
199
+
200
+ def _resolve_relative_packages(
201
+ self,
202
+ packages: PackageOrIterable | None,
203
+ caller_package: str | None,
204
+ ) -> list[Package]:
205
+ """Resolve relative package names to absolute names."""
206
+ if packages is None:
207
+ return []
208
+
209
+ if isinstance(packages, (ModuleType, str)):
210
+ packages = [packages]
211
+
212
+ resolved: list[Package] = []
213
+ for package in packages:
214
+ if isinstance(package, ModuleType):
215
+ resolved.append(package)
216
+ elif not package.startswith("."):
217
+ resolved.append(package)
218
+ else:
219
+ if caller_package is None:
220
+ raise ValueError(
221
+ "Cannot use relative package paths: "
222
+ "unable to determine caller package. "
223
+ "Use absolute package names instead."
224
+ )
225
+ resolved.append(self._resolve_relative_name(package, caller_package))
226
+
227
+ return resolved
228
+
229
+ def _normalize_ignore(self, ignore: PackageOrIterable | None) -> tuple[str, ...]:
230
+ """Normalize ignore parameter to a tuple of module name prefixes."""
231
+ if ignore is None:
232
+ return ()
233
+
234
+ if isinstance(ignore, (ModuleType, str)):
235
+ ignore = [ignore]
236
+
237
+ prefixes: list[str] = []
238
+ for item in ignore:
239
+ name = item.__name__ if isinstance(item, ModuleType) else item
240
+ prefixes.append(name)
241
+ prefixes.append(name + ".") # For startswith check
242
+ return tuple(prefixes)
243
+
244
+ def _should_ignore_module(
245
+ self, module_name: str, ignore_prefixes: tuple[str, ...]
246
+ ) -> bool:
247
+ """Check if a module should be ignored based on ignore prefixes."""
248
+ return module_name.startswith(ignore_prefixes) if ignore_prefixes else False
249
+
250
+ def _iter_modules(
251
+ self, packages: Iterable[Package], *, ignore_prefixes: tuple[str, ...]
252
+ ) -> Iterator[ModuleType]:
253
+ """Iterate over all modules in the given packages."""
254
+ for package in packages:
255
+ if isinstance(package, str):
256
+ package = importlib.import_module(package)
257
+
258
+ # Single module (not a package)
259
+ if not hasattr(package, "__path__"):
260
+ if not self._should_ignore_module(package.__name__, ignore_prefixes):
261
+ yield package
262
+ continue
263
+
264
+ # Package - walk all submodules
265
+ for module_info in pkgutil.walk_packages(
266
+ package.__path__, prefix=package.__name__ + "."
267
+ ):
268
+ if not self._should_ignore_module(module_info.name, ignore_prefixes):
269
+ yield from self._import_module_with_tracking(module_info.name)
270
+
271
+ def _import_module_with_tracking(self, module_name: str) -> Iterator[ModuleType]:
272
+ """Import a module while tracking for circular imports."""
273
+ # Check if we're already importing this module (circular import)
274
+ if module_name in self._importing_modules:
275
+ import_chain = " -> ".join(sorted(self._importing_modules))
276
+ raise RuntimeError(
277
+ f"Circular import detected during container scanning!\n"
278
+ f"Module '{module_name}' is being imported while already "
279
+ f"in the import chain.\n"
280
+ f"Import chain: {import_chain} -> {module_name}\n\n"
281
+ f"This usually happens when:\n"
282
+ f"1. A scanned module imports the container at module level\n"
283
+ f"2. The container creation triggers scanning\n"
284
+ f"3. Scanning tries to import the module again\n\n"
285
+ f"Solutions:\n"
286
+ f"- Add '{module_name}' to the ignore list\n"
287
+ f"- Move container imports inside functions (lazy import)\n"
288
+ f"- Check for modules importing the container module"
289
+ )
290
+
291
+ # Track that we're importing this module
292
+ self._importing_modules.add(module_name)
293
+ try:
294
+ module = importlib.import_module(module_name)
295
+ yield module
296
+ finally:
297
+ # Always cleanup, even if import fails
298
+ self._importing_modules.discard(module_name)
@@ -97,8 +97,10 @@ def _process_callback(callback: Callable[..., Any], container: Container) -> Any
97
97
  processed_parameter = container._injector.unwrap_parameter(parameter)
98
98
  if should_inject:
99
99
  injected_param_names.add(parameter.name)
100
+ # Resolve alias to canonical type if needed
101
+ canonical_type = container.aliases.get(dependency_type, dependency_type)
100
102
  try:
101
- scopes.add(container.providers[dependency_type].scope)
103
+ scopes.add(container.providers[canonical_type].scope)
102
104
  except KeyError:
103
105
  if inspect.isclass(dependency_type) and is_provided(dependency_type):
104
106
  scopes.add(dependency_type.__provided__["scope"])
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "anydi"
3
- version = "0.70.2"
3
+ version = "0.72.0"
4
4
  description = "Dependency Injection library"
5
5
  authors = [{ name = "Anton Ruhlov", email = "antonruhlov@gmail.com" }]
6
6
  requires-python = ">=3.10.0, <3.15"
@@ -1,160 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import importlib
4
- import inspect
5
- import pkgutil
6
- from collections.abc import Callable, Iterable, Iterator
7
- from dataclasses import dataclass
8
- from types import ModuleType
9
- from typing import TYPE_CHECKING, Any
10
-
11
- from ._decorators import Provided, is_injectable, is_provided
12
- from ._types import to_list
13
-
14
- if TYPE_CHECKING:
15
- from ._container import Container
16
-
17
- Package = ModuleType | str
18
- PackageOrIterable = Package | Iterable[Package]
19
-
20
-
21
- @dataclass(kw_only=True)
22
- class ScannedDependency:
23
- member: Any
24
- module: ModuleType
25
-
26
- def __post_init__(self) -> None:
27
- # Unwrap decorated functions if necessary
28
- if hasattr(self.member, "__wrapped__"):
29
- self.member = self.member.__wrapped__
30
-
31
-
32
- class Scanner:
33
- def __init__(self, container: Container) -> None:
34
- self._container = container
35
-
36
- def scan(
37
- self,
38
- /,
39
- packages: PackageOrIterable,
40
- *,
41
- tags: Iterable[str] | None = None,
42
- ignore: PackageOrIterable | None = None,
43
- ) -> None:
44
- """Scan packages or modules for decorated members and inject dependencies."""
45
- if isinstance(packages, (ModuleType, str)):
46
- packages = [packages]
47
-
48
- tags_list = list(tags) if tags else []
49
- ignore_prefixes = self._normalize_ignore(ignore)
50
- provided_classes: list[type[Provided]] = []
51
- injectable_dependencies: list[ScannedDependency] = []
52
-
53
- # Single pass: collect both @provided classes and @injectable functions
54
- for module in self._iter_modules(packages, ignore_prefixes=ignore_prefixes):
55
- provided_classes.extend(self._scan_module_for_provided(module))
56
- injectable_dependencies.extend(
57
- self._scan_module_for_injectable(module, tags=tags_list)
58
- )
59
-
60
- # First: register @provided classes
61
- for cls in provided_classes:
62
- if not self._container.is_registered(cls):
63
- scope = cls.__provided__["scope"]
64
- from_context = cls.__provided__.get("from_context", False)
65
- self._container.register(
66
- cls, cls, scope=scope, from_context=from_context
67
- )
68
- # Create aliases if specified (alias → cls)
69
- for alias_type in to_list(cls.__provided__.get("alias")):
70
- self._container.alias(alias_type, cls)
71
-
72
- # Second: inject @injectable functions
73
- for dependency in injectable_dependencies:
74
- decorated = self._container.inject()(dependency.member)
75
- setattr(dependency.module, dependency.member.__name__, decorated)
76
-
77
- def _normalize_ignore(self, ignore: PackageOrIterable | None) -> list[str]:
78
- """Normalize ignore parameter to a list of module name prefixes."""
79
- if ignore is None:
80
- return []
81
-
82
- if isinstance(ignore, (ModuleType, str)):
83
- ignore = [ignore]
84
-
85
- prefixes: list[str] = []
86
- for item in ignore:
87
- if isinstance(item, ModuleType):
88
- prefixes.append(item.__name__)
89
- else:
90
- prefixes.append(item)
91
- return prefixes
92
-
93
- def _should_ignore_module(
94
- self, module_name: str, ignore_prefixes: list[str]
95
- ) -> bool:
96
- """Check if a module should be ignored based on ignore prefixes."""
97
- for prefix in ignore_prefixes:
98
- if module_name == prefix or module_name.startswith(prefix + "."):
99
- return True
100
- return False
101
-
102
- def _iter_modules(
103
- self, packages: Iterable[Package], *, ignore_prefixes: list[str]
104
- ) -> Iterator[ModuleType]:
105
- """Iterate over all modules in the given packages."""
106
- for package in packages:
107
- if isinstance(package, str):
108
- package = importlib.import_module(package)
109
-
110
- # Single module (not a package)
111
- if not hasattr(package, "__path__"):
112
- if not self._should_ignore_module(package.__name__, ignore_prefixes):
113
- yield package
114
- continue
115
-
116
- # Package - walk all submodules
117
- for module_info in pkgutil.walk_packages(
118
- package.__path__, prefix=package.__name__ + "."
119
- ):
120
- if not self._should_ignore_module(module_info.name, ignore_prefixes):
121
- yield importlib.import_module(module_info.name)
122
-
123
- def _scan_module_for_provided(self, module: ModuleType) -> list[type[Provided]]:
124
- """Scan a module for @provided classes."""
125
- provided_classes: list[type[Provided]] = []
126
-
127
- for _, member in inspect.getmembers(module, predicate=inspect.isclass):
128
- if getattr(member, "__module__", None) != module.__name__:
129
- continue
130
-
131
- if is_provided(member):
132
- provided_classes.append(member)
133
-
134
- return provided_classes
135
-
136
- def _scan_module_for_injectable(
137
- self, module: ModuleType, *, tags: list[str]
138
- ) -> list[ScannedDependency]:
139
- """Scan a module for @injectable functions."""
140
- dependencies: list[ScannedDependency] = []
141
-
142
- for _, member in inspect.getmembers(module, predicate=callable):
143
- if getattr(member, "__module__", None) != module.__name__:
144
- continue
145
-
146
- if self._should_include_member(member, tags=tags):
147
- dependencies.append(ScannedDependency(member=member, module=module))
148
-
149
- return dependencies
150
-
151
- @staticmethod
152
- def _should_include_member(member: Callable[..., Any], *, tags: list[str]) -> bool:
153
- """Determine if a member should be included based on tags or marker defaults."""
154
- if is_injectable(member):
155
- member_tags = set(member.__injectable__["tags"] or [])
156
- if tags:
157
- return bool(set(tags) & member_tags)
158
- return True # No tags passed → include all injectables
159
-
160
- return False
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes