wepositive-di 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wepositive_di-0.1.0/PKG-INFO +16 -0
- wepositive_di-0.1.0/pyproject.toml +103 -0
- wepositive_di-0.1.0/src/wepositive_di/__init__.py +23 -0
- wepositive_di-0.1.0/src/wepositive_di/context.py +161 -0
- wepositive_di-0.1.0/src/wepositive_di/context.pye +161 -0
- wepositive_di-0.1.0/src/wepositive_di/di.py +655 -0
- wepositive_di-0.1.0/src/wepositive_di/providers.py +50 -0
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: wepositive-di
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Adaptation of the dependency-injector framework to provide a simpler DI syntax slimiar to FastAPI's.
|
|
5
|
+
License-Expression: Apache-2.0
|
|
6
|
+
Author: Dolf Andringa
|
|
7
|
+
Author-email: dolf@wepositive.energy
|
|
8
|
+
Requires-Python: >=3.12
|
|
9
|
+
Classifier: Programming Language :: Python :: 3
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.14
|
|
13
|
+
Requires-Dist: aiologic (>=0.16.0,<0.17.0)
|
|
14
|
+
Requires-Dist: dependency-injector (>=4.49.0,<5.0.0)
|
|
15
|
+
Requires-Dist: pydantic (>=2.13.4,<3.0.0)
|
|
16
|
+
Requires-Dist: pydantic-settings (>=2.14.1,<3.0.0)
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "wepositive-di"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "Adaptation of the dependency-injector framework to provide a simpler DI syntax slimiar to FastAPI's."
|
|
5
|
+
authors = [
|
|
6
|
+
{name = "Dolf Andringa",email = "dolf@wepositive.energy"}
|
|
7
|
+
]
|
|
8
|
+
license = "Apache-2.0"
|
|
9
|
+
requires-python = ">=3.12"
|
|
10
|
+
packages = [{include = "wepositive_di", from = "src"}]
|
|
11
|
+
dependencies = [
|
|
12
|
+
"dependency-injector (>=4.49.0,<5.0.0)",
|
|
13
|
+
"aiologic (>=0.16.0,<0.17.0)",
|
|
14
|
+
"pydantic (>=2.13.4,<3.0.0)",
|
|
15
|
+
"pydantic-settings (>=2.14.1,<3.0.0)"
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
[build-system]
|
|
20
|
+
requires = ["poetry-core>=2.0.0,<3.0.0"]
|
|
21
|
+
build-backend = "poetry.core.masonry.api"
|
|
22
|
+
|
|
23
|
+
[dependency-groups]
|
|
24
|
+
dev = [
|
|
25
|
+
"pytest (>=9.0.3,<10.0.0)",
|
|
26
|
+
"pytest-mock (>=3.15.1,<4.0.0)",
|
|
27
|
+
"pyright (>=1.1.410,<2.0.0)",
|
|
28
|
+
"pre-commit (>=4.6.0,<5.0.0)",
|
|
29
|
+
"pytest-cov (>=7.1.0,<8.0.0)",
|
|
30
|
+
"ruff (>=0.15.15,<0.16.0)",
|
|
31
|
+
"pytest-asyncio (>=1.4.0,<2.0.0)",
|
|
32
|
+
"mkdocs-material (>=9.7.6,<10.0.0)",
|
|
33
|
+
"mkdocstrings (>=1.0.4,<2.0.0)",
|
|
34
|
+
"mkdocs-autorefs (>=1.4.4,<2.0.0)",
|
|
35
|
+
"mkdocstrings-python (>=2.0.3,<3.0.0)"
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
[tool.pyright]
|
|
39
|
+
stubPath = "./typings/"
|
|
40
|
+
reportImportCycles = false
|
|
41
|
+
typeCheckingMode = "strict"
|
|
42
|
+
venv = ".venv"
|
|
43
|
+
venvPath = "."
|
|
44
|
+
exclude = [
|
|
45
|
+
"./typings/*",
|
|
46
|
+
"**/node_modules",
|
|
47
|
+
"**/__pycache__",
|
|
48
|
+
"**/.*",
|
|
49
|
+
"alembic/*",
|
|
50
|
+
]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
[tool.pytest.ini_options]
|
|
54
|
+
addopts = [
|
|
55
|
+
"--import-mode=importlib",
|
|
56
|
+
"--cov",
|
|
57
|
+
]
|
|
58
|
+
asyncio_mode = "auto"
|
|
59
|
+
asyncio_default_test_loop_scope = "session"
|
|
60
|
+
asyncio_default_fixture_loop_scope = "session"
|
|
61
|
+
|
|
62
|
+
[tool.coverage.run]
|
|
63
|
+
branch = true
|
|
64
|
+
source = ["./src/wepositive_di"]
|
|
65
|
+
|
|
66
|
+
[tool.coverage.report]
|
|
67
|
+
exclude_also = [
|
|
68
|
+
"if .*TYPE_CHECKING:",
|
|
69
|
+
"@overload",
|
|
70
|
+
"@typing.overload",
|
|
71
|
+
"class .*\\(Protocol\\):",
|
|
72
|
+
"@abc.abstractmethod",
|
|
73
|
+
"@abstractmethod",
|
|
74
|
+
"\\A(?s:.*# pragma: exclude file.*)\\Z"
|
|
75
|
+
]
|
|
76
|
+
fail_under = 100
|
|
77
|
+
show_missing = true
|
|
78
|
+
|
|
79
|
+
[tool.ruff.lint]
|
|
80
|
+
ignore = ["PT013"]
|
|
81
|
+
select = [
|
|
82
|
+
"E4",
|
|
83
|
+
"E7",
|
|
84
|
+
"E9",
|
|
85
|
+
"F",
|
|
86
|
+
"BLE",
|
|
87
|
+
"ASYNC",
|
|
88
|
+
"DTZ",
|
|
89
|
+
"C4",
|
|
90
|
+
"T10",
|
|
91
|
+
"T20",
|
|
92
|
+
"PTH",
|
|
93
|
+
"UP",
|
|
94
|
+
"I",
|
|
95
|
+
"LOG",
|
|
96
|
+
"G",
|
|
97
|
+
"PT",
|
|
98
|
+
]
|
|
99
|
+
extend-safe-fixes = ["UP"]
|
|
100
|
+
|
|
101
|
+
[tool.ruff]
|
|
102
|
+
extend-exclude = ["typings/*", "alembic/*"]
|
|
103
|
+
target-version = "py312"
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from wepositive_di import context
|
|
2
|
+
from wepositive_di.di import (
|
|
3
|
+
Depends,
|
|
4
|
+
clear_overrides,
|
|
5
|
+
inject,
|
|
6
|
+
override_provider,
|
|
7
|
+
provider_overrides,
|
|
8
|
+
register_provider,
|
|
9
|
+
registry,
|
|
10
|
+
setup,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"Depends",
|
|
15
|
+
"clear_overrides",
|
|
16
|
+
"inject",
|
|
17
|
+
"override_provider",
|
|
18
|
+
"provider_overrides",
|
|
19
|
+
"register_provider",
|
|
20
|
+
"registry",
|
|
21
|
+
"setup",
|
|
22
|
+
"context",
|
|
23
|
+
]
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from collections.abc import AsyncGenerator
|
|
3
|
+
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
|
4
|
+
from uuid import UUID
|
|
5
|
+
|
|
6
|
+
import aiologic
|
|
7
|
+
from pydantic import BaseModel
|
|
8
|
+
|
|
9
|
+
from wepositive_di.di import register_provider
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ContextStorage(ABC):
|
|
13
|
+
"""ContextStorage interface.
|
|
14
|
+
|
|
15
|
+
This interface allows for different storage backends (in-memory, Redis, etc.)
|
|
16
|
+
to be used for storing context data, keyed by both context type and UUID.
|
|
17
|
+
|
|
18
|
+
One storage instance can hold multiple context types simultaneously.
|
|
19
|
+
Implementations must be thread-safe and async-safe.
|
|
20
|
+
|
|
21
|
+
get_context is an async context manager that yields the context while holding
|
|
22
|
+
a lock, ensuring safe modifications during the entire usage period.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
@abstractmethod
|
|
26
|
+
def get_context[ContextTypeT: BaseModel](
|
|
27
|
+
self, ctx_type: type[ContextTypeT], context_id: UUID
|
|
28
|
+
) -> AbstractAsyncContextManager[ContextTypeT]:
|
|
29
|
+
"""Get a context for the given type and context_id.
|
|
30
|
+
|
|
31
|
+
This is an async context manager that yields the context while holding a lock.
|
|
32
|
+
The lock is held until the context manager exits.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
ctx_type: The type of context to retrieve
|
|
36
|
+
context_id: The UUID identifying the context
|
|
37
|
+
|
|
38
|
+
Yields:
|
|
39
|
+
The context associated with this type and identifier
|
|
40
|
+
|
|
41
|
+
Raises:
|
|
42
|
+
KeyError: If the context does not exist
|
|
43
|
+
"""
|
|
44
|
+
...
|
|
45
|
+
|
|
46
|
+
@abstractmethod
|
|
47
|
+
async def store_context[ContextTypeT: BaseModel](
|
|
48
|
+
self, ctx_type: type[ContextTypeT], context_id: UUID, context: ContextTypeT
|
|
49
|
+
) -> None:
|
|
50
|
+
"""Store a new context.
|
|
51
|
+
|
|
52
|
+
This creates or replaces a context for the given type and context_id.
|
|
53
|
+
Thread-safe and async-safe.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
ctx_type: The type of context being stored
|
|
57
|
+
context_id: The UUID identifying the context
|
|
58
|
+
context: The context to store
|
|
59
|
+
"""
|
|
60
|
+
...
|
|
61
|
+
|
|
62
|
+
@abstractmethod
|
|
63
|
+
async def get_context_snapshot[ContextTypeT: BaseModel](
|
|
64
|
+
self, ctx_type: type[ContextTypeT], context_id: UUID
|
|
65
|
+
) -> ContextTypeT:
|
|
66
|
+
"""Get a read-only snapshot source without taking the context lock.
|
|
67
|
+
|
|
68
|
+
This is intended for event emission paths that must not wait behind a
|
|
69
|
+
long-running mutable context lock.
|
|
70
|
+
"""
|
|
71
|
+
...
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class InMemoryContextStorage(ContextStorage):
|
|
75
|
+
"""Unified in-memory storage for contexts that works in both async and threaded environments.
|
|
76
|
+
|
|
77
|
+
Uses aiologic.RLock for synchronization, which works seamlessly across:
|
|
78
|
+
- Pure async servers (FastAPI with single event loop)
|
|
79
|
+
- Threaded servers with multiple threads
|
|
80
|
+
- Hybrid environments (multiple threads each with their own event loop)
|
|
81
|
+
|
|
82
|
+
Contexts are stored in a two-level dict keyed first by context type, then by UUID.
|
|
83
|
+
Fine-grained per-(type, id) locking allows concurrent access to different contexts.
|
|
84
|
+
|
|
85
|
+
Note: This implementation is single-process only. For multi-process
|
|
86
|
+
deployments (e.g., gunicorn with multiple processes), consider using
|
|
87
|
+
a distributed storage backend like Redis.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
def __init__(self) -> None:
|
|
91
|
+
self._states: dict[type[BaseModel], dict[UUID, BaseModel]] = {}
|
|
92
|
+
self._locks: dict[type[BaseModel], dict[UUID, aiologic.RLock]] = {}
|
|
93
|
+
self._locks_lock = aiologic.RLock()
|
|
94
|
+
|
|
95
|
+
async def _get_lock(
|
|
96
|
+
self, ctx_type: type[BaseModel], context_id: UUID
|
|
97
|
+
) -> aiologic.RLock:
|
|
98
|
+
"""Get or create a lock for the given (ctx_type, context_id) pair."""
|
|
99
|
+
async with self._locks_lock:
|
|
100
|
+
if ctx_type not in self._locks:
|
|
101
|
+
self._locks[ctx_type] = {}
|
|
102
|
+
if context_id not in self._locks[ctx_type]:
|
|
103
|
+
self._locks[ctx_type][context_id] = aiologic.RLock()
|
|
104
|
+
return self._locks[ctx_type][context_id]
|
|
105
|
+
|
|
106
|
+
@asynccontextmanager
|
|
107
|
+
async def get_context[ContextTypeT: BaseModel]( # pyright: ignore [reportReturnType]
|
|
108
|
+
self, ctx_type: type[ContextTypeT], context_id: UUID
|
|
109
|
+
) -> AsyncGenerator[ContextTypeT]:
|
|
110
|
+
lock = await self._get_lock(ctx_type, context_id)
|
|
111
|
+
async with lock:
|
|
112
|
+
type_store = self._states.get(ctx_type, {})
|
|
113
|
+
if context_id not in type_store:
|
|
114
|
+
raise KeyError(f"No {ctx_type.__name__} context known for {context_id}")
|
|
115
|
+
yield type_store[context_id] # pyright: ignore [reportReturnType]
|
|
116
|
+
|
|
117
|
+
async def store_context[ContextTypeT: BaseModel](
|
|
118
|
+
self,
|
|
119
|
+
ctx_type: type[ContextTypeT],
|
|
120
|
+
context_id: UUID,
|
|
121
|
+
context: ContextTypeT,
|
|
122
|
+
) -> None:
|
|
123
|
+
"""Store a new context.
|
|
124
|
+
|
|
125
|
+
Thread-safe creation/replacement of context.
|
|
126
|
+
Acquires the fine-grained lock for this (ctx_type, context_id) pair.
|
|
127
|
+
"""
|
|
128
|
+
lock = await self._get_lock(ctx_type, context_id)
|
|
129
|
+
async with lock:
|
|
130
|
+
if ctx_type not in self._states:
|
|
131
|
+
self._states[ctx_type] = {}
|
|
132
|
+
self._states[ctx_type][context_id] = context
|
|
133
|
+
|
|
134
|
+
async def get_context_snapshot[ContextTypeT: BaseModel](
|
|
135
|
+
self,
|
|
136
|
+
ctx_type: type[ContextTypeT],
|
|
137
|
+
context_id: UUID,
|
|
138
|
+
) -> ContextTypeT:
|
|
139
|
+
type_store = self._states.get(ctx_type, {})
|
|
140
|
+
if context_id not in type_store:
|
|
141
|
+
raise KeyError(f"No {ctx_type.__name__} context known for {context_id}")
|
|
142
|
+
context = type_store[context_id]
|
|
143
|
+
return context.model_copy(deep=True) # pyright: ignore [reportReturnType]
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
@register_provider(singleton=True)
|
|
147
|
+
def context_storage_singleton() -> InMemoryContextStorage:
|
|
148
|
+
"""Singleton provider for the context storage.
|
|
149
|
+
|
|
150
|
+
Returns the same InMemoryContextStorage instance for the lifetime of the application.
|
|
151
|
+
One instance can hold all context types, keyed by (type, UUID).
|
|
152
|
+
|
|
153
|
+
This implementation uses aiologic.RLock which works seamlessly in:
|
|
154
|
+
- Async servers (FastAPI): Non-blocking async synchronization
|
|
155
|
+
- Threaded servers: Thread-safe synchronization
|
|
156
|
+
- Hybrid environments: Multiple threads with event loops per thread
|
|
157
|
+
|
|
158
|
+
For multi-process deployments, replace with RedisContextStorage or another
|
|
159
|
+
distributed storage implementation.
|
|
160
|
+
"""
|
|
161
|
+
return InMemoryContextStorage()
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from collections.abc import AsyncIterator
|
|
3
|
+
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
|
4
|
+
from uuid import UUID
|
|
5
|
+
|
|
6
|
+
import aiologic
|
|
7
|
+
from pydantic import BaseModel
|
|
8
|
+
|
|
9
|
+
from cem.di import register_provider
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ContextStorage(ABC):
|
|
13
|
+
"""ContextStorage interface.
|
|
14
|
+
|
|
15
|
+
This interface allows for different storage backends (in-memory, Redis, etc.)
|
|
16
|
+
to be used for storing context data, keyed by both context type and UUID.
|
|
17
|
+
|
|
18
|
+
One storage instance can hold multiple context types simultaneously.
|
|
19
|
+
Implementations must be thread-safe and async-safe.
|
|
20
|
+
|
|
21
|
+
get_context is an async context manager that yields the context while holding
|
|
22
|
+
a lock, ensuring safe modifications during the entire usage period.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
@abstractmethod
|
|
26
|
+
def get_context[ContextTypeT: BaseModel](
|
|
27
|
+
self, ctx_type: type[ContextTypeT], context_id: UUID
|
|
28
|
+
) -> AbstractAsyncContextManager[ContextTypeT]:
|
|
29
|
+
"""Get a context for the given type and context_id.
|
|
30
|
+
|
|
31
|
+
This is an async context manager that yields the context while holding a lock.
|
|
32
|
+
The lock is held until the context manager exits.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
ctx_type: The type of context to retrieve
|
|
36
|
+
context_id: The UUID identifying the context
|
|
37
|
+
|
|
38
|
+
Yields:
|
|
39
|
+
The context associated with this type and identifier
|
|
40
|
+
|
|
41
|
+
Raises:
|
|
42
|
+
KeyError: If the context does not exist
|
|
43
|
+
"""
|
|
44
|
+
...
|
|
45
|
+
|
|
46
|
+
@abstractmethod
|
|
47
|
+
async def store_context[ContextTypeT: BaseModel](
|
|
48
|
+
self, ctx_type: type[ContextTypeT], context_id: UUID, context: ContextTypeT
|
|
49
|
+
) -> None:
|
|
50
|
+
"""Store a new context.
|
|
51
|
+
|
|
52
|
+
This creates or replaces a context for the given type and context_id.
|
|
53
|
+
Thread-safe and async-safe.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
ctx_type: The type of context being stored
|
|
57
|
+
context_id: The UUID identifying the context
|
|
58
|
+
context: The context to store
|
|
59
|
+
"""
|
|
60
|
+
...
|
|
61
|
+
|
|
62
|
+
@abstractmethod
|
|
63
|
+
async def get_context_snapshot[ContextTypeT: BaseModel](
|
|
64
|
+
self, ctx_type: type[ContextTypeT], context_id: UUID
|
|
65
|
+
) -> ContextTypeT:
|
|
66
|
+
"""Get a read-only snapshot source without taking the context lock.
|
|
67
|
+
|
|
68
|
+
This is intended for event emission paths that must not wait behind a
|
|
69
|
+
long-running mutable context lock.
|
|
70
|
+
"""
|
|
71
|
+
...
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class InMemoryContextStorage(ContextStorage):
|
|
75
|
+
"""Unified in-memory storage for contexts that works in both async and threaded environments.
|
|
76
|
+
|
|
77
|
+
Uses aiologic.RLock for synchronization, which works seamlessly across:
|
|
78
|
+
- Pure async servers (FastAPI with single event loop)
|
|
79
|
+
- Threaded servers with multiple threads
|
|
80
|
+
- Hybrid environments (multiple threads each with their own event loop)
|
|
81
|
+
|
|
82
|
+
Contexts are stored in a two-level dict keyed first by context type, then by UUID.
|
|
83
|
+
Fine-grained per-(type, id) locking allows concurrent access to different contexts.
|
|
84
|
+
|
|
85
|
+
Note: This implementation is single-process only. For multi-process
|
|
86
|
+
deployments (e.g., gunicorn with multiple processes), consider using
|
|
87
|
+
a distributed storage backend like Redis.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
def __init__(self) -> None:
|
|
91
|
+
self._states: dict[type[BaseModel], dict[UUID, BaseModel]] = {}
|
|
92
|
+
self._locks: dict[type[BaseModel], dict[UUID, aiologic.RLock]] = {}
|
|
93
|
+
self._locks_lock = aiologic.RLock()
|
|
94
|
+
|
|
95
|
+
async def _get_lock(
|
|
96
|
+
self, ctx_type: type[BaseModel], context_id: UUID
|
|
97
|
+
) -> aiologic.RLock:
|
|
98
|
+
"""Get or create a lock for the given (ctx_type, context_id) pair."""
|
|
99
|
+
async with self._locks_lock:
|
|
100
|
+
if ctx_type not in self._locks:
|
|
101
|
+
self._locks[ctx_type] = {}
|
|
102
|
+
if context_id not in self._locks[ctx_type]:
|
|
103
|
+
self._locks[ctx_type][context_id] = aiologic.RLock()
|
|
104
|
+
return self._locks[ctx_type][context_id]
|
|
105
|
+
|
|
106
|
+
@asynccontextmanager
|
|
107
|
+
async def get_context[ContextTypeT: BaseModel]( # pyright: ignore [reportReturnType]
|
|
108
|
+
self, ctx_type: type[ContextTypeT], context_id: UUID
|
|
109
|
+
) -> AsyncIterator[ContextTypeT]:
|
|
110
|
+
lock = await self._get_lock(ctx_type, context_id)
|
|
111
|
+
async with lock:
|
|
112
|
+
type_store = self._states.get(ctx_type, {})
|
|
113
|
+
if context_id not in type_store:
|
|
114
|
+
raise KeyError(f"No {ctx_type.__name__} context known for {context_id}")
|
|
115
|
+
yield type_store[context_id] # pyright: ignore [reportReturnType]
|
|
116
|
+
|
|
117
|
+
async def store_context[ContextTypeT: BaseModel](
|
|
118
|
+
self,
|
|
119
|
+
ctx_type: type[ContextTypeT],
|
|
120
|
+
context_id: UUID,
|
|
121
|
+
context: ContextTypeT,
|
|
122
|
+
) -> None:
|
|
123
|
+
"""Store a new context.
|
|
124
|
+
|
|
125
|
+
Thread-safe creation/replacement of context.
|
|
126
|
+
Acquires the fine-grained lock for this (ctx_type, context_id) pair.
|
|
127
|
+
"""
|
|
128
|
+
lock = await self._get_lock(ctx_type, context_id)
|
|
129
|
+
async with lock:
|
|
130
|
+
if ctx_type not in self._states:
|
|
131
|
+
self._states[ctx_type] = {}
|
|
132
|
+
self._states[ctx_type][context_id] = context
|
|
133
|
+
|
|
134
|
+
async def get_context_snapshot[ContextTypeT: BaseModel](
|
|
135
|
+
self,
|
|
136
|
+
ctx_type: type[ContextTypeT],
|
|
137
|
+
context_id: UUID,
|
|
138
|
+
) -> ContextTypeT:
|
|
139
|
+
type_store = self._states.get(ctx_type, {})
|
|
140
|
+
if context_id not in type_store:
|
|
141
|
+
raise KeyError(f"No {ctx_type.__name__} context known for {context_id}")
|
|
142
|
+
context = type_store[context_id]
|
|
143
|
+
return context.model_copy(deep=True) # pyright: ignore [reportReturnType]
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
@register_provider(singleton=True)
|
|
147
|
+
def context_storage_singleton() -> InMemoryContextStorage:
|
|
148
|
+
"""Singleton provider for the context storage.
|
|
149
|
+
|
|
150
|
+
Returns the same InMemoryContextStorage instance for the lifetime of the application.
|
|
151
|
+
One instance can hold all context types, keyed by (type, UUID).
|
|
152
|
+
|
|
153
|
+
This implementation uses aiologic.RLock which works seamlessly in:
|
|
154
|
+
- Async servers (FastAPI): Non-blocking async synchronization
|
|
155
|
+
- Threaded servers: Thread-safe synchronization
|
|
156
|
+
- Hybrid environments: Multiple threads with event loops per thread
|
|
157
|
+
|
|
158
|
+
For multi-process deployments, replace with RedisContextStorage or another
|
|
159
|
+
distributed storage implementation.
|
|
160
|
+
"""
|
|
161
|
+
return InMemoryContextStorage()
|
|
@@ -0,0 +1,655 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import functools
|
|
3
|
+
import inspect
|
|
4
|
+
import logging
|
|
5
|
+
import sys
|
|
6
|
+
import typing
|
|
7
|
+
from collections.abc import Callable, Coroutine
|
|
8
|
+
from contextlib import (
|
|
9
|
+
AbstractAsyncContextManager,
|
|
10
|
+
AbstractContextManager,
|
|
11
|
+
contextmanager,
|
|
12
|
+
)
|
|
13
|
+
from typing import (
|
|
14
|
+
Any,
|
|
15
|
+
TypeVar,
|
|
16
|
+
overload,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
from dependency_injector import containers, providers
|
|
20
|
+
|
|
21
|
+
from wepositive_di.providers import (
|
|
22
|
+
AsyncCMFactory as _AsyncCMFactory,
|
|
23
|
+
)
|
|
24
|
+
from wepositive_di.providers import (
|
|
25
|
+
AsyncSingletonCMFactory as _AsyncSingletonCMFactory,
|
|
26
|
+
)
|
|
27
|
+
from wepositive_di.providers import (
|
|
28
|
+
CMFactory as _CMFactory,
|
|
29
|
+
)
|
|
30
|
+
from wepositive_di.providers import (
|
|
31
|
+
SyncSingletonCMFactory as _SyncSingletonCMFactory,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
_registered_modules: set[Any] = set()
|
|
37
|
+
_provider_overrides: dict[str, providers.Provider[Any]] = {}
|
|
38
|
+
_context_manager_providers: set[str] = set()
|
|
39
|
+
|
|
40
|
+
registry = containers.DynamicContainer()
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _detect_lifecycle(
|
|
44
|
+
func: Callable[..., Any],
|
|
45
|
+
cm_qualname: str,
|
|
46
|
+
enter_attr: str,
|
|
47
|
+
exit_attr: str,
|
|
48
|
+
cache_attr: str,
|
|
49
|
+
) -> bool:
|
|
50
|
+
cached = getattr(func, cache_attr, None)
|
|
51
|
+
if cached is not None:
|
|
52
|
+
return cached # type: ignore[return-value]
|
|
53
|
+
|
|
54
|
+
code = getattr(func, "__code__", None)
|
|
55
|
+
if getattr(code, "co_qualname", None) == cm_qualname:
|
|
56
|
+
result = True
|
|
57
|
+
else:
|
|
58
|
+
result = False
|
|
59
|
+
try:
|
|
60
|
+
hints = typing.get_type_hints(func)
|
|
61
|
+
return_type = hints.get("return")
|
|
62
|
+
if return_type is not None:
|
|
63
|
+
result = hasattr(return_type, enter_attr) and hasattr(
|
|
64
|
+
return_type, exit_attr
|
|
65
|
+
)
|
|
66
|
+
except Exception: # noqa: BLE001
|
|
67
|
+
pass
|
|
68
|
+
|
|
69
|
+
try:
|
|
70
|
+
setattr(func, cache_attr, result)
|
|
71
|
+
except (AttributeError, TypeError):
|
|
72
|
+
pass
|
|
73
|
+
return result
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _is_async_lifecycle_annotation(func: Callable[..., Any]) -> bool:
|
|
77
|
+
"""Return True if *func* produces an async context manager.
|
|
78
|
+
|
|
79
|
+
Detects two patterns:
|
|
80
|
+
1. Functions decorated with @asynccontextmanager — identified via the code
|
|
81
|
+
object's co_qualname, which @wraps cannot change.
|
|
82
|
+
2. Functions whose return type annotation has `__aenter__`/`__aexit__` (either
|
|
83
|
+
directly on the class, or proxied through a generic alias to its origin).
|
|
84
|
+
"""
|
|
85
|
+
return _detect_lifecycle(
|
|
86
|
+
func,
|
|
87
|
+
"asynccontextmanager.<locals>.helper",
|
|
88
|
+
"__aenter__",
|
|
89
|
+
"__aexit__",
|
|
90
|
+
"_di_is_async_lifecycle",
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _is_sync_lifecycle_annotation(func: Callable[..., Any]) -> bool:
|
|
95
|
+
"""Return True if *func* produces a sync context manager.
|
|
96
|
+
|
|
97
|
+
Detects two patterns:
|
|
98
|
+
1. Functions decorated with @contextmanager — identified via the code
|
|
99
|
+
object's co_qualname, which @wraps cannot change.
|
|
100
|
+
2. Functions whose return type annotation has `__enter__`/`__exit__` (either
|
|
101
|
+
directly on the class, or proxied through a generic alias to its origin).
|
|
102
|
+
"""
|
|
103
|
+
return _detect_lifecycle(
|
|
104
|
+
func,
|
|
105
|
+
"contextmanager.<locals>.helper",
|
|
106
|
+
"__enter__",
|
|
107
|
+
"__exit__",
|
|
108
|
+
"_di_is_sync_lifecycle",
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _lookup_provider(name: str) -> Any:
|
|
113
|
+
"""Return the provider callable for *name*, checking overrides first."""
|
|
114
|
+
if name in _provider_overrides:
|
|
115
|
+
return _provider_overrides[name]
|
|
116
|
+
return getattr(registry, name)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _uses_context_manager(provider_name: str) -> bool:
|
|
120
|
+
return provider_name in _context_manager_providers
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _resolve_deps_sync(
|
|
124
|
+
sig: inspect.Signature, provider_name: str, func_name: str
|
|
125
|
+
) -> dict[str, Any]:
|
|
126
|
+
"""Resolve Depends[...] markers in *sig* synchronously.
|
|
127
|
+
|
|
128
|
+
Raises RuntimeError if any dependency resolves to a coroutine — sync
|
|
129
|
+
providers cannot have async dependencies.
|
|
130
|
+
"""
|
|
131
|
+
bound = sig.bind_partial()
|
|
132
|
+
bound.apply_defaults()
|
|
133
|
+
for param_name in list(bound.arguments.keys()):
|
|
134
|
+
value = bound.arguments[param_name]
|
|
135
|
+
if isinstance(value, _DependsMarker):
|
|
136
|
+
result = _lookup_provider(value.name)()
|
|
137
|
+
if asyncio.iscoroutine(result):
|
|
138
|
+
result.close()
|
|
139
|
+
raise RuntimeError(
|
|
140
|
+
f"Cannot resolve async dependency '{param_name}' in sync provider "
|
|
141
|
+
f"'{provider_name}'. Sync providers cannot have async dependencies. "
|
|
142
|
+
f"Make your provider async instead: async def {func_name}(...)"
|
|
143
|
+
)
|
|
144
|
+
bound.arguments[param_name] = result
|
|
145
|
+
return bound.arguments
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
async def _resolve_deps_async(sig: inspect.Signature) -> dict[str, Any]:
|
|
149
|
+
"""Resolve Depends[...] markers in *sig*, awaiting any async results."""
|
|
150
|
+
bound = sig.bind_partial()
|
|
151
|
+
bound.apply_defaults()
|
|
152
|
+
for param_name in list(bound.arguments.keys()):
|
|
153
|
+
value = bound.arguments[param_name]
|
|
154
|
+
if isinstance(value, _DependsMarker):
|
|
155
|
+
result = _lookup_provider(value.name)()
|
|
156
|
+
if asyncio.iscoroutine(result):
|
|
157
|
+
result = await result
|
|
158
|
+
bound.arguments[param_name] = result
|
|
159
|
+
return bound.arguments
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _create_provider(
|
|
163
|
+
func: Callable[..., Any],
|
|
164
|
+
*,
|
|
165
|
+
provider_name: str | None = None,
|
|
166
|
+
singleton: bool = False,
|
|
167
|
+
context_manager: bool = False,
|
|
168
|
+
) -> providers.Provider[Any]:
|
|
169
|
+
"""Wrap *func* as the appropriate provider type, with Depends resolution.
|
|
170
|
+
|
|
171
|
+
Provider type is chosen based on the function's characteristics:
|
|
172
|
+
- Async context manager + singleton → providers.Resource
|
|
173
|
+
- Async context manager → AsyncCMFactory
|
|
174
|
+
- Sync context manager + singleton → providers.Resource
|
|
175
|
+
- Sync context manager → CMFactory
|
|
176
|
+
- Plain async function → providers.Coroutine
|
|
177
|
+
- Sync singleton → providers.Singleton
|
|
178
|
+
- Sync factory (default) → providers.Factory
|
|
179
|
+
"""
|
|
180
|
+
name = provider_name or func.__name__
|
|
181
|
+
sig = inspect.signature(func)
|
|
182
|
+
is_async_cm = context_manager and _is_async_lifecycle_annotation(func)
|
|
183
|
+
is_sync_cm = context_manager and _is_sync_lifecycle_annotation(func)
|
|
184
|
+
is_async_func = inspect.iscoroutinefunction(func)
|
|
185
|
+
|
|
186
|
+
if context_manager and not (is_async_cm or is_sync_cm):
|
|
187
|
+
raise ValueError(
|
|
188
|
+
f"Provider '{name}' was registered with context_manager=True, "
|
|
189
|
+
"but it does not return a supported sync or async context manager."
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
if is_async_cm:
|
|
193
|
+
|
|
194
|
+
async def async_cm_wrapper():
|
|
195
|
+
return func(**(await _resolve_deps_async(sig)))
|
|
196
|
+
|
|
197
|
+
if singleton:
|
|
198
|
+
return _AsyncSingletonCMFactory(async_cm_wrapper)
|
|
199
|
+
return _AsyncCMFactory(async_cm_wrapper)
|
|
200
|
+
|
|
201
|
+
elif is_sync_cm:
|
|
202
|
+
|
|
203
|
+
def sync_cm_wrapper():
|
|
204
|
+
return func(**_resolve_deps_sync(sig, name, func.__name__))
|
|
205
|
+
|
|
206
|
+
if singleton:
|
|
207
|
+
return _SyncSingletonCMFactory(sync_cm_wrapper)
|
|
208
|
+
return _CMFactory(sync_cm_wrapper) # type: ignore[return-value]
|
|
209
|
+
|
|
210
|
+
elif is_async_func:
|
|
211
|
+
|
|
212
|
+
async def async_func_wrapper():
|
|
213
|
+
return await func(**(await _resolve_deps_async(sig)))
|
|
214
|
+
|
|
215
|
+
return providers.Coroutine(async_func_wrapper)
|
|
216
|
+
|
|
217
|
+
else:
|
|
218
|
+
|
|
219
|
+
def sync_wrapper():
|
|
220
|
+
return func(**_resolve_deps_sync(sig, name, func.__name__))
|
|
221
|
+
|
|
222
|
+
if singleton:
|
|
223
|
+
return providers.Singleton(sync_wrapper)
|
|
224
|
+
return providers.Factory(sync_wrapper) # type: ignore[return-value]
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def register_provider(
|
|
228
|
+
name: str | None = None,
|
|
229
|
+
singleton: bool = False,
|
|
230
|
+
context_manager: bool = False,
|
|
231
|
+
):
|
|
232
|
+
def decorator(func: Callable[..., Any]):
|
|
233
|
+
"""Register a provider function (sync or async) in the registry.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
name: Optional name for the provider (defaults to function name)
|
|
237
|
+
singleton: If True, caches and reuses the first created instance.
|
|
238
|
+
context_manager: If True, enter and exit the provider's context manager
|
|
239
|
+
when resolving dependencies. Context-manager handling is opt-in.
|
|
240
|
+
"""
|
|
241
|
+
provider_name = name or func.__name__
|
|
242
|
+
is_async_func = inspect.iscoroutinefunction(func)
|
|
243
|
+
|
|
244
|
+
if is_async_func and not context_manager and singleton:
|
|
245
|
+
raise ValueError(
|
|
246
|
+
f"Async provider '{provider_name}' cannot be a singleton. "
|
|
247
|
+
f"The dependency-injector library doesn't support singleton caching for Coroutine providers. "
|
|
248
|
+
f"Make your provider a sync function instead: def {func.__name__}(...)"
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
provider = _create_provider(
|
|
252
|
+
func,
|
|
253
|
+
provider_name=provider_name,
|
|
254
|
+
singleton=singleton,
|
|
255
|
+
context_manager=context_manager,
|
|
256
|
+
)
|
|
257
|
+
setattr(registry, provider_name, provider)
|
|
258
|
+
if context_manager:
|
|
259
|
+
_context_manager_providers.add(provider_name)
|
|
260
|
+
else:
|
|
261
|
+
_context_manager_providers.discard(provider_name)
|
|
262
|
+
|
|
263
|
+
module = inspect.getmodule(func)
|
|
264
|
+
if module is not None: # pragma: no branch
|
|
265
|
+
_registered_modules.add(module)
|
|
266
|
+
return func
|
|
267
|
+
|
|
268
|
+
return decorator
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def setup(
|
|
272
|
+
overrides: dict[Callable[..., Any] | str, Callable[..., Any]] | None = None,
|
|
273
|
+
):
|
|
274
|
+
"""Wire the dependency injection system.
|
|
275
|
+
|
|
276
|
+
Args:
|
|
277
|
+
overrides: Optional dictionary of provider overrides to apply before wiring.
|
|
278
|
+
Maps original providers to their override implementations.
|
|
279
|
+
|
|
280
|
+
Usage:
|
|
281
|
+
|
|
282
|
+
```python
|
|
283
|
+
setup()
|
|
284
|
+
|
|
285
|
+
def redis_storage() -> ContextStorage:
|
|
286
|
+
return RedisContextStorage()
|
|
287
|
+
|
|
288
|
+
setup(overrides={context_storage_singleton: redis_storage})
|
|
289
|
+
```
|
|
290
|
+
"""
|
|
291
|
+
if overrides:
|
|
292
|
+
for original, override_func in overrides.items():
|
|
293
|
+
provider_name = original if isinstance(original, str) else original.__name__
|
|
294
|
+
_provider_overrides[provider_name] = _create_provider(
|
|
295
|
+
override_func,
|
|
296
|
+
provider_name=provider_name,
|
|
297
|
+
context_manager=_uses_context_manager(provider_name),
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
registry.wire(modules=list(_registered_modules))
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
@overload
|
|
304
|
+
def override_provider(
|
|
305
|
+
original: Callable[..., Any] | str,
|
|
306
|
+
) -> Callable[[Callable[..., Any]], Callable[..., Any]]: ...
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
@overload
|
|
310
|
+
def override_provider(
|
|
311
|
+
original: Callable[..., Any] | str,
|
|
312
|
+
override: Callable[..., Any],
|
|
313
|
+
) -> None: ...
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def override_provider(
|
|
317
|
+
original: Callable[..., Any] | str,
|
|
318
|
+
override: Callable[..., Any] | None = None,
|
|
319
|
+
) -> None | Callable[[Callable[..., Any]], Callable[..., Any]]:
|
|
320
|
+
"""Override a provider with a new implementation.
|
|
321
|
+
|
|
322
|
+
Can be used as a function or a decorator.
|
|
323
|
+
|
|
324
|
+
Args:
|
|
325
|
+
original: The original provider function or its name
|
|
326
|
+
override: The new provider function (when used as a function call)
|
|
327
|
+
|
|
328
|
+
Returns:
|
|
329
|
+
None when used as a function, decorator when used as @override_provider(original)
|
|
330
|
+
|
|
331
|
+
Usage:
|
|
332
|
+
|
|
333
|
+
```python
|
|
334
|
+
@register_provider()
|
|
335
|
+
async def config() -> Config:
|
|
336
|
+
return Config()
|
|
337
|
+
|
|
338
|
+
async def prod_config() -> Config:
|
|
339
|
+
return Config(db_url="production")
|
|
340
|
+
|
|
341
|
+
override_provider(config, prod_config)
|
|
342
|
+
```
|
|
343
|
+
|
|
344
|
+
```python
|
|
345
|
+
@override_provider(config)
|
|
346
|
+
async def prod_config() -> Config:
|
|
347
|
+
return Config(db_url="production")
|
|
348
|
+
```
|
|
349
|
+
"""
|
|
350
|
+
provider_name = original if isinstance(original, str) else original.__name__
|
|
351
|
+
|
|
352
|
+
# Used as @override_provider(original)
|
|
353
|
+
if override is None:
|
|
354
|
+
|
|
355
|
+
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
356
|
+
_provider_overrides[provider_name] = _create_provider(
|
|
357
|
+
func,
|
|
358
|
+
provider_name=provider_name,
|
|
359
|
+
context_manager=_uses_context_manager(provider_name),
|
|
360
|
+
)
|
|
361
|
+
return func
|
|
362
|
+
|
|
363
|
+
return decorator
|
|
364
|
+
|
|
365
|
+
# Used as override_provider(original, override_func)
|
|
366
|
+
_provider_overrides[provider_name] = _create_provider(
|
|
367
|
+
override,
|
|
368
|
+
provider_name=provider_name,
|
|
369
|
+
context_manager=_uses_context_manager(provider_name),
|
|
370
|
+
)
|
|
371
|
+
return None
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
def clear_overrides() -> None:
|
|
375
|
+
"""Clear all provider overrides."""
|
|
376
|
+
_provider_overrides.clear()
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
@contextmanager
|
|
380
|
+
def provider_overrides(
|
|
381
|
+
overrides: dict[Callable[..., Any] | str, Callable[..., Any]],
|
|
382
|
+
):
|
|
383
|
+
"""Context manager to temporarily override providers for testing.
|
|
384
|
+
|
|
385
|
+
Args:
|
|
386
|
+
overrides: Dictionary mapping original providers to their overrides
|
|
387
|
+
|
|
388
|
+
Usage:
|
|
389
|
+
|
|
390
|
+
```python
|
|
391
|
+
async def test_config() -> Config:
|
|
392
|
+
return Config(sqlalchemy_db_uri=SecretStr("sqlite:///:memory:"))
|
|
393
|
+
|
|
394
|
+
with provider_overrides({config: test_config}):
|
|
395
|
+
result = await my_function()
|
|
396
|
+
```
|
|
397
|
+
"""
|
|
398
|
+
# Save current state
|
|
399
|
+
old_overrides = _provider_overrides.copy()
|
|
400
|
+
|
|
401
|
+
# Apply new overrides
|
|
402
|
+
for original, override in overrides.items():
|
|
403
|
+
provider_name = original if isinstance(original, str) else original.__name__
|
|
404
|
+
_provider_overrides[provider_name] = _create_provider(
|
|
405
|
+
override,
|
|
406
|
+
provider_name=provider_name,
|
|
407
|
+
context_manager=_uses_context_manager(provider_name),
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
try:
|
|
411
|
+
yield
|
|
412
|
+
finally:
|
|
413
|
+
# Restore previous state
|
|
414
|
+
_provider_overrides.clear()
|
|
415
|
+
_provider_overrides.update(old_overrides)
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
class _DependsMarker:
|
|
419
|
+
"""Marker class for lazy dependency injection."""
|
|
420
|
+
|
|
421
|
+
def __init__(self, name: str):
|
|
422
|
+
self.name = name
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
class _DependsType:
|
|
426
|
+
"""Subscriptable type for Depends[func] syntax."""
|
|
427
|
+
|
|
428
|
+
def __getitem__(self, func: Callable[..., Any] | str) -> Any:
|
|
429
|
+
"""Create a dependency marker using subscript notation.
|
|
430
|
+
|
|
431
|
+
Usage: def my_func(config: Config = Depends[config]):
|
|
432
|
+
"""
|
|
433
|
+
if isinstance(func, str):
|
|
434
|
+
name = func
|
|
435
|
+
else:
|
|
436
|
+
name = func.__name__
|
|
437
|
+
|
|
438
|
+
return _DependsMarker(name)
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
Depends = _DependsType()
|
|
442
|
+
|
|
443
|
+
T = TypeVar("T", bound=Callable[..., Any])
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
@contextmanager
|
|
447
|
+
def _create_event_loop(param_name: str, func_name: str):
|
|
448
|
+
has_running_loop = False
|
|
449
|
+
try:
|
|
450
|
+
asyncio.get_running_loop()
|
|
451
|
+
has_running_loop = True
|
|
452
|
+
except RuntimeError:
|
|
453
|
+
# No running loop - this is fine
|
|
454
|
+
pass
|
|
455
|
+
|
|
456
|
+
if has_running_loop:
|
|
457
|
+
# Can't use run_until_complete in an already-running loop
|
|
458
|
+
raise RuntimeError(
|
|
459
|
+
f"Cannot resolve async dependency '{param_name}' in sync function "
|
|
460
|
+
f"'{func_name}' from within an async context. "
|
|
461
|
+
f"Either make '{func_name}' async or call it from a sync context."
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
# No running loop - create one for this sync context
|
|
465
|
+
loop = asyncio.new_event_loop()
|
|
466
|
+
try:
|
|
467
|
+
yield loop
|
|
468
|
+
finally:
|
|
469
|
+
loop.close()
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
@overload
|
|
473
|
+
def inject(
|
|
474
|
+
func: Callable[..., Coroutine[Any, Any, Any]],
|
|
475
|
+
) -> Callable[..., Coroutine[Any, Any, Any]]: ...
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
@overload
|
|
479
|
+
def inject(func: Callable[..., Any]) -> Callable[..., Any]: ...
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
def inject[T: Callable[..., Any]](func: T) -> T:
|
|
483
|
+
"""Decorator that resolves Depends markers in function arguments.
|
|
484
|
+
|
|
485
|
+
Works with both sync and async functions.
|
|
486
|
+
|
|
487
|
+
The decorator:
|
|
488
|
+
|
|
489
|
+
- Inspects the function signature for `_DependsMarker` defaults.
|
|
490
|
+
- At call time, resolves each marker by calling the registry provider.
|
|
491
|
+
- Handles context manager providers transparently: enters the context manager,
|
|
492
|
+
passes the yielded value to the function, and exits on completion.
|
|
493
|
+
- Returns the appropriate wrapper based on the function type.
|
|
494
|
+
|
|
495
|
+
Provider types and how they are resolved:
|
|
496
|
+
|
|
497
|
+
- AsyncCMFactory: await coroutine, get context manager, await `__aenter__`.
|
|
498
|
+
- CMFactory: get context manager, call `__enter__`.
|
|
499
|
+
- providers.Coroutine: await result.
|
|
500
|
+
- providers.Factory / providers.Singleton: use result directly.
|
|
501
|
+
|
|
502
|
+
Usage:
|
|
503
|
+
|
|
504
|
+
```python
|
|
505
|
+
@inject
|
|
506
|
+
def my_func(config: Config = Depends[config]):
|
|
507
|
+
return config.value
|
|
508
|
+
|
|
509
|
+
@inject
|
|
510
|
+
async def my_async_func(session: AsyncSession = Depends[async_session]):
|
|
511
|
+
return session.query(...)
|
|
512
|
+
```
|
|
513
|
+
|
|
514
|
+
"""
|
|
515
|
+
sig = inspect.signature(func)
|
|
516
|
+
dependant_is_async = inspect.iscoroutinefunction(
|
|
517
|
+
func
|
|
518
|
+
) or inspect.isasyncgenfunction(func)
|
|
519
|
+
|
|
520
|
+
async def _resolve_dependencies_async(
|
|
521
|
+
args: tuple[Any, ...], kwargs: dict[str, Any]
|
|
522
|
+
) -> tuple[
|
|
523
|
+
dict[str, Any],
|
|
524
|
+
set[AbstractAsyncContextManager[Any]],
|
|
525
|
+
set[AbstractContextManager[Any]],
|
|
526
|
+
]:
|
|
527
|
+
bound = sig.bind_partial(*args, **kwargs)
|
|
528
|
+
bound.apply_defaults()
|
|
529
|
+
sync_lifecycles_to_cleanup: set[AbstractContextManager[Any]] = set()
|
|
530
|
+
async_lifecycles_to_cleanup: set[AbstractAsyncContextManager[Any]] = set()
|
|
531
|
+
|
|
532
|
+
for param_name in list(bound.arguments.keys()):
|
|
533
|
+
value = bound.arguments[param_name]
|
|
534
|
+
if not isinstance(value, _DependsMarker):
|
|
535
|
+
continue
|
|
536
|
+
provider = _lookup_provider(value.name)
|
|
537
|
+
result = provider()
|
|
538
|
+
|
|
539
|
+
if isinstance(provider, _AsyncSingletonCMFactory):
|
|
540
|
+
result = await result
|
|
541
|
+
elif isinstance(provider, _SyncSingletonCMFactory):
|
|
542
|
+
pass # sync Resource returns the cached value directly
|
|
543
|
+
elif isinstance(provider, _AsyncCMFactory):
|
|
544
|
+
cm = await result # await async wrapper to get the CM object
|
|
545
|
+
async_lifecycles_to_cleanup.add(cm)
|
|
546
|
+
result = await cm.__aenter__()
|
|
547
|
+
elif isinstance(provider, _CMFactory):
|
|
548
|
+
cm = result # sync wrapper returns the CM directly
|
|
549
|
+
sync_lifecycles_to_cleanup.add(cm)
|
|
550
|
+
result = cm.__enter__()
|
|
551
|
+
elif isinstance(provider, providers.Coroutine):
|
|
552
|
+
result = await result
|
|
553
|
+
|
|
554
|
+
bound.arguments[param_name] = result
|
|
555
|
+
|
|
556
|
+
return bound.arguments, async_lifecycles_to_cleanup, sync_lifecycles_to_cleanup
|
|
557
|
+
|
|
558
|
+
def _resolve_dependencies_sync(
|
|
559
|
+
args: tuple[Any, ...], kwargs: dict[str, Any]
|
|
560
|
+
) -> tuple[
|
|
561
|
+
dict[str, Any],
|
|
562
|
+
set[AbstractAsyncContextManager[Any]],
|
|
563
|
+
set[AbstractContextManager[Any]],
|
|
564
|
+
]:
|
|
565
|
+
bound = sig.bind_partial(*args, **kwargs)
|
|
566
|
+
bound.apply_defaults()
|
|
567
|
+
sync_lifecycles_to_cleanup: set[AbstractContextManager[Any]] = set()
|
|
568
|
+
async_lifecycles_to_cleanup: set[AbstractAsyncContextManager[Any]] = set()
|
|
569
|
+
|
|
570
|
+
for param_name in list(bound.arguments.keys()):
|
|
571
|
+
value = bound.arguments[param_name]
|
|
572
|
+
if not isinstance(value, _DependsMarker):
|
|
573
|
+
continue
|
|
574
|
+
provider = _lookup_provider(value.name)
|
|
575
|
+
result = provider()
|
|
576
|
+
|
|
577
|
+
if isinstance(provider, _AsyncSingletonCMFactory):
|
|
578
|
+
with _create_event_loop(param_name, func.__name__) as loop:
|
|
579
|
+
result = loop.run_until_complete(result)
|
|
580
|
+
elif isinstance(provider, _SyncSingletonCMFactory):
|
|
581
|
+
pass # sync Resource returns the cached value directly
|
|
582
|
+
elif isinstance(provider, _AsyncCMFactory):
|
|
583
|
+
with _create_event_loop(param_name, func.__name__) as loop:
|
|
584
|
+
cm = loop.run_until_complete(result) # await async wrapper → CM
|
|
585
|
+
async_lifecycles_to_cleanup.add(cm)
|
|
586
|
+
result = loop.run_until_complete(cm.__aenter__())
|
|
587
|
+
elif isinstance(provider, _CMFactory):
|
|
588
|
+
cm = result
|
|
589
|
+
sync_lifecycles_to_cleanup.add(cm)
|
|
590
|
+
result = cm.__enter__()
|
|
591
|
+
elif isinstance(provider, providers.Coroutine):
|
|
592
|
+
with _create_event_loop(param_name, func.__name__) as loop:
|
|
593
|
+
result = loop.run_until_complete(result)
|
|
594
|
+
|
|
595
|
+
bound.arguments[param_name] = result
|
|
596
|
+
|
|
597
|
+
return bound.arguments, async_lifecycles_to_cleanup, sync_lifecycles_to_cleanup
|
|
598
|
+
|
|
599
|
+
if dependant_is_async:
|
|
600
|
+
|
|
601
|
+
@functools.wraps(func)
|
|
602
|
+
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
603
|
+
(
|
|
604
|
+
resolved_args,
|
|
605
|
+
async_lifecycles_to_cleanup,
|
|
606
|
+
sync_lifecycles_to_cleanup,
|
|
607
|
+
) = await _resolve_dependencies_async(args, kwargs)
|
|
608
|
+
|
|
609
|
+
exc_info = (None, None, None)
|
|
610
|
+
try:
|
|
611
|
+
result = await func(**resolved_args)
|
|
612
|
+
return result
|
|
613
|
+
except Exception:
|
|
614
|
+
exc_info = sys.exc_info()
|
|
615
|
+
raise
|
|
616
|
+
finally:
|
|
617
|
+
suppressed = False
|
|
618
|
+
for cm in async_lifecycles_to_cleanup:
|
|
619
|
+
if await cm.__aexit__(*exc_info):
|
|
620
|
+
suppressed = True
|
|
621
|
+
for cm in sync_lifecycles_to_cleanup:
|
|
622
|
+
if cm.__exit__(*exc_info):
|
|
623
|
+
suppressed = True
|
|
624
|
+
if exc_info[0] is not None and not suppressed:
|
|
625
|
+
raise exc_info[1].with_traceback(exc_info[2]) # type: ignore[union-attr]
|
|
626
|
+
|
|
627
|
+
return async_wrapper # type: ignore
|
|
628
|
+
else:
|
|
629
|
+
|
|
630
|
+
@functools.wraps(func)
|
|
631
|
+
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
632
|
+
resolved_args, async_lifecycles_to_cleanup, sync_lifecycles_to_cleanup = (
|
|
633
|
+
_resolve_dependencies_sync(args, kwargs)
|
|
634
|
+
)
|
|
635
|
+
exc_info = (None, None, None)
|
|
636
|
+
try:
|
|
637
|
+
result = func(**resolved_args)
|
|
638
|
+
return result
|
|
639
|
+
except Exception:
|
|
640
|
+
exc_info = sys.exc_info()
|
|
641
|
+
raise
|
|
642
|
+
finally:
|
|
643
|
+
suppressed = False
|
|
644
|
+
for cm in async_lifecycles_to_cleanup:
|
|
645
|
+
cm_name = getattr(cm, "__wrapped__", type(cm)).__name__
|
|
646
|
+
with _create_event_loop(cm_name, func.__name__) as loop:
|
|
647
|
+
if loop.run_until_complete(cm.__aexit__(*exc_info)):
|
|
648
|
+
suppressed = True
|
|
649
|
+
for cm in sync_lifecycles_to_cleanup:
|
|
650
|
+
if cm.__exit__(*exc_info):
|
|
651
|
+
suppressed = True
|
|
652
|
+
if exc_info[0] is not None and not suppressed:
|
|
653
|
+
raise exc_info[1].with_traceback(exc_info[2]) # type: ignore[union-attr]
|
|
654
|
+
|
|
655
|
+
return sync_wrapper # type: ignore
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from collections.abc import Coroutine
|
|
2
|
+
from contextlib import AbstractAsyncContextManager, AbstractContextManager
|
|
3
|
+
from typing import Any, TypeVar
|
|
4
|
+
|
|
5
|
+
from dependency_injector import providers
|
|
6
|
+
|
|
7
|
+
_T = TypeVar("_T")
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class CMFactory(providers.Factory[AbstractContextManager[_T]]):
|
|
11
|
+
"""Provider that creates a new sync context manager on every call.
|
|
12
|
+
|
|
13
|
+
The inject decorator detects this type and automatically calls `__enter__`,
|
|
14
|
+
passes the yielded value to the dependant function, then calls `__exit__`
|
|
15
|
+
with any exception raised — transparently to the caller.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class AsyncCMFactory(
|
|
20
|
+
providers.Factory[Coroutine[Any, Any, AbstractAsyncContextManager[_T]]]
|
|
21
|
+
):
|
|
22
|
+
"""Provider that creates a new async context manager on every call.
|
|
23
|
+
|
|
24
|
+
The wrapper function is async (to resolve async Depends), so calling this
|
|
25
|
+
provider returns a coroutine that resolves to the context manager object.
|
|
26
|
+
The inject decorator awaits it to get the CM, then calls `__aenter__`, passes
|
|
27
|
+
the yielded value to the dependant function, and calls `__aexit__` with any
|
|
28
|
+
exception raised — transparently to the caller.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class SyncSingletonCMFactory(providers.Resource[_T]):
|
|
33
|
+
"""Singleton sync context manager provider.
|
|
34
|
+
|
|
35
|
+
Inherits from providers.Resource so that the underlying context manager is
|
|
36
|
+
entered once, the yielded value is cached, and teardown is triggered by
|
|
37
|
+
registry.shutdown_resources(). Calling this provider returns the cached
|
|
38
|
+
value directly (no await needed).
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class AsyncSingletonCMFactory(providers.Resource[_T]):
|
|
43
|
+
"""Singleton async context manager provider.
|
|
44
|
+
|
|
45
|
+
Inherits from providers.Resource so that the underlying context manager is
|
|
46
|
+
entered once, the yielded value is cached, and teardown is triggered by
|
|
47
|
+
registry.shutdown_resources(). Calling this provider returns a coroutine on
|
|
48
|
+
the first call (before the value is cached); the inject decorator awaits it
|
|
49
|
+
to obtain the yielded value.
|
|
50
|
+
"""
|