wepositive-di 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,16 @@
1
+ Metadata-Version: 2.4
2
+ Name: wepositive-di
3
+ Version: 0.1.0
4
+ Summary: Adaptation of the dependency-injector framework to provide a simpler DI syntax slimiar to FastAPI's.
5
+ License-Expression: Apache-2.0
6
+ Author: Dolf Andringa
7
+ Author-email: dolf@wepositive.energy
8
+ Requires-Python: >=3.12
9
+ Classifier: Programming Language :: Python :: 3
10
+ Classifier: Programming Language :: Python :: 3.12
11
+ Classifier: Programming Language :: Python :: 3.13
12
+ Classifier: Programming Language :: Python :: 3.14
13
+ Requires-Dist: aiologic (>=0.16.0,<0.17.0)
14
+ Requires-Dist: dependency-injector (>=4.49.0,<5.0.0)
15
+ Requires-Dist: pydantic (>=2.13.4,<3.0.0)
16
+ Requires-Dist: pydantic-settings (>=2.14.1,<3.0.0)
@@ -0,0 +1,103 @@
1
+ [project]
2
+ name = "wepositive-di"
3
+ version = "0.1.0"
4
+ description = "Adaptation of the dependency-injector framework to provide a simpler DI syntax slimiar to FastAPI's."
5
+ authors = [
6
+ {name = "Dolf Andringa",email = "dolf@wepositive.energy"}
7
+ ]
8
+ license = "Apache-2.0"
9
+ requires-python = ">=3.12"
10
+ packages = [{include = "wepositive_di", from = "src"}]
11
+ dependencies = [
12
+ "dependency-injector (>=4.49.0,<5.0.0)",
13
+ "aiologic (>=0.16.0,<0.17.0)",
14
+ "pydantic (>=2.13.4,<3.0.0)",
15
+ "pydantic-settings (>=2.14.1,<3.0.0)"
16
+ ]
17
+
18
+
19
+ [build-system]
20
+ requires = ["poetry-core>=2.0.0,<3.0.0"]
21
+ build-backend = "poetry.core.masonry.api"
22
+
23
+ [dependency-groups]
24
+ dev = [
25
+ "pytest (>=9.0.3,<10.0.0)",
26
+ "pytest-mock (>=3.15.1,<4.0.0)",
27
+ "pyright (>=1.1.410,<2.0.0)",
28
+ "pre-commit (>=4.6.0,<5.0.0)",
29
+ "pytest-cov (>=7.1.0,<8.0.0)",
30
+ "ruff (>=0.15.15,<0.16.0)",
31
+ "pytest-asyncio (>=1.4.0,<2.0.0)",
32
+ "mkdocs-material (>=9.7.6,<10.0.0)",
33
+ "mkdocstrings (>=1.0.4,<2.0.0)",
34
+ "mkdocs-autorefs (>=1.4.4,<2.0.0)",
35
+ "mkdocstrings-python (>=2.0.3,<3.0.0)"
36
+ ]
37
+
38
+ [tool.pyright]
39
+ stubPath = "./typings/"
40
+ reportImportCycles = false
41
+ typeCheckingMode = "strict"
42
+ venv = ".venv"
43
+ venvPath = "."
44
+ exclude = [
45
+ "./typings/*",
46
+ "**/node_modules",
47
+ "**/__pycache__",
48
+ "**/.*",
49
+ "alembic/*",
50
+ ]
51
+
52
+
53
+ [tool.pytest.ini_options]
54
+ addopts = [
55
+ "--import-mode=importlib",
56
+ "--cov",
57
+ ]
58
+ asyncio_mode = "auto"
59
+ asyncio_default_test_loop_scope = "session"
60
+ asyncio_default_fixture_loop_scope = "session"
61
+
62
+ [tool.coverage.run]
63
+ branch = true
64
+ source = ["./src/wepositive_di"]
65
+
66
+ [tool.coverage.report]
67
+ exclude_also = [
68
+ "if .*TYPE_CHECKING:",
69
+ "@overload",
70
+ "@typing.overload",
71
+ "class .*\\(Protocol\\):",
72
+ "@abc.abstractmethod",
73
+ "@abstractmethod",
74
+ "\\A(?s:.*# pragma: exclude file.*)\\Z"
75
+ ]
76
+ fail_under = 100
77
+ show_missing = true
78
+
79
+ [tool.ruff.lint]
80
+ ignore = ["PT013"]
81
+ select = [
82
+ "E4",
83
+ "E7",
84
+ "E9",
85
+ "F",
86
+ "BLE",
87
+ "ASYNC",
88
+ "DTZ",
89
+ "C4",
90
+ "T10",
91
+ "T20",
92
+ "PTH",
93
+ "UP",
94
+ "I",
95
+ "LOG",
96
+ "G",
97
+ "PT",
98
+ ]
99
+ extend-safe-fixes = ["UP"]
100
+
101
+ [tool.ruff]
102
+ extend-exclude = ["typings/*", "alembic/*"]
103
+ target-version = "py312"
@@ -0,0 +1,23 @@
1
+ from wepositive_di import context
2
+ from wepositive_di.di import (
3
+ Depends,
4
+ clear_overrides,
5
+ inject,
6
+ override_provider,
7
+ provider_overrides,
8
+ register_provider,
9
+ registry,
10
+ setup,
11
+ )
12
+
13
+ __all__ = [
14
+ "Depends",
15
+ "clear_overrides",
16
+ "inject",
17
+ "override_provider",
18
+ "provider_overrides",
19
+ "register_provider",
20
+ "registry",
21
+ "setup",
22
+ "context",
23
+ ]
@@ -0,0 +1,161 @@
1
+ from abc import ABC, abstractmethod
2
+ from collections.abc import AsyncGenerator
3
+ from contextlib import AbstractAsyncContextManager, asynccontextmanager
4
+ from uuid import UUID
5
+
6
+ import aiologic
7
+ from pydantic import BaseModel
8
+
9
+ from wepositive_di.di import register_provider
10
+
11
+
12
+ class ContextStorage(ABC):
13
+ """ContextStorage interface.
14
+
15
+ This interface allows for different storage backends (in-memory, Redis, etc.)
16
+ to be used for storing context data, keyed by both context type and UUID.
17
+
18
+ One storage instance can hold multiple context types simultaneously.
19
+ Implementations must be thread-safe and async-safe.
20
+
21
+ get_context is an async context manager that yields the context while holding
22
+ a lock, ensuring safe modifications during the entire usage period.
23
+ """
24
+
25
+ @abstractmethod
26
+ def get_context[ContextTypeT: BaseModel](
27
+ self, ctx_type: type[ContextTypeT], context_id: UUID
28
+ ) -> AbstractAsyncContextManager[ContextTypeT]:
29
+ """Get a context for the given type and context_id.
30
+
31
+ This is an async context manager that yields the context while holding a lock.
32
+ The lock is held until the context manager exits.
33
+
34
+ Args:
35
+ ctx_type: The type of context to retrieve
36
+ context_id: The UUID identifying the context
37
+
38
+ Yields:
39
+ The context associated with this type and identifier
40
+
41
+ Raises:
42
+ KeyError: If the context does not exist
43
+ """
44
+ ...
45
+
46
+ @abstractmethod
47
+ async def store_context[ContextTypeT: BaseModel](
48
+ self, ctx_type: type[ContextTypeT], context_id: UUID, context: ContextTypeT
49
+ ) -> None:
50
+ """Store a new context.
51
+
52
+ This creates or replaces a context for the given type and context_id.
53
+ Thread-safe and async-safe.
54
+
55
+ Args:
56
+ ctx_type: The type of context being stored
57
+ context_id: The UUID identifying the context
58
+ context: The context to store
59
+ """
60
+ ...
61
+
62
+ @abstractmethod
63
+ async def get_context_snapshot[ContextTypeT: BaseModel](
64
+ self, ctx_type: type[ContextTypeT], context_id: UUID
65
+ ) -> ContextTypeT:
66
+ """Get a read-only snapshot source without taking the context lock.
67
+
68
+ This is intended for event emission paths that must not wait behind a
69
+ long-running mutable context lock.
70
+ """
71
+ ...
72
+
73
+
74
+ class InMemoryContextStorage(ContextStorage):
75
+ """Unified in-memory storage for contexts that works in both async and threaded environments.
76
+
77
+ Uses aiologic.RLock for synchronization, which works seamlessly across:
78
+ - Pure async servers (FastAPI with single event loop)
79
+ - Threaded servers with multiple threads
80
+ - Hybrid environments (multiple threads each with their own event loop)
81
+
82
+ Contexts are stored in a two-level dict keyed first by context type, then by UUID.
83
+ Fine-grained per-(type, id) locking allows concurrent access to different contexts.
84
+
85
+ Note: This implementation is single-process only. For multi-process
86
+ deployments (e.g., gunicorn with multiple processes), consider using
87
+ a distributed storage backend like Redis.
88
+ """
89
+
90
+ def __init__(self) -> None:
91
+ self._states: dict[type[BaseModel], dict[UUID, BaseModel]] = {}
92
+ self._locks: dict[type[BaseModel], dict[UUID, aiologic.RLock]] = {}
93
+ self._locks_lock = aiologic.RLock()
94
+
95
+ async def _get_lock(
96
+ self, ctx_type: type[BaseModel], context_id: UUID
97
+ ) -> aiologic.RLock:
98
+ """Get or create a lock for the given (ctx_type, context_id) pair."""
99
+ async with self._locks_lock:
100
+ if ctx_type not in self._locks:
101
+ self._locks[ctx_type] = {}
102
+ if context_id not in self._locks[ctx_type]:
103
+ self._locks[ctx_type][context_id] = aiologic.RLock()
104
+ return self._locks[ctx_type][context_id]
105
+
106
+ @asynccontextmanager
107
+ async def get_context[ContextTypeT: BaseModel]( # pyright: ignore [reportReturnType]
108
+ self, ctx_type: type[ContextTypeT], context_id: UUID
109
+ ) -> AsyncGenerator[ContextTypeT]:
110
+ lock = await self._get_lock(ctx_type, context_id)
111
+ async with lock:
112
+ type_store = self._states.get(ctx_type, {})
113
+ if context_id not in type_store:
114
+ raise KeyError(f"No {ctx_type.__name__} context known for {context_id}")
115
+ yield type_store[context_id] # pyright: ignore [reportReturnType]
116
+
117
+ async def store_context[ContextTypeT: BaseModel](
118
+ self,
119
+ ctx_type: type[ContextTypeT],
120
+ context_id: UUID,
121
+ context: ContextTypeT,
122
+ ) -> None:
123
+ """Store a new context.
124
+
125
+ Thread-safe creation/replacement of context.
126
+ Acquires the fine-grained lock for this (ctx_type, context_id) pair.
127
+ """
128
+ lock = await self._get_lock(ctx_type, context_id)
129
+ async with lock:
130
+ if ctx_type not in self._states:
131
+ self._states[ctx_type] = {}
132
+ self._states[ctx_type][context_id] = context
133
+
134
+ async def get_context_snapshot[ContextTypeT: BaseModel](
135
+ self,
136
+ ctx_type: type[ContextTypeT],
137
+ context_id: UUID,
138
+ ) -> ContextTypeT:
139
+ type_store = self._states.get(ctx_type, {})
140
+ if context_id not in type_store:
141
+ raise KeyError(f"No {ctx_type.__name__} context known for {context_id}")
142
+ context = type_store[context_id]
143
+ return context.model_copy(deep=True) # pyright: ignore [reportReturnType]
144
+
145
+
146
+ @register_provider(singleton=True)
147
+ def context_storage_singleton() -> InMemoryContextStorage:
148
+ """Singleton provider for the context storage.
149
+
150
+ Returns the same InMemoryContextStorage instance for the lifetime of the application.
151
+ One instance can hold all context types, keyed by (type, UUID).
152
+
153
+ This implementation uses aiologic.RLock which works seamlessly in:
154
+ - Async servers (FastAPI): Non-blocking async synchronization
155
+ - Threaded servers: Thread-safe synchronization
156
+ - Hybrid environments: Multiple threads with event loops per thread
157
+
158
+ For multi-process deployments, replace with RedisContextStorage or another
159
+ distributed storage implementation.
160
+ """
161
+ return InMemoryContextStorage()
@@ -0,0 +1,161 @@
1
+ from abc import ABC, abstractmethod
2
+ from collections.abc import AsyncIterator
3
+ from contextlib import AbstractAsyncContextManager, asynccontextmanager
4
+ from uuid import UUID
5
+
6
+ import aiologic
7
+ from pydantic import BaseModel
8
+
9
+ from cem.di import register_provider
10
+
11
+
12
+ class ContextStorage(ABC):
13
+ """ContextStorage interface.
14
+
15
+ This interface allows for different storage backends (in-memory, Redis, etc.)
16
+ to be used for storing context data, keyed by both context type and UUID.
17
+
18
+ One storage instance can hold multiple context types simultaneously.
19
+ Implementations must be thread-safe and async-safe.
20
+
21
+ get_context is an async context manager that yields the context while holding
22
+ a lock, ensuring safe modifications during the entire usage period.
23
+ """
24
+
25
+ @abstractmethod
26
+ def get_context[ContextTypeT: BaseModel](
27
+ self, ctx_type: type[ContextTypeT], context_id: UUID
28
+ ) -> AbstractAsyncContextManager[ContextTypeT]:
29
+ """Get a context for the given type and context_id.
30
+
31
+ This is an async context manager that yields the context while holding a lock.
32
+ The lock is held until the context manager exits.
33
+
34
+ Args:
35
+ ctx_type: The type of context to retrieve
36
+ context_id: The UUID identifying the context
37
+
38
+ Yields:
39
+ The context associated with this type and identifier
40
+
41
+ Raises:
42
+ KeyError: If the context does not exist
43
+ """
44
+ ...
45
+
46
+ @abstractmethod
47
+ async def store_context[ContextTypeT: BaseModel](
48
+ self, ctx_type: type[ContextTypeT], context_id: UUID, context: ContextTypeT
49
+ ) -> None:
50
+ """Store a new context.
51
+
52
+ This creates or replaces a context for the given type and context_id.
53
+ Thread-safe and async-safe.
54
+
55
+ Args:
56
+ ctx_type: The type of context being stored
57
+ context_id: The UUID identifying the context
58
+ context: The context to store
59
+ """
60
+ ...
61
+
62
+ @abstractmethod
63
+ async def get_context_snapshot[ContextTypeT: BaseModel](
64
+ self, ctx_type: type[ContextTypeT], context_id: UUID
65
+ ) -> ContextTypeT:
66
+ """Get a read-only snapshot source without taking the context lock.
67
+
68
+ This is intended for event emission paths that must not wait behind a
69
+ long-running mutable context lock.
70
+ """
71
+ ...
72
+
73
+
74
+ class InMemoryContextStorage(ContextStorage):
75
+ """Unified in-memory storage for contexts that works in both async and threaded environments.
76
+
77
+ Uses aiologic.RLock for synchronization, which works seamlessly across:
78
+ - Pure async servers (FastAPI with single event loop)
79
+ - Threaded servers with multiple threads
80
+ - Hybrid environments (multiple threads each with their own event loop)
81
+
82
+ Contexts are stored in a two-level dict keyed first by context type, then by UUID.
83
+ Fine-grained per-(type, id) locking allows concurrent access to different contexts.
84
+
85
+ Note: This implementation is single-process only. For multi-process
86
+ deployments (e.g., gunicorn with multiple processes), consider using
87
+ a distributed storage backend like Redis.
88
+ """
89
+
90
+ def __init__(self) -> None:
91
+ self._states: dict[type[BaseModel], dict[UUID, BaseModel]] = {}
92
+ self._locks: dict[type[BaseModel], dict[UUID, aiologic.RLock]] = {}
93
+ self._locks_lock = aiologic.RLock()
94
+
95
+ async def _get_lock(
96
+ self, ctx_type: type[BaseModel], context_id: UUID
97
+ ) -> aiologic.RLock:
98
+ """Get or create a lock for the given (ctx_type, context_id) pair."""
99
+ async with self._locks_lock:
100
+ if ctx_type not in self._locks:
101
+ self._locks[ctx_type] = {}
102
+ if context_id not in self._locks[ctx_type]:
103
+ self._locks[ctx_type][context_id] = aiologic.RLock()
104
+ return self._locks[ctx_type][context_id]
105
+
106
+ @asynccontextmanager
107
+ async def get_context[ContextTypeT: BaseModel]( # pyright: ignore [reportReturnType]
108
+ self, ctx_type: type[ContextTypeT], context_id: UUID
109
+ ) -> AsyncIterator[ContextTypeT]:
110
+ lock = await self._get_lock(ctx_type, context_id)
111
+ async with lock:
112
+ type_store = self._states.get(ctx_type, {})
113
+ if context_id not in type_store:
114
+ raise KeyError(f"No {ctx_type.__name__} context known for {context_id}")
115
+ yield type_store[context_id] # pyright: ignore [reportReturnType]
116
+
117
+ async def store_context[ContextTypeT: BaseModel](
118
+ self,
119
+ ctx_type: type[ContextTypeT],
120
+ context_id: UUID,
121
+ context: ContextTypeT,
122
+ ) -> None:
123
+ """Store a new context.
124
+
125
+ Thread-safe creation/replacement of context.
126
+ Acquires the fine-grained lock for this (ctx_type, context_id) pair.
127
+ """
128
+ lock = await self._get_lock(ctx_type, context_id)
129
+ async with lock:
130
+ if ctx_type not in self._states:
131
+ self._states[ctx_type] = {}
132
+ self._states[ctx_type][context_id] = context
133
+
134
+ async def get_context_snapshot[ContextTypeT: BaseModel](
135
+ self,
136
+ ctx_type: type[ContextTypeT],
137
+ context_id: UUID,
138
+ ) -> ContextTypeT:
139
+ type_store = self._states.get(ctx_type, {})
140
+ if context_id not in type_store:
141
+ raise KeyError(f"No {ctx_type.__name__} context known for {context_id}")
142
+ context = type_store[context_id]
143
+ return context.model_copy(deep=True) # pyright: ignore [reportReturnType]
144
+
145
+
146
+ @register_provider(singleton=True)
147
+ def context_storage_singleton() -> InMemoryContextStorage:
148
+ """Singleton provider for the context storage.
149
+
150
+ Returns the same InMemoryContextStorage instance for the lifetime of the application.
151
+ One instance can hold all context types, keyed by (type, UUID).
152
+
153
+ This implementation uses aiologic.RLock which works seamlessly in:
154
+ - Async servers (FastAPI): Non-blocking async synchronization
155
+ - Threaded servers: Thread-safe synchronization
156
+ - Hybrid environments: Multiple threads with event loops per thread
157
+
158
+ For multi-process deployments, replace with RedisContextStorage or another
159
+ distributed storage implementation.
160
+ """
161
+ return InMemoryContextStorage()
@@ -0,0 +1,655 @@
1
+ import asyncio
2
+ import functools
3
+ import inspect
4
+ import logging
5
+ import sys
6
+ import typing
7
+ from collections.abc import Callable, Coroutine
8
+ from contextlib import (
9
+ AbstractAsyncContextManager,
10
+ AbstractContextManager,
11
+ contextmanager,
12
+ )
13
+ from typing import (
14
+ Any,
15
+ TypeVar,
16
+ overload,
17
+ )
18
+
19
+ from dependency_injector import containers, providers
20
+
21
+ from wepositive_di.providers import (
22
+ AsyncCMFactory as _AsyncCMFactory,
23
+ )
24
+ from wepositive_di.providers import (
25
+ AsyncSingletonCMFactory as _AsyncSingletonCMFactory,
26
+ )
27
+ from wepositive_di.providers import (
28
+ CMFactory as _CMFactory,
29
+ )
30
+ from wepositive_di.providers import (
31
+ SyncSingletonCMFactory as _SyncSingletonCMFactory,
32
+ )
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+ _registered_modules: set[Any] = set()
37
+ _provider_overrides: dict[str, providers.Provider[Any]] = {}
38
+ _context_manager_providers: set[str] = set()
39
+
40
+ registry = containers.DynamicContainer()
41
+
42
+
43
+ def _detect_lifecycle(
44
+ func: Callable[..., Any],
45
+ cm_qualname: str,
46
+ enter_attr: str,
47
+ exit_attr: str,
48
+ cache_attr: str,
49
+ ) -> bool:
50
+ cached = getattr(func, cache_attr, None)
51
+ if cached is not None:
52
+ return cached # type: ignore[return-value]
53
+
54
+ code = getattr(func, "__code__", None)
55
+ if getattr(code, "co_qualname", None) == cm_qualname:
56
+ result = True
57
+ else:
58
+ result = False
59
+ try:
60
+ hints = typing.get_type_hints(func)
61
+ return_type = hints.get("return")
62
+ if return_type is not None:
63
+ result = hasattr(return_type, enter_attr) and hasattr(
64
+ return_type, exit_attr
65
+ )
66
+ except Exception: # noqa: BLE001
67
+ pass
68
+
69
+ try:
70
+ setattr(func, cache_attr, result)
71
+ except (AttributeError, TypeError):
72
+ pass
73
+ return result
74
+
75
+
76
+ def _is_async_lifecycle_annotation(func: Callable[..., Any]) -> bool:
77
+ """Return True if *func* produces an async context manager.
78
+
79
+ Detects two patterns:
80
+ 1. Functions decorated with @asynccontextmanager — identified via the code
81
+ object's co_qualname, which @wraps cannot change.
82
+ 2. Functions whose return type annotation has `__aenter__`/`__aexit__` (either
83
+ directly on the class, or proxied through a generic alias to its origin).
84
+ """
85
+ return _detect_lifecycle(
86
+ func,
87
+ "asynccontextmanager.<locals>.helper",
88
+ "__aenter__",
89
+ "__aexit__",
90
+ "_di_is_async_lifecycle",
91
+ )
92
+
93
+
94
+ def _is_sync_lifecycle_annotation(func: Callable[..., Any]) -> bool:
95
+ """Return True if *func* produces a sync context manager.
96
+
97
+ Detects two patterns:
98
+ 1. Functions decorated with @contextmanager — identified via the code
99
+ object's co_qualname, which @wraps cannot change.
100
+ 2. Functions whose return type annotation has `__enter__`/`__exit__` (either
101
+ directly on the class, or proxied through a generic alias to its origin).
102
+ """
103
+ return _detect_lifecycle(
104
+ func,
105
+ "contextmanager.<locals>.helper",
106
+ "__enter__",
107
+ "__exit__",
108
+ "_di_is_sync_lifecycle",
109
+ )
110
+
111
+
112
+ def _lookup_provider(name: str) -> Any:
113
+ """Return the provider callable for *name*, checking overrides first."""
114
+ if name in _provider_overrides:
115
+ return _provider_overrides[name]
116
+ return getattr(registry, name)
117
+
118
+
119
+ def _uses_context_manager(provider_name: str) -> bool:
120
+ return provider_name in _context_manager_providers
121
+
122
+
123
+ def _resolve_deps_sync(
124
+ sig: inspect.Signature, provider_name: str, func_name: str
125
+ ) -> dict[str, Any]:
126
+ """Resolve Depends[...] markers in *sig* synchronously.
127
+
128
+ Raises RuntimeError if any dependency resolves to a coroutine — sync
129
+ providers cannot have async dependencies.
130
+ """
131
+ bound = sig.bind_partial()
132
+ bound.apply_defaults()
133
+ for param_name in list(bound.arguments.keys()):
134
+ value = bound.arguments[param_name]
135
+ if isinstance(value, _DependsMarker):
136
+ result = _lookup_provider(value.name)()
137
+ if asyncio.iscoroutine(result):
138
+ result.close()
139
+ raise RuntimeError(
140
+ f"Cannot resolve async dependency '{param_name}' in sync provider "
141
+ f"'{provider_name}'. Sync providers cannot have async dependencies. "
142
+ f"Make your provider async instead: async def {func_name}(...)"
143
+ )
144
+ bound.arguments[param_name] = result
145
+ return bound.arguments
146
+
147
+
148
+ async def _resolve_deps_async(sig: inspect.Signature) -> dict[str, Any]:
149
+ """Resolve Depends[...] markers in *sig*, awaiting any async results."""
150
+ bound = sig.bind_partial()
151
+ bound.apply_defaults()
152
+ for param_name in list(bound.arguments.keys()):
153
+ value = bound.arguments[param_name]
154
+ if isinstance(value, _DependsMarker):
155
+ result = _lookup_provider(value.name)()
156
+ if asyncio.iscoroutine(result):
157
+ result = await result
158
+ bound.arguments[param_name] = result
159
+ return bound.arguments
160
+
161
+
162
+ def _create_provider(
163
+ func: Callable[..., Any],
164
+ *,
165
+ provider_name: str | None = None,
166
+ singleton: bool = False,
167
+ context_manager: bool = False,
168
+ ) -> providers.Provider[Any]:
169
+ """Wrap *func* as the appropriate provider type, with Depends resolution.
170
+
171
+ Provider type is chosen based on the function's characteristics:
172
+ - Async context manager + singleton → providers.Resource
173
+ - Async context manager → AsyncCMFactory
174
+ - Sync context manager + singleton → providers.Resource
175
+ - Sync context manager → CMFactory
176
+ - Plain async function → providers.Coroutine
177
+ - Sync singleton → providers.Singleton
178
+ - Sync factory (default) → providers.Factory
179
+ """
180
+ name = provider_name or func.__name__
181
+ sig = inspect.signature(func)
182
+ is_async_cm = context_manager and _is_async_lifecycle_annotation(func)
183
+ is_sync_cm = context_manager and _is_sync_lifecycle_annotation(func)
184
+ is_async_func = inspect.iscoroutinefunction(func)
185
+
186
+ if context_manager and not (is_async_cm or is_sync_cm):
187
+ raise ValueError(
188
+ f"Provider '{name}' was registered with context_manager=True, "
189
+ "but it does not return a supported sync or async context manager."
190
+ )
191
+
192
+ if is_async_cm:
193
+
194
+ async def async_cm_wrapper():
195
+ return func(**(await _resolve_deps_async(sig)))
196
+
197
+ if singleton:
198
+ return _AsyncSingletonCMFactory(async_cm_wrapper)
199
+ return _AsyncCMFactory(async_cm_wrapper)
200
+
201
+ elif is_sync_cm:
202
+
203
+ def sync_cm_wrapper():
204
+ return func(**_resolve_deps_sync(sig, name, func.__name__))
205
+
206
+ if singleton:
207
+ return _SyncSingletonCMFactory(sync_cm_wrapper)
208
+ return _CMFactory(sync_cm_wrapper) # type: ignore[return-value]
209
+
210
+ elif is_async_func:
211
+
212
+ async def async_func_wrapper():
213
+ return await func(**(await _resolve_deps_async(sig)))
214
+
215
+ return providers.Coroutine(async_func_wrapper)
216
+
217
+ else:
218
+
219
+ def sync_wrapper():
220
+ return func(**_resolve_deps_sync(sig, name, func.__name__))
221
+
222
+ if singleton:
223
+ return providers.Singleton(sync_wrapper)
224
+ return providers.Factory(sync_wrapper) # type: ignore[return-value]
225
+
226
+
227
+ def register_provider(
228
+ name: str | None = None,
229
+ singleton: bool = False,
230
+ context_manager: bool = False,
231
+ ):
232
+ def decorator(func: Callable[..., Any]):
233
+ """Register a provider function (sync or async) in the registry.
234
+
235
+ Args:
236
+ name: Optional name for the provider (defaults to function name)
237
+ singleton: If True, caches and reuses the first created instance.
238
+ context_manager: If True, enter and exit the provider's context manager
239
+ when resolving dependencies. Context-manager handling is opt-in.
240
+ """
241
+ provider_name = name or func.__name__
242
+ is_async_func = inspect.iscoroutinefunction(func)
243
+
244
+ if is_async_func and not context_manager and singleton:
245
+ raise ValueError(
246
+ f"Async provider '{provider_name}' cannot be a singleton. "
247
+ f"The dependency-injector library doesn't support singleton caching for Coroutine providers. "
248
+ f"Make your provider a sync function instead: def {func.__name__}(...)"
249
+ )
250
+
251
+ provider = _create_provider(
252
+ func,
253
+ provider_name=provider_name,
254
+ singleton=singleton,
255
+ context_manager=context_manager,
256
+ )
257
+ setattr(registry, provider_name, provider)
258
+ if context_manager:
259
+ _context_manager_providers.add(provider_name)
260
+ else:
261
+ _context_manager_providers.discard(provider_name)
262
+
263
+ module = inspect.getmodule(func)
264
+ if module is not None: # pragma: no branch
265
+ _registered_modules.add(module)
266
+ return func
267
+
268
+ return decorator
269
+
270
+
271
+ def setup(
272
+ overrides: dict[Callable[..., Any] | str, Callable[..., Any]] | None = None,
273
+ ):
274
+ """Wire the dependency injection system.
275
+
276
+ Args:
277
+ overrides: Optional dictionary of provider overrides to apply before wiring.
278
+ Maps original providers to their override implementations.
279
+
280
+ Usage:
281
+
282
+ ```python
283
+ setup()
284
+
285
+ def redis_storage() -> ContextStorage:
286
+ return RedisContextStorage()
287
+
288
+ setup(overrides={context_storage_singleton: redis_storage})
289
+ ```
290
+ """
291
+ if overrides:
292
+ for original, override_func in overrides.items():
293
+ provider_name = original if isinstance(original, str) else original.__name__
294
+ _provider_overrides[provider_name] = _create_provider(
295
+ override_func,
296
+ provider_name=provider_name,
297
+ context_manager=_uses_context_manager(provider_name),
298
+ )
299
+
300
+ registry.wire(modules=list(_registered_modules))
301
+
302
+
303
+ @overload
304
+ def override_provider(
305
+ original: Callable[..., Any] | str,
306
+ ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: ...
307
+
308
+
309
+ @overload
310
+ def override_provider(
311
+ original: Callable[..., Any] | str,
312
+ override: Callable[..., Any],
313
+ ) -> None: ...
314
+
315
+
316
+ def override_provider(
317
+ original: Callable[..., Any] | str,
318
+ override: Callable[..., Any] | None = None,
319
+ ) -> None | Callable[[Callable[..., Any]], Callable[..., Any]]:
320
+ """Override a provider with a new implementation.
321
+
322
+ Can be used as a function or a decorator.
323
+
324
+ Args:
325
+ original: The original provider function or its name
326
+ override: The new provider function (when used as a function call)
327
+
328
+ Returns:
329
+ None when used as a function, decorator when used as @override_provider(original)
330
+
331
+ Usage:
332
+
333
+ ```python
334
+ @register_provider()
335
+ async def config() -> Config:
336
+ return Config()
337
+
338
+ async def prod_config() -> Config:
339
+ return Config(db_url="production")
340
+
341
+ override_provider(config, prod_config)
342
+ ```
343
+
344
+ ```python
345
+ @override_provider(config)
346
+ async def prod_config() -> Config:
347
+ return Config(db_url="production")
348
+ ```
349
+ """
350
+ provider_name = original if isinstance(original, str) else original.__name__
351
+
352
+ # Used as @override_provider(original)
353
+ if override is None:
354
+
355
+ def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
356
+ _provider_overrides[provider_name] = _create_provider(
357
+ func,
358
+ provider_name=provider_name,
359
+ context_manager=_uses_context_manager(provider_name),
360
+ )
361
+ return func
362
+
363
+ return decorator
364
+
365
+ # Used as override_provider(original, override_func)
366
+ _provider_overrides[provider_name] = _create_provider(
367
+ override,
368
+ provider_name=provider_name,
369
+ context_manager=_uses_context_manager(provider_name),
370
+ )
371
+ return None
372
+
373
+
374
+ def clear_overrides() -> None:
375
+ """Clear all provider overrides."""
376
+ _provider_overrides.clear()
377
+
378
+
379
+ @contextmanager
380
+ def provider_overrides(
381
+ overrides: dict[Callable[..., Any] | str, Callable[..., Any]],
382
+ ):
383
+ """Context manager to temporarily override providers for testing.
384
+
385
+ Args:
386
+ overrides: Dictionary mapping original providers to their overrides
387
+
388
+ Usage:
389
+
390
+ ```python
391
+ async def test_config() -> Config:
392
+ return Config(sqlalchemy_db_uri=SecretStr("sqlite:///:memory:"))
393
+
394
+ with provider_overrides({config: test_config}):
395
+ result = await my_function()
396
+ ```
397
+ """
398
+ # Save current state
399
+ old_overrides = _provider_overrides.copy()
400
+
401
+ # Apply new overrides
402
+ for original, override in overrides.items():
403
+ provider_name = original if isinstance(original, str) else original.__name__
404
+ _provider_overrides[provider_name] = _create_provider(
405
+ override,
406
+ provider_name=provider_name,
407
+ context_manager=_uses_context_manager(provider_name),
408
+ )
409
+
410
+ try:
411
+ yield
412
+ finally:
413
+ # Restore previous state
414
+ _provider_overrides.clear()
415
+ _provider_overrides.update(old_overrides)
416
+
417
+
418
+ class _DependsMarker:
419
+ """Marker class for lazy dependency injection."""
420
+
421
+ def __init__(self, name: str):
422
+ self.name = name
423
+
424
+
425
+ class _DependsType:
426
+ """Subscriptable type for Depends[func] syntax."""
427
+
428
+ def __getitem__(self, func: Callable[..., Any] | str) -> Any:
429
+ """Create a dependency marker using subscript notation.
430
+
431
+ Usage: def my_func(config: Config = Depends[config]):
432
+ """
433
+ if isinstance(func, str):
434
+ name = func
435
+ else:
436
+ name = func.__name__
437
+
438
+ return _DependsMarker(name)
439
+
440
+
441
+ Depends = _DependsType()
442
+
443
+ T = TypeVar("T", bound=Callable[..., Any])
444
+
445
+
446
+ @contextmanager
447
+ def _create_event_loop(param_name: str, func_name: str):
448
+ has_running_loop = False
449
+ try:
450
+ asyncio.get_running_loop()
451
+ has_running_loop = True
452
+ except RuntimeError:
453
+ # No running loop - this is fine
454
+ pass
455
+
456
+ if has_running_loop:
457
+ # Can't use run_until_complete in an already-running loop
458
+ raise RuntimeError(
459
+ f"Cannot resolve async dependency '{param_name}' in sync function "
460
+ f"'{func_name}' from within an async context. "
461
+ f"Either make '{func_name}' async or call it from a sync context."
462
+ )
463
+
464
+ # No running loop - create one for this sync context
465
+ loop = asyncio.new_event_loop()
466
+ try:
467
+ yield loop
468
+ finally:
469
+ loop.close()
470
+
471
+
472
+ @overload
473
+ def inject(
474
+ func: Callable[..., Coroutine[Any, Any, Any]],
475
+ ) -> Callable[..., Coroutine[Any, Any, Any]]: ...
476
+
477
+
478
+ @overload
479
+ def inject(func: Callable[..., Any]) -> Callable[..., Any]: ...
480
+
481
+
482
+ def inject[T: Callable[..., Any]](func: T) -> T:
483
+ """Decorator that resolves Depends markers in function arguments.
484
+
485
+ Works with both sync and async functions.
486
+
487
+ The decorator:
488
+
489
+ - Inspects the function signature for `_DependsMarker` defaults.
490
+ - At call time, resolves each marker by calling the registry provider.
491
+ - Handles context manager providers transparently: enters the context manager,
492
+ passes the yielded value to the function, and exits on completion.
493
+ - Returns the appropriate wrapper based on the function type.
494
+
495
+ Provider types and how they are resolved:
496
+
497
+ - AsyncCMFactory: await coroutine, get context manager, await `__aenter__`.
498
+ - CMFactory: get context manager, call `__enter__`.
499
+ - providers.Coroutine: await result.
500
+ - providers.Factory / providers.Singleton: use result directly.
501
+
502
+ Usage:
503
+
504
+ ```python
505
+ @inject
506
+ def my_func(config: Config = Depends[config]):
507
+ return config.value
508
+
509
+ @inject
510
+ async def my_async_func(session: AsyncSession = Depends[async_session]):
511
+ return session.query(...)
512
+ ```
513
+
514
+ """
515
+ sig = inspect.signature(func)
516
+ dependant_is_async = inspect.iscoroutinefunction(
517
+ func
518
+ ) or inspect.isasyncgenfunction(func)
519
+
520
+ async def _resolve_dependencies_async(
521
+ args: tuple[Any, ...], kwargs: dict[str, Any]
522
+ ) -> tuple[
523
+ dict[str, Any],
524
+ set[AbstractAsyncContextManager[Any]],
525
+ set[AbstractContextManager[Any]],
526
+ ]:
527
+ bound = sig.bind_partial(*args, **kwargs)
528
+ bound.apply_defaults()
529
+ sync_lifecycles_to_cleanup: set[AbstractContextManager[Any]] = set()
530
+ async_lifecycles_to_cleanup: set[AbstractAsyncContextManager[Any]] = set()
531
+
532
+ for param_name in list(bound.arguments.keys()):
533
+ value = bound.arguments[param_name]
534
+ if not isinstance(value, _DependsMarker):
535
+ continue
536
+ provider = _lookup_provider(value.name)
537
+ result = provider()
538
+
539
+ if isinstance(provider, _AsyncSingletonCMFactory):
540
+ result = await result
541
+ elif isinstance(provider, _SyncSingletonCMFactory):
542
+ pass # sync Resource returns the cached value directly
543
+ elif isinstance(provider, _AsyncCMFactory):
544
+ cm = await result # await async wrapper to get the CM object
545
+ async_lifecycles_to_cleanup.add(cm)
546
+ result = await cm.__aenter__()
547
+ elif isinstance(provider, _CMFactory):
548
+ cm = result # sync wrapper returns the CM directly
549
+ sync_lifecycles_to_cleanup.add(cm)
550
+ result = cm.__enter__()
551
+ elif isinstance(provider, providers.Coroutine):
552
+ result = await result
553
+
554
+ bound.arguments[param_name] = result
555
+
556
+ return bound.arguments, async_lifecycles_to_cleanup, sync_lifecycles_to_cleanup
557
+
558
+ def _resolve_dependencies_sync(
559
+ args: tuple[Any, ...], kwargs: dict[str, Any]
560
+ ) -> tuple[
561
+ dict[str, Any],
562
+ set[AbstractAsyncContextManager[Any]],
563
+ set[AbstractContextManager[Any]],
564
+ ]:
565
+ bound = sig.bind_partial(*args, **kwargs)
566
+ bound.apply_defaults()
567
+ sync_lifecycles_to_cleanup: set[AbstractContextManager[Any]] = set()
568
+ async_lifecycles_to_cleanup: set[AbstractAsyncContextManager[Any]] = set()
569
+
570
+ for param_name in list(bound.arguments.keys()):
571
+ value = bound.arguments[param_name]
572
+ if not isinstance(value, _DependsMarker):
573
+ continue
574
+ provider = _lookup_provider(value.name)
575
+ result = provider()
576
+
577
+ if isinstance(provider, _AsyncSingletonCMFactory):
578
+ with _create_event_loop(param_name, func.__name__) as loop:
579
+ result = loop.run_until_complete(result)
580
+ elif isinstance(provider, _SyncSingletonCMFactory):
581
+ pass # sync Resource returns the cached value directly
582
+ elif isinstance(provider, _AsyncCMFactory):
583
+ with _create_event_loop(param_name, func.__name__) as loop:
584
+ cm = loop.run_until_complete(result) # await async wrapper → CM
585
+ async_lifecycles_to_cleanup.add(cm)
586
+ result = loop.run_until_complete(cm.__aenter__())
587
+ elif isinstance(provider, _CMFactory):
588
+ cm = result
589
+ sync_lifecycles_to_cleanup.add(cm)
590
+ result = cm.__enter__()
591
+ elif isinstance(provider, providers.Coroutine):
592
+ with _create_event_loop(param_name, func.__name__) as loop:
593
+ result = loop.run_until_complete(result)
594
+
595
+ bound.arguments[param_name] = result
596
+
597
+ return bound.arguments, async_lifecycles_to_cleanup, sync_lifecycles_to_cleanup
598
+
599
+ if dependant_is_async:
600
+
601
+ @functools.wraps(func)
602
+ async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
603
+ (
604
+ resolved_args,
605
+ async_lifecycles_to_cleanup,
606
+ sync_lifecycles_to_cleanup,
607
+ ) = await _resolve_dependencies_async(args, kwargs)
608
+
609
+ exc_info = (None, None, None)
610
+ try:
611
+ result = await func(**resolved_args)
612
+ return result
613
+ except Exception:
614
+ exc_info = sys.exc_info()
615
+ raise
616
+ finally:
617
+ suppressed = False
618
+ for cm in async_lifecycles_to_cleanup:
619
+ if await cm.__aexit__(*exc_info):
620
+ suppressed = True
621
+ for cm in sync_lifecycles_to_cleanup:
622
+ if cm.__exit__(*exc_info):
623
+ suppressed = True
624
+ if exc_info[0] is not None and not suppressed:
625
+ raise exc_info[1].with_traceback(exc_info[2]) # type: ignore[union-attr]
626
+
627
+ return async_wrapper # type: ignore
628
+ else:
629
+
630
+ @functools.wraps(func)
631
+ def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
632
+ resolved_args, async_lifecycles_to_cleanup, sync_lifecycles_to_cleanup = (
633
+ _resolve_dependencies_sync(args, kwargs)
634
+ )
635
+ exc_info = (None, None, None)
636
+ try:
637
+ result = func(**resolved_args)
638
+ return result
639
+ except Exception:
640
+ exc_info = sys.exc_info()
641
+ raise
642
+ finally:
643
+ suppressed = False
644
+ for cm in async_lifecycles_to_cleanup:
645
+ cm_name = getattr(cm, "__wrapped__", type(cm)).__name__
646
+ with _create_event_loop(cm_name, func.__name__) as loop:
647
+ if loop.run_until_complete(cm.__aexit__(*exc_info)):
648
+ suppressed = True
649
+ for cm in sync_lifecycles_to_cleanup:
650
+ if cm.__exit__(*exc_info):
651
+ suppressed = True
652
+ if exc_info[0] is not None and not suppressed:
653
+ raise exc_info[1].with_traceback(exc_info[2]) # type: ignore[union-attr]
654
+
655
+ return sync_wrapper # type: ignore
@@ -0,0 +1,50 @@
1
+ from collections.abc import Coroutine
2
+ from contextlib import AbstractAsyncContextManager, AbstractContextManager
3
+ from typing import Any, TypeVar
4
+
5
+ from dependency_injector import providers
6
+
7
+ _T = TypeVar("_T")
8
+
9
+
10
+ class CMFactory(providers.Factory[AbstractContextManager[_T]]):
11
+ """Provider that creates a new sync context manager on every call.
12
+
13
+ The inject decorator detects this type and automatically calls `__enter__`,
14
+ passes the yielded value to the dependant function, then calls `__exit__`
15
+ with any exception raised — transparently to the caller.
16
+ """
17
+
18
+
19
+ class AsyncCMFactory(
20
+ providers.Factory[Coroutine[Any, Any, AbstractAsyncContextManager[_T]]]
21
+ ):
22
+ """Provider that creates a new async context manager on every call.
23
+
24
+ The wrapper function is async (to resolve async Depends), so calling this
25
+ provider returns a coroutine that resolves to the context manager object.
26
+ The inject decorator awaits it to get the CM, then calls `__aenter__`, passes
27
+ the yielded value to the dependant function, and calls `__aexit__` with any
28
+ exception raised — transparently to the caller.
29
+ """
30
+
31
+
32
+ class SyncSingletonCMFactory(providers.Resource[_T]):
33
+ """Singleton sync context manager provider.
34
+
35
+ Inherits from providers.Resource so that the underlying context manager is
36
+ entered once, the yielded value is cached, and teardown is triggered by
37
+ registry.shutdown_resources(). Calling this provider returns the cached
38
+ value directly (no await needed).
39
+ """
40
+
41
+
42
+ class AsyncSingletonCMFactory(providers.Resource[_T]):
43
+ """Singleton async context manager provider.
44
+
45
+ Inherits from providers.Resource so that the underlying context manager is
46
+ entered once, the yielded value is cached, and teardown is triggered by
47
+ registry.shutdown_resources(). Calling this provider returns a coroutine on
48
+ the first call (before the value is cached); the inject decorator awaits it
49
+ to obtain the yielded value.
50
+ """