anydi 0.56.0__py3-none-any.whl → 0.58.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
anydi/_resolver.py CHANGED
@@ -453,7 +453,7 @@ class Resolver:
453
453
  create_lines.append(" context.enter(inst)")
454
454
 
455
455
  create_lines.append(" if context is not None and store:")
456
- create_lines.append(" context.set(_interface, inst)")
456
+ create_lines.append(" context._instances[_interface] = inst")
457
457
 
458
458
  # Wrap instance if in override mode (only for override version)
459
459
  if with_override:
@@ -470,18 +470,28 @@ class Resolver:
470
470
  resolver_lines.append("def _resolver(container, context=None):")
471
471
 
472
472
  # Only define NOT_SET_ if we actually need it
473
- needs_not_set = scope in ("singleton", "request")
473
+ needs_not_set = scope != "transient"
474
474
  if needs_not_set:
475
475
  resolver_lines.append(" NOT_SET_ = _NOT_SET")
476
476
 
477
477
  if scope == "singleton":
478
478
  resolver_lines.append(" if context is None:")
479
479
  resolver_lines.append(" context = container._singleton_context")
480
- elif scope == "request":
481
- resolver_lines.append(" if context is None:")
482
- resolver_lines.append(" context = container._get_request_context()")
483
- else:
480
+ elif scope == "transient":
484
481
  resolver_lines.append(" context = None")
482
+ else:
483
+ # Custom scopes (including "request")
484
+ # Inline context retrieval to avoid method call overhead
485
+ resolver_lines.append(" if context is None:")
486
+ resolver_lines.append(" try:")
487
+ resolver_lines.append(" context = _scoped_context_var.get()")
488
+ resolver_lines.append(" except LookupError:")
489
+ resolver_lines.append(
490
+ f" raise LookupError("
491
+ f"'The {scope} context has not been started. "
492
+ f"Please ensure that the {scope} context is properly initialized "
493
+ f"before attempting to use it.')"
494
+ )
485
495
 
486
496
  if scope == "singleton":
487
497
  if with_override:
@@ -507,33 +517,36 @@ class Resolver:
507
517
  store=True,
508
518
  indent=" ",
509
519
  )
510
- elif scope == "request":
520
+ elif scope == "transient":
521
+ # Transient scope
511
522
  if with_override:
512
- self._add_override_check(resolver_lines)
513
-
514
- # Fast path: check cached instance
515
- resolver_lines.append(" inst = context.get(_interface)")
516
- resolver_lines.append(" if inst is not NOT_SET_:")
517
- resolver_lines.append(" return inst")
523
+ self._add_override_check(resolver_lines, include_not_set=True)
518
524
 
519
525
  self._add_create_call(
520
526
  resolver_lines,
521
527
  is_async=is_async,
522
528
  with_override=with_override,
523
- context="context",
524
- store=True,
529
+ context="",
530
+ store=False,
525
531
  )
526
532
  else:
527
- # Transient scope
533
+ # Custom scopes (including "request")
528
534
  if with_override:
529
- self._add_override_check(resolver_lines, include_not_set=True)
535
+ self._add_override_check(resolver_lines)
536
+
537
+ # Fast path: check cached instance (inline dict access for speed)
538
+ resolver_lines.append(
539
+ " inst = context._instances.get(_interface, NOT_SET_)"
540
+ )
541
+ resolver_lines.append(" if inst is not NOT_SET_:")
542
+ resolver_lines.append(" return inst")
530
543
 
531
544
  self._add_create_call(
532
545
  resolver_lines,
533
546
  is_async=is_async,
534
547
  with_override=with_override,
535
- context="",
536
- store=False,
548
+ context="context",
549
+ store=True,
537
550
  )
538
551
 
539
552
  create_resolver_lines: list[str] = []
@@ -552,18 +565,26 @@ class Resolver:
552
565
 
553
566
  if scope == "singleton":
554
567
  create_resolver_lines.append(" context = container._singleton_context")
555
- elif scope == "request":
568
+ elif scope == "transient":
569
+ create_resolver_lines.append(" context = None")
570
+ else:
571
+ # Custom scopes (including "request")
572
+ # Inline context retrieval to avoid method call overhead
573
+ create_resolver_lines.append(" try:")
574
+ create_resolver_lines.append(" context = _scoped_context_var.get()")
575
+ create_resolver_lines.append(" except LookupError:")
556
576
  create_resolver_lines.append(
557
- " context = container._get_request_context()"
577
+ f" raise LookupError("
578
+ f"'The {scope} context has not been started. "
579
+ f"Please ensure that the {scope} context is properly initialized "
580
+ f"before attempting to use it.')"
558
581
  )
559
- else:
560
- create_resolver_lines.append(" context = None")
561
582
 
562
583
  if with_override:
563
584
  self._add_override_check(create_resolver_lines, include_not_set=True)
564
585
 
565
586
  # Determine context for create call
566
- context_arg = "context" if scope in ("singleton", "request") else ""
587
+ context_arg = "context" if scope != "transient" else ""
567
588
 
568
589
  self._add_create_call(
569
590
  create_resolver_lines,
@@ -603,6 +624,12 @@ class Resolver:
603
624
  "resolver": self,
604
625
  }
605
626
 
627
+ # For custom scopes, cache the ContextVar to avoid dictionary lookups
628
+ if scope not in ("singleton", "transient"):
629
+ ns["_scoped_context_var"] = self._container._get_scoped_context_var( # type: ignore[reportPrivateUsage]
630
+ scope
631
+ )
632
+
606
633
  # Add async-specific namespace entries
607
634
  if is_async:
608
635
  ns["_asynccontextmanager"] = contextlib.asynccontextmanager
anydi/_scanner.py CHANGED
@@ -3,13 +3,12 @@ from __future__ import annotations
3
3
  import importlib
4
4
  import inspect
5
5
  import pkgutil
6
- from collections.abc import Callable, Iterable
6
+ from collections.abc import Callable, Iterable, Iterator
7
7
  from dataclasses import dataclass
8
8
  from types import ModuleType
9
9
  from typing import TYPE_CHECKING, Any
10
10
 
11
- from ._decorators import is_injectable
12
- from ._types import is_inject_marker
11
+ from ._decorators import Provided, is_injectable, is_provided
13
12
 
14
13
  if TYPE_CHECKING:
15
14
  from ._container import Container
@@ -38,45 +37,64 @@ class Scanner:
38
37
  ) -> None:
39
38
  """Scan packages or modules for decorated members and inject dependencies."""
40
39
  if isinstance(packages, (ModuleType, str)):
41
- scan_packages: Iterable[Package] = [packages]
42
- else:
43
- scan_packages = packages
44
-
45
- dependencies = [
46
- dependency
47
- for package in scan_packages
48
- for dependency in self._scan_package(package, tags=tags)
49
- ]
50
-
51
- for dependency in dependencies:
40
+ packages = [packages]
41
+
42
+ tags_list = list(tags) if tags else []
43
+ provided_classes: list[type[Provided]] = []
44
+ injectable_dependencies: list[ScannedDependency] = []
45
+
46
+ # Single pass: collect both @provided classes and @injectable functions
47
+ for module in self._iter_modules(packages):
48
+ provided_classes.extend(self._scan_module_for_provided(module))
49
+ injectable_dependencies.extend(
50
+ self._scan_module_for_injectable(module, tags=tags_list)
51
+ )
52
+
53
+ # First: register @provided classes
54
+ for cls in provided_classes:
55
+ if not self._container.is_registered(cls):
56
+ scope = cls.__provided__["scope"]
57
+ self._container.register(cls, scope=scope)
58
+
59
+ # Second: inject @injectable functions
60
+ for dependency in injectable_dependencies:
52
61
  decorated = self._container.inject()(dependency.member)
53
62
  setattr(dependency.module, dependency.member.__name__, decorated)
54
63
 
55
- def _scan_package(
56
- self, package: Package, *, tags: Iterable[str] | None = None
57
- ) -> list[ScannedDependency]:
58
- """Scan a package or module for decorated members."""
59
- tags = list(tags) if tags else []
64
+ def _iter_modules(self, packages: Iterable[Package]) -> Iterator[ModuleType]:
65
+ """Iterate over all modules in the given packages."""
66
+ for package in packages:
67
+ if isinstance(package, str):
68
+ package = importlib.import_module(package)
60
69
 
61
- if isinstance(package, str):
62
- package = importlib.import_module(package)
70
+ # Single module (not a package)
71
+ if not hasattr(package, "__path__"):
72
+ yield package
73
+ continue
63
74
 
64
- if not hasattr(package, "__path__"):
65
- return self._scan_module(package, tags=tags)
75
+ # Package - walk all submodules
76
+ for module_info in pkgutil.walk_packages(
77
+ package.__path__, prefix=package.__name__ + "."
78
+ ):
79
+ yield importlib.import_module(module_info.name)
66
80
 
67
- dependencies: list[ScannedDependency] = []
68
- for module_info in pkgutil.walk_packages(
69
- package.__path__, prefix=package.__name__ + "."
70
- ):
71
- module = importlib.import_module(module_info.name)
72
- dependencies.extend(self._scan_module(module, tags=tags))
81
+ def _scan_module_for_provided(self, module: ModuleType) -> list[type[Provided]]:
82
+ """Scan a module for @provided classes."""
83
+ provided_classes: list[type[Provided]] = []
73
84
 
74
- return dependencies
85
+ for _, member in inspect.getmembers(module, predicate=inspect.isclass):
86
+ if getattr(member, "__module__", None) != module.__name__:
87
+ continue
88
+
89
+ if is_provided(member):
90
+ provided_classes.append(member)
75
91
 
76
- def _scan_module(
77
- self, module: ModuleType, *, tags: Iterable[str]
92
+ return provided_classes
93
+
94
+ def _scan_module_for_injectable(
95
+ self, module: ModuleType, *, tags: list[str]
78
96
  ) -> list[ScannedDependency]:
79
- """Scan a module for decorated members."""
97
+ """Scan a module for @injectable functions."""
80
98
  dependencies: list[ScannedDependency] = []
81
99
 
82
100
  for _, member in inspect.getmembers(module, predicate=callable):
@@ -89,22 +107,12 @@ class Scanner:
89
107
  return dependencies
90
108
 
91
109
  @staticmethod
92
- def _should_include_member(
93
- member: Callable[..., Any], *, tags: Iterable[str]
94
- ) -> bool:
110
+ def _should_include_member(member: Callable[..., Any], *, tags: list[str]) -> bool:
95
111
  """Determine if a member should be included based on tags or marker defaults."""
96
-
97
112
  if is_injectable(member):
98
113
  member_tags = set(member.__injectable__["tags"] or [])
99
114
  if tags:
100
115
  return bool(set(tags) & member_tags)
101
116
  return True # No tags passed → include all injectables
102
117
 
103
- # If no tags are passed and not explicitly injectable,
104
- # check for parameter markers
105
- if not tags:
106
- for parameter in inspect.signature(member).parameters.values():
107
- if is_inject_marker(parameter.default):
108
- return True
109
-
110
118
  return False
anydi/_types.py CHANGED
@@ -3,19 +3,21 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import inspect
6
- from collections.abc import AsyncIterator, Iterator
6
+ from collections.abc import AsyncIterator, Callable, Iterator
7
7
  from types import NoneType
8
- from typing import Any, Literal
8
+ from typing import TYPE_CHECKING, Annotated, Any, Literal, TypeVar
9
9
 
10
10
  from typing_extensions import Sentinel
11
11
 
12
- Scope = Literal["transient", "singleton", "request"]
12
+ T = TypeVar("T")
13
+
14
+ Scope = Literal["transient", "singleton", "request"] | str
13
15
 
14
16
  NOT_SET = Sentinel("NOT_SET")
15
17
 
16
18
 
17
- class InjectMarker:
18
- """A marker object for declaring injectable dependencies."""
19
+ class ProvideMarker:
20
+ """A marker object for declaring dependency."""
19
21
 
20
22
  __slots__ = ("_interface",)
21
23
 
@@ -32,13 +34,52 @@ class InjectMarker:
32
34
  def interface(self, interface: Any) -> None:
33
35
  self._interface = interface
34
36
 
37
+ def __class_getitem__(cls, item: Any) -> Any:
38
+ return Annotated[item, cls()]
39
+
40
+
41
+ _provide_factory: Callable[[], Any] = ProvideMarker
42
+
43
+
44
+ def set_provide_factory(factory: Callable[[], Any]) -> Callable[[], Any]:
45
+ """Set the global factory used by Inject() and Provide."""
46
+ global _provide_factory
47
+ previous = _provide_factory
48
+ _provide_factory = factory
49
+ return previous
50
+
51
+
52
+ def is_provide_marker(obj: Any) -> bool:
53
+ return isinstance(obj, ProvideMarker)
54
+
35
55
 
36
- def is_inject_marker(obj: Any) -> bool:
37
- return isinstance(obj, InjectMarker)
56
+ class _ProvideMeta(type):
57
+ """Metaclass for Provide that delegates __class_getitem__ to the current factory."""
58
+
59
+ def __getitem__(cls, item: Any) -> Any:
60
+ # Use the current factory's __class_getitem__ if available
61
+ factory = _provide_factory
62
+ if hasattr(factory, "__class_getitem__"):
63
+ return factory.__class_getitem__(item) # type: ignore[attr-defined]
64
+ # Fallback to creating Annotated with factory instance
65
+ return Annotated[item, factory()]
66
+
67
+
68
+ if TYPE_CHECKING:
69
+ Provide = Annotated[T, ProvideMarker()]
70
+
71
+ else:
72
+
73
+ class Provide(metaclass=_ProvideMeta):
74
+ pass
38
75
 
39
76
 
40
77
  def Inject() -> Any:
41
- return InjectMarker()
78
+ return _provide_factory()
79
+
80
+
81
+ # Alias from backward compatibility
82
+ is_inject_marker = is_provide_marker
42
83
 
43
84
 
44
85
  class Event:
anydi/ext/fastapi.py CHANGED
@@ -11,37 +11,12 @@ from fastapi.dependencies.models import Dependant
11
11
  from fastapi.routing import APIRoute
12
12
  from starlette.requests import Request
13
13
 
14
- from anydi._container import Container
15
- from anydi._types import InjectMarker
14
+ from anydi import Container, Inject
15
+ from anydi._types import ProvideMarker, set_provide_factory
16
16
 
17
17
  from .starlette.middleware import RequestScopedMiddleware
18
18
 
19
- __all__ = ["RequestScopedMiddleware", "install", "get_container", "Inject"]
20
-
21
-
22
- def install(app: FastAPI, container: Container) -> None:
23
- """Install AnyDI into a FastAPI application.
24
-
25
- This function installs the AnyDI container into a FastAPI application by attaching
26
- it to the application state. It also patches the route dependencies to inject the
27
- required dependencies using AnyDI.
28
- """
29
- app.state.container = container # noqa
30
-
31
- patched = []
32
-
33
- for route in app.routes:
34
- if not isinstance(route, APIRoute):
35
- continue
36
- for dependant in _iter_dependencies(route.dependant):
37
- if dependant.cache_key in patched:
38
- continue
39
- patched.append(dependant.cache_key)
40
- call, *params = dependant.cache_key
41
- if not call:
42
- continue # pragma: no cover
43
- for parameter in inspect.signature(call, eval_str=True).parameters.values():
44
- container.validate_injected_parameter(parameter, call=call)
19
+ __all__ = ["install", "get_container", "Inject", "RequestScopedMiddleware"]
45
20
 
46
21
 
47
22
  def get_container(request: Request) -> Container:
@@ -49,10 +24,10 @@ def get_container(request: Request) -> Container:
49
24
  return cast(Container, request.app.state.container)
50
25
 
51
26
 
52
- class _Inject(params.Depends, InjectMarker):
27
+ class _ProvideMarker(params.Depends, ProvideMarker):
53
28
  def __init__(self) -> None:
54
29
  super().__init__(dependency=self._dependency, use_cache=True)
55
- InjectMarker.__init__(self)
30
+ ProvideMarker.__init__(self)
56
31
 
57
32
  async def _dependency(
58
33
  self, container: Annotated[Container, Depends(get_container)]
@@ -60,13 +35,36 @@ class _Inject(params.Depends, InjectMarker):
60
35
  return await container.aresolve(self.interface)
61
36
 
62
37
 
63
- def Inject() -> Any:
64
- return _Inject()
38
+ # Configure Inject() and Provide[T] to use FastAPI-specific marker
39
+ set_provide_factory(_ProvideMarker)
65
40
 
66
41
 
67
42
  def _iter_dependencies(dependant: Dependant) -> Iterator[Dependant]:
68
- """Iterate over the dependencies of a dependant."""
69
43
  yield dependant
70
44
  if dependant.dependencies:
71
45
  for sub_dependant in dependant.dependencies:
72
46
  yield from _iter_dependencies(sub_dependant)
47
+
48
+
49
+ def _validate_route_dependencies(
50
+ route: APIRoute, container: Container, patched: set[tuple[Any, ...]]
51
+ ) -> None:
52
+ for dependant in _iter_dependencies(route.dependant):
53
+ if dependant.cache_key in patched:
54
+ continue
55
+ patched.add(dependant.cache_key)
56
+ call, *_ = dependant.cache_key
57
+ if not call:
58
+ continue # pragma: no cover
59
+ for parameter in inspect.signature(call, eval_str=True).parameters.values():
60
+ container.validate_injected_parameter(parameter, call=call)
61
+
62
+
63
+ def install(app: FastAPI, container: Container) -> None:
64
+ """Install AnyDI into a FastAPI application."""
65
+ app.state.container = container # noqa
66
+ patched: set[tuple[Any, ...]] = set()
67
+ for route in app.routes:
68
+ if not isinstance(route, APIRoute):
69
+ continue
70
+ _validate_route_dependencies(route, container, patched)
anydi/ext/faststream.py CHANGED
@@ -10,49 +10,43 @@ from faststream import ContextRepo
10
10
  from faststream.broker.core.usecase import BrokerUsecase
11
11
 
12
12
  from anydi import Container
13
- from anydi._types import InjectMarker
13
+ from anydi._types import Inject, ProvideMarker, set_provide_factory
14
14
 
15
-
16
- def install(broker: BrokerUsecase[Any, Any], container: Container) -> None:
17
- """Install AnyDI into a FastStream broker.
18
-
19
- This function installs the AnyDI container into a FastStream broker by attaching
20
- it to the broker. It also patches the broker handlers to inject the required
21
- dependencies using AnyDI.
22
- """
23
- broker._container = container # type: ignore
24
-
25
- for handler in _get_broken_handlers(broker):
26
- call = handler._original_call # noqa
27
- for parameter in inspect.signature(call, eval_str=True).parameters.values():
28
- container.validate_injected_parameter(parameter, call=call)
29
-
30
-
31
- def _get_broken_handlers(broker: BrokerUsecase[Any, Any]) -> list[Any]:
32
- if (handlers := getattr(broker, "handlers", None)) is not None:
33
- return [handler.calls[0][0] for handler in handlers.values()]
34
- # faststream > 0.5.0
35
- return [
36
- subscriber.calls[0].handler
37
- for subscriber in broker._subscribers.values() # noqa
38
- ]
15
+ __all__ = ["install", "get_container", "Inject"]
39
16
 
40
17
 
41
18
  def get_container(broker: BrokerUsecase[Any, Any]) -> Container:
19
+ """Get the AnyDI container from a FastStream broker."""
42
20
  return cast(Container, getattr(broker, "_container")) # noqa
43
21
 
44
22
 
45
- class _Inject(Depends, InjectMarker):
46
- """Parameter dependency class for injecting dependencies using AnyDI."""
47
-
23
+ class _ProvideMarker(Depends, ProvideMarker):
48
24
  def __init__(self) -> None:
49
25
  super().__init__(dependency=self._dependency, use_cache=True, cast=True)
50
- InjectMarker.__init__(self)
26
+ ProvideMarker.__init__(self)
51
27
 
52
28
  async def _dependency(self, context: ContextRepo) -> Any:
53
29
  container = get_container(context.get("broker"))
54
30
  return await container.aresolve(self.interface)
55
31
 
56
32
 
57
- def Inject() -> Any:
58
- return _Inject()
33
+ # Configure Inject() and Provide[T] to use FastStream-specific marker
34
+ set_provide_factory(_ProvideMarker)
35
+
36
+
37
+ def _get_broker_handlers(broker: BrokerUsecase[Any, Any]) -> list[Any]:
38
+ if (handlers := getattr(broker, "handlers", None)) is not None:
39
+ return [handler.calls[0][0] for handler in handlers.values()]
40
+ return [
41
+ subscriber.calls[0].handler
42
+ for subscriber in broker._subscribers.values() # noqa
43
+ ]
44
+
45
+
46
+ def install(broker: BrokerUsecase[Any, Any], container: Container) -> None:
47
+ """Install AnyDI into a FastStream broker."""
48
+ broker._container = container # type: ignore
49
+ for handler in _get_broker_handlers(broker):
50
+ call = handler._original_call # noqa
51
+ for parameter in inspect.signature(call, eval_str=True).parameters.values():
52
+ container.validate_injected_parameter(parameter, call=call)
@@ -22,7 +22,8 @@ def install(
22
22
  prefix += "."
23
23
 
24
24
  def _register_settings(_settings: BaseSettings) -> None:
25
- all_fields = {**_settings.model_fields, **_settings.model_computed_fields}
25
+ settings_cls = type(_settings)
26
+ all_fields = {**settings_cls.model_fields, **settings_cls.model_computed_fields}
26
27
  for setting_name, field_info in all_fields.items():
27
28
  if isinstance(field_info, ComputedFieldInfo):
28
29
  interface = field_info.return_type