modern-di 0.16.3__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of modern-di might be problematic. Click here for more details.
- modern_di/__init__.py +6 -4
- modern_di/containers/__init__.py +0 -0
- modern_di/containers/abstract.py +110 -0
- modern_di/containers/async_container.py +104 -0
- modern_di/containers/sync_container.py +105 -0
- modern_di/group.py +26 -0
- modern_di/helpers/type_helpers.py +33 -0
- modern_di/providers/__init__.py +2 -4
- modern_di/providers/abstract.py +62 -74
- modern_di/providers/async_factory.py +9 -11
- modern_di/providers/async_singleton.py +14 -26
- modern_di/providers/container_provider.py +4 -8
- modern_di/providers/context_provider.py +16 -0
- modern_di/providers/dict.py +20 -13
- modern_di/providers/factory.py +17 -23
- modern_di/providers/list.py +20 -13
- modern_di/providers/object.py +6 -11
- modern_di/providers/resource.py +38 -66
- modern_di/providers/singleton.py +23 -43
- modern_di/registries/__init__.py +0 -0
- modern_di/registries/context_registry.py +16 -0
- modern_di/registries/overrides_registry.py +22 -0
- modern_di/registries/providers_registry.py +38 -0
- modern_di/registries/state_registry/__init__.py +0 -0
- modern_di/{provider_state.py → registries/state_registry/state.py} +15 -11
- modern_di/registries/state_registry/state_registry.py +52 -0
- {modern_di-0.16.3.dist-info → modern_di-0.17.0.dist-info}/METADATA +1 -1
- modern_di-0.17.0.dist-info/RECORD +33 -0
- modern_di/container.py +0 -171
- modern_di/graph.py +0 -39
- modern_di/providers/context_adapter.py +0 -27
- modern_di/providers/injected_factory.py +0 -45
- modern_di/providers/selector.py +0 -39
- modern_di-0.16.3.dist-info/RECORD +0 -25
- {modern_di-0.16.3.dist-info → modern_di-0.17.0.dist-info}/WHEEL +0 -0
modern_di/__init__.py
CHANGED
|
@@ -1,11 +1,13 @@
|
|
|
1
|
-
from modern_di.
|
|
2
|
-
from modern_di.
|
|
1
|
+
from modern_di.containers.async_container import AsyncContainer
|
|
2
|
+
from modern_di.containers.sync_container import SyncContainer
|
|
3
|
+
from modern_di.group import Group
|
|
3
4
|
from modern_di.scope import Scope
|
|
4
5
|
|
|
5
6
|
|
|
6
7
|
__all__ = [
|
|
7
|
-
"
|
|
8
|
-
"
|
|
8
|
+
"AsyncContainer",
|
|
9
|
+
"Group",
|
|
9
10
|
"Scope",
|
|
11
|
+
"SyncContainer",
|
|
10
12
|
"providers",
|
|
11
13
|
]
|
|
File without changes
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
import enum
|
|
2
|
+
import typing
|
|
3
|
+
|
|
4
|
+
from modern_di.group import Group
|
|
5
|
+
from modern_di.providers.abstract import AbstractProvider
|
|
6
|
+
from modern_di.providers.context_provider import ContextProvider
|
|
7
|
+
from modern_di.registries.context_registry import ContextRegistry
|
|
8
|
+
from modern_di.registries.overrides_registry import OverridesRegistry
|
|
9
|
+
from modern_di.registries.providers_registry import ProvidersRegistry
|
|
10
|
+
from modern_di.scope import Scope
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
if typing.TYPE_CHECKING:
|
|
14
|
+
import typing_extensions
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
T_co = typing.TypeVar("T_co", covariant=True)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class AbstractContainer:
|
|
21
|
+
BASE_SLOTS: typing.ClassVar[list[str]] = [
|
|
22
|
+
"_is_entered",
|
|
23
|
+
"state_registry",
|
|
24
|
+
"context",
|
|
25
|
+
"parent_container",
|
|
26
|
+
"scope",
|
|
27
|
+
"overrides_registry",
|
|
28
|
+
"providers_registry",
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
*,
|
|
34
|
+
scope: Scope = Scope.APP,
|
|
35
|
+
parent_container: typing.Optional["typing_extensions.Self"] = None,
|
|
36
|
+
context: dict[type[typing.Any], typing.Any] | None = None,
|
|
37
|
+
groups: list[type[Group]] | None = None,
|
|
38
|
+
) -> None:
|
|
39
|
+
self._is_entered = False
|
|
40
|
+
self.scope = scope
|
|
41
|
+
self.parent_container = parent_container
|
|
42
|
+
self.providers_registry = ProvidersRegistry()
|
|
43
|
+
if groups:
|
|
44
|
+
for one_group in groups:
|
|
45
|
+
self.providers_registry.add_providers(**one_group.get_providers())
|
|
46
|
+
self.overrides_registry: OverridesRegistry
|
|
47
|
+
self.context_registry = ContextRegistry(context or {})
|
|
48
|
+
if parent_container:
|
|
49
|
+
self.overrides_registry = parent_container.overrides_registry
|
|
50
|
+
else:
|
|
51
|
+
self.overrides_registry = OverridesRegistry()
|
|
52
|
+
|
|
53
|
+
def _check_entered(self) -> None:
|
|
54
|
+
if not self._is_entered:
|
|
55
|
+
msg = f"Enter the context of {self.scope.name} scope"
|
|
56
|
+
raise RuntimeError(msg)
|
|
57
|
+
|
|
58
|
+
def _resolve_context_provider(self, provider: ContextProvider[T_co]) -> T_co:
|
|
59
|
+
context = self.context_registry.find_context(provider.context_type)
|
|
60
|
+
if not context:
|
|
61
|
+
msg = f"Context of type {provider.context_type} is missing"
|
|
62
|
+
raise RuntimeError(msg)
|
|
63
|
+
|
|
64
|
+
return context
|
|
65
|
+
|
|
66
|
+
def build_child_container(
|
|
67
|
+
self, context: dict[type[typing.Any], typing.Any] | None = None, scope: Scope | None = None
|
|
68
|
+
) -> "typing_extensions.Self":
|
|
69
|
+
self._check_entered()
|
|
70
|
+
if scope and scope <= self.scope:
|
|
71
|
+
msg = "Scope of child container must be more than current scope"
|
|
72
|
+
raise RuntimeError(msg)
|
|
73
|
+
|
|
74
|
+
if not scope:
|
|
75
|
+
try:
|
|
76
|
+
scope = self.scope.__class__(self.scope.value + 1)
|
|
77
|
+
except ValueError as exc:
|
|
78
|
+
msg = f"Max scope is reached, {self.scope.name}"
|
|
79
|
+
raise RuntimeError(msg) from exc
|
|
80
|
+
|
|
81
|
+
return self.__class__(scope=scope, parent_container=self, context=context)
|
|
82
|
+
|
|
83
|
+
def find_container(self, scope: enum.IntEnum) -> "typing_extensions.Self":
|
|
84
|
+
container = self
|
|
85
|
+
if container.scope < scope:
|
|
86
|
+
msg = f"Scope {scope.name} is not initialized"
|
|
87
|
+
raise RuntimeError(msg)
|
|
88
|
+
|
|
89
|
+
while container.scope > scope and container.parent_container:
|
|
90
|
+
container = container.parent_container
|
|
91
|
+
|
|
92
|
+
if container.scope != scope:
|
|
93
|
+
msg = f"Scope {scope.name} is skipped"
|
|
94
|
+
raise RuntimeError(msg)
|
|
95
|
+
|
|
96
|
+
return container
|
|
97
|
+
|
|
98
|
+
def override(self, provider: AbstractProvider[T_co], override_object: object) -> None:
|
|
99
|
+
self.overrides_registry.override(provider.provider_id, override_object)
|
|
100
|
+
|
|
101
|
+
def reset_override(self, provider: AbstractProvider[T_co] | None = None) -> None:
|
|
102
|
+
self.overrides_registry.reset_override(provider.provider_id if provider else None)
|
|
103
|
+
|
|
104
|
+
def __deepcopy__(self, *_: object, **__: object) -> "typing_extensions.Self":
|
|
105
|
+
"""Hack to prevent cloning object."""
|
|
106
|
+
return self
|
|
107
|
+
|
|
108
|
+
def __copy__(self, *_: object, **__: object) -> "typing_extensions.Self":
|
|
109
|
+
"""Hack to prevent cloning object."""
|
|
110
|
+
return self
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
import types
|
|
3
|
+
import typing
|
|
4
|
+
|
|
5
|
+
from modern_di.containers.abstract import AbstractContainer
|
|
6
|
+
from modern_di.group import Group
|
|
7
|
+
from modern_di.providers.abstract import AbstractProvider
|
|
8
|
+
from modern_di.providers.container_provider import ContainerProvider
|
|
9
|
+
from modern_di.providers.context_provider import ContextProvider
|
|
10
|
+
from modern_di.registries.state_registry.state_registry import AsyncStateRegistry
|
|
11
|
+
from modern_di.scope import Scope
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
if typing.TYPE_CHECKING:
|
|
15
|
+
import typing_extensions
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
T_co = typing.TypeVar("T_co", covariant=True)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class AsyncContainer(contextlib.AbstractAsyncContextManager["AsyncContainer"], AbstractContainer):
|
|
22
|
+
__slots__ = AbstractContainer.BASE_SLOTS
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
*,
|
|
27
|
+
scope: Scope = Scope.APP,
|
|
28
|
+
parent_container: typing.Optional["typing_extensions.Self"] = None,
|
|
29
|
+
context: dict[type[typing.Any], typing.Any] | None = None,
|
|
30
|
+
groups: list[type[Group]] | None = None,
|
|
31
|
+
use_lock: bool = True,
|
|
32
|
+
) -> None:
|
|
33
|
+
super().__init__(scope=scope, parent_container=parent_container, context=context, groups=groups)
|
|
34
|
+
self.state_registry = AsyncStateRegistry(use_lock=use_lock)
|
|
35
|
+
|
|
36
|
+
async def _resolve_args(self, args: list[typing.Any]) -> list[typing.Any]:
|
|
37
|
+
return [await self.resolve_provider(x) if isinstance(x, AbstractProvider) else x for x in args]
|
|
38
|
+
|
|
39
|
+
async def _resolve_kwargs(self, kwargs: dict[str, typing.Any]) -> dict[str, typing.Any]:
|
|
40
|
+
return {k: await self.resolve_provider(v) if isinstance(v, AbstractProvider) else v for k, v in kwargs.items()}
|
|
41
|
+
|
|
42
|
+
async def resolve(
|
|
43
|
+
self, dependency_type: type[T_co] | None = None, *, dependency_name: str | None = None
|
|
44
|
+
) -> T_co | None:
|
|
45
|
+
provider = self.providers_registry.find_provider(
|
|
46
|
+
dependency_type=dependency_type, dependency_name=dependency_name
|
|
47
|
+
)
|
|
48
|
+
if not provider:
|
|
49
|
+
return None
|
|
50
|
+
|
|
51
|
+
return await self.resolve_provider(provider)
|
|
52
|
+
|
|
53
|
+
async def resolve_provider(self, provider: AbstractProvider[T_co]) -> T_co:
|
|
54
|
+
self._check_entered()
|
|
55
|
+
|
|
56
|
+
container = self.find_container(provider.scope)
|
|
57
|
+
if isinstance(provider, ContainerProvider):
|
|
58
|
+
return typing.cast(T_co, container)
|
|
59
|
+
|
|
60
|
+
if isinstance(provider, ContextProvider):
|
|
61
|
+
return typing.cast(T_co, self._resolve_context_provider(provider))
|
|
62
|
+
|
|
63
|
+
if (override := container.overrides_registry.fetch_override(provider.provider_id)) is not None:
|
|
64
|
+
return typing.cast(T_co, override)
|
|
65
|
+
|
|
66
|
+
provider_state = container.state_registry.fetch_provider_state(provider)
|
|
67
|
+
if provider_state and provider_state.instance is not None:
|
|
68
|
+
return provider_state.instance
|
|
69
|
+
|
|
70
|
+
if provider_state and provider_state.lock:
|
|
71
|
+
await provider_state.lock.acquire()
|
|
72
|
+
try:
|
|
73
|
+
if provider_state and provider_state.instance is not None:
|
|
74
|
+
return provider_state.instance
|
|
75
|
+
|
|
76
|
+
return await provider.async_resolve(
|
|
77
|
+
args=await self._resolve_args(provider.args or []),
|
|
78
|
+
kwargs=await self._resolve_kwargs(provider.kwargs or {}),
|
|
79
|
+
provider_state=provider_state,
|
|
80
|
+
)
|
|
81
|
+
finally:
|
|
82
|
+
if provider_state and provider_state.lock:
|
|
83
|
+
provider_state.lock.release()
|
|
84
|
+
|
|
85
|
+
def enter(self) -> "AsyncContainer":
|
|
86
|
+
self._is_entered = True
|
|
87
|
+
return self
|
|
88
|
+
|
|
89
|
+
async def close(self) -> None:
|
|
90
|
+
self._check_entered()
|
|
91
|
+
self._is_entered = False
|
|
92
|
+
await self.state_registry.clear_state()
|
|
93
|
+
self.overrides_registry.reset_override()
|
|
94
|
+
|
|
95
|
+
async def __aenter__(self) -> "AsyncContainer":
|
|
96
|
+
return self.enter()
|
|
97
|
+
|
|
98
|
+
async def __aexit__(
|
|
99
|
+
self,
|
|
100
|
+
exc_type: type[BaseException] | None,
|
|
101
|
+
exc_val: BaseException | None,
|
|
102
|
+
traceback: types.TracebackType | None,
|
|
103
|
+
) -> None:
|
|
104
|
+
await self.close()
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
import types
|
|
3
|
+
import typing
|
|
4
|
+
|
|
5
|
+
from modern_di.containers.abstract import AbstractContainer
|
|
6
|
+
from modern_di.group import Group
|
|
7
|
+
from modern_di.providers import ContainerProvider
|
|
8
|
+
from modern_di.providers.abstract import AbstractProvider
|
|
9
|
+
from modern_di.providers.context_provider import ContextProvider
|
|
10
|
+
from modern_di.registries.state_registry.state_registry import SyncStateRegistry
|
|
11
|
+
from modern_di.scope import Scope
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
if typing.TYPE_CHECKING:
|
|
15
|
+
import typing_extensions
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
T_co = typing.TypeVar("T_co", covariant=True)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class SyncContainer(contextlib.AbstractContextManager["SyncContainer"], AbstractContainer):
|
|
22
|
+
__slots__ = AbstractContainer.BASE_SLOTS
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
*,
|
|
27
|
+
scope: Scope = Scope.APP,
|
|
28
|
+
parent_container: typing.Optional["typing_extensions.Self"] = None,
|
|
29
|
+
context: dict[type[typing.Any], typing.Any] | None = None,
|
|
30
|
+
groups: list[type[Group]] | None = None,
|
|
31
|
+
use_lock: bool = True,
|
|
32
|
+
) -> None:
|
|
33
|
+
super().__init__(scope=scope, parent_container=parent_container, context=context, groups=groups)
|
|
34
|
+
self.state_registry = SyncStateRegistry(use_lock=use_lock)
|
|
35
|
+
|
|
36
|
+
def _resolve_args(self, args: list[typing.Any]) -> list[typing.Any]:
|
|
37
|
+
return [self.resolve_provider(x) if isinstance(x, AbstractProvider) else x for x in args]
|
|
38
|
+
|
|
39
|
+
def _resolve_kwargs(self, kwargs: dict[str, typing.Any]) -> dict[str, typing.Any]:
|
|
40
|
+
return {k: self.resolve_provider(v) if isinstance(v, AbstractProvider) else v for k, v in kwargs.items()}
|
|
41
|
+
|
|
42
|
+
def resolve(self, dependency_type: type[T_co] | None = None, *, dependency_name: str | None = None) -> T_co | None:
|
|
43
|
+
provider = self.providers_registry.find_provider(
|
|
44
|
+
dependency_type=dependency_type, dependency_name=dependency_name
|
|
45
|
+
)
|
|
46
|
+
if not provider:
|
|
47
|
+
return None
|
|
48
|
+
|
|
49
|
+
return self.resolve_provider(provider)
|
|
50
|
+
|
|
51
|
+
def resolve_provider(self, provider: AbstractProvider[T_co]) -> T_co:
|
|
52
|
+
self._check_entered()
|
|
53
|
+
if provider.is_async:
|
|
54
|
+
msg = f"{type(provider).__name__} cannot be resolved synchronously"
|
|
55
|
+
raise RuntimeError(msg)
|
|
56
|
+
|
|
57
|
+
container = self.find_container(provider.scope)
|
|
58
|
+
if isinstance(provider, ContainerProvider):
|
|
59
|
+
return typing.cast(T_co, container)
|
|
60
|
+
|
|
61
|
+
if isinstance(provider, ContextProvider):
|
|
62
|
+
return typing.cast(T_co, self._resolve_context_provider(provider))
|
|
63
|
+
|
|
64
|
+
if (override := container.overrides_registry.fetch_override(provider.provider_id)) is not None:
|
|
65
|
+
return typing.cast(T_co, override)
|
|
66
|
+
|
|
67
|
+
provider_state = container.state_registry.fetch_provider_state(provider)
|
|
68
|
+
if provider_state and provider_state.instance is not None:
|
|
69
|
+
return provider_state.instance
|
|
70
|
+
|
|
71
|
+
if provider_state and provider_state.lock:
|
|
72
|
+
provider_state.lock.acquire()
|
|
73
|
+
try:
|
|
74
|
+
if provider_state and provider_state.instance is not None:
|
|
75
|
+
return provider_state.instance
|
|
76
|
+
|
|
77
|
+
return provider.sync_resolve(
|
|
78
|
+
args=self._resolve_args(provider.args or []),
|
|
79
|
+
kwargs=self._resolve_kwargs(provider.kwargs or {}),
|
|
80
|
+
provider_state=provider_state,
|
|
81
|
+
)
|
|
82
|
+
finally:
|
|
83
|
+
if provider_state and provider_state.lock:
|
|
84
|
+
provider_state.lock.release()
|
|
85
|
+
|
|
86
|
+
def enter(self) -> "SyncContainer":
|
|
87
|
+
self._is_entered = True
|
|
88
|
+
return self
|
|
89
|
+
|
|
90
|
+
def close(self) -> None:
|
|
91
|
+
self._check_entered()
|
|
92
|
+
self._is_entered = False
|
|
93
|
+
self.state_registry.clear_state()
|
|
94
|
+
self.overrides_registry.reset_override()
|
|
95
|
+
|
|
96
|
+
def __enter__(self) -> "SyncContainer":
|
|
97
|
+
return self.enter()
|
|
98
|
+
|
|
99
|
+
def __exit__(
|
|
100
|
+
self,
|
|
101
|
+
exc_type: type[BaseException] | None,
|
|
102
|
+
exc_value: BaseException | None,
|
|
103
|
+
traceback: types.TracebackType | None,
|
|
104
|
+
) -> None:
|
|
105
|
+
self.close()
|
modern_di/group.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
|
|
3
|
+
from modern_di.providers.abstract import AbstractProvider
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
if typing.TYPE_CHECKING:
|
|
7
|
+
import typing_extensions
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
T = typing.TypeVar("T")
|
|
11
|
+
P = typing.ParamSpec("P")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Group:
|
|
15
|
+
providers: dict[str, AbstractProvider[typing.Any]]
|
|
16
|
+
|
|
17
|
+
def __new__(cls, *_: typing.Any, **__: typing.Any) -> "typing_extensions.Self": # noqa: ANN401
|
|
18
|
+
msg = f"{cls.__name__} cannot not be instantiated"
|
|
19
|
+
raise RuntimeError(msg)
|
|
20
|
+
|
|
21
|
+
@classmethod
|
|
22
|
+
def get_providers(cls) -> dict[str, AbstractProvider[typing.Any]]:
|
|
23
|
+
if not hasattr(cls, "providers"):
|
|
24
|
+
cls.providers = {k: v for k, v in cls.__dict__.items() if isinstance(v, AbstractProvider)}
|
|
25
|
+
|
|
26
|
+
return cls.providers
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
import collections
|
|
2
|
+
import types
|
|
3
|
+
import typing
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
GENERIC_TYPES = {
|
|
7
|
+
typing.Iterator,
|
|
8
|
+
typing.AsyncIterator,
|
|
9
|
+
collections.abc.Iterator,
|
|
10
|
+
collections.abc.AsyncIterator,
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def define_bound_type(creator: type | object) -> type | None:
|
|
15
|
+
if isinstance(creator, type):
|
|
16
|
+
return creator
|
|
17
|
+
|
|
18
|
+
type_hints = typing.get_type_hints(creator)
|
|
19
|
+
return_annotation = type_hints.get("return")
|
|
20
|
+
if not return_annotation:
|
|
21
|
+
return None
|
|
22
|
+
|
|
23
|
+
if isinstance(return_annotation, type) and not isinstance(return_annotation, types.GenericAlias):
|
|
24
|
+
return return_annotation
|
|
25
|
+
|
|
26
|
+
if typing.get_origin(return_annotation) not in GENERIC_TYPES:
|
|
27
|
+
return None
|
|
28
|
+
|
|
29
|
+
args = typing.get_args(return_annotation)
|
|
30
|
+
if not args:
|
|
31
|
+
return None
|
|
32
|
+
|
|
33
|
+
return typing.cast(type, args[0])
|
modern_di/providers/__init__.py
CHANGED
|
@@ -2,13 +2,12 @@ from modern_di.providers.abstract import AbstractProvider
|
|
|
2
2
|
from modern_di.providers.async_factory import AsyncFactory
|
|
3
3
|
from modern_di.providers.async_singleton import AsyncSingleton
|
|
4
4
|
from modern_di.providers.container_provider import ContainerProvider
|
|
5
|
-
from modern_di.providers.
|
|
5
|
+
from modern_di.providers.context_provider import ContextProvider
|
|
6
6
|
from modern_di.providers.dict import Dict
|
|
7
7
|
from modern_di.providers.factory import Factory
|
|
8
8
|
from modern_di.providers.list import List
|
|
9
9
|
from modern_di.providers.object import Object
|
|
10
10
|
from modern_di.providers.resource import Resource
|
|
11
|
-
from modern_di.providers.selector import Selector
|
|
12
11
|
from modern_di.providers.singleton import Singleton
|
|
13
12
|
|
|
14
13
|
|
|
@@ -17,12 +16,11 @@ __all__ = [
|
|
|
17
16
|
"AsyncFactory",
|
|
18
17
|
"AsyncSingleton",
|
|
19
18
|
"ContainerProvider",
|
|
20
|
-
"
|
|
19
|
+
"ContextProvider",
|
|
21
20
|
"Dict",
|
|
22
21
|
"Factory",
|
|
23
22
|
"List",
|
|
24
23
|
"Object",
|
|
25
24
|
"Resource",
|
|
26
|
-
"Selector",
|
|
27
25
|
"Singleton",
|
|
28
26
|
]
|
modern_di/providers/abstract.py
CHANGED
|
@@ -3,10 +3,12 @@ import enum
|
|
|
3
3
|
import typing
|
|
4
4
|
import uuid
|
|
5
5
|
|
|
6
|
+
import typing_extensions
|
|
6
7
|
from typing_extensions import override
|
|
7
8
|
|
|
8
|
-
from modern_di import Container
|
|
9
9
|
from modern_di.helpers.attr_getter_helpers import get_value_from_object_by_dotted_path
|
|
10
|
+
from modern_di.helpers.type_helpers import define_bound_type
|
|
11
|
+
from modern_di.registries.state_registry.state import AsyncState, SyncState
|
|
10
12
|
|
|
11
13
|
|
|
12
14
|
T_co = typing.TypeVar("T_co", covariant=True)
|
|
@@ -14,36 +16,62 @@ R = typing.TypeVar("R")
|
|
|
14
16
|
P = typing.ParamSpec("P")
|
|
15
17
|
|
|
16
18
|
|
|
17
|
-
class AbstractProvider(typing.Generic[T_co]
|
|
18
|
-
BASE_SLOTS: typing.ClassVar = ["scope", "provider_id"]
|
|
19
|
+
class AbstractProvider(abc.ABC, typing.Generic[T_co]):
|
|
20
|
+
BASE_SLOTS: typing.ClassVar = ["scope", "provider_id", "args", "kwargs", "is_async", "bound_type"]
|
|
21
|
+
HAS_STATE: bool = False
|
|
19
22
|
|
|
20
|
-
def __init__(
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
scope: enum.IntEnum,
|
|
26
|
+
args: list[typing.Any] | None = None,
|
|
27
|
+
kwargs: dict[str, typing.Any] | None = None,
|
|
28
|
+
bound_type: type | None = None,
|
|
29
|
+
) -> None:
|
|
21
30
|
self.scope = scope
|
|
22
31
|
self.provider_id: typing.Final = str(uuid.uuid4())
|
|
32
|
+
self.args = args
|
|
33
|
+
self.kwargs = kwargs
|
|
34
|
+
self.is_async = False
|
|
35
|
+
self.bound_type = bound_type
|
|
36
|
+
self._check_providers_scope()
|
|
37
|
+
|
|
38
|
+
def bind_type(self, new_type: type) -> typing_extensions.Self:
|
|
39
|
+
self.bound_type = new_type
|
|
40
|
+
return self
|
|
23
41
|
|
|
24
|
-
|
|
25
|
-
|
|
42
|
+
async def async_resolve(
|
|
43
|
+
self,
|
|
44
|
+
*,
|
|
45
|
+
args: list[typing.Any],
|
|
46
|
+
kwargs: dict[str, typing.Any],
|
|
47
|
+
provider_state: AsyncState[T_co] | None,
|
|
48
|
+
) -> T_co: # pragma: no cover
|
|
26
49
|
"""Resolve dependency asynchronously."""
|
|
50
|
+
raise NotImplementedError
|
|
27
51
|
|
|
28
|
-
|
|
29
|
-
|
|
52
|
+
def sync_resolve(
|
|
53
|
+
self,
|
|
54
|
+
*,
|
|
55
|
+
args: list[typing.Any],
|
|
56
|
+
kwargs: dict[str, typing.Any],
|
|
57
|
+
provider_state: SyncState[T_co] | None,
|
|
58
|
+
) -> T_co: # pragma: no cover
|
|
30
59
|
"""Resolve dependency synchronously."""
|
|
60
|
+
raise NotImplementedError
|
|
31
61
|
|
|
32
62
|
@property
|
|
33
63
|
def cast(self) -> T_co:
|
|
34
64
|
return typing.cast(T_co, self)
|
|
35
65
|
|
|
36
|
-
def _check_providers_scope(
|
|
37
|
-
self
|
|
38
|
-
|
|
39
|
-
if args:
|
|
40
|
-
for provider in args:
|
|
66
|
+
def _check_providers_scope(self) -> None:
|
|
67
|
+
if self.args:
|
|
68
|
+
for provider in self.args:
|
|
41
69
|
if isinstance(provider, AbstractProvider) and provider.scope > self.scope:
|
|
42
70
|
msg = f"Scope of dependency is {provider.scope.name} and current scope is {self.scope.name}"
|
|
43
71
|
raise RuntimeError(msg)
|
|
44
72
|
|
|
45
|
-
if kwargs:
|
|
46
|
-
for name, provider in kwargs.items():
|
|
73
|
+
if self.kwargs:
|
|
74
|
+
for name, provider in self.kwargs.items():
|
|
47
75
|
if isinstance(provider, AbstractProvider) and provider.scope > self.scope:
|
|
48
76
|
msg = f"Scope of {name} is {provider.scope.name} and current scope is {self.scope.name}"
|
|
49
77
|
raise RuntimeError(msg)
|
|
@@ -65,16 +93,8 @@ class AbstractProvider(typing.Generic[T_co], abc.ABC):
|
|
|
65
93
|
return AttrGetter(provider=self, attr_name=attr_name)
|
|
66
94
|
|
|
67
95
|
|
|
68
|
-
class
|
|
69
|
-
|
|
70
|
-
container.override(self.provider_id, override_object)
|
|
71
|
-
|
|
72
|
-
def reset_override(self, container: Container) -> None:
|
|
73
|
-
container.reset_override(self.provider_id)
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
class AbstractCreatorProvider(AbstractOverrideProvider[T_co], abc.ABC):
|
|
77
|
-
BASE_SLOTS: typing.ClassVar = [*AbstractProvider.BASE_SLOTS, "_args", "_kwargs", "_creator"]
|
|
96
|
+
class AbstractCreatorProvider(AbstractProvider[T_co], abc.ABC):
|
|
97
|
+
BASE_SLOTS: typing.ClassVar = [*AbstractProvider.BASE_SLOTS, "_creator"]
|
|
78
98
|
|
|
79
99
|
def __init__(
|
|
80
100
|
self,
|
|
@@ -83,55 +103,15 @@ class AbstractCreatorProvider(AbstractOverrideProvider[T_co], abc.ABC):
|
|
|
83
103
|
*args: P.args,
|
|
84
104
|
**kwargs: P.kwargs,
|
|
85
105
|
) -> None:
|
|
86
|
-
super().__init__(scope)
|
|
87
|
-
self._check_providers_scope(args=args, kwargs=kwargs)
|
|
106
|
+
super().__init__(scope, args=list(args), kwargs=kwargs, bound_type=define_bound_type(creator))
|
|
88
107
|
self._creator: typing.Final = creator
|
|
89
|
-
self._args: typing.Final = args
|
|
90
|
-
self._kwargs: typing.Final = kwargs
|
|
91
|
-
|
|
92
|
-
def _sync_resolve_args(self, container: Container) -> list[typing.Any]:
|
|
93
|
-
return [x.sync_resolve(container) if isinstance(x, AbstractProvider) else x for x in self._args]
|
|
94
|
-
|
|
95
|
-
def _sync_resolve_kwargs(self, container: Container) -> dict[str, typing.Any]:
|
|
96
|
-
return {k: v.sync_resolve(container) if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}
|
|
97
|
-
|
|
98
|
-
def _sync_build_creator(self, container: Container) -> typing.Any: # noqa: ANN401
|
|
99
|
-
return self._creator(
|
|
100
|
-
*self._sync_resolve_args(container),
|
|
101
|
-
**self._sync_resolve_kwargs(container),
|
|
102
|
-
)
|
|
103
|
-
|
|
104
|
-
async def _async_resolve_args(self, container: Container) -> list[typing.Any]:
|
|
105
|
-
return [await x.async_resolve(container) if isinstance(x, AbstractProvider) else x for x in self._args]
|
|
106
|
-
|
|
107
|
-
async def _async_resolve_kwargs(self, container: Container) -> dict[str, typing.Any]:
|
|
108
|
-
return {
|
|
109
|
-
k: await v.async_resolve(container) if isinstance(v, AbstractProvider) else v
|
|
110
|
-
for k, v in self._kwargs.items()
|
|
111
|
-
}
|
|
112
|
-
|
|
113
|
-
async def _async_build_creator(self, container: Container) -> typing.Any: # noqa: ANN401
|
|
114
|
-
return self._creator(
|
|
115
|
-
*await self._async_resolve_args(container),
|
|
116
|
-
**await self._async_resolve_kwargs(container),
|
|
117
|
-
)
|
|
118
108
|
|
|
119
109
|
|
|
120
110
|
class AttrGetter(AbstractProvider[T_co]):
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
__slots__ = [*AbstractProvider.BASE_SLOTS, "_attrs", "_provider"]
|
|
111
|
+
__slots__ = [*AbstractProvider.BASE_SLOTS, "_attrs"]
|
|
124
112
|
|
|
125
113
|
def __init__(self, provider: AbstractProvider[T_co], attr_name: str) -> None:
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
Args:
|
|
129
|
-
provider: provider to wrap.
|
|
130
|
-
attr_name: attribute name to resolve when the provider is resolved.
|
|
131
|
-
|
|
132
|
-
"""
|
|
133
|
-
super().__init__(scope=provider.scope)
|
|
134
|
-
self._provider = provider
|
|
114
|
+
super().__init__(scope=provider.scope, args=[provider])
|
|
135
115
|
self._attrs = [attr_name]
|
|
136
116
|
|
|
137
117
|
def __getattr__(self, attr: str) -> "AttrGetter[T_co]":
|
|
@@ -142,13 +122,21 @@ class AttrGetter(AbstractProvider[T_co]):
|
|
|
142
122
|
return self
|
|
143
123
|
|
|
144
124
|
@override
|
|
145
|
-
async def async_resolve(
|
|
146
|
-
|
|
125
|
+
async def async_resolve(
|
|
126
|
+
self,
|
|
127
|
+
*,
|
|
128
|
+
args: list[typing.Any],
|
|
129
|
+
**_: object,
|
|
130
|
+
) -> typing.Any:
|
|
147
131
|
attribute_path = ".".join(self._attrs)
|
|
148
|
-
return get_value_from_object_by_dotted_path(
|
|
132
|
+
return get_value_from_object_by_dotted_path(args[0], attribute_path)
|
|
149
133
|
|
|
150
134
|
@override
|
|
151
|
-
def sync_resolve(
|
|
152
|
-
|
|
135
|
+
def sync_resolve(
|
|
136
|
+
self,
|
|
137
|
+
*,
|
|
138
|
+
args: list[typing.Any],
|
|
139
|
+
**_: object,
|
|
140
|
+
) -> typing.Any:
|
|
153
141
|
attribute_path = ".".join(self._attrs)
|
|
154
|
-
return get_value_from_object_by_dotted_path(
|
|
142
|
+
return get_value_from_object_by_dotted_path(args[0], attribute_path)
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import enum
|
|
2
2
|
import typing
|
|
3
3
|
|
|
4
|
-
from modern_di import Container
|
|
5
4
|
from modern_di.providers.abstract import AbstractCreatorProvider
|
|
6
5
|
|
|
7
6
|
|
|
@@ -20,15 +19,14 @@ class AsyncFactory(AbstractCreatorProvider[T_co]):
|
|
|
20
19
|
**kwargs: P.kwargs,
|
|
21
20
|
) -> None:
|
|
22
21
|
super().__init__(scope, creator, *args, **kwargs)
|
|
22
|
+
self.is_async = True
|
|
23
23
|
|
|
24
|
-
async def async_resolve(
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
24
|
+
async def async_resolve(
|
|
25
|
+
self,
|
|
26
|
+
*,
|
|
27
|
+
args: list[typing.Any],
|
|
28
|
+
kwargs: dict[str, typing.Any],
|
|
29
|
+
**__: object,
|
|
30
|
+
) -> T_co:
|
|
31
|
+
coroutine: typing.Awaitable[T_co] = self._creator(*args, **kwargs)
|
|
30
32
|
return await coroutine
|
|
31
|
-
|
|
32
|
-
def sync_resolve(self, _: Container) -> typing.NoReturn:
|
|
33
|
-
msg = "AsyncFactory cannot be resolved synchronously"
|
|
34
|
-
raise RuntimeError(msg)
|