modern-di 0.16.2__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of modern-di might be problematic. Click here for more details.

Files changed (35) hide show
  1. modern_di/__init__.py +6 -4
  2. modern_di/containers/__init__.py +0 -0
  3. modern_di/containers/abstract.py +110 -0
  4. modern_di/containers/async_container.py +104 -0
  5. modern_di/containers/sync_container.py +105 -0
  6. modern_di/group.py +26 -0
  7. modern_di/helpers/type_helpers.py +33 -0
  8. modern_di/providers/__init__.py +2 -4
  9. modern_di/providers/abstract.py +62 -74
  10. modern_di/providers/async_factory.py +9 -11
  11. modern_di/providers/async_singleton.py +14 -26
  12. modern_di/providers/container_provider.py +4 -8
  13. modern_di/providers/context_provider.py +16 -0
  14. modern_di/providers/dict.py +20 -13
  15. modern_di/providers/factory.py +17 -23
  16. modern_di/providers/list.py +20 -13
  17. modern_di/providers/object.py +6 -11
  18. modern_di/providers/resource.py +38 -66
  19. modern_di/providers/singleton.py +23 -43
  20. modern_di/registries/__init__.py +0 -0
  21. modern_di/registries/context_registry.py +16 -0
  22. modern_di/registries/overrides_registry.py +22 -0
  23. modern_di/registries/providers_registry.py +38 -0
  24. modern_di/registries/state_registry/__init__.py +0 -0
  25. modern_di/{provider_state.py → registries/state_registry/state.py} +15 -11
  26. modern_di/registries/state_registry/state_registry.py +52 -0
  27. {modern_di-0.16.2.dist-info → modern_di-0.17.0.dist-info}/METADATA +1 -1
  28. modern_di-0.17.0.dist-info/RECORD +33 -0
  29. modern_di/container.py +0 -171
  30. modern_di/graph.py +0 -39
  31. modern_di/providers/context_adapter.py +0 -27
  32. modern_di/providers/injected_factory.py +0 -45
  33. modern_di/providers/selector.py +0 -39
  34. modern_di-0.16.2.dist-info/RECORD +0 -25
  35. {modern_di-0.16.2.dist-info → modern_di-0.17.0.dist-info}/WHEEL +0 -0
modern_di/__init__.py CHANGED
@@ -1,11 +1,13 @@
1
- from modern_di.container import Container
2
- from modern_di.graph import BaseGraph
1
+ from modern_di.containers.async_container import AsyncContainer
2
+ from modern_di.containers.sync_container import SyncContainer
3
+ from modern_di.group import Group
3
4
  from modern_di.scope import Scope
4
5
 
5
6
 
6
7
  __all__ = [
7
- "BaseGraph",
8
- "Container",
8
+ "AsyncContainer",
9
+ "Group",
9
10
  "Scope",
11
+ "SyncContainer",
10
12
  "providers",
11
13
  ]
File without changes
@@ -0,0 +1,110 @@
1
+ import enum
2
+ import typing
3
+
4
+ from modern_di.group import Group
5
+ from modern_di.providers.abstract import AbstractProvider
6
+ from modern_di.providers.context_provider import ContextProvider
7
+ from modern_di.registries.context_registry import ContextRegistry
8
+ from modern_di.registries.overrides_registry import OverridesRegistry
9
+ from modern_di.registries.providers_registry import ProvidersRegistry
10
+ from modern_di.scope import Scope
11
+
12
+
13
+ if typing.TYPE_CHECKING:
14
+ import typing_extensions
15
+
16
+
17
+ T_co = typing.TypeVar("T_co", covariant=True)
18
+
19
+
20
+ class AbstractContainer:
21
+ BASE_SLOTS: typing.ClassVar[list[str]] = [
22
+ "_is_entered",
23
+ "state_registry",
24
+ "context",
25
+ "parent_container",
26
+ "scope",
27
+ "overrides_registry",
28
+ "providers_registry",
29
+ ]
30
+
31
+ def __init__(
32
+ self,
33
+ *,
34
+ scope: Scope = Scope.APP,
35
+ parent_container: typing.Optional["typing_extensions.Self"] = None,
36
+ context: dict[type[typing.Any], typing.Any] | None = None,
37
+ groups: list[type[Group]] | None = None,
38
+ ) -> None:
39
+ self._is_entered = False
40
+ self.scope = scope
41
+ self.parent_container = parent_container
42
+ self.providers_registry = ProvidersRegistry()
43
+ if groups:
44
+ for one_group in groups:
45
+ self.providers_registry.add_providers(**one_group.get_providers())
46
+ self.overrides_registry: OverridesRegistry
47
+ self.context_registry = ContextRegistry(context or {})
48
+ if parent_container:
49
+ self.overrides_registry = parent_container.overrides_registry
50
+ else:
51
+ self.overrides_registry = OverridesRegistry()
52
+
53
+ def _check_entered(self) -> None:
54
+ if not self._is_entered:
55
+ msg = f"Enter the context of {self.scope.name} scope"
56
+ raise RuntimeError(msg)
57
+
58
+ def _resolve_context_provider(self, provider: ContextProvider[T_co]) -> T_co:
59
+ context = self.context_registry.find_context(provider.context_type)
60
+ if not context:
61
+ msg = f"Context of type {provider.context_type} is missing"
62
+ raise RuntimeError(msg)
63
+
64
+ return context
65
+
66
+ def build_child_container(
67
+ self, context: dict[type[typing.Any], typing.Any] | None = None, scope: Scope | None = None
68
+ ) -> "typing_extensions.Self":
69
+ self._check_entered()
70
+ if scope and scope <= self.scope:
71
+ msg = "Scope of child container must be more than current scope"
72
+ raise RuntimeError(msg)
73
+
74
+ if not scope:
75
+ try:
76
+ scope = self.scope.__class__(self.scope.value + 1)
77
+ except ValueError as exc:
78
+ msg = f"Max scope is reached, {self.scope.name}"
79
+ raise RuntimeError(msg) from exc
80
+
81
+ return self.__class__(scope=scope, parent_container=self, context=context)
82
+
83
+ def find_container(self, scope: enum.IntEnum) -> "typing_extensions.Self":
84
+ container = self
85
+ if container.scope < scope:
86
+ msg = f"Scope {scope.name} is not initialized"
87
+ raise RuntimeError(msg)
88
+
89
+ while container.scope > scope and container.parent_container:
90
+ container = container.parent_container
91
+
92
+ if container.scope != scope:
93
+ msg = f"Scope {scope.name} is skipped"
94
+ raise RuntimeError(msg)
95
+
96
+ return container
97
+
98
+ def override(self, provider: AbstractProvider[T_co], override_object: object) -> None:
99
+ self.overrides_registry.override(provider.provider_id, override_object)
100
+
101
+ def reset_override(self, provider: AbstractProvider[T_co] | None = None) -> None:
102
+ self.overrides_registry.reset_override(provider.provider_id if provider else None)
103
+
104
+ def __deepcopy__(self, *_: object, **__: object) -> "typing_extensions.Self":
105
+ """Hack to prevent cloning object."""
106
+ return self
107
+
108
+ def __copy__(self, *_: object, **__: object) -> "typing_extensions.Self":
109
+ """Hack to prevent cloning object."""
110
+ return self
@@ -0,0 +1,104 @@
1
+ import contextlib
2
+ import types
3
+ import typing
4
+
5
+ from modern_di.containers.abstract import AbstractContainer
6
+ from modern_di.group import Group
7
+ from modern_di.providers.abstract import AbstractProvider
8
+ from modern_di.providers.container_provider import ContainerProvider
9
+ from modern_di.providers.context_provider import ContextProvider
10
+ from modern_di.registries.state_registry.state_registry import AsyncStateRegistry
11
+ from modern_di.scope import Scope
12
+
13
+
14
+ if typing.TYPE_CHECKING:
15
+ import typing_extensions
16
+
17
+
18
+ T_co = typing.TypeVar("T_co", covariant=True)
19
+
20
+
21
+ class AsyncContainer(contextlib.AbstractAsyncContextManager["AsyncContainer"], AbstractContainer):
22
+ __slots__ = AbstractContainer.BASE_SLOTS
23
+
24
+ def __init__(
25
+ self,
26
+ *,
27
+ scope: Scope = Scope.APP,
28
+ parent_container: typing.Optional["typing_extensions.Self"] = None,
29
+ context: dict[type[typing.Any], typing.Any] | None = None,
30
+ groups: list[type[Group]] | None = None,
31
+ use_lock: bool = True,
32
+ ) -> None:
33
+ super().__init__(scope=scope, parent_container=parent_container, context=context, groups=groups)
34
+ self.state_registry = AsyncStateRegistry(use_lock=use_lock)
35
+
36
+ async def _resolve_args(self, args: list[typing.Any]) -> list[typing.Any]:
37
+ return [await self.resolve_provider(x) if isinstance(x, AbstractProvider) else x for x in args]
38
+
39
+ async def _resolve_kwargs(self, kwargs: dict[str, typing.Any]) -> dict[str, typing.Any]:
40
+ return {k: await self.resolve_provider(v) if isinstance(v, AbstractProvider) else v for k, v in kwargs.items()}
41
+
42
+ async def resolve(
43
+ self, dependency_type: type[T_co] | None = None, *, dependency_name: str | None = None
44
+ ) -> T_co | None:
45
+ provider = self.providers_registry.find_provider(
46
+ dependency_type=dependency_type, dependency_name=dependency_name
47
+ )
48
+ if not provider:
49
+ return None
50
+
51
+ return await self.resolve_provider(provider)
52
+
53
+ async def resolve_provider(self, provider: AbstractProvider[T_co]) -> T_co:
54
+ self._check_entered()
55
+
56
+ container = self.find_container(provider.scope)
57
+ if isinstance(provider, ContainerProvider):
58
+ return typing.cast(T_co, container)
59
+
60
+ if isinstance(provider, ContextProvider):
61
+ return typing.cast(T_co, self._resolve_context_provider(provider))
62
+
63
+ if (override := container.overrides_registry.fetch_override(provider.provider_id)) is not None:
64
+ return typing.cast(T_co, override)
65
+
66
+ provider_state = container.state_registry.fetch_provider_state(provider)
67
+ if provider_state and provider_state.instance is not None:
68
+ return provider_state.instance
69
+
70
+ if provider_state and provider_state.lock:
71
+ await provider_state.lock.acquire()
72
+ try:
73
+ if provider_state and provider_state.instance is not None:
74
+ return provider_state.instance
75
+
76
+ return await provider.async_resolve(
77
+ args=await self._resolve_args(provider.args or []),
78
+ kwargs=await self._resolve_kwargs(provider.kwargs or {}),
79
+ provider_state=provider_state,
80
+ )
81
+ finally:
82
+ if provider_state and provider_state.lock:
83
+ provider_state.lock.release()
84
+
85
+ def enter(self) -> "AsyncContainer":
86
+ self._is_entered = True
87
+ return self
88
+
89
+ async def close(self) -> None:
90
+ self._check_entered()
91
+ self._is_entered = False
92
+ await self.state_registry.clear_state()
93
+ self.overrides_registry.reset_override()
94
+
95
+ async def __aenter__(self) -> "AsyncContainer":
96
+ return self.enter()
97
+
98
+ async def __aexit__(
99
+ self,
100
+ exc_type: type[BaseException] | None,
101
+ exc_val: BaseException | None,
102
+ traceback: types.TracebackType | None,
103
+ ) -> None:
104
+ await self.close()
@@ -0,0 +1,105 @@
1
+ import contextlib
2
+ import types
3
+ import typing
4
+
5
+ from modern_di.containers.abstract import AbstractContainer
6
+ from modern_di.group import Group
7
+ from modern_di.providers import ContainerProvider
8
+ from modern_di.providers.abstract import AbstractProvider
9
+ from modern_di.providers.context_provider import ContextProvider
10
+ from modern_di.registries.state_registry.state_registry import SyncStateRegistry
11
+ from modern_di.scope import Scope
12
+
13
+
14
+ if typing.TYPE_CHECKING:
15
+ import typing_extensions
16
+
17
+
18
+ T_co = typing.TypeVar("T_co", covariant=True)
19
+
20
+
21
+ class SyncContainer(contextlib.AbstractContextManager["SyncContainer"], AbstractContainer):
22
+ __slots__ = AbstractContainer.BASE_SLOTS
23
+
24
+ def __init__(
25
+ self,
26
+ *,
27
+ scope: Scope = Scope.APP,
28
+ parent_container: typing.Optional["typing_extensions.Self"] = None,
29
+ context: dict[type[typing.Any], typing.Any] | None = None,
30
+ groups: list[type[Group]] | None = None,
31
+ use_lock: bool = True,
32
+ ) -> None:
33
+ super().__init__(scope=scope, parent_container=parent_container, context=context, groups=groups)
34
+ self.state_registry = SyncStateRegistry(use_lock=use_lock)
35
+
36
+ def _resolve_args(self, args: list[typing.Any]) -> list[typing.Any]:
37
+ return [self.resolve_provider(x) if isinstance(x, AbstractProvider) else x for x in args]
38
+
39
+ def _resolve_kwargs(self, kwargs: dict[str, typing.Any]) -> dict[str, typing.Any]:
40
+ return {k: self.resolve_provider(v) if isinstance(v, AbstractProvider) else v for k, v in kwargs.items()}
41
+
42
+ def resolve(self, dependency_type: type[T_co] | None = None, *, dependency_name: str | None = None) -> T_co | None:
43
+ provider = self.providers_registry.find_provider(
44
+ dependency_type=dependency_type, dependency_name=dependency_name
45
+ )
46
+ if not provider:
47
+ return None
48
+
49
+ return self.resolve_provider(provider)
50
+
51
+ def resolve_provider(self, provider: AbstractProvider[T_co]) -> T_co:
52
+ self._check_entered()
53
+ if provider.is_async:
54
+ msg = f"{type(provider).__name__} cannot be resolved synchronously"
55
+ raise RuntimeError(msg)
56
+
57
+ container = self.find_container(provider.scope)
58
+ if isinstance(provider, ContainerProvider):
59
+ return typing.cast(T_co, container)
60
+
61
+ if isinstance(provider, ContextProvider):
62
+ return typing.cast(T_co, self._resolve_context_provider(provider))
63
+
64
+ if (override := container.overrides_registry.fetch_override(provider.provider_id)) is not None:
65
+ return typing.cast(T_co, override)
66
+
67
+ provider_state = container.state_registry.fetch_provider_state(provider)
68
+ if provider_state and provider_state.instance is not None:
69
+ return provider_state.instance
70
+
71
+ if provider_state and provider_state.lock:
72
+ provider_state.lock.acquire()
73
+ try:
74
+ if provider_state and provider_state.instance is not None:
75
+ return provider_state.instance
76
+
77
+ return provider.sync_resolve(
78
+ args=self._resolve_args(provider.args or []),
79
+ kwargs=self._resolve_kwargs(provider.kwargs or {}),
80
+ provider_state=provider_state,
81
+ )
82
+ finally:
83
+ if provider_state and provider_state.lock:
84
+ provider_state.lock.release()
85
+
86
+ def enter(self) -> "SyncContainer":
87
+ self._is_entered = True
88
+ return self
89
+
90
+ def close(self) -> None:
91
+ self._check_entered()
92
+ self._is_entered = False
93
+ self.state_registry.clear_state()
94
+ self.overrides_registry.reset_override()
95
+
96
+ def __enter__(self) -> "SyncContainer":
97
+ return self.enter()
98
+
99
+ def __exit__(
100
+ self,
101
+ exc_type: type[BaseException] | None,
102
+ exc_value: BaseException | None,
103
+ traceback: types.TracebackType | None,
104
+ ) -> None:
105
+ self.close()
modern_di/group.py ADDED
@@ -0,0 +1,26 @@
1
+ import typing
2
+
3
+ from modern_di.providers.abstract import AbstractProvider
4
+
5
+
6
+ if typing.TYPE_CHECKING:
7
+ import typing_extensions
8
+
9
+
10
+ T = typing.TypeVar("T")
11
+ P = typing.ParamSpec("P")
12
+
13
+
14
+ class Group:
15
+ providers: dict[str, AbstractProvider[typing.Any]]
16
+
17
+ def __new__(cls, *_: typing.Any, **__: typing.Any) -> "typing_extensions.Self": # noqa: ANN401
18
+ msg = f"{cls.__name__} cannot not be instantiated"
19
+ raise RuntimeError(msg)
20
+
21
+ @classmethod
22
+ def get_providers(cls) -> dict[str, AbstractProvider[typing.Any]]:
23
+ if not hasattr(cls, "providers"):
24
+ cls.providers = {k: v for k, v in cls.__dict__.items() if isinstance(v, AbstractProvider)}
25
+
26
+ return cls.providers
@@ -0,0 +1,33 @@
1
+ import collections
2
+ import types
3
+ import typing
4
+
5
+
6
+ GENERIC_TYPES = {
7
+ typing.Iterator,
8
+ typing.AsyncIterator,
9
+ collections.abc.Iterator,
10
+ collections.abc.AsyncIterator,
11
+ }
12
+
13
+
14
+ def define_bound_type(creator: type | object) -> type | None:
15
+ if isinstance(creator, type):
16
+ return creator
17
+
18
+ type_hints = typing.get_type_hints(creator)
19
+ return_annotation = type_hints.get("return")
20
+ if not return_annotation:
21
+ return None
22
+
23
+ if isinstance(return_annotation, type) and not isinstance(return_annotation, types.GenericAlias):
24
+ return return_annotation
25
+
26
+ if typing.get_origin(return_annotation) not in GENERIC_TYPES:
27
+ return None
28
+
29
+ args = typing.get_args(return_annotation)
30
+ if not args:
31
+ return None
32
+
33
+ return typing.cast(type, args[0])
@@ -2,13 +2,12 @@ from modern_di.providers.abstract import AbstractProvider
2
2
  from modern_di.providers.async_factory import AsyncFactory
3
3
  from modern_di.providers.async_singleton import AsyncSingleton
4
4
  from modern_di.providers.container_provider import ContainerProvider
5
- from modern_di.providers.context_adapter import ContextAdapter
5
+ from modern_di.providers.context_provider import ContextProvider
6
6
  from modern_di.providers.dict import Dict
7
7
  from modern_di.providers.factory import Factory
8
8
  from modern_di.providers.list import List
9
9
  from modern_di.providers.object import Object
10
10
  from modern_di.providers.resource import Resource
11
- from modern_di.providers.selector import Selector
12
11
  from modern_di.providers.singleton import Singleton
13
12
 
14
13
 
@@ -17,12 +16,11 @@ __all__ = [
17
16
  "AsyncFactory",
18
17
  "AsyncSingleton",
19
18
  "ContainerProvider",
20
- "ContextAdapter",
19
+ "ContextProvider",
21
20
  "Dict",
22
21
  "Factory",
23
22
  "List",
24
23
  "Object",
25
24
  "Resource",
26
- "Selector",
27
25
  "Singleton",
28
26
  ]
@@ -3,10 +3,12 @@ import enum
3
3
  import typing
4
4
  import uuid
5
5
 
6
+ import typing_extensions
6
7
  from typing_extensions import override
7
8
 
8
- from modern_di import Container
9
9
  from modern_di.helpers.attr_getter_helpers import get_value_from_object_by_dotted_path
10
+ from modern_di.helpers.type_helpers import define_bound_type
11
+ from modern_di.registries.state_registry.state import AsyncState, SyncState
10
12
 
11
13
 
12
14
  T_co = typing.TypeVar("T_co", covariant=True)
@@ -14,36 +16,62 @@ R = typing.TypeVar("R")
14
16
  P = typing.ParamSpec("P")
15
17
 
16
18
 
17
- class AbstractProvider(typing.Generic[T_co], abc.ABC):
18
- BASE_SLOTS: typing.ClassVar = ["scope", "provider_id"]
19
+ class AbstractProvider(abc.ABC, typing.Generic[T_co]):
20
+ BASE_SLOTS: typing.ClassVar = ["scope", "provider_id", "args", "kwargs", "is_async", "bound_type"]
21
+ HAS_STATE: bool = False
19
22
 
20
- def __init__(self, scope: enum.IntEnum) -> None:
23
+ def __init__(
24
+ self,
25
+ scope: enum.IntEnum,
26
+ args: list[typing.Any] | None = None,
27
+ kwargs: dict[str, typing.Any] | None = None,
28
+ bound_type: type | None = None,
29
+ ) -> None:
21
30
  self.scope = scope
22
31
  self.provider_id: typing.Final = str(uuid.uuid4())
32
+ self.args = args
33
+ self.kwargs = kwargs
34
+ self.is_async = False
35
+ self.bound_type = bound_type
36
+ self._check_providers_scope()
37
+
38
+ def bind_type(self, new_type: type) -> typing_extensions.Self:
39
+ self.bound_type = new_type
40
+ return self
23
41
 
24
- @abc.abstractmethod
25
- async def async_resolve(self, container: Container) -> T_co:
42
+ async def async_resolve(
43
+ self,
44
+ *,
45
+ args: list[typing.Any],
46
+ kwargs: dict[str, typing.Any],
47
+ provider_state: AsyncState[T_co] | None,
48
+ ) -> T_co: # pragma: no cover
26
49
  """Resolve dependency asynchronously."""
50
+ raise NotImplementedError
27
51
 
28
- @abc.abstractmethod
29
- def sync_resolve(self, container: Container) -> T_co:
52
+ def sync_resolve(
53
+ self,
54
+ *,
55
+ args: list[typing.Any],
56
+ kwargs: dict[str, typing.Any],
57
+ provider_state: SyncState[T_co] | None,
58
+ ) -> T_co: # pragma: no cover
30
59
  """Resolve dependency synchronously."""
60
+ raise NotImplementedError
31
61
 
32
62
  @property
33
63
  def cast(self) -> T_co:
34
64
  return typing.cast(T_co, self)
35
65
 
36
- def _check_providers_scope(
37
- self, *, args: typing.Iterable[typing.Any] | None = None, kwargs: typing.Mapping[str, typing.Any] | None = None
38
- ) -> None:
39
- if args:
40
- for provider in args:
66
+ def _check_providers_scope(self) -> None:
67
+ if self.args:
68
+ for provider in self.args:
41
69
  if isinstance(provider, AbstractProvider) and provider.scope > self.scope:
42
70
  msg = f"Scope of dependency is {provider.scope.name} and current scope is {self.scope.name}"
43
71
  raise RuntimeError(msg)
44
72
 
45
- if kwargs:
46
- for name, provider in kwargs.items():
73
+ if self.kwargs:
74
+ for name, provider in self.kwargs.items():
47
75
  if isinstance(provider, AbstractProvider) and provider.scope > self.scope:
48
76
  msg = f"Scope of {name} is {provider.scope.name} and current scope is {self.scope.name}"
49
77
  raise RuntimeError(msg)
@@ -65,16 +93,8 @@ class AbstractProvider(typing.Generic[T_co], abc.ABC):
65
93
  return AttrGetter(provider=self, attr_name=attr_name)
66
94
 
67
95
 
68
- class AbstractOverrideProvider(AbstractProvider[T_co], abc.ABC):
69
- def override(self, override_object: object, container: Container) -> None:
70
- container.override(self.provider_id, override_object)
71
-
72
- def reset_override(self, container: Container) -> None:
73
- container.reset_override(self.provider_id)
74
-
75
-
76
- class AbstractCreatorProvider(AbstractOverrideProvider[T_co], abc.ABC):
77
- BASE_SLOTS: typing.ClassVar = [*AbstractProvider.BASE_SLOTS, "_args", "_kwargs", "_creator"]
96
+ class AbstractCreatorProvider(AbstractProvider[T_co], abc.ABC):
97
+ BASE_SLOTS: typing.ClassVar = [*AbstractProvider.BASE_SLOTS, "_creator"]
78
98
 
79
99
  def __init__(
80
100
  self,
@@ -83,55 +103,15 @@ class AbstractCreatorProvider(AbstractOverrideProvider[T_co], abc.ABC):
83
103
  *args: P.args,
84
104
  **kwargs: P.kwargs,
85
105
  ) -> None:
86
- super().__init__(scope)
87
- self._check_providers_scope(args=args, kwargs=kwargs)
106
+ super().__init__(scope, args=list(args), kwargs=kwargs, bound_type=define_bound_type(creator))
88
107
  self._creator: typing.Final = creator
89
- self._args: typing.Final = args
90
- self._kwargs: typing.Final = kwargs
91
-
92
- def _sync_resolve_args(self, container: Container) -> list[typing.Any]:
93
- return [x.sync_resolve(container) if isinstance(x, AbstractProvider) else x for x in self._args]
94
-
95
- def _sync_resolve_kwargs(self, container: Container) -> dict[str, typing.Any]:
96
- return {k: v.sync_resolve(container) if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}
97
-
98
- def _sync_build_creator(self, container: Container) -> typing.Any: # noqa: ANN401
99
- return self._creator(
100
- *self._sync_resolve_args(container),
101
- **self._sync_resolve_kwargs(container),
102
- )
103
-
104
- async def _async_resolve_args(self, container: Container) -> list[typing.Any]:
105
- return [await x.async_resolve(container) if isinstance(x, AbstractProvider) else x for x in self._args]
106
-
107
- async def _async_resolve_kwargs(self, container: Container) -> dict[str, typing.Any]:
108
- return {
109
- k: await v.async_resolve(container) if isinstance(v, AbstractProvider) else v
110
- for k, v in self._kwargs.items()
111
- }
112
-
113
- async def _async_build_creator(self, container: Container) -> typing.Any: # noqa: ANN401
114
- return self._creator(
115
- *await self._async_resolve_args(container),
116
- **await self._async_resolve_kwargs(container),
117
- )
118
108
 
119
109
 
120
110
  class AttrGetter(AbstractProvider[T_co]):
121
- """Provides an attribute after resolving the wrapped provider."""
122
-
123
- __slots__ = [*AbstractProvider.BASE_SLOTS, "_attrs", "_provider"]
111
+ __slots__ = [*AbstractProvider.BASE_SLOTS, "_attrs"]
124
112
 
125
113
  def __init__(self, provider: AbstractProvider[T_co], attr_name: str) -> None:
126
- """Create a new AttrGetter instance.
127
-
128
- Args:
129
- provider: provider to wrap.
130
- attr_name: attribute name to resolve when the provider is resolved.
131
-
132
- """
133
- super().__init__(scope=provider.scope)
134
- self._provider = provider
114
+ super().__init__(scope=provider.scope, args=[provider])
135
115
  self._attrs = [attr_name]
136
116
 
137
117
  def __getattr__(self, attr: str) -> "AttrGetter[T_co]":
@@ -142,13 +122,21 @@ class AttrGetter(AbstractProvider[T_co]):
142
122
  return self
143
123
 
144
124
  @override
145
- async def async_resolve(self, container: Container) -> typing.Any:
146
- resolved_provider_object = await self._provider.async_resolve(container)
125
+ async def async_resolve(
126
+ self,
127
+ *,
128
+ args: list[typing.Any],
129
+ **_: object,
130
+ ) -> typing.Any:
147
131
  attribute_path = ".".join(self._attrs)
148
- return get_value_from_object_by_dotted_path(resolved_provider_object, attribute_path)
132
+ return get_value_from_object_by_dotted_path(args[0], attribute_path)
149
133
 
150
134
  @override
151
- def sync_resolve(self, container: Container) -> typing.Any:
152
- resolved_provider_object = self._provider.sync_resolve(container)
135
+ def sync_resolve(
136
+ self,
137
+ *,
138
+ args: list[typing.Any],
139
+ **_: object,
140
+ ) -> typing.Any:
153
141
  attribute_path = ".".join(self._attrs)
154
- return get_value_from_object_by_dotted_path(resolved_provider_object, attribute_path)
142
+ return get_value_from_object_by_dotted_path(args[0], attribute_path)
@@ -1,7 +1,6 @@
1
1
  import enum
2
2
  import typing
3
3
 
4
- from modern_di import Container
5
4
  from modern_di.providers.abstract import AbstractCreatorProvider
6
5
 
7
6
 
@@ -20,15 +19,14 @@ class AsyncFactory(AbstractCreatorProvider[T_co]):
20
19
  **kwargs: P.kwargs,
21
20
  ) -> None:
22
21
  super().__init__(scope, creator, *args, **kwargs)
22
+ self.is_async = True
23
23
 
24
- async def async_resolve(self, container: Container) -> T_co:
25
- container = container.find_container(self.scope)
26
- if (override := container.fetch_override(self.provider_id)) is not None:
27
- return typing.cast(T_co, override)
28
-
29
- coroutine: typing.Awaitable[T_co] = await self._async_build_creator(container)
24
+ async def async_resolve(
25
+ self,
26
+ *,
27
+ args: list[typing.Any],
28
+ kwargs: dict[str, typing.Any],
29
+ **__: object,
30
+ ) -> T_co:
31
+ coroutine: typing.Awaitable[T_co] = self._creator(*args, **kwargs)
30
32
  return await coroutine
31
-
32
- def sync_resolve(self, _: Container) -> typing.NoReturn:
33
- msg = "AsyncFactory cannot be resolved synchronously"
34
- raise RuntimeError(msg)