wireup 2.1.0__tar.gz → 2.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {wireup-2.1.0 → wireup-2.2.0}/PKG-INFO +1 -1
- {wireup-2.1.0 → wireup-2.2.0}/pyproject.toml +1 -1
- {wireup-2.1.0 → wireup-2.2.0}/wireup/_decorators.py +4 -4
- {wireup-2.1.0 → wireup-2.2.0}/wireup/integration/aiohttp.py +3 -1
- {wireup-2.1.0 → wireup-2.2.0}/wireup/integration/django/apps.py +3 -3
- {wireup-2.1.0 → wireup-2.2.0}/wireup/integration/fastapi.py +4 -4
- {wireup-2.1.0 → wireup-2.2.0}/wireup/ioc/container/__init__.py +32 -13
- {wireup-2.1.0 → wireup-2.2.0}/wireup/ioc/container/async_container.py +30 -7
- wireup-2.2.0/wireup/ioc/container/base_container.py +93 -0
- {wireup-2.1.0 → wireup-2.2.0}/wireup/ioc/container/sync_container.py +5 -3
- wireup-2.2.0/wireup/ioc/factory_compiler.py +159 -0
- wireup-2.2.0/wireup/ioc/override_manager.py +139 -0
- {wireup-2.1.0 → wireup-2.2.0}/wireup/ioc/service_registry.py +144 -8
- {wireup-2.1.0 → wireup-2.2.0}/wireup/ioc/types.py +1 -6
- {wireup-2.1.0 → wireup-2.2.0}/wireup/ioc/util.py +68 -0
- wireup-2.1.0/wireup/_async_to_sync.py +0 -34
- wireup-2.1.0/wireup/ioc/container/base_container.py +0 -189
- wireup-2.1.0/wireup/ioc/override_manager.py +0 -75
- wireup-2.1.0/wireup/ioc/validation.py +0 -170
- {wireup-2.1.0 → wireup-2.2.0}/readme.md +0 -0
- {wireup-2.1.0 → wireup-2.2.0}/wireup/__init__.py +0 -0
- {wireup-2.1.0 → wireup-2.2.0}/wireup/_annotations.py +0 -0
- {wireup-2.1.0 → wireup-2.2.0}/wireup/_discovery.py +0 -0
- {wireup-2.1.0 → wireup-2.2.0}/wireup/errors.py +0 -0
- {wireup-2.1.0 → wireup-2.2.0}/wireup/integration/__init__.py +0 -0
- {wireup-2.1.0 → wireup-2.2.0}/wireup/integration/click.py +0 -0
- {wireup-2.1.0 → wireup-2.2.0}/wireup/integration/django/__init__.py +0 -0
- {wireup-2.1.0 → wireup-2.2.0}/wireup/integration/flask.py +0 -0
- {wireup-2.1.0 → wireup-2.2.0}/wireup/integration/starlette.py +0 -0
- {wireup-2.1.0 → wireup-2.2.0}/wireup/ioc/__init__.py +0 -0
- {wireup-2.1.0 → wireup-2.2.0}/wireup/ioc/_exit_stack.py +0 -0
- {wireup-2.1.0 → wireup-2.2.0}/wireup/ioc/parameter.py +0 -0
- {wireup-2.1.0 → wireup-2.2.0}/wireup/py.typed +0 -0
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import asyncio
|
|
4
3
|
import contextlib
|
|
5
4
|
import functools
|
|
5
|
+
import inspect
|
|
6
6
|
from contextlib import AsyncExitStack, ExitStack
|
|
7
7
|
from typing import TYPE_CHECKING, Any, Callable
|
|
8
8
|
|
|
@@ -10,7 +10,7 @@ from wireup.errors import WireupError
|
|
|
10
10
|
from wireup.ioc.container.async_container import AsyncContainer, ScopedAsyncContainer, async_container_force_sync_scope
|
|
11
11
|
from wireup.ioc.container.sync_container import SyncContainer
|
|
12
12
|
from wireup.ioc.types import AnnotatedParameter, ParameterWrapper
|
|
13
|
-
from wireup.ioc.
|
|
13
|
+
from wireup.ioc.util import (
|
|
14
14
|
get_inject_annotated_parameters,
|
|
15
15
|
get_valid_injection_annotated_parameters,
|
|
16
16
|
)
|
|
@@ -60,7 +60,7 @@ def inject_from_container(
|
|
|
60
60
|
"""
|
|
61
61
|
|
|
62
62
|
def _decorator(target: Callable[..., Any]) -> Callable[..., Any]:
|
|
63
|
-
if
|
|
63
|
+
if inspect.iscoroutinefunction(target) and isinstance(container, SyncContainer):
|
|
64
64
|
msg = (
|
|
65
65
|
"Sync container cannot perform injection on async targets. "
|
|
66
66
|
"Create an async container via wireup.create_async_container."
|
|
@@ -96,7 +96,7 @@ def inject_from_container_util( # noqa: C901
|
|
|
96
96
|
if not names_to_inject:
|
|
97
97
|
return target
|
|
98
98
|
|
|
99
|
-
if
|
|
99
|
+
if inspect.iscoroutinefunction(target):
|
|
100
100
|
|
|
101
101
|
@functools.wraps(target)
|
|
102
102
|
async def _inject_async_target(*args: Any, **kwargs: Any) -> Any:
|
|
@@ -77,7 +77,9 @@ def _get_startup_event(
|
|
|
77
77
|
handlers: Optional[Iterable[Type[_WireupHandler]]],
|
|
78
78
|
) -> Callable[[web.Application], Awaitable[None]]:
|
|
79
79
|
for handler_type in handlers or []:
|
|
80
|
-
container._registry.
|
|
80
|
+
container._registry.extend(impls=[ServiceDeclaration(handler_type)])
|
|
81
|
+
container._compiler.compile()
|
|
82
|
+
container._scoped_compiler.compile()
|
|
81
83
|
|
|
82
84
|
async def _inner(app: web.Application) -> None:
|
|
83
85
|
if handlers:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
import asyncio
|
|
2
1
|
import functools
|
|
3
2
|
import importlib
|
|
3
|
+
import inspect
|
|
4
4
|
from contextvars import ContextVar
|
|
5
5
|
from dataclasses import dataclass
|
|
6
6
|
from types import ModuleType
|
|
@@ -21,7 +21,7 @@ from wireup.errors import WireupError
|
|
|
21
21
|
from wireup.ioc.container.async_container import AsyncContainer, ScopedAsyncContainer, async_container_force_sync_scope
|
|
22
22
|
from wireup.ioc.container.sync_container import ScopedSyncContainer
|
|
23
23
|
from wireup.ioc.types import ParameterWrapper
|
|
24
|
-
from wireup.ioc.
|
|
24
|
+
from wireup.ioc.util import get_valid_injection_annotated_parameters
|
|
25
25
|
|
|
26
26
|
if TYPE_CHECKING:
|
|
27
27
|
from wireup.integration.django import WireupSettings
|
|
@@ -38,7 +38,7 @@ def wireup_middleware(
|
|
|
38
38
|
) -> Callable[[HttpRequest], Union[HttpResponse, Awaitable[HttpResponse]]]:
|
|
39
39
|
container = get_app_container()
|
|
40
40
|
|
|
41
|
-
if
|
|
41
|
+
if inspect.iscoroutinefunction(get_response):
|
|
42
42
|
|
|
43
43
|
async def async_inner(request: HttpRequest) -> HttpResponse:
|
|
44
44
|
async with container.enter_scope() as scoped:
|
|
@@ -33,8 +33,7 @@ from wireup.integration.starlette import (
|
|
|
33
33
|
from wireup.ioc.container.async_container import AsyncContainer, ScopedAsyncContainer
|
|
34
34
|
from wireup.ioc.container.sync_container import ScopedSyncContainer
|
|
35
35
|
from wireup.ioc.types import AnyCallable
|
|
36
|
-
from wireup.ioc.
|
|
37
|
-
assert_dependencies_valid,
|
|
36
|
+
from wireup.ioc.util import (
|
|
38
37
|
get_inject_annotated_parameters,
|
|
39
38
|
hide_annotated_names,
|
|
40
39
|
)
|
|
@@ -166,8 +165,9 @@ def _update_lifespan(
|
|
|
166
165
|
async def lifespan(app: FastAPI) -> AsyncIterator[Any]:
|
|
167
166
|
if class_based_routes:
|
|
168
167
|
for cbr in class_based_routes:
|
|
169
|
-
container._registry.
|
|
170
|
-
|
|
168
|
+
container._registry.extend(impls=[ServiceDeclaration(cbr)])
|
|
169
|
+
container._compiler.compile()
|
|
170
|
+
container._scoped_compiler.compile()
|
|
171
171
|
|
|
172
172
|
for cbr in class_based_routes:
|
|
173
173
|
await _instantiate_class_based_route(app, container, cbr)
|
|
@@ -8,11 +8,10 @@ from wireup.errors import WireupError
|
|
|
8
8
|
from wireup.ioc.container.async_container import AsyncContainer
|
|
9
9
|
from wireup.ioc.container.base_container import BaseContainer
|
|
10
10
|
from wireup.ioc.container.sync_container import SyncContainer
|
|
11
|
+
from wireup.ioc.factory_compiler import FactoryCompiler
|
|
11
12
|
from wireup.ioc.override_manager import OverrideManager
|
|
12
13
|
from wireup.ioc.parameter import ParameterBag
|
|
13
14
|
from wireup.ioc.service_registry import ServiceRegistry
|
|
14
|
-
from wireup.ioc.types import ContainerScope
|
|
15
|
-
from wireup.ioc.validation import assert_dependencies_valid
|
|
16
15
|
|
|
17
16
|
if TYPE_CHECKING:
|
|
18
17
|
from types import ModuleType
|
|
@@ -36,6 +35,36 @@ def _create_container(
|
|
|
36
35
|
:param parameters: Dict containing parameters you want to expose to the container. Services or factories can
|
|
37
36
|
request parameters via the `Inject(param="name")` syntax.
|
|
38
37
|
"""
|
|
38
|
+
abstracts, impls = _merge_definitions(service_modules, services)
|
|
39
|
+
registry = ServiceRegistry(parameters=ParameterBag(parameters), abstracts=abstracts, impls=impls)
|
|
40
|
+
|
|
41
|
+
# The container uses a dual-compiler optimization strategy:
|
|
42
|
+
# 1. The singleton compiler generates optimized factories for singleton dependencies
|
|
43
|
+
# and throws errors if scoped dependencies are accessed outside a scope.
|
|
44
|
+
# 2. The scoped compiler handles dependencies that require request/scope isolation.
|
|
45
|
+
#
|
|
46
|
+
# When entering/exiting scopes, the container switches between these compilers.
|
|
47
|
+
# This eliminates the need to check lifetime rules at runtime.
|
|
48
|
+
singleton_compiler = FactoryCompiler(registry, is_scoped_container=False)
|
|
49
|
+
scoped_compiler = FactoryCompiler(registry, is_scoped_container=True)
|
|
50
|
+
singleton_compiler.compile()
|
|
51
|
+
scoped_compiler.compile()
|
|
52
|
+
|
|
53
|
+
override_manager = OverrideManager(registry.is_type_with_qualifier_known, singleton_compiler, scoped_compiler)
|
|
54
|
+
return klass(
|
|
55
|
+
registry=registry,
|
|
56
|
+
factory_compiler=singleton_compiler,
|
|
57
|
+
scoped_compiler=scoped_compiler,
|
|
58
|
+
global_scope_objects={},
|
|
59
|
+
global_scope_exit_stack=[],
|
|
60
|
+
override_manager=override_manager,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _merge_definitions(
|
|
65
|
+
service_modules: Iterable[ModuleType] | None = None,
|
|
66
|
+
services: Iterable[Any] | None = None,
|
|
67
|
+
) -> tuple[list[AbstractDeclaration], list[ServiceDeclaration]]:
|
|
39
68
|
abstracts: list[AbstractDeclaration] = []
|
|
40
69
|
impls: list[ServiceDeclaration] = []
|
|
41
70
|
|
|
@@ -57,17 +86,7 @@ def _create_container(
|
|
|
57
86
|
abstracts.extend(discovered_abstracts)
|
|
58
87
|
impls.extend(discovered_services)
|
|
59
88
|
|
|
60
|
-
|
|
61
|
-
container = klass(
|
|
62
|
-
registry=registry,
|
|
63
|
-
parameters=ParameterBag(parameters),
|
|
64
|
-
global_scope=ContainerScope(objects={}, exit_stack=[]),
|
|
65
|
-
override_manager=OverrideManager({}, registry.is_type_with_qualifier_known),
|
|
66
|
-
)
|
|
67
|
-
|
|
68
|
-
assert_dependencies_valid(container)
|
|
69
|
-
|
|
70
|
-
return container
|
|
89
|
+
return abstracts, impls
|
|
71
90
|
|
|
72
91
|
|
|
73
92
|
def create_sync_container(
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from typing import TYPE_CHECKING
|
|
3
|
+
from typing import TYPE_CHECKING, TypeVar
|
|
4
4
|
|
|
5
5
|
from typing_extensions import Self
|
|
6
6
|
|
|
7
|
+
from wireup.errors import UnknownServiceRequestedError
|
|
7
8
|
from wireup.ioc._exit_stack import async_clean_exit_stack
|
|
8
9
|
from wireup.ioc.container.base_container import BaseContainer
|
|
9
10
|
from wireup.ioc.container.sync_container import ScopedSyncContainer
|
|
@@ -11,12 +12,30 @@ from wireup.ioc.container.sync_container import ScopedSyncContainer
|
|
|
11
12
|
if TYPE_CHECKING:
|
|
12
13
|
from types import TracebackType
|
|
13
14
|
|
|
15
|
+
from wireup.ioc.types import Qualifier
|
|
16
|
+
|
|
17
|
+
T = TypeVar("T")
|
|
18
|
+
|
|
14
19
|
|
|
15
20
|
class BareAsyncContainer(BaseContainer):
|
|
16
|
-
get =
|
|
21
|
+
async def get(self, klass: type[T], qualifier: Qualifier | None = None) -> T:
|
|
22
|
+
"""Get an instance of the requested type.
|
|
23
|
+
|
|
24
|
+
:param qualifier: Qualifier for the class if it was registered with one.
|
|
25
|
+
:param klass: Class of the dependency already registered in the container.
|
|
26
|
+
:return: An instance of the requested object. Always returns an existing instance when one is available.
|
|
27
|
+
"""
|
|
28
|
+
obj_id = hash(klass if qualifier is None else (klass, qualifier))
|
|
29
|
+
|
|
30
|
+
if compiled_factory := self._factories.get(obj_id):
|
|
31
|
+
res = compiled_factory.factory(self)
|
|
32
|
+
|
|
33
|
+
return await res if compiled_factory.is_async else res # type:ignore[no-any-return]
|
|
34
|
+
|
|
35
|
+
raise UnknownServiceRequestedError(klass, qualifier)
|
|
17
36
|
|
|
18
37
|
async def close(self) -> None:
|
|
19
|
-
await async_clean_exit_stack(self.
|
|
38
|
+
await async_clean_exit_stack(self._global_scope_exit_stack)
|
|
20
39
|
|
|
21
40
|
|
|
22
41
|
class ScopedAsyncContainer(BareAsyncContainer):
|
|
@@ -37,11 +56,13 @@ class AsyncContainer(BareAsyncContainer):
|
|
|
37
56
|
def enter_scope(self) -> ScopedAsyncContainer:
|
|
38
57
|
return ScopedAsyncContainer(
|
|
39
58
|
registry=self._registry,
|
|
40
|
-
parameters=self._params,
|
|
41
59
|
override_manager=self._override_mgr,
|
|
42
|
-
|
|
60
|
+
global_scope_objects=self._global_scope_objects,
|
|
61
|
+
global_scope_exit_stack=self._global_scope_exit_stack,
|
|
43
62
|
current_scope_objects={},
|
|
44
63
|
current_scope_exit_stack=[],
|
|
64
|
+
factory_compiler=self._scoped_compiler,
|
|
65
|
+
scoped_compiler=self._scoped_compiler,
|
|
45
66
|
)
|
|
46
67
|
|
|
47
68
|
|
|
@@ -53,9 +74,11 @@ def async_container_force_sync_scope(container: AsyncContainer) -> ScopedSyncCon
|
|
|
53
74
|
"""
|
|
54
75
|
return ScopedSyncContainer(
|
|
55
76
|
registry=container._registry,
|
|
56
|
-
parameters=container._params,
|
|
57
77
|
override_manager=container._override_mgr,
|
|
58
|
-
|
|
78
|
+
global_scope_objects=container._global_scope_objects,
|
|
79
|
+
global_scope_exit_stack=container._global_scope_exit_stack,
|
|
59
80
|
current_scope_objects={},
|
|
60
81
|
current_scope_exit_stack=[],
|
|
82
|
+
factory_compiler=container._scoped_compiler,
|
|
83
|
+
scoped_compiler=container._scoped_compiler,
|
|
61
84
|
)
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
from typing import (
|
|
2
|
+
Any,
|
|
3
|
+
AsyncGenerator,
|
|
4
|
+
Dict,
|
|
5
|
+
Generator,
|
|
6
|
+
List,
|
|
7
|
+
Optional,
|
|
8
|
+
Type,
|
|
9
|
+
TypeVar,
|
|
10
|
+
Union,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
from wireup.errors import (
|
|
14
|
+
UnknownServiceRequestedError,
|
|
15
|
+
WireupError,
|
|
16
|
+
)
|
|
17
|
+
from wireup.ioc.factory_compiler import FactoryCompiler
|
|
18
|
+
from wireup.ioc.override_manager import OverrideManager
|
|
19
|
+
from wireup.ioc.parameter import ParameterBag
|
|
20
|
+
from wireup.ioc.service_registry import ServiceRegistry
|
|
21
|
+
from wireup.ioc.types import (
|
|
22
|
+
ContainerObjectIdentifier,
|
|
23
|
+
Qualifier,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
T = TypeVar("T")
|
|
27
|
+
ContainerExitStack = List[Union[Generator[Any, Any, Any], AsyncGenerator[Any, Any]]]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class BaseContainer:
|
|
31
|
+
__slots__ = (
|
|
32
|
+
"_compiler",
|
|
33
|
+
"_current_scope_exit_stack",
|
|
34
|
+
"_current_scope_objects",
|
|
35
|
+
"_factories",
|
|
36
|
+
"_global_scope_exit_stack",
|
|
37
|
+
"_global_scope_objects",
|
|
38
|
+
"_override_mgr",
|
|
39
|
+
"_registry",
|
|
40
|
+
"_scoped_compiler",
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
def __init__( # noqa: PLR0913
|
|
44
|
+
self,
|
|
45
|
+
registry: ServiceRegistry,
|
|
46
|
+
override_manager: OverrideManager,
|
|
47
|
+
factory_compiler: FactoryCompiler,
|
|
48
|
+
scoped_compiler: FactoryCompiler,
|
|
49
|
+
global_scope_objects: Dict[ContainerObjectIdentifier, Any],
|
|
50
|
+
global_scope_exit_stack: List[Union[Generator[Any, Any, Any], AsyncGenerator[Any, Any]]],
|
|
51
|
+
current_scope_objects: Optional[Dict[ContainerObjectIdentifier, Any]] = None,
|
|
52
|
+
current_scope_exit_stack: Optional[List[Union[Generator[Any, Any, Any], AsyncGenerator[Any, Any]]]] = None,
|
|
53
|
+
) -> None:
|
|
54
|
+
self._registry = registry
|
|
55
|
+
self._override_mgr = override_manager
|
|
56
|
+
self._global_scope_objects = global_scope_objects
|
|
57
|
+
self._global_scope_exit_stack = global_scope_exit_stack
|
|
58
|
+
self._current_scope_objects = current_scope_objects
|
|
59
|
+
self._current_scope_exit_stack = current_scope_exit_stack
|
|
60
|
+
self._compiler = factory_compiler
|
|
61
|
+
self._scoped_compiler = scoped_compiler
|
|
62
|
+
self._factories = self._compiler.factories
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def params(self) -> ParameterBag:
|
|
66
|
+
"""Parameter bag associated with this container."""
|
|
67
|
+
return self._registry.parameters
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def override(self) -> OverrideManager:
|
|
71
|
+
"""Override registered container services with new values."""
|
|
72
|
+
return self._override_mgr
|
|
73
|
+
|
|
74
|
+
def _synchronous_get(self, klass: Type[T], qualifier: Optional[Qualifier] = None) -> T:
|
|
75
|
+
"""Get an instance of the requested type.
|
|
76
|
+
|
|
77
|
+
:param qualifier: Qualifier for the class if it was registered with one.
|
|
78
|
+
:param klass: Class of the dependency already registered in the container.
|
|
79
|
+
:return: An instance of the requested object. Always returns an existing instance when one is available.
|
|
80
|
+
"""
|
|
81
|
+
obj_id = hash(klass if qualifier is None else (klass, qualifier))
|
|
82
|
+
|
|
83
|
+
if compiled_factory := self._factories.get(obj_id):
|
|
84
|
+
if compiled_factory.is_async:
|
|
85
|
+
msg = (
|
|
86
|
+
f"{klass} is an async dependency and it cannot be created in a synchronous context. "
|
|
87
|
+
"Create and use an async container via wireup.create_async_container."
|
|
88
|
+
)
|
|
89
|
+
raise WireupError(msg)
|
|
90
|
+
|
|
91
|
+
return compiled_factory.factory(self) # type:ignore[no-any-return]
|
|
92
|
+
|
|
93
|
+
raise UnknownServiceRequestedError(klass, qualifier)
|
|
@@ -15,7 +15,7 @@ class BareSyncContainer(BaseContainer):
|
|
|
15
15
|
get = BaseContainer._synchronous_get
|
|
16
16
|
|
|
17
17
|
def close(self) -> None:
|
|
18
|
-
clean_exit_stack(self.
|
|
18
|
+
clean_exit_stack(self._global_scope_exit_stack)
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class ScopedSyncContainer(BareSyncContainer):
|
|
@@ -36,9 +36,11 @@ class SyncContainer(BareSyncContainer):
|
|
|
36
36
|
def enter_scope(self) -> ScopedSyncContainer:
|
|
37
37
|
return ScopedSyncContainer(
|
|
38
38
|
registry=self._registry,
|
|
39
|
-
parameters=self._params,
|
|
40
39
|
override_manager=self._override_mgr,
|
|
41
|
-
|
|
40
|
+
global_scope_objects=self._global_scope_objects,
|
|
41
|
+
global_scope_exit_stack=self._global_scope_exit_stack,
|
|
42
42
|
current_scope_objects={},
|
|
43
43
|
current_scope_exit_stack=[],
|
|
44
|
+
factory_compiler=self._scoped_compiler,
|
|
45
|
+
scoped_compiler=self._scoped_compiler,
|
|
44
46
|
)
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Callable, Hashable
|
|
5
|
+
|
|
6
|
+
from wireup.errors import WireupError
|
|
7
|
+
from wireup.ioc.service_registry import GENERATOR_FACTORY_TYPES, FactoryType, ServiceRegistry
|
|
8
|
+
from wireup.ioc.types import ParameterWrapper, TemplatedString
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from wireup.ioc.container.base_container import BaseContainer
|
|
12
|
+
from wireup.ioc.service_registry import ServiceFactory
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class CompiledFactory:
|
|
17
|
+
factory: Callable[[BaseContainer], Any]
|
|
18
|
+
is_async: bool
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
_CONTAINER_SCOPE_ERROR_MSG = (
|
|
22
|
+
"Cannot create 'transient' or 'scoped' lifetime objects from the base container. "
|
|
23
|
+
"Please enter a scope using container.enter_scope. "
|
|
24
|
+
"If you are within a scope, use the scoped container instance to create dependencies."
|
|
25
|
+
)
|
|
26
|
+
_WIREUP_GENERATED_FACTORY_NAME = "_wireup_factory"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class FactoryCompiler:
|
|
30
|
+
def __init__(self, registry: ServiceRegistry, *, is_scoped_container: bool) -> None:
|
|
31
|
+
self._registry = registry
|
|
32
|
+
self._is_scoped_container = is_scoped_container
|
|
33
|
+
self.factories: dict[int, CompiledFactory] = {}
|
|
34
|
+
|
|
35
|
+
@classmethod
|
|
36
|
+
def get_object_id(cls, impl: type, qualifier: Hashable) -> int:
|
|
37
|
+
return hash(impl if qualifier is None else (impl, qualifier))
|
|
38
|
+
|
|
39
|
+
def compile(self) -> None:
|
|
40
|
+
for impl, qualifiers in self._registry.impls.items():
|
|
41
|
+
for qualifier in qualifiers:
|
|
42
|
+
obj_id = FactoryCompiler.get_object_id(impl, qualifier)
|
|
43
|
+
|
|
44
|
+
self.factories[obj_id] = self._compile_and_create_function(
|
|
45
|
+
self._registry.factories[impl, qualifier],
|
|
46
|
+
impl,
|
|
47
|
+
qualifier,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
for interface, impls in self._registry.interfaces.items():
|
|
51
|
+
for qualifier, impl in impls.items():
|
|
52
|
+
obj_id = FactoryCompiler.get_object_id(interface, qualifier)
|
|
53
|
+
|
|
54
|
+
self.factories[obj_id] = self._compile_and_create_function(
|
|
55
|
+
self._registry.factories[impl, qualifier],
|
|
56
|
+
interface,
|
|
57
|
+
qualifier,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
def _get_factory_code(self, factory: ServiceFactory, impl: type, qualifier: Hashable) -> tuple[str, bool]: # noqa: C901, PLR0912
|
|
61
|
+
is_interface = self._registry.is_interface_known(impl)
|
|
62
|
+
if is_interface:
|
|
63
|
+
lifetime = self._registry.lifetime[self._registry.interface_resolve_impl(impl, qualifier), qualifier]
|
|
64
|
+
else:
|
|
65
|
+
lifetime = self._registry.lifetime[impl, qualifier]
|
|
66
|
+
|
|
67
|
+
if lifetime != "singleton" and not self._is_scoped_container:
|
|
68
|
+
code = f"def {_WIREUP_GENERATED_FACTORY_NAME}(container):\n"
|
|
69
|
+
code += " raise WireupError(_CONTAINER_SCOPE_ERROR_MSG)\n"
|
|
70
|
+
|
|
71
|
+
return code, False
|
|
72
|
+
|
|
73
|
+
maybe_async = "async " if factory.is_async else ""
|
|
74
|
+
code = f"{maybe_async}def {_WIREUP_GENERATED_FACTORY_NAME}(container):\n"
|
|
75
|
+
cache_created_instance = lifetime != "transient"
|
|
76
|
+
|
|
77
|
+
if cache_created_instance:
|
|
78
|
+
if lifetime == "singleton":
|
|
79
|
+
code += " storage = container._global_scope_objects\n"
|
|
80
|
+
else:
|
|
81
|
+
code += " storage = container._current_scope_objects\n"
|
|
82
|
+
|
|
83
|
+
code += " if res := storage.get(OBJ_ID):\n"
|
|
84
|
+
code += " return res\n"
|
|
85
|
+
|
|
86
|
+
kwargs = ""
|
|
87
|
+
for name, dep in self._registry.dependencies[factory.factory].items():
|
|
88
|
+
if isinstance(dep.annotation, ParameterWrapper):
|
|
89
|
+
param_value = (
|
|
90
|
+
str(dep.annotation.param)
|
|
91
|
+
if isinstance(dep.annotation.param, TemplatedString)
|
|
92
|
+
else f'"{dep.annotation.param}"'
|
|
93
|
+
)
|
|
94
|
+
code += f" _obj_dep_{name} = parameters.get({param_value})\n"
|
|
95
|
+
else:
|
|
96
|
+
if self._registry.is_interface_known(dep.klass):
|
|
97
|
+
dep_class = self._registry.interface_resolve_impl(dep.klass, dep.qualifier_value)
|
|
98
|
+
else:
|
|
99
|
+
dep_class = dep.klass
|
|
100
|
+
|
|
101
|
+
maybe_await = "await " if self._registry.factories[dep_class, dep.qualifier_value].is_async else ""
|
|
102
|
+
dep_hash = FactoryCompiler.get_object_id(dep_class, dep.qualifier_value)
|
|
103
|
+
code += f" _obj_dep_{name} = {maybe_await}factories[{dep_hash}].factory(container)\n"
|
|
104
|
+
kwargs += f"{name}=_obj_dep_{name}, "
|
|
105
|
+
|
|
106
|
+
maybe_await = "await " if factory.factory_type == FactoryType.COROUTINE_FN else ""
|
|
107
|
+
|
|
108
|
+
code += f" instance = {maybe_await}ORIGINAL_FACTORY({kwargs.strip()})\n"
|
|
109
|
+
|
|
110
|
+
if factory.factory_type in GENERATOR_FACTORY_TYPES:
|
|
111
|
+
if lifetime == "singleton":
|
|
112
|
+
code += " container._global_scope_exit_stack.append(instance)\n"
|
|
113
|
+
else:
|
|
114
|
+
code += " container._current_scope_exit_stack.append(instance)\n"
|
|
115
|
+
|
|
116
|
+
if factory.factory_type == FactoryType.GENERATOR:
|
|
117
|
+
code += " instance = next(instance)\n"
|
|
118
|
+
else:
|
|
119
|
+
code += " instance = await instance.__anext__()\n"
|
|
120
|
+
|
|
121
|
+
if cache_created_instance:
|
|
122
|
+
code += " storage[OBJ_ID] = instance\n"
|
|
123
|
+
if is_interface:
|
|
124
|
+
code += " storage[ORIGINAL_OBJ_ID] = instance\n"
|
|
125
|
+
|
|
126
|
+
code += " return instance\n"
|
|
127
|
+
|
|
128
|
+
return code, factory.is_async
|
|
129
|
+
|
|
130
|
+
def _compile_and_create_function(self, factory: ServiceFactory, impl: type, qualifier: Hashable) -> CompiledFactory:
|
|
131
|
+
obj_id = impl, qualifier
|
|
132
|
+
resolved_obj_id = (
|
|
133
|
+
(self._registry.interface_resolve_impl(impl, qualifier), qualifier)
|
|
134
|
+
if self._registry.is_interface_known(impl)
|
|
135
|
+
else obj_id
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
source, is_async = self._get_factory_code(factory, impl, qualifier)
|
|
139
|
+
|
|
140
|
+
try:
|
|
141
|
+
namespace: dict[str, Any] = {
|
|
142
|
+
"factories": self.factories,
|
|
143
|
+
"ORIGINAL_OBJ_ID": obj_id,
|
|
144
|
+
"OBJ_ID": resolved_obj_id,
|
|
145
|
+
"ORIGINAL_FACTORY": self._registry.ctors[obj_id][0],
|
|
146
|
+
"TemplatedString": TemplatedString,
|
|
147
|
+
"WireupError": WireupError,
|
|
148
|
+
"_CONTAINER_SCOPE_ERROR_MSG": _CONTAINER_SCOPE_ERROR_MSG,
|
|
149
|
+
"parameters": self._registry.parameters,
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
compiled_code = compile(source, f"<{_WIREUP_GENERATED_FACTORY_NAME}_{obj_id}>", "exec")
|
|
153
|
+
exec(compiled_code, namespace) # noqa: S102
|
|
154
|
+
|
|
155
|
+
return CompiledFactory(factory=namespace[_WIREUP_GENERATED_FACTORY_NAME], is_async=is_async)
|
|
156
|
+
|
|
157
|
+
except Exception as e:
|
|
158
|
+
msg = f"Failed to compile generated factory {obj_id}: {e}"
|
|
159
|
+
raise WireupError(msg) from e
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from contextlib import contextmanager
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Iterator
|
|
5
|
+
|
|
6
|
+
from wireup.errors import UnknownOverrideRequestedError
|
|
7
|
+
from wireup.ioc.factory_compiler import FactoryCompiler
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from collections.abc import Callable
|
|
11
|
+
|
|
12
|
+
from wireup.ioc.types import AnyCallable, Qualifier, ServiceOverride
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class OverrideManager:
|
|
16
|
+
"""Enables overriding of services registered with the container."""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
is_valid_override: Callable[[type, Qualifier], bool],
|
|
21
|
+
factory_compiler: FactoryCompiler,
|
|
22
|
+
scoped_factory_compiler: FactoryCompiler,
|
|
23
|
+
) -> None:
|
|
24
|
+
self.__is_valid_override = is_valid_override
|
|
25
|
+
self._factory_compiler = factory_compiler
|
|
26
|
+
self._scoped_factory_compiler = scoped_factory_compiler
|
|
27
|
+
self._original_factory_functions: dict[tuple[type, Qualifier], tuple[Any, Any]] = {}
|
|
28
|
+
|
|
29
|
+
def _compiler_override_obj_id(
|
|
30
|
+
self,
|
|
31
|
+
compiler: FactoryCompiler,
|
|
32
|
+
target: type,
|
|
33
|
+
qualifier: Qualifier,
|
|
34
|
+
new: Callable[[Any], Any],
|
|
35
|
+
) -> None:
|
|
36
|
+
compiler.factories[compiler.get_object_id(target, qualifier)].factory = new
|
|
37
|
+
|
|
38
|
+
def _compiler_restore_obj_id(
|
|
39
|
+
self,
|
|
40
|
+
compiler: FactoryCompiler,
|
|
41
|
+
target: type,
|
|
42
|
+
qualifier: Qualifier,
|
|
43
|
+
original: AnyCallable,
|
|
44
|
+
) -> None:
|
|
45
|
+
compiler.factories[compiler.get_object_id(target, qualifier)].factory = original
|
|
46
|
+
|
|
47
|
+
def set(self, target: type, new: Any, qualifier: Qualifier | None = None) -> None:
|
|
48
|
+
"""Override the `target` service with `new`.
|
|
49
|
+
|
|
50
|
+
Future requests to inject `target` will result in `new` being injected.
|
|
51
|
+
|
|
52
|
+
:param target: The target service to override.
|
|
53
|
+
:param qualifier: The qualifier of the service to override. Set this if service is registered
|
|
54
|
+
with the qualifier parameter set to a value.
|
|
55
|
+
:param new: The new object to be injected instead of `target`.
|
|
56
|
+
"""
|
|
57
|
+
if not self.__is_valid_override(target, qualifier):
|
|
58
|
+
raise UnknownOverrideRequestedError(klass=target, qualifier=qualifier)
|
|
59
|
+
|
|
60
|
+
obj_id = FactoryCompiler.get_object_id(target, qualifier)
|
|
61
|
+
|
|
62
|
+
self._original_factory_functions[target, qualifier] = (
|
|
63
|
+
self._factory_compiler.factories[obj_id].factory,
|
|
64
|
+
self._scoped_factory_compiler.factories[obj_id].factory,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
def override_factory(_container: Any) -> Any:
|
|
68
|
+
return new
|
|
69
|
+
|
|
70
|
+
self._compiler_override_obj_id(
|
|
71
|
+
target=target,
|
|
72
|
+
qualifier=qualifier,
|
|
73
|
+
compiler=self._factory_compiler,
|
|
74
|
+
new=override_factory,
|
|
75
|
+
)
|
|
76
|
+
self._compiler_override_obj_id(
|
|
77
|
+
target=target,
|
|
78
|
+
qualifier=qualifier,
|
|
79
|
+
compiler=self._scoped_factory_compiler,
|
|
80
|
+
new=override_factory,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
def _restore_factory_methods(self, target: type, qualifier: Qualifier | None) -> None:
|
|
84
|
+
"""Restore original factory methods after override is removed."""
|
|
85
|
+
if (target, qualifier) not in self._original_factory_functions:
|
|
86
|
+
return
|
|
87
|
+
|
|
88
|
+
factory_func, scoped_factory_func = self._original_factory_functions[target, qualifier]
|
|
89
|
+
self._compiler_restore_obj_id(
|
|
90
|
+
compiler=self._factory_compiler,
|
|
91
|
+
target=target,
|
|
92
|
+
qualifier=qualifier,
|
|
93
|
+
original=factory_func,
|
|
94
|
+
)
|
|
95
|
+
self._compiler_restore_obj_id(
|
|
96
|
+
compiler=self._scoped_factory_compiler,
|
|
97
|
+
target=target,
|
|
98
|
+
qualifier=qualifier,
|
|
99
|
+
original=scoped_factory_func,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
del self._original_factory_functions[target, qualifier]
|
|
103
|
+
|
|
104
|
+
def delete(self, target: type, qualifier: Qualifier | None = None) -> None:
|
|
105
|
+
"""Clear active override for the `target` service."""
|
|
106
|
+
self._restore_factory_methods(target, qualifier)
|
|
107
|
+
|
|
108
|
+
def clear(self) -> None:
|
|
109
|
+
"""Clear active service overrides."""
|
|
110
|
+
for key in self._original_factory_functions:
|
|
111
|
+
self._restore_factory_methods(key[0], key[1])
|
|
112
|
+
|
|
113
|
+
@contextmanager
|
|
114
|
+
def service(self, target: type, new: Any, qualifier: Qualifier | None = None) -> Iterator[None]:
|
|
115
|
+
"""Override the `target` service with `new` for the duration of the context manager.
|
|
116
|
+
|
|
117
|
+
Future requests to inject `target` will result in `new` being injected.
|
|
118
|
+
|
|
119
|
+
:param target: The target service to override.
|
|
120
|
+
:param qualifier: The qualifier of the service to override. Set this if service is registered
|
|
121
|
+
with the qualifier parameter set to a value.
|
|
122
|
+
:param new: The new object to be injected instead of `target`.
|
|
123
|
+
"""
|
|
124
|
+
try:
|
|
125
|
+
self.set(target, new, qualifier)
|
|
126
|
+
yield
|
|
127
|
+
finally:
|
|
128
|
+
self.delete(target, qualifier)
|
|
129
|
+
|
|
130
|
+
@contextmanager
|
|
131
|
+
def services(self, overrides: list[ServiceOverride]) -> Iterator[None]:
|
|
132
|
+
"""Override a number of services with new for the duration of the context manager."""
|
|
133
|
+
try:
|
|
134
|
+
for override in overrides:
|
|
135
|
+
self.set(override.target, override.new, override.qualifier)
|
|
136
|
+
yield
|
|
137
|
+
finally:
|
|
138
|
+
for override in overrides:
|
|
139
|
+
self.delete(override.target, override.qualifier)
|