wireup 2.1.0__tar.gz → 2.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {wireup-2.1.0 → wireup-2.2.0}/PKG-INFO +1 -1
  2. {wireup-2.1.0 → wireup-2.2.0}/pyproject.toml +1 -1
  3. {wireup-2.1.0 → wireup-2.2.0}/wireup/_decorators.py +4 -4
  4. {wireup-2.1.0 → wireup-2.2.0}/wireup/integration/aiohttp.py +3 -1
  5. {wireup-2.1.0 → wireup-2.2.0}/wireup/integration/django/apps.py +3 -3
  6. {wireup-2.1.0 → wireup-2.2.0}/wireup/integration/fastapi.py +4 -4
  7. {wireup-2.1.0 → wireup-2.2.0}/wireup/ioc/container/__init__.py +32 -13
  8. {wireup-2.1.0 → wireup-2.2.0}/wireup/ioc/container/async_container.py +30 -7
  9. wireup-2.2.0/wireup/ioc/container/base_container.py +93 -0
  10. {wireup-2.1.0 → wireup-2.2.0}/wireup/ioc/container/sync_container.py +5 -3
  11. wireup-2.2.0/wireup/ioc/factory_compiler.py +159 -0
  12. wireup-2.2.0/wireup/ioc/override_manager.py +139 -0
  13. {wireup-2.1.0 → wireup-2.2.0}/wireup/ioc/service_registry.py +144 -8
  14. {wireup-2.1.0 → wireup-2.2.0}/wireup/ioc/types.py +1 -6
  15. {wireup-2.1.0 → wireup-2.2.0}/wireup/ioc/util.py +68 -0
  16. wireup-2.1.0/wireup/_async_to_sync.py +0 -34
  17. wireup-2.1.0/wireup/ioc/container/base_container.py +0 -189
  18. wireup-2.1.0/wireup/ioc/override_manager.py +0 -75
  19. wireup-2.1.0/wireup/ioc/validation.py +0 -170
  20. {wireup-2.1.0 → wireup-2.2.0}/readme.md +0 -0
  21. {wireup-2.1.0 → wireup-2.2.0}/wireup/__init__.py +0 -0
  22. {wireup-2.1.0 → wireup-2.2.0}/wireup/_annotations.py +0 -0
  23. {wireup-2.1.0 → wireup-2.2.0}/wireup/_discovery.py +0 -0
  24. {wireup-2.1.0 → wireup-2.2.0}/wireup/errors.py +0 -0
  25. {wireup-2.1.0 → wireup-2.2.0}/wireup/integration/__init__.py +0 -0
  26. {wireup-2.1.0 → wireup-2.2.0}/wireup/integration/click.py +0 -0
  27. {wireup-2.1.0 → wireup-2.2.0}/wireup/integration/django/__init__.py +0 -0
  28. {wireup-2.1.0 → wireup-2.2.0}/wireup/integration/flask.py +0 -0
  29. {wireup-2.1.0 → wireup-2.2.0}/wireup/integration/starlette.py +0 -0
  30. {wireup-2.1.0 → wireup-2.2.0}/wireup/ioc/__init__.py +0 -0
  31. {wireup-2.1.0 → wireup-2.2.0}/wireup/ioc/_exit_stack.py +0 -0
  32. {wireup-2.1.0 → wireup-2.2.0}/wireup/ioc/parameter.py +0 -0
  33. {wireup-2.1.0 → wireup-2.2.0}/wireup/py.typed +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: wireup
3
- Version: 2.1.0
3
+ Version: 2.2.0
4
4
  Summary: Python Dependency Injection Library
5
5
  License: MIT
6
6
  Keywords: flask,django,injector,dependency injection,dependency injection container,dependency injector
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "wireup"
3
- version = "2.1.0"
3
+ version = "2.2.0"
4
4
  description = "Python Dependency Injection Library"
5
5
  authors = ["Aldo Mateli <aldo.mateli@gmail.com>"]
6
6
  license = "MIT"
@@ -1,8 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
- import asyncio
4
3
  import contextlib
5
4
  import functools
5
+ import inspect
6
6
  from contextlib import AsyncExitStack, ExitStack
7
7
  from typing import TYPE_CHECKING, Any, Callable
8
8
 
@@ -10,7 +10,7 @@ from wireup.errors import WireupError
10
10
  from wireup.ioc.container.async_container import AsyncContainer, ScopedAsyncContainer, async_container_force_sync_scope
11
11
  from wireup.ioc.container.sync_container import SyncContainer
12
12
  from wireup.ioc.types import AnnotatedParameter, ParameterWrapper
13
- from wireup.ioc.validation import (
13
+ from wireup.ioc.util import (
14
14
  get_inject_annotated_parameters,
15
15
  get_valid_injection_annotated_parameters,
16
16
  )
@@ -60,7 +60,7 @@ def inject_from_container(
60
60
  """
61
61
 
62
62
  def _decorator(target: Callable[..., Any]) -> Callable[..., Any]:
63
- if asyncio.iscoroutinefunction(target) and isinstance(container, SyncContainer):
63
+ if inspect.iscoroutinefunction(target) and isinstance(container, SyncContainer):
64
64
  msg = (
65
65
  "Sync container cannot perform injection on async targets. "
66
66
  "Create an async container via wireup.create_async_container."
@@ -96,7 +96,7 @@ def inject_from_container_util( # noqa: C901
96
96
  if not names_to_inject:
97
97
  return target
98
98
 
99
- if asyncio.iscoroutinefunction(target):
99
+ if inspect.iscoroutinefunction(target):
100
100
 
101
101
  @functools.wraps(target)
102
102
  async def _inject_async_target(*args: Any, **kwargs: Any) -> Any:
@@ -77,7 +77,9 @@ def _get_startup_event(
77
77
  handlers: Optional[Iterable[Type[_WireupHandler]]],
78
78
  ) -> Callable[[web.Application], Awaitable[None]]:
79
79
  for handler_type in handlers or []:
80
- container._registry._extend_with_services(abstracts=[], impls=[ServiceDeclaration(handler_type)])
80
+ container._registry.extend(impls=[ServiceDeclaration(handler_type)])
81
+ container._compiler.compile()
82
+ container._scoped_compiler.compile()
81
83
 
82
84
  async def _inner(app: web.Application) -> None:
83
85
  if handlers:
@@ -1,6 +1,6 @@
1
- import asyncio
2
1
  import functools
3
2
  import importlib
3
+ import inspect
4
4
  from contextvars import ContextVar
5
5
  from dataclasses import dataclass
6
6
  from types import ModuleType
@@ -21,7 +21,7 @@ from wireup.errors import WireupError
21
21
  from wireup.ioc.container.async_container import AsyncContainer, ScopedAsyncContainer, async_container_force_sync_scope
22
22
  from wireup.ioc.container.sync_container import ScopedSyncContainer
23
23
  from wireup.ioc.types import ParameterWrapper
24
- from wireup.ioc.validation import get_valid_injection_annotated_parameters
24
+ from wireup.ioc.util import get_valid_injection_annotated_parameters
25
25
 
26
26
  if TYPE_CHECKING:
27
27
  from wireup.integration.django import WireupSettings
@@ -38,7 +38,7 @@ def wireup_middleware(
38
38
  ) -> Callable[[HttpRequest], Union[HttpResponse, Awaitable[HttpResponse]]]:
39
39
  container = get_app_container()
40
40
 
41
- if asyncio.iscoroutinefunction(get_response):
41
+ if inspect.iscoroutinefunction(get_response):
42
42
 
43
43
  async def async_inner(request: HttpRequest) -> HttpResponse:
44
44
  async with container.enter_scope() as scoped:
@@ -33,8 +33,7 @@ from wireup.integration.starlette import (
33
33
  from wireup.ioc.container.async_container import AsyncContainer, ScopedAsyncContainer
34
34
  from wireup.ioc.container.sync_container import ScopedSyncContainer
35
35
  from wireup.ioc.types import AnyCallable
36
- from wireup.ioc.validation import (
37
- assert_dependencies_valid,
36
+ from wireup.ioc.util import (
38
37
  get_inject_annotated_parameters,
39
38
  hide_annotated_names,
40
39
  )
@@ -166,8 +165,9 @@ def _update_lifespan(
166
165
  async def lifespan(app: FastAPI) -> AsyncIterator[Any]:
167
166
  if class_based_routes:
168
167
  for cbr in class_based_routes:
169
- container._registry._extend_with_services(abstracts=[], impls=[ServiceDeclaration(cbr)])
170
- assert_dependencies_valid(container)
168
+ container._registry.extend(impls=[ServiceDeclaration(cbr)])
169
+ container._compiler.compile()
170
+ container._scoped_compiler.compile()
171
171
 
172
172
  for cbr in class_based_routes:
173
173
  await _instantiate_class_based_route(app, container, cbr)
@@ -8,11 +8,10 @@ from wireup.errors import WireupError
8
8
  from wireup.ioc.container.async_container import AsyncContainer
9
9
  from wireup.ioc.container.base_container import BaseContainer
10
10
  from wireup.ioc.container.sync_container import SyncContainer
11
+ from wireup.ioc.factory_compiler import FactoryCompiler
11
12
  from wireup.ioc.override_manager import OverrideManager
12
13
  from wireup.ioc.parameter import ParameterBag
13
14
  from wireup.ioc.service_registry import ServiceRegistry
14
- from wireup.ioc.types import ContainerScope
15
- from wireup.ioc.validation import assert_dependencies_valid
16
15
 
17
16
  if TYPE_CHECKING:
18
17
  from types import ModuleType
@@ -36,6 +35,36 @@ def _create_container(
36
35
  :param parameters: Dict containing parameters you want to expose to the container. Services or factories can
37
36
  request parameters via the `Inject(param="name")` syntax.
38
37
  """
38
+ abstracts, impls = _merge_definitions(service_modules, services)
39
+ registry = ServiceRegistry(parameters=ParameterBag(parameters), abstracts=abstracts, impls=impls)
40
+
41
+ # The container uses a dual-compiler optimization strategy:
42
+ # 1. The singleton compiler generates optimized factories for singleton dependencies
43
+ # and throws errors if scoped dependencies are accessed outside a scope.
44
+ # 2. The scoped compiler handles dependencies that require request/scope isolation.
45
+ #
46
+ # When entering/exiting scopes, the container switches between these compilers.
47
+ # This eliminates the need to check lifetime rules at runtime.
48
+ singleton_compiler = FactoryCompiler(registry, is_scoped_container=False)
49
+ scoped_compiler = FactoryCompiler(registry, is_scoped_container=True)
50
+ singleton_compiler.compile()
51
+ scoped_compiler.compile()
52
+
53
+ override_manager = OverrideManager(registry.is_type_with_qualifier_known, singleton_compiler, scoped_compiler)
54
+ return klass(
55
+ registry=registry,
56
+ factory_compiler=singleton_compiler,
57
+ scoped_compiler=scoped_compiler,
58
+ global_scope_objects={},
59
+ global_scope_exit_stack=[],
60
+ override_manager=override_manager,
61
+ )
62
+
63
+
64
+ def _merge_definitions(
65
+ service_modules: Iterable[ModuleType] | None = None,
66
+ services: Iterable[Any] | None = None,
67
+ ) -> tuple[list[AbstractDeclaration], list[ServiceDeclaration]]:
39
68
  abstracts: list[AbstractDeclaration] = []
40
69
  impls: list[ServiceDeclaration] = []
41
70
 
@@ -57,17 +86,7 @@ def _create_container(
57
86
  abstracts.extend(discovered_abstracts)
58
87
  impls.extend(discovered_services)
59
88
 
60
- registry = ServiceRegistry(abstracts=abstracts, impls=impls)
61
- container = klass(
62
- registry=registry,
63
- parameters=ParameterBag(parameters),
64
- global_scope=ContainerScope(objects={}, exit_stack=[]),
65
- override_manager=OverrideManager({}, registry.is_type_with_qualifier_known),
66
- )
67
-
68
- assert_dependencies_valid(container)
69
-
70
- return container
89
+ return abstracts, impls
71
90
 
72
91
 
73
92
  def create_sync_container(
@@ -1,9 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import TYPE_CHECKING
3
+ from typing import TYPE_CHECKING, TypeVar
4
4
 
5
5
  from typing_extensions import Self
6
6
 
7
+ from wireup.errors import UnknownServiceRequestedError
7
8
  from wireup.ioc._exit_stack import async_clean_exit_stack
8
9
  from wireup.ioc.container.base_container import BaseContainer
9
10
  from wireup.ioc.container.sync_container import ScopedSyncContainer
@@ -11,12 +12,30 @@ from wireup.ioc.container.sync_container import ScopedSyncContainer
11
12
  if TYPE_CHECKING:
12
13
  from types import TracebackType
13
14
 
15
+ from wireup.ioc.types import Qualifier
16
+
17
+ T = TypeVar("T")
18
+
14
19
 
15
20
  class BareAsyncContainer(BaseContainer):
16
- get = BaseContainer._async_get
21
+ async def get(self, klass: type[T], qualifier: Qualifier | None = None) -> T:
22
+ """Get an instance of the requested type.
23
+
24
+ :param qualifier: Qualifier for the class if it was registered with one.
25
+ :param klass: Class of the dependency already registered in the container.
26
+ :return: An instance of the requested object. Always returns an existing instance when one is available.
27
+ """
28
+ obj_id = hash(klass if qualifier is None else (klass, qualifier))
29
+
30
+ if compiled_factory := self._factories.get(obj_id):
31
+ res = compiled_factory.factory(self)
32
+
33
+ return await res if compiled_factory.is_async else res # type:ignore[no-any-return]
34
+
35
+ raise UnknownServiceRequestedError(klass, qualifier)
17
36
 
18
37
  async def close(self) -> None:
19
- await async_clean_exit_stack(self._global_scope.exit_stack)
38
+ await async_clean_exit_stack(self._global_scope_exit_stack)
20
39
 
21
40
 
22
41
  class ScopedAsyncContainer(BareAsyncContainer):
@@ -37,11 +56,13 @@ class AsyncContainer(BareAsyncContainer):
37
56
  def enter_scope(self) -> ScopedAsyncContainer:
38
57
  return ScopedAsyncContainer(
39
58
  registry=self._registry,
40
- parameters=self._params,
41
59
  override_manager=self._override_mgr,
42
- global_scope=self._global_scope,
60
+ global_scope_objects=self._global_scope_objects,
61
+ global_scope_exit_stack=self._global_scope_exit_stack,
43
62
  current_scope_objects={},
44
63
  current_scope_exit_stack=[],
64
+ factory_compiler=self._scoped_compiler,
65
+ scoped_compiler=self._scoped_compiler,
45
66
  )
46
67
 
47
68
 
@@ -53,9 +74,11 @@ def async_container_force_sync_scope(container: AsyncContainer) -> ScopedSyncCon
53
74
  """
54
75
  return ScopedSyncContainer(
55
76
  registry=container._registry,
56
- parameters=container._params,
57
77
  override_manager=container._override_mgr,
58
- global_scope=container._global_scope,
78
+ global_scope_objects=container._global_scope_objects,
79
+ global_scope_exit_stack=container._global_scope_exit_stack,
59
80
  current_scope_objects={},
60
81
  current_scope_exit_stack=[],
82
+ factory_compiler=container._scoped_compiler,
83
+ scoped_compiler=container._scoped_compiler,
61
84
  )
@@ -0,0 +1,93 @@
1
+ from typing import (
2
+ Any,
3
+ AsyncGenerator,
4
+ Dict,
5
+ Generator,
6
+ List,
7
+ Optional,
8
+ Type,
9
+ TypeVar,
10
+ Union,
11
+ )
12
+
13
+ from wireup.errors import (
14
+ UnknownServiceRequestedError,
15
+ WireupError,
16
+ )
17
+ from wireup.ioc.factory_compiler import FactoryCompiler
18
+ from wireup.ioc.override_manager import OverrideManager
19
+ from wireup.ioc.parameter import ParameterBag
20
+ from wireup.ioc.service_registry import ServiceRegistry
21
+ from wireup.ioc.types import (
22
+ ContainerObjectIdentifier,
23
+ Qualifier,
24
+ )
25
+
26
+ T = TypeVar("T")
27
+ ContainerExitStack = List[Union[Generator[Any, Any, Any], AsyncGenerator[Any, Any]]]
28
+
29
+
30
+ class BaseContainer:
31
+ __slots__ = (
32
+ "_compiler",
33
+ "_current_scope_exit_stack",
34
+ "_current_scope_objects",
35
+ "_factories",
36
+ "_global_scope_exit_stack",
37
+ "_global_scope_objects",
38
+ "_override_mgr",
39
+ "_registry",
40
+ "_scoped_compiler",
41
+ )
42
+
43
+ def __init__( # noqa: PLR0913
44
+ self,
45
+ registry: ServiceRegistry,
46
+ override_manager: OverrideManager,
47
+ factory_compiler: FactoryCompiler,
48
+ scoped_compiler: FactoryCompiler,
49
+ global_scope_objects: Dict[ContainerObjectIdentifier, Any],
50
+ global_scope_exit_stack: List[Union[Generator[Any, Any, Any], AsyncGenerator[Any, Any]]],
51
+ current_scope_objects: Optional[Dict[ContainerObjectIdentifier, Any]] = None,
52
+ current_scope_exit_stack: Optional[List[Union[Generator[Any, Any, Any], AsyncGenerator[Any, Any]]]] = None,
53
+ ) -> None:
54
+ self._registry = registry
55
+ self._override_mgr = override_manager
56
+ self._global_scope_objects = global_scope_objects
57
+ self._global_scope_exit_stack = global_scope_exit_stack
58
+ self._current_scope_objects = current_scope_objects
59
+ self._current_scope_exit_stack = current_scope_exit_stack
60
+ self._compiler = factory_compiler
61
+ self._scoped_compiler = scoped_compiler
62
+ self._factories = self._compiler.factories
63
+
64
+ @property
65
+ def params(self) -> ParameterBag:
66
+ """Parameter bag associated with this container."""
67
+ return self._registry.parameters
68
+
69
+ @property
70
+ def override(self) -> OverrideManager:
71
+ """Override registered container services with new values."""
72
+ return self._override_mgr
73
+
74
+ def _synchronous_get(self, klass: Type[T], qualifier: Optional[Qualifier] = None) -> T:
75
+ """Get an instance of the requested type.
76
+
77
+ :param qualifier: Qualifier for the class if it was registered with one.
78
+ :param klass: Class of the dependency already registered in the container.
79
+ :return: An instance of the requested object. Always returns an existing instance when one is available.
80
+ """
81
+ obj_id = hash(klass if qualifier is None else (klass, qualifier))
82
+
83
+ if compiled_factory := self._factories.get(obj_id):
84
+ if compiled_factory.is_async:
85
+ msg = (
86
+ f"{klass} is an async dependency and it cannot be created in a synchronous context. "
87
+ "Create and use an async container via wireup.create_async_container."
88
+ )
89
+ raise WireupError(msg)
90
+
91
+ return compiled_factory.factory(self) # type:ignore[no-any-return]
92
+
93
+ raise UnknownServiceRequestedError(klass, qualifier)
@@ -15,7 +15,7 @@ class BareSyncContainer(BaseContainer):
15
15
  get = BaseContainer._synchronous_get
16
16
 
17
17
  def close(self) -> None:
18
- clean_exit_stack(self._global_scope.exit_stack)
18
+ clean_exit_stack(self._global_scope_exit_stack)
19
19
 
20
20
 
21
21
  class ScopedSyncContainer(BareSyncContainer):
@@ -36,9 +36,11 @@ class SyncContainer(BareSyncContainer):
36
36
  def enter_scope(self) -> ScopedSyncContainer:
37
37
  return ScopedSyncContainer(
38
38
  registry=self._registry,
39
- parameters=self._params,
40
39
  override_manager=self._override_mgr,
41
- global_scope=self._global_scope,
40
+ global_scope_objects=self._global_scope_objects,
41
+ global_scope_exit_stack=self._global_scope_exit_stack,
42
42
  current_scope_objects={},
43
43
  current_scope_exit_stack=[],
44
+ factory_compiler=self._scoped_compiler,
45
+ scoped_compiler=self._scoped_compiler,
44
46
  )
@@ -0,0 +1,159 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING, Any, Callable, Hashable
5
+
6
+ from wireup.errors import WireupError
7
+ from wireup.ioc.service_registry import GENERATOR_FACTORY_TYPES, FactoryType, ServiceRegistry
8
+ from wireup.ioc.types import ParameterWrapper, TemplatedString
9
+
10
+ if TYPE_CHECKING:
11
+ from wireup.ioc.container.base_container import BaseContainer
12
+ from wireup.ioc.service_registry import ServiceFactory
13
+
14
+
15
+ @dataclass
16
+ class CompiledFactory:
17
+ factory: Callable[[BaseContainer], Any]
18
+ is_async: bool
19
+
20
+
21
+ _CONTAINER_SCOPE_ERROR_MSG = (
22
+ "Cannot create 'transient' or 'scoped' lifetime objects from the base container. "
23
+ "Please enter a scope using container.enter_scope. "
24
+ "If you are within a scope, use the scoped container instance to create dependencies."
25
+ )
26
+ _WIREUP_GENERATED_FACTORY_NAME = "_wireup_factory"
27
+
28
+
29
+ class FactoryCompiler:
30
+ def __init__(self, registry: ServiceRegistry, *, is_scoped_container: bool) -> None:
31
+ self._registry = registry
32
+ self._is_scoped_container = is_scoped_container
33
+ self.factories: dict[int, CompiledFactory] = {}
34
+
35
+ @classmethod
36
+ def get_object_id(cls, impl: type, qualifier: Hashable) -> int:
37
+ return hash(impl if qualifier is None else (impl, qualifier))
38
+
39
+ def compile(self) -> None:
40
+ for impl, qualifiers in self._registry.impls.items():
41
+ for qualifier in qualifiers:
42
+ obj_id = FactoryCompiler.get_object_id(impl, qualifier)
43
+
44
+ self.factories[obj_id] = self._compile_and_create_function(
45
+ self._registry.factories[impl, qualifier],
46
+ impl,
47
+ qualifier,
48
+ )
49
+
50
+ for interface, impls in self._registry.interfaces.items():
51
+ for qualifier, impl in impls.items():
52
+ obj_id = FactoryCompiler.get_object_id(interface, qualifier)
53
+
54
+ self.factories[obj_id] = self._compile_and_create_function(
55
+ self._registry.factories[impl, qualifier],
56
+ interface,
57
+ qualifier,
58
+ )
59
+
60
+ def _get_factory_code(self, factory: ServiceFactory, impl: type, qualifier: Hashable) -> tuple[str, bool]: # noqa: C901, PLR0912
61
+ is_interface = self._registry.is_interface_known(impl)
62
+ if is_interface:
63
+ lifetime = self._registry.lifetime[self._registry.interface_resolve_impl(impl, qualifier), qualifier]
64
+ else:
65
+ lifetime = self._registry.lifetime[impl, qualifier]
66
+
67
+ if lifetime != "singleton" and not self._is_scoped_container:
68
+ code = f"def {_WIREUP_GENERATED_FACTORY_NAME}(container):\n"
69
+ code += " raise WireupError(_CONTAINER_SCOPE_ERROR_MSG)\n"
70
+
71
+ return code, False
72
+
73
+ maybe_async = "async " if factory.is_async else ""
74
+ code = f"{maybe_async}def {_WIREUP_GENERATED_FACTORY_NAME}(container):\n"
75
+ cache_created_instance = lifetime != "transient"
76
+
77
+ if cache_created_instance:
78
+ if lifetime == "singleton":
79
+ code += " storage = container._global_scope_objects\n"
80
+ else:
81
+ code += " storage = container._current_scope_objects\n"
82
+
83
+ code += " if res := storage.get(OBJ_ID):\n"
84
+ code += " return res\n"
85
+
86
+ kwargs = ""
87
+ for name, dep in self._registry.dependencies[factory.factory].items():
88
+ if isinstance(dep.annotation, ParameterWrapper):
89
+ param_value = (
90
+ str(dep.annotation.param)
91
+ if isinstance(dep.annotation.param, TemplatedString)
92
+ else f'"{dep.annotation.param}"'
93
+ )
94
+ code += f" _obj_dep_{name} = parameters.get({param_value})\n"
95
+ else:
96
+ if self._registry.is_interface_known(dep.klass):
97
+ dep_class = self._registry.interface_resolve_impl(dep.klass, dep.qualifier_value)
98
+ else:
99
+ dep_class = dep.klass
100
+
101
+ maybe_await = "await " if self._registry.factories[dep_class, dep.qualifier_value].is_async else ""
102
+ dep_hash = FactoryCompiler.get_object_id(dep_class, dep.qualifier_value)
103
+ code += f" _obj_dep_{name} = {maybe_await}factories[{dep_hash}].factory(container)\n"
104
+ kwargs += f"{name}=_obj_dep_{name}, "
105
+
106
+ maybe_await = "await " if factory.factory_type == FactoryType.COROUTINE_FN else ""
107
+
108
+ code += f" instance = {maybe_await}ORIGINAL_FACTORY({kwargs.strip()})\n"
109
+
110
+ if factory.factory_type in GENERATOR_FACTORY_TYPES:
111
+ if lifetime == "singleton":
112
+ code += " container._global_scope_exit_stack.append(instance)\n"
113
+ else:
114
+ code += " container._current_scope_exit_stack.append(instance)\n"
115
+
116
+ if factory.factory_type == FactoryType.GENERATOR:
117
+ code += " instance = next(instance)\n"
118
+ else:
119
+ code += " instance = await instance.__anext__()\n"
120
+
121
+ if cache_created_instance:
122
+ code += " storage[OBJ_ID] = instance\n"
123
+ if is_interface:
124
+ code += " storage[ORIGINAL_OBJ_ID] = instance\n"
125
+
126
+ code += " return instance\n"
127
+
128
+ return code, factory.is_async
129
+
130
+ def _compile_and_create_function(self, factory: ServiceFactory, impl: type, qualifier: Hashable) -> CompiledFactory:
131
+ obj_id = impl, qualifier
132
+ resolved_obj_id = (
133
+ (self._registry.interface_resolve_impl(impl, qualifier), qualifier)
134
+ if self._registry.is_interface_known(impl)
135
+ else obj_id
136
+ )
137
+
138
+ source, is_async = self._get_factory_code(factory, impl, qualifier)
139
+
140
+ try:
141
+ namespace: dict[str, Any] = {
142
+ "factories": self.factories,
143
+ "ORIGINAL_OBJ_ID": obj_id,
144
+ "OBJ_ID": resolved_obj_id,
145
+ "ORIGINAL_FACTORY": self._registry.ctors[obj_id][0],
146
+ "TemplatedString": TemplatedString,
147
+ "WireupError": WireupError,
148
+ "_CONTAINER_SCOPE_ERROR_MSG": _CONTAINER_SCOPE_ERROR_MSG,
149
+ "parameters": self._registry.parameters,
150
+ }
151
+
152
+ compiled_code = compile(source, f"<{_WIREUP_GENERATED_FACTORY_NAME}_{obj_id}>", "exec")
153
+ exec(compiled_code, namespace) # noqa: S102
154
+
155
+ return CompiledFactory(factory=namespace[_WIREUP_GENERATED_FACTORY_NAME], is_async=is_async)
156
+
157
+ except Exception as e:
158
+ msg = f"Failed to compile generated factory {obj_id}: {e}"
159
+ raise WireupError(msg) from e
@@ -0,0 +1,139 @@
1
+ from __future__ import annotations
2
+
3
+ from contextlib import contextmanager
4
+ from typing import TYPE_CHECKING, Any, Iterator
5
+
6
+ from wireup.errors import UnknownOverrideRequestedError
7
+ from wireup.ioc.factory_compiler import FactoryCompiler
8
+
9
+ if TYPE_CHECKING:
10
+ from collections.abc import Callable
11
+
12
+ from wireup.ioc.types import AnyCallable, Qualifier, ServiceOverride
13
+
14
+
15
+ class OverrideManager:
16
+ """Enables overriding of services registered with the container."""
17
+
18
+ def __init__(
19
+ self,
20
+ is_valid_override: Callable[[type, Qualifier], bool],
21
+ factory_compiler: FactoryCompiler,
22
+ scoped_factory_compiler: FactoryCompiler,
23
+ ) -> None:
24
+ self.__is_valid_override = is_valid_override
25
+ self._factory_compiler = factory_compiler
26
+ self._scoped_factory_compiler = scoped_factory_compiler
27
+ self._original_factory_functions: dict[tuple[type, Qualifier], tuple[Any, Any]] = {}
28
+
29
+ def _compiler_override_obj_id(
30
+ self,
31
+ compiler: FactoryCompiler,
32
+ target: type,
33
+ qualifier: Qualifier,
34
+ new: Callable[[Any], Any],
35
+ ) -> None:
36
+ compiler.factories[compiler.get_object_id(target, qualifier)].factory = new
37
+
38
+ def _compiler_restore_obj_id(
39
+ self,
40
+ compiler: FactoryCompiler,
41
+ target: type,
42
+ qualifier: Qualifier,
43
+ original: AnyCallable,
44
+ ) -> None:
45
+ compiler.factories[compiler.get_object_id(target, qualifier)].factory = original
46
+
47
+ def set(self, target: type, new: Any, qualifier: Qualifier | None = None) -> None:
48
+ """Override the `target` service with `new`.
49
+
50
+ Future requests to inject `target` will result in `new` being injected.
51
+
52
+ :param target: The target service to override.
53
+ :param qualifier: The qualifier of the service to override. Set this if service is registered
54
+ with the qualifier parameter set to a value.
55
+ :param new: The new object to be injected instead of `target`.
56
+ """
57
+ if not self.__is_valid_override(target, qualifier):
58
+ raise UnknownOverrideRequestedError(klass=target, qualifier=qualifier)
59
+
60
+ obj_id = FactoryCompiler.get_object_id(target, qualifier)
61
+
62
+ self._original_factory_functions[target, qualifier] = (
63
+ self._factory_compiler.factories[obj_id].factory,
64
+ self._scoped_factory_compiler.factories[obj_id].factory,
65
+ )
66
+
67
+ def override_factory(_container: Any) -> Any:
68
+ return new
69
+
70
+ self._compiler_override_obj_id(
71
+ target=target,
72
+ qualifier=qualifier,
73
+ compiler=self._factory_compiler,
74
+ new=override_factory,
75
+ )
76
+ self._compiler_override_obj_id(
77
+ target=target,
78
+ qualifier=qualifier,
79
+ compiler=self._scoped_factory_compiler,
80
+ new=override_factory,
81
+ )
82
+
83
+ def _restore_factory_methods(self, target: type, qualifier: Qualifier | None) -> None:
84
+ """Restore original factory methods after override is removed."""
85
+ if (target, qualifier) not in self._original_factory_functions:
86
+ return
87
+
88
+ factory_func, scoped_factory_func = self._original_factory_functions[target, qualifier]
89
+ self._compiler_restore_obj_id(
90
+ compiler=self._factory_compiler,
91
+ target=target,
92
+ qualifier=qualifier,
93
+ original=factory_func,
94
+ )
95
+ self._compiler_restore_obj_id(
96
+ compiler=self._scoped_factory_compiler,
97
+ target=target,
98
+ qualifier=qualifier,
99
+ original=scoped_factory_func,
100
+ )
101
+
102
+ del self._original_factory_functions[target, qualifier]
103
+
104
+ def delete(self, target: type, qualifier: Qualifier | None = None) -> None:
105
+ """Clear active override for the `target` service."""
106
+ self._restore_factory_methods(target, qualifier)
107
+
108
+ def clear(self) -> None:
109
+ """Clear active service overrides."""
110
+ for key in self._original_factory_functions:
111
+ self._restore_factory_methods(key[0], key[1])
112
+
113
+ @contextmanager
114
+ def service(self, target: type, new: Any, qualifier: Qualifier | None = None) -> Iterator[None]:
115
+ """Override the `target` service with `new` for the duration of the context manager.
116
+
117
+ Future requests to inject `target` will result in `new` being injected.
118
+
119
+ :param target: The target service to override.
120
+ :param qualifier: The qualifier of the service to override. Set this if service is registered
121
+ with the qualifier parameter set to a value.
122
+ :param new: The new object to be injected instead of `target`.
123
+ """
124
+ try:
125
+ self.set(target, new, qualifier)
126
+ yield
127
+ finally:
128
+ self.delete(target, qualifier)
129
+
130
+ @contextmanager
131
+ def services(self, overrides: list[ServiceOverride]) -> Iterator[None]:
132
+ """Override a number of services with new for the duration of the context manager."""
133
+ try:
134
+ for override in overrides:
135
+ self.set(override.target, override.new, override.qualifier)
136
+ yield
137
+ finally:
138
+ for override in overrides:
139
+ self.delete(override.target, override.qualifier)