modern-di 0.14.0__py3-none-any.whl → 1.0.0a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. modern_di/__init__.py +6 -4
  2. modern_di/containers/__init__.py +0 -0
  3. modern_di/containers/abstract.py +172 -0
  4. modern_di/containers/async_container.py +97 -0
  5. modern_di/containers/sync_container.py +40 -0
  6. modern_di/group.py +26 -0
  7. modern_di/helpers/type_helpers.py +33 -0
  8. modern_di/providers/__init__.py +4 -4
  9. modern_di/providers/abstract.py +52 -108
  10. modern_di/providers/async_factory.py +10 -12
  11. modern_di/providers/async_singleton.py +37 -0
  12. modern_di/providers/container_provider.py +4 -8
  13. modern_di/providers/context_provider.py +16 -0
  14. modern_di/providers/dict.py +20 -13
  15. modern_di/providers/factory.py +18 -24
  16. modern_di/providers/list.py +20 -13
  17. modern_di/providers/object.py +6 -11
  18. modern_di/providers/resource.py +39 -67
  19. modern_di/providers/singleton.py +24 -44
  20. modern_di/registries/__init__.py +0 -0
  21. modern_di/registries/context_registry.py +16 -0
  22. modern_di/registries/overrides_registry.py +22 -0
  23. modern_di/registries/providers_registry.py +38 -0
  24. modern_di/registries/state_registry/__init__.py +0 -0
  25. modern_di/{provider_state.py → registries/state_registry/state.py} +2 -6
  26. modern_di/registries/state_registry/state_registry.py +30 -0
  27. {modern_di-0.14.0.dist-info → modern_di-1.0.0a3.dist-info}/METADATA +1 -1
  28. modern_di-1.0.0a3.dist-info/RECORD +32 -0
  29. modern_di/container.py +0 -162
  30. modern_di/graph.py +0 -39
  31. modern_di/helpers/attr_getter_helpers.py +0 -7
  32. modern_di/providers/context_adapter.py +0 -27
  33. modern_di/providers/injected_factory.py +0 -45
  34. modern_di/providers/selector.py +0 -39
  35. modern_di-0.14.0.dist-info/RECORD +0 -24
  36. {modern_di-0.14.0.dist-info → modern_di-1.0.0a3.dist-info}/WHEEL +0 -0
modern_di/__init__.py CHANGED
@@ -1,11 +1,13 @@
1
- from modern_di.container import Container
2
- from modern_di.graph import BaseGraph
1
+ from modern_di.containers.async_container import AsyncContainer
2
+ from modern_di.containers.sync_container import SyncContainer
3
+ from modern_di.group import Group
3
4
  from modern_di.scope import Scope
4
5
 
5
6
 
6
7
  __all__ = [
7
- "BaseGraph",
8
- "Container",
8
+ "AsyncContainer",
9
+ "Group",
9
10
  "Scope",
11
+ "SyncContainer",
10
12
  "providers",
11
13
  ]
File without changes
@@ -0,0 +1,172 @@
1
+ import asyncio
2
+ import enum
3
+ import threading
4
+ import typing
5
+
6
+ from modern_di.group import Group
7
+ from modern_di.providers import ContainerProvider
8
+ from modern_di.providers.abstract import AbstractProvider
9
+ from modern_di.providers.context_provider import ContextProvider
10
+ from modern_di.registries.context_registry import ContextRegistry
11
+ from modern_di.registries.overrides_registry import OverridesRegistry
12
+ from modern_di.registries.providers_registry import ProvidersRegistry
13
+ from modern_di.registries.state_registry.state_registry import StateRegistry
14
+ from modern_di.scope import Scope
15
+
16
+
17
+ if typing.TYPE_CHECKING:
18
+ import typing_extensions
19
+
20
+
21
+ T_co = typing.TypeVar("T_co", covariant=True)
22
+
23
+
24
+ class AbstractContainer:
25
+ BASE_SLOTS: typing.ClassVar[list[str]] = [
26
+ "_is_entered",
27
+ "_async_lock",
28
+ "_sync_lock",
29
+ "scope",
30
+ "parent_container",
31
+ "state_registry",
32
+ "context_registry",
33
+ "providers_registry",
34
+ "overrides_registry",
35
+ ]
36
+
37
+ def __init__( # noqa: PLR0913
38
+ self,
39
+ *,
40
+ scope: Scope = Scope.APP,
41
+ parent_container: typing.Optional["typing_extensions.Self"] = None,
42
+ context: dict[type[typing.Any], typing.Any] | None = None,
43
+ groups: list[type[Group]] | None = None,
44
+ use_async_lock: bool = True,
45
+ use_sync_lock: bool = True,
46
+ ) -> None:
47
+ self._is_entered = False
48
+ self._async_lock = asyncio.Lock() if use_async_lock else None
49
+ self._sync_lock = threading.Lock() if use_sync_lock else None
50
+ self.scope = scope
51
+ self.parent_container = parent_container
52
+ self.state_registry = StateRegistry()
53
+ self.context_registry = ContextRegistry(context or {})
54
+ self.providers_registry: ProvidersRegistry
55
+ self.overrides_registry: OverridesRegistry
56
+ if parent_container:
57
+ self.providers_registry = parent_container.providers_registry
58
+ self.overrides_registry = parent_container.overrides_registry
59
+ else:
60
+ self.providers_registry = ProvidersRegistry()
61
+ self.overrides_registry = OverridesRegistry()
62
+ if groups:
63
+ for one_group in groups:
64
+ self.providers_registry.add_providers(**one_group.get_providers())
65
+
66
+ def _check_entered(self) -> None:
67
+ if not self._is_entered:
68
+ msg = f"Enter the context of {self.scope.name} scope"
69
+ raise RuntimeError(msg)
70
+
71
+ def _resolve_context_provider(self, provider: ContextProvider[T_co]) -> T_co | None:
72
+ return self.context_registry.find_context(provider.context_type)
73
+
74
+ def build_child_container(
75
+ self, context: dict[type[typing.Any], typing.Any] | None = None, scope: Scope | None = None
76
+ ) -> "typing_extensions.Self":
77
+ self._check_entered()
78
+ if scope and scope <= self.scope:
79
+ msg = "Scope of child container must be more than current scope"
80
+ raise RuntimeError(msg)
81
+
82
+ if not scope:
83
+ try:
84
+ scope = self.scope.__class__(self.scope.value + 1)
85
+ except ValueError as exc:
86
+ msg = f"Max scope is reached, {self.scope.name}"
87
+ raise RuntimeError(msg) from exc
88
+
89
+ return self.__class__(scope=scope, parent_container=self, context=context)
90
+
91
+ def find_container(self, scope: enum.IntEnum) -> "typing_extensions.Self":
92
+ container = self
93
+ if container.scope < scope:
94
+ msg = f"Scope {scope.name} is not initialized"
95
+ raise RuntimeError(msg)
96
+
97
+ while container.scope > scope and container.parent_container:
98
+ container = container.parent_container
99
+
100
+ if container.scope != scope:
101
+ msg = f"Scope {scope.name} is skipped"
102
+ raise RuntimeError(msg)
103
+
104
+ return container
105
+
106
+ def override(self, provider: AbstractProvider[T_co], override_object: object) -> None:
107
+ self.overrides_registry.override(provider.provider_id, override_object)
108
+
109
+ def reset_override(self, provider: AbstractProvider[T_co] | None = None) -> None:
110
+ self.overrides_registry.reset_override(provider.provider_id if provider else None)
111
+
112
+ def _sync_resolve_args(self, args: list[typing.Any]) -> list[typing.Any]:
113
+ return [self._sync_resolve_provider(x) if isinstance(x, AbstractProvider) else x for x in args]
114
+
115
+ def _sync_resolve_kwargs(self, kwargs: dict[str, typing.Any]) -> dict[str, typing.Any]:
116
+ return {k: self._sync_resolve_provider(v) if isinstance(v, AbstractProvider) else v for k, v in kwargs.items()}
117
+
118
+ def _sync_resolve(self, dependency_type: type[T_co] | None = None, *, dependency_name: str | None = None) -> T_co:
119
+ provider = self.providers_registry.find_provider(
120
+ dependency_type=dependency_type, dependency_name=dependency_name
121
+ )
122
+ if not provider:
123
+ msg = f"Provider is not found, {dependency_type=}, {dependency_name=}"
124
+ raise RuntimeError(msg)
125
+
126
+ return self._sync_resolve_provider(provider)
127
+
128
+ def _sync_resolve_provider(self, provider: AbstractProvider[T_co]) -> T_co:
129
+ self._check_entered()
130
+ if provider.is_async:
131
+ msg = f"{type(provider).__name__} cannot be resolved synchronously"
132
+ raise RuntimeError(msg)
133
+
134
+ container = self.find_container(provider.scope)
135
+ if isinstance(provider, ContainerProvider):
136
+ return typing.cast(T_co, container)
137
+
138
+ if isinstance(provider, ContextProvider):
139
+ return typing.cast(T_co, self._resolve_context_provider(provider))
140
+
141
+ if (override := container.overrides_registry.fetch_override(provider.provider_id)) is not None:
142
+ return typing.cast(T_co, override)
143
+
144
+ provider_state = container.state_registry.fetch_provider_state(provider)
145
+ if provider_state and provider_state.instance is not None:
146
+ return provider_state.instance
147
+
148
+ args = self._sync_resolve_args(provider.args or [])
149
+ kwargs = self._sync_resolve_kwargs(provider.kwargs or {})
150
+
151
+ if provider_state and self._sync_lock:
152
+ self._sync_lock.acquire()
153
+ try:
154
+ if provider_state and provider_state.instance is not None:
155
+ return provider_state.instance
156
+
157
+ return provider.sync_resolve(
158
+ args=args,
159
+ kwargs=kwargs,
160
+ provider_state=provider_state,
161
+ )
162
+ finally:
163
+ if provider_state and self._sync_lock:
164
+ self._sync_lock.release()
165
+
166
+ def __deepcopy__(self, *_: object, **__: object) -> "typing_extensions.Self":
167
+ """Hack to prevent cloning object."""
168
+ return self
169
+
170
+ def __copy__(self, *_: object, **__: object) -> "typing_extensions.Self":
171
+ """Hack to prevent cloning object."""
172
+ return self
@@ -0,0 +1,97 @@
1
+ import contextlib
2
+ import types
3
+ import typing
4
+
5
+ from modern_di.containers.abstract import AbstractContainer
6
+ from modern_di.providers.abstract import AbstractProvider
7
+ from modern_di.providers.container_provider import ContainerProvider
8
+ from modern_di.providers.context_provider import ContextProvider
9
+
10
+
11
+ if typing.TYPE_CHECKING:
12
+ pass
13
+
14
+
15
+ T_co = typing.TypeVar("T_co", covariant=True)
16
+
17
+
18
+ class AsyncContainer(contextlib.AbstractAsyncContextManager["AsyncContainer"], AbstractContainer):
19
+ __slots__ = AbstractContainer.BASE_SLOTS
20
+
21
+ async def _resolve_args(self, args: list[typing.Any]) -> list[typing.Any]:
22
+ return [await self.resolve_provider(x) if isinstance(x, AbstractProvider) else x for x in args]
23
+
24
+ async def _resolve_kwargs(self, kwargs: dict[str, typing.Any]) -> dict[str, typing.Any]:
25
+ return {k: await self.resolve_provider(v) if isinstance(v, AbstractProvider) else v for k, v in kwargs.items()}
26
+
27
+ async def resolve(self, dependency_type: type[T_co] | None = None, *, dependency_name: str | None = None) -> T_co:
28
+ provider = self.providers_registry.find_provider(
29
+ dependency_type=dependency_type, dependency_name=dependency_name
30
+ )
31
+ if not provider:
32
+ msg = f"Provider is not found, {dependency_type=}, {dependency_name=}"
33
+ raise RuntimeError(msg)
34
+
35
+ return await self.resolve_provider(provider)
36
+
37
+ async def resolve_provider(self, provider: AbstractProvider[T_co]) -> T_co:
38
+ self._check_entered()
39
+
40
+ container = self.find_container(provider.scope)
41
+ if isinstance(provider, ContainerProvider):
42
+ return typing.cast(T_co, container)
43
+
44
+ if isinstance(provider, ContextProvider):
45
+ return typing.cast(T_co, self._resolve_context_provider(provider))
46
+
47
+ if (override := container.overrides_registry.fetch_override(provider.provider_id)) is not None:
48
+ return typing.cast(T_co, override)
49
+
50
+ provider_state = container.state_registry.fetch_provider_state(provider)
51
+ if provider_state and provider_state.instance is not None:
52
+ return provider_state.instance
53
+
54
+ args = await self._resolve_args(provider.args or [])
55
+ kwargs = await self._resolve_kwargs(provider.kwargs or {})
56
+
57
+ if provider_state and self._async_lock:
58
+ await self._async_lock.acquire()
59
+ try:
60
+ if provider_state and provider_state.instance is not None:
61
+ return provider_state.instance
62
+
63
+ return await provider.async_resolve(
64
+ args=args,
65
+ kwargs=kwargs,
66
+ provider_state=provider_state,
67
+ )
68
+ finally:
69
+ if provider_state and self._async_lock:
70
+ self._async_lock.release()
71
+
72
+ def sync_resolve(self, dependency_type: type[T_co] | None = None, *, dependency_name: str | None = None) -> T_co:
73
+ return self._sync_resolve(dependency_type=dependency_type, dependency_name=dependency_name)
74
+
75
+ def sync_resolve_provider(self, provider: AbstractProvider[T_co]) -> T_co:
76
+ return self._sync_resolve_provider(provider)
77
+
78
+ def enter(self) -> "AsyncContainer":
79
+ self._is_entered = True
80
+ return self
81
+
82
+ async def close(self) -> None:
83
+ self._check_entered()
84
+ self._is_entered = False
85
+ await self.state_registry.async_tear_down()
86
+ self.overrides_registry.reset_override()
87
+
88
+ async def __aenter__(self) -> "AsyncContainer":
89
+ return self.enter()
90
+
91
+ async def __aexit__(
92
+ self,
93
+ exc_type: type[BaseException] | None,
94
+ exc_val: BaseException | None,
95
+ traceback: types.TracebackType | None,
96
+ ) -> None:
97
+ await self.close()
@@ -0,0 +1,40 @@
1
+ import contextlib
2
+ import types
3
+ import typing
4
+
5
+ from modern_di.containers.abstract import AbstractContainer
6
+ from modern_di.providers.abstract import AbstractProvider
7
+
8
+
9
+ T_co = typing.TypeVar("T_co", covariant=True)
10
+
11
+
12
+ class SyncContainer(contextlib.AbstractContextManager["SyncContainer"], AbstractContainer):
13
+ __slots__ = AbstractContainer.BASE_SLOTS
14
+
15
+ def resolve(self, dependency_type: type[T_co] | None = None, *, dependency_name: str | None = None) -> T_co:
16
+ return self._sync_resolve(dependency_type=dependency_type, dependency_name=dependency_name)
17
+
18
+ def resolve_provider(self, provider: AbstractProvider[T_co]) -> T_co:
19
+ return self._sync_resolve_provider(provider)
20
+
21
+ def enter(self) -> "SyncContainer":
22
+ self._is_entered = True
23
+ return self
24
+
25
+ def close(self) -> None:
26
+ self._check_entered()
27
+ self._is_entered = False
28
+ self.state_registry.sync_tear_down()
29
+ self.overrides_registry.reset_override()
30
+
31
+ def __enter__(self) -> "SyncContainer":
32
+ return self.enter()
33
+
34
+ def __exit__(
35
+ self,
36
+ exc_type: type[BaseException] | None,
37
+ exc_value: BaseException | None,
38
+ traceback: types.TracebackType | None,
39
+ ) -> None:
40
+ self.close()
modern_di/group.py ADDED
@@ -0,0 +1,26 @@
1
+ import typing
2
+
3
+ from modern_di.providers.abstract import AbstractProvider
4
+
5
+
6
+ if typing.TYPE_CHECKING:
7
+ import typing_extensions
8
+
9
+
10
+ T = typing.TypeVar("T")
11
+ P = typing.ParamSpec("P")
12
+
13
+
14
+ class Group:
15
+ providers: dict[str, AbstractProvider[typing.Any]]
16
+
17
+ def __new__(cls, *_: typing.Any, **__: typing.Any) -> "typing_extensions.Self": # noqa: ANN401
18
+ msg = f"{cls.__name__} cannot not be instantiated"
19
+ raise RuntimeError(msg)
20
+
21
+ @classmethod
22
+ def get_providers(cls) -> dict[str, AbstractProvider[typing.Any]]:
23
+ if not hasattr(cls, "providers"):
24
+ cls.providers = {k: v for k, v in cls.__dict__.items() if isinstance(v, AbstractProvider)}
25
+
26
+ return cls.providers
@@ -0,0 +1,33 @@
1
+ import collections
2
+ import types
3
+ import typing
4
+
5
+
6
+ GENERIC_TYPES = {
7
+ typing.Iterator,
8
+ typing.AsyncIterator,
9
+ collections.abc.Iterator,
10
+ collections.abc.AsyncIterator,
11
+ }
12
+
13
+
14
+ def define_bound_type(creator: type | object) -> type | None:
15
+ if isinstance(creator, type):
16
+ return creator
17
+
18
+ type_hints = typing.get_type_hints(creator)
19
+ return_annotation = type_hints.get("return")
20
+ if not return_annotation:
21
+ return None
22
+
23
+ if isinstance(return_annotation, type) and not isinstance(return_annotation, types.GenericAlias):
24
+ return return_annotation
25
+
26
+ if typing.get_origin(return_annotation) not in GENERIC_TYPES:
27
+ return None
28
+
29
+ args = typing.get_args(return_annotation)
30
+ if not args:
31
+ return None
32
+
33
+ return typing.cast(type, args[0])
@@ -1,26 +1,26 @@
1
1
  from modern_di.providers.abstract import AbstractProvider
2
2
  from modern_di.providers.async_factory import AsyncFactory
3
+ from modern_di.providers.async_singleton import AsyncSingleton
3
4
  from modern_di.providers.container_provider import ContainerProvider
4
- from modern_di.providers.context_adapter import ContextAdapter
5
+ from modern_di.providers.context_provider import ContextProvider
5
6
  from modern_di.providers.dict import Dict
6
7
  from modern_di.providers.factory import Factory
7
8
  from modern_di.providers.list import List
8
9
  from modern_di.providers.object import Object
9
10
  from modern_di.providers.resource import Resource
10
- from modern_di.providers.selector import Selector
11
11
  from modern_di.providers.singleton import Singleton
12
12
 
13
13
 
14
14
  __all__ = [
15
15
  "AbstractProvider",
16
16
  "AsyncFactory",
17
+ "AsyncSingleton",
17
18
  "ContainerProvider",
18
- "ContextAdapter",
19
+ "ContextProvider",
19
20
  "Dict",
20
21
  "Factory",
21
22
  "List",
22
23
  "Object",
23
24
  "Resource",
24
- "Selector",
25
25
  "Singleton",
26
26
  ]
@@ -1,13 +1,12 @@
1
1
  import abc
2
2
  import enum
3
- import itertools
4
3
  import typing
5
4
  import uuid
6
5
 
7
- from typing_extensions import override
6
+ import typing_extensions
8
7
 
9
- from modern_di import Container
10
- from modern_di.helpers.attr_getter_helpers import get_value_from_object_by_dotted_path
8
+ from modern_di.helpers.type_helpers import define_bound_type
9
+ from modern_di.registries.state_registry.state import ProviderState
11
10
 
12
11
 
13
12
  T_co = typing.TypeVar("T_co", covariant=True)
@@ -15,57 +14,69 @@ R = typing.TypeVar("R")
15
14
  P = typing.ParamSpec("P")
16
15
 
17
16
 
18
- class AbstractProvider(typing.Generic[T_co], abc.ABC):
19
- BASE_SLOTS: typing.ClassVar = ["scope", "provider_id"]
17
+ class AbstractProvider(abc.ABC, typing.Generic[T_co]):
18
+ BASE_SLOTS: typing.ClassVar = ["scope", "provider_id", "args", "kwargs", "is_async", "bound_type"]
19
+ HAS_STATE: bool = False
20
20
 
21
- def __init__(self, scope: enum.IntEnum) -> None:
21
+ def __init__(
22
+ self,
23
+ scope: enum.IntEnum,
24
+ args: list[typing.Any] | None = None,
25
+ kwargs: dict[str, typing.Any] | None = None,
26
+ bound_type: type | None = None,
27
+ ) -> None:
22
28
  self.scope = scope
23
29
  self.provider_id: typing.Final = str(uuid.uuid4())
30
+ self.args = args
31
+ self.kwargs = kwargs
32
+ self.is_async = False
33
+ self.bound_type = bound_type
34
+ self._check_providers_scope()
35
+
36
+ def bind_type(self, new_type: type) -> typing_extensions.Self:
37
+ self.bound_type = new_type
38
+ return self
24
39
 
25
- @abc.abstractmethod
26
- async def async_resolve(self, container: Container) -> T_co:
40
+ async def async_resolve(
41
+ self,
42
+ *,
43
+ args: list[typing.Any],
44
+ kwargs: dict[str, typing.Any],
45
+ provider_state: ProviderState[T_co] | None,
46
+ ) -> T_co: # pragma: no cover
27
47
  """Resolve dependency asynchronously."""
48
+ raise NotImplementedError
28
49
 
29
- @abc.abstractmethod
30
- def sync_resolve(self, container: Container) -> T_co:
50
+ def sync_resolve(
51
+ self,
52
+ *,
53
+ args: list[typing.Any],
54
+ kwargs: dict[str, typing.Any],
55
+ provider_state: ProviderState[T_co] | None,
56
+ ) -> T_co: # pragma: no cover
31
57
  """Resolve dependency synchronously."""
58
+ raise NotImplementedError
32
59
 
33
60
  @property
34
61
  def cast(self) -> T_co:
35
62
  return typing.cast(T_co, self)
36
63
 
37
- def _check_providers_scope(self, providers: typing.Iterable[typing.Any]) -> None:
38
- if any(x.scope > self.scope for x in providers if isinstance(x, AbstractProvider)):
39
- msg = "Scope of dependency cannot be more than scope of dependent"
40
- raise RuntimeError(msg)
41
-
42
- def __getattr__(self, attr_name: str) -> typing.Any: # noqa: ANN401
43
- """Get an attribute from the resolve object.
64
+ def _check_providers_scope(self) -> None:
65
+ if self.args:
66
+ for provider in self.args:
67
+ if isinstance(provider, AbstractProvider) and provider.scope > self.scope:
68
+ msg = f"Scope of dependency is {provider.scope.name} and current scope is {self.scope.name}"
69
+ raise RuntimeError(msg)
44
70
 
45
- Args:
46
- attr_name: name of attribute to get.
71
+ if self.kwargs:
72
+ for name, provider in self.kwargs.items():
73
+ if isinstance(provider, AbstractProvider) and provider.scope > self.scope:
74
+ msg = f"Scope of {name} is {provider.scope.name} and current scope is {self.scope.name}"
75
+ raise RuntimeError(msg)
47
76
 
48
- Returns:
49
- An `AttrGetter` provider that will get the attribute after resolving the current provider.
50
77
 
51
- """
52
- if attr_name.startswith("_"):
53
- msg = f"'{type(self)}' object has no attribute '{attr_name}'"
54
- raise AttributeError(msg)
55
-
56
- return AttrGetter(provider=self, attr_name=attr_name)
57
-
58
-
59
- class AbstractOverrideProvider(AbstractProvider[T_co], abc.ABC):
60
- def override(self, override_object: object, container: Container) -> None:
61
- container.override(self.provider_id, override_object)
62
-
63
- def reset_override(self, container: Container) -> None:
64
- container.reset_override(self.provider_id)
65
-
66
-
67
- class AbstractCreatorProvider(AbstractOverrideProvider[T_co], abc.ABC):
68
- BASE_SLOTS: typing.ClassVar = [*AbstractProvider.BASE_SLOTS, "_args", "_kwargs", "_creator"]
78
+ class AbstractCreatorProvider(AbstractProvider[T_co], abc.ABC):
79
+ BASE_SLOTS: typing.ClassVar = [*AbstractProvider.BASE_SLOTS, "_creator"]
69
80
 
70
81
  def __init__(
71
82
  self,
@@ -74,72 +85,5 @@ class AbstractCreatorProvider(AbstractOverrideProvider[T_co], abc.ABC):
74
85
  *args: P.args,
75
86
  **kwargs: P.kwargs,
76
87
  ) -> None:
77
- super().__init__(scope)
78
- self._check_providers_scope(itertools.chain(args, kwargs.values()))
88
+ super().__init__(scope, args=list(args), kwargs=kwargs, bound_type=define_bound_type(creator))
79
89
  self._creator: typing.Final = creator
80
- self._args: typing.Final = args
81
- self._kwargs: typing.Final = kwargs
82
-
83
- def _sync_resolve_args(self, container: Container) -> list[typing.Any]:
84
- return [x.sync_resolve(container) if isinstance(x, AbstractProvider) else x for x in self._args]
85
-
86
- def _sync_resolve_kwargs(self, container: Container) -> dict[str, typing.Any]:
87
- return {k: v.sync_resolve(container) if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}
88
-
89
- def _sync_build_creator(self, container: Container) -> typing.Any: # noqa: ANN401
90
- return self._creator(
91
- *self._sync_resolve_args(container),
92
- **self._sync_resolve_kwargs(container),
93
- )
94
-
95
- async def _async_resolve_args(self, container: Container) -> list[typing.Any]:
96
- return [await x.async_resolve(container) if isinstance(x, AbstractProvider) else x for x in self._args]
97
-
98
- async def _async_resolve_kwargs(self, container: Container) -> dict[str, typing.Any]:
99
- return {
100
- k: await v.async_resolve(container) if isinstance(v, AbstractProvider) else v
101
- for k, v in self._kwargs.items()
102
- }
103
-
104
- async def _async_build_creator(self, container: Container) -> typing.Any: # noqa: ANN401
105
- return self._creator(
106
- *await self._async_resolve_args(container),
107
- **await self._async_resolve_kwargs(container),
108
- )
109
-
110
-
111
- class AttrGetter(AbstractProvider[T_co]):
112
- """Provides an attribute after resolving the wrapped provider."""
113
-
114
- __slots__ = [*AbstractProvider.BASE_SLOTS, "_attrs", "_provider"]
115
-
116
- def __init__(self, provider: AbstractProvider[T_co], attr_name: str) -> None:
117
- """Create a new AttrGetter instance.
118
-
119
- Args:
120
- provider: provider to wrap.
121
- attr_name: attribute name to resolve when the provider is resolved.
122
-
123
- """
124
- super().__init__(scope=provider.scope)
125
- self._provider = provider
126
- self._attrs = [attr_name]
127
-
128
- def __getattr__(self, attr: str) -> "AttrGetter[T_co]":
129
- if attr.startswith("_"):
130
- msg = f"'{type(self)}' object has no attribute '{attr}'"
131
- raise AttributeError(msg)
132
- self._attrs.append(attr)
133
- return self
134
-
135
- @override
136
- async def async_resolve(self, container: Container) -> typing.Any:
137
- resolved_provider_object = await self._provider.async_resolve(container)
138
- attribute_path = ".".join(self._attrs)
139
- return get_value_from_object_by_dotted_path(resolved_provider_object, attribute_path)
140
-
141
- @override
142
- def sync_resolve(self, container: Container) -> typing.Any:
143
- resolved_provider_object = self._provider.sync_resolve(container)
144
- attribute_path = ".".join(self._attrs)
145
- return get_value_from_object_by_dotted_path(resolved_provider_object, attribute_path)
@@ -1,7 +1,6 @@
1
1
  import enum
2
2
  import typing
3
3
 
4
- from modern_di import Container
5
4
  from modern_di.providers.abstract import AbstractCreatorProvider
6
5
 
7
6
 
@@ -10,7 +9,7 @@ P = typing.ParamSpec("P")
10
9
 
11
10
 
12
11
  class AsyncFactory(AbstractCreatorProvider[T_co]):
13
- __slots__ = [*AbstractCreatorProvider.BASE_SLOTS, "_creator"]
12
+ __slots__ = AbstractCreatorProvider.BASE_SLOTS
14
13
 
15
14
  def __init__(
16
15
  self,
@@ -20,15 +19,14 @@ class AsyncFactory(AbstractCreatorProvider[T_co]):
20
19
  **kwargs: P.kwargs,
21
20
  ) -> None:
22
21
  super().__init__(scope, creator, *args, **kwargs)
22
+ self.is_async = True
23
23
 
24
- async def async_resolve(self, container: Container) -> T_co:
25
- container = container.find_container(self.scope)
26
- if (override := container.fetch_override(self.provider_id)) is not None:
27
- return typing.cast(T_co, override)
28
-
29
- coroutine: typing.Awaitable[T_co] = await self._async_build_creator(container)
24
+ async def async_resolve(
25
+ self,
26
+ *,
27
+ args: list[typing.Any],
28
+ kwargs: dict[str, typing.Any],
29
+ **__: object,
30
+ ) -> T_co:
31
+ coroutine: typing.Awaitable[T_co] = self._creator(*args, **kwargs)
30
32
  return await coroutine
31
-
32
- def sync_resolve(self, _: Container) -> typing.NoReturn:
33
- msg = "AsyncFactory cannot be resolved synchronously"
34
- raise RuntimeError(msg)