anydi 0.22.1__py3-none-any.whl → 0.37.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- anydi/__init__.py +14 -14
- anydi/_container.py +811 -571
- anydi/_context.py +39 -281
- anydi/_provider.py +232 -0
- anydi/_types.py +49 -96
- anydi/_utils.py +108 -77
- anydi/ext/_utils.py +49 -28
- anydi/ext/django/__init__.py +9 -0
- anydi/ext/django/_container.py +18 -0
- anydi/ext/django/_settings.py +39 -0
- anydi/ext/django/_utils.py +128 -0
- anydi/ext/django/apps.py +85 -0
- anydi/ext/django/middleware.py +28 -0
- anydi/ext/django/ninja/__init__.py +16 -0
- anydi/ext/django/ninja/_operation.py +75 -0
- anydi/ext/django/ninja/_signature.py +64 -0
- anydi/ext/fastapi.py +11 -27
- anydi/ext/faststream.py +58 -0
- anydi/ext/pydantic_settings.py +48 -0
- anydi/ext/pytest_plugin.py +67 -41
- anydi/ext/starlette/middleware.py +2 -16
- {anydi-0.22.1.dist-info → anydi-0.37.4.dist-info}/METADATA +71 -21
- anydi-0.37.4.dist-info/RECORD +29 -0
- {anydi-0.22.1.dist-info → anydi-0.37.4.dist-info}/WHEEL +1 -1
- anydi-0.37.4.dist-info/entry_points.txt +2 -0
- anydi/_logger.py +0 -3
- anydi/_module.py +0 -124
- anydi/_scanner.py +0 -233
- anydi-0.22.1.dist-info/RECORD +0 -20
- anydi-0.22.1.dist-info/entry_points.txt +0 -3
- {anydi-0.22.1.dist-info → anydi-0.37.4.dist-info/licenses}/LICENSE +0 -0
anydi/_types.py
CHANGED
|
@@ -1,17 +1,16 @@
|
|
|
1
|
-
import
|
|
2
|
-
from dataclasses import dataclass
|
|
3
|
-
from functools import cached_property
|
|
4
|
-
from typing import Any, Callable, Type, TypeVar, Union
|
|
1
|
+
from __future__ import annotations
|
|
5
2
|
|
|
6
|
-
|
|
3
|
+
import inspect
|
|
4
|
+
from collections.abc import Iterable
|
|
5
|
+
from types import ModuleType
|
|
6
|
+
from typing import Annotated, Any, NamedTuple, Union
|
|
7
7
|
|
|
8
|
-
|
|
8
|
+
import wrapt
|
|
9
|
+
from typing_extensions import Literal, Self, TypeAlias
|
|
9
10
|
|
|
10
11
|
Scope = Literal["transient", "singleton", "request"]
|
|
11
12
|
|
|
12
|
-
|
|
13
|
-
AnyInterface: TypeAlias = Union[Type[Any], Annotated[Any, ...]]
|
|
14
|
-
Interface: TypeAlias = Type[T]
|
|
13
|
+
AnyInterface: TypeAlias = Union[type[Any], Annotated[Any, ...]]
|
|
15
14
|
|
|
16
15
|
|
|
17
16
|
class Marker:
|
|
@@ -19,97 +18,51 @@ class Marker:
|
|
|
19
18
|
|
|
20
19
|
__slots__ = ()
|
|
21
20
|
|
|
21
|
+
def __call__(self) -> Self:
|
|
22
|
+
return self
|
|
22
23
|
|
|
23
|
-
@dataclass(frozen=True)
|
|
24
|
-
class Provider:
|
|
25
|
-
"""Represents a provider object.
|
|
26
24
|
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
"""
|
|
25
|
+
def is_marker(obj: Any) -> bool:
|
|
26
|
+
"""Checks if an object is a marker."""
|
|
27
|
+
return isinstance(obj, Marker)
|
|
31
28
|
|
|
32
|
-
obj: Callable[..., Any]
|
|
33
|
-
scope: Scope
|
|
34
29
|
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
return get_full_qualname(self.obj)
|
|
51
|
-
|
|
52
|
-
@cached_property
|
|
53
|
-
def parameters(self) -> Mapping[str, inspect.Parameter]:
|
|
54
|
-
"""Returns the parameters of the provider as a mapping.
|
|
55
|
-
|
|
56
|
-
Returns:
|
|
57
|
-
The parameters of the provider.
|
|
58
|
-
"""
|
|
59
|
-
return get_signature(self.obj).parameters
|
|
60
|
-
|
|
61
|
-
@cached_property
|
|
62
|
-
def is_class(self) -> bool:
|
|
63
|
-
"""Checks if the provider object is a class.
|
|
64
|
-
|
|
65
|
-
Returns:
|
|
66
|
-
True if the provider object is a class, False otherwise.
|
|
67
|
-
"""
|
|
68
|
-
return inspect.isclass(self.obj)
|
|
69
|
-
|
|
70
|
-
@cached_property
|
|
71
|
-
def is_function(self) -> bool:
|
|
72
|
-
"""Checks if the provider object is a function.
|
|
73
|
-
|
|
74
|
-
Returns:
|
|
75
|
-
True if the provider object is a function, False otherwise.
|
|
76
|
-
"""
|
|
77
|
-
return (inspect.isfunction(self.obj) or inspect.ismethod(self.obj)) and not (
|
|
78
|
-
self.is_resource
|
|
79
|
-
)
|
|
80
|
-
|
|
81
|
-
@cached_property
|
|
82
|
-
def is_coroutine(self) -> bool:
|
|
83
|
-
"""Checks if the provider object is a coroutine function.
|
|
84
|
-
|
|
85
|
-
Returns:
|
|
86
|
-
True if the provider object is a coroutine function, False otherwise.
|
|
87
|
-
"""
|
|
88
|
-
return inspect.iscoroutinefunction(self.obj)
|
|
89
|
-
|
|
90
|
-
@cached_property
|
|
91
|
-
def is_generator(self) -> bool:
|
|
92
|
-
"""Checks if the provider object is a generator function.
|
|
93
|
-
|
|
94
|
-
Returns:
|
|
95
|
-
True if the provider object is a resource, False otherwise.
|
|
96
|
-
"""
|
|
97
|
-
return inspect.isgeneratorfunction(self.obj)
|
|
98
|
-
|
|
99
|
-
@cached_property
|
|
100
|
-
def is_async_generator(self) -> bool:
|
|
101
|
-
"""Checks if the provider object is an async generator function.
|
|
102
|
-
|
|
103
|
-
Returns:
|
|
104
|
-
True if the provider object is an async resource, False otherwise.
|
|
105
|
-
"""
|
|
106
|
-
return inspect.isasyncgenfunction(self.obj)
|
|
30
|
+
class Event:
|
|
31
|
+
"""Represents an event object."""
|
|
32
|
+
|
|
33
|
+
__slots__ = ()
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def is_event_type(obj: Any) -> bool:
|
|
37
|
+
"""Checks if an object is an event type."""
|
|
38
|
+
return inspect.isclass(obj) and issubclass(obj, Event)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class InstanceProxy(wrapt.ObjectProxy): # type: ignore[misc]
|
|
42
|
+
def __init__(self, wrapped: Any, *, interface: type[Any]) -> None:
|
|
43
|
+
super().__init__(wrapped)
|
|
44
|
+
self._self_interface = interface
|
|
107
45
|
|
|
108
46
|
@property
|
|
109
|
-
def
|
|
110
|
-
|
|
47
|
+
def interface(self) -> type[Any]:
|
|
48
|
+
return self._self_interface
|
|
49
|
+
|
|
50
|
+
def __getattribute__(self, item: str) -> Any:
|
|
51
|
+
if item in "interface":
|
|
52
|
+
return object.__getattribute__(self, item)
|
|
53
|
+
return object.__getattribute__(self, item)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class ProviderDecoratorArgs(NamedTuple):
|
|
57
|
+
scope: Scope
|
|
58
|
+
override: bool
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class Dependency(NamedTuple):
|
|
62
|
+
member: Any
|
|
63
|
+
module: ModuleType
|
|
64
|
+
|
|
111
65
|
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
return self.is_generator or self.is_async_generator
|
|
66
|
+
class InjectableDecoratorArgs(NamedTuple):
|
|
67
|
+
wrapped: bool
|
|
68
|
+
tags: Iterable[str] | None
|
anydi/_utils.py
CHANGED
|
@@ -1,87 +1,84 @@
|
|
|
1
1
|
"""Shared AnyDI utils module."""
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
3
5
|
import builtins
|
|
4
6
|
import functools
|
|
7
|
+
import importlib
|
|
5
8
|
import inspect
|
|
9
|
+
import re
|
|
6
10
|
import sys
|
|
7
|
-
from
|
|
8
|
-
|
|
9
|
-
from typing_extensions import Annotated, ParamSpec, get_origin
|
|
10
|
-
|
|
11
|
-
try:
|
|
12
|
-
import anyio # noqa
|
|
13
|
-
except ImportError:
|
|
14
|
-
anyio = None # type: ignore[assignment]
|
|
11
|
+
from types import TracebackType
|
|
12
|
+
from typing import Any, Callable, ForwardRef, TypeVar
|
|
15
13
|
|
|
14
|
+
import anyio
|
|
15
|
+
from typing_extensions import ParamSpec, Self, get_args, get_origin
|
|
16
16
|
|
|
17
17
|
T = TypeVar("T")
|
|
18
18
|
P = ParamSpec("P")
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
def get_full_qualname(obj: Any) -> str:
|
|
22
|
-
"""Get the fully qualified name of an object.
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
Args:
|
|
28
|
-
obj: The object for which to retrieve the fully qualified name.
|
|
22
|
+
"""Get the fully qualified name of an object."""
|
|
23
|
+
# Get module and qualname with defaults to handle non-types directly
|
|
24
|
+
module = getattr(obj, "__module__", type(obj).__module__)
|
|
25
|
+
qualname = getattr(obj, "__qualname__", type(obj).__qualname__)
|
|
29
26
|
|
|
30
|
-
Returns:
|
|
31
|
-
The fully qualified name of the object.
|
|
32
|
-
"""
|
|
33
27
|
origin = get_origin(obj)
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
28
|
+
# If origin exists, handle generics recursively
|
|
29
|
+
if origin:
|
|
30
|
+
args = ", ".join(get_full_qualname(arg) for arg in get_args(obj))
|
|
31
|
+
return f"{get_full_qualname(origin)}[{args}]"
|
|
32
|
+
|
|
33
|
+
# Substitute standard library prefixes for clarity
|
|
34
|
+
full_qualname = f"{module}.{qualname}"
|
|
35
|
+
return re.sub(
|
|
36
|
+
r"\b(builtins|typing|typing_extensions|collections\.abc|types)\.",
|
|
37
|
+
"",
|
|
38
|
+
full_qualname,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def is_builtin_type(tp: type[Any]) -> bool:
|
|
43
|
+
"""Check if the given type is a built-in type."""
|
|
44
|
+
return tp.__module__ == builtins.__name__
|
|
50
45
|
|
|
51
|
-
if module_name == builtins.__name__:
|
|
52
|
-
return qualname
|
|
53
|
-
return f"{module_name}.{qualname}"
|
|
54
46
|
|
|
47
|
+
def is_context_manager(obj: Any) -> bool:
|
|
48
|
+
"""Check if the given object is a context manager."""
|
|
49
|
+
return hasattr(obj, "__enter__") and hasattr(obj, "__exit__")
|
|
55
50
|
|
|
56
|
-
def is_builtin_type(tp: Type[Any]) -> bool:
|
|
57
|
-
"""
|
|
58
|
-
Check if the given type is a built-in type.
|
|
59
|
-
Args:
|
|
60
|
-
tp (type): The type to check.
|
|
61
|
-
Returns:
|
|
62
|
-
bool: True if the type is a built-in type, False otherwise.
|
|
63
|
-
"""
|
|
64
|
-
return tp.__module__ == builtins.__name__
|
|
65
51
|
|
|
52
|
+
def is_async_context_manager(obj: Any) -> bool:
|
|
53
|
+
"""Check if the given object is an async context manager."""
|
|
54
|
+
return hasattr(obj, "__aenter__") and hasattr(obj, "__aexit__")
|
|
66
55
|
|
|
67
|
-
@functools.lru_cache(maxsize=None)
|
|
68
|
-
def get_signature(obj: Callable[..., Any]) -> inspect.Signature:
|
|
69
|
-
"""Get the signature of a callable object.
|
|
70
56
|
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
57
|
+
def get_typed_annotation(
|
|
58
|
+
annotation: Any, globalns: dict[str, Any], module: Any = None
|
|
59
|
+
) -> Any:
|
|
60
|
+
"""Get the typed annotation of a callable object."""
|
|
61
|
+
if isinstance(annotation, str):
|
|
62
|
+
if sys.version_info >= (3, 10):
|
|
63
|
+
ref = ForwardRef(annotation, module=module)
|
|
64
|
+
else:
|
|
65
|
+
ref = ForwardRef(annotation)
|
|
66
|
+
annotation = ref._evaluate(globalns, globalns, recursive_guard=frozenset()) # noqa
|
|
67
|
+
return annotation
|
|
74
68
|
|
|
75
|
-
Args:
|
|
76
|
-
obj: The callable object to inspect.
|
|
77
69
|
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
""
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
70
|
+
def get_typed_parameters(obj: Callable[..., Any]) -> list[inspect.Parameter]:
|
|
71
|
+
"""Get the typed parameters of a callable object."""
|
|
72
|
+
globalns = getattr(obj, "__globals__", {})
|
|
73
|
+
module = getattr(obj, "__module__", None)
|
|
74
|
+
return [
|
|
75
|
+
parameter.replace(
|
|
76
|
+
annotation=get_typed_annotation(
|
|
77
|
+
parameter.annotation, globalns, module=module
|
|
78
|
+
)
|
|
79
|
+
)
|
|
80
|
+
for parameter in inspect.signature(obj).parameters.values()
|
|
81
|
+
]
|
|
85
82
|
|
|
86
83
|
|
|
87
84
|
async def run_async(
|
|
@@ -90,22 +87,56 @@ async def run_async(
|
|
|
90
87
|
*args: P.args,
|
|
91
88
|
**kwargs: P.kwargs,
|
|
92
89
|
) -> T:
|
|
93
|
-
"""Runs the given function asynchronously using the `anyio` library.
|
|
94
|
-
|
|
95
|
-
Args:
|
|
96
|
-
func: The function to run asynchronously.
|
|
97
|
-
args: The positional arguments to pass to the function.
|
|
98
|
-
kwargs: The keyword arguments to pass to the function.
|
|
90
|
+
"""Runs the given function asynchronously using the `anyio` library."""
|
|
91
|
+
return await anyio.to_thread.run_sync(functools.partial(func, *args, **kwargs))
|
|
99
92
|
|
|
100
|
-
Returns:
|
|
101
|
-
The result of the function.
|
|
102
93
|
|
|
103
|
-
|
|
104
|
-
ImportError: If the `anyio` library is not installed.
|
|
94
|
+
def import_string(dotted_path: str) -> Any:
|
|
105
95
|
"""
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
96
|
+
Import a module or a specific attribute from a module using its dotted string path.
|
|
97
|
+
"""
|
|
98
|
+
try:
|
|
99
|
+
module_path, _, attribute_name = dotted_path.rpartition(".")
|
|
100
|
+
if module_path:
|
|
101
|
+
module = importlib.import_module(module_path)
|
|
102
|
+
return getattr(module, attribute_name)
|
|
103
|
+
else:
|
|
104
|
+
return importlib.import_module(attribute_name)
|
|
105
|
+
except (ImportError, AttributeError) as exc:
|
|
106
|
+
raise ImportError(f"Cannot import '{dotted_path}': {exc}") from exc
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class AsyncRLock:
|
|
110
|
+
def __init__(self) -> None:
|
|
111
|
+
self._lock = anyio.Lock()
|
|
112
|
+
self._owner: anyio.TaskInfo | None = None
|
|
113
|
+
self._count = 0
|
|
114
|
+
|
|
115
|
+
async def acquire(self) -> None:
|
|
116
|
+
current_task = anyio.get_current_task()
|
|
117
|
+
if self._owner == current_task:
|
|
118
|
+
self._count += 1
|
|
119
|
+
else:
|
|
120
|
+
await self._lock.acquire()
|
|
121
|
+
self._owner = current_task
|
|
122
|
+
self._count = 1
|
|
123
|
+
|
|
124
|
+
def release(self) -> None:
|
|
125
|
+
if self._owner != anyio.get_current_task():
|
|
126
|
+
raise RuntimeError("Lock can only be released by the owner")
|
|
127
|
+
self._count -= 1
|
|
128
|
+
if self._count == 0:
|
|
129
|
+
self._owner = None
|
|
130
|
+
self._lock.release()
|
|
131
|
+
|
|
132
|
+
async def __aenter__(self) -> Self:
|
|
133
|
+
await self.acquire()
|
|
134
|
+
return self
|
|
135
|
+
|
|
136
|
+
async def __aexit__(
|
|
137
|
+
self,
|
|
138
|
+
exc_type: type[BaseException] | None,
|
|
139
|
+
exc_val: BaseException | None,
|
|
140
|
+
exc_tb: TracebackType | None,
|
|
141
|
+
) -> Any:
|
|
142
|
+
self.release()
|
anydi/ext/_utils.py
CHANGED
|
@@ -1,16 +1,21 @@
|
|
|
1
|
+
"""AnyDI FastAPI extension."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
1
5
|
import inspect
|
|
2
|
-
|
|
6
|
+
import logging
|
|
7
|
+
from typing import Annotated, Any, Callable
|
|
3
8
|
|
|
4
|
-
from typing_extensions import
|
|
9
|
+
from typing_extensions import get_args, get_origin
|
|
5
10
|
|
|
6
11
|
from anydi import Container
|
|
7
|
-
from anydi._logger import logger
|
|
8
12
|
from anydi._utils import get_full_qualname
|
|
9
13
|
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
10
16
|
|
|
11
17
|
class HasInterface:
|
|
12
|
-
|
|
13
|
-
self._interface: Any = None
|
|
18
|
+
_interface: Any = None
|
|
14
19
|
|
|
15
20
|
@property
|
|
16
21
|
def interface(self) -> Any:
|
|
@@ -23,36 +28,52 @@ class HasInterface:
|
|
|
23
28
|
self._interface = interface
|
|
24
29
|
|
|
25
30
|
|
|
26
|
-
def
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
interface, default = parameter.annotation, parameter.default
|
|
37
|
-
|
|
38
|
-
if get_origin(interface) is Annotated:
|
|
39
|
-
args = get_args(interface)
|
|
40
|
-
if len(args) == 2:
|
|
41
|
-
interface, default = args
|
|
42
|
-
elif len(args) == 3:
|
|
43
|
-
interface, metadata, default = args
|
|
44
|
-
interface = Annotated[interface, metadata]
|
|
31
|
+
def patch_annotated_parameter(parameter: inspect.Parameter) -> inspect.Parameter:
|
|
32
|
+
"""Patch an annotated parameter to resolve the default value."""
|
|
33
|
+
if not (
|
|
34
|
+
get_origin(parameter.annotation) is Annotated
|
|
35
|
+
and parameter.default is inspect.Parameter.empty
|
|
36
|
+
):
|
|
37
|
+
return parameter
|
|
38
|
+
|
|
39
|
+
tp_origin, *tp_metadata = get_args(parameter.annotation)
|
|
40
|
+
default = tp_metadata[-1]
|
|
45
41
|
|
|
46
42
|
if not isinstance(default, HasInterface):
|
|
47
|
-
return
|
|
43
|
+
return parameter
|
|
48
44
|
|
|
49
|
-
|
|
45
|
+
if (num := len(tp_metadata[:-1])) == 0:
|
|
46
|
+
interface = tp_origin
|
|
47
|
+
elif num == 1:
|
|
48
|
+
interface = Annotated[tp_origin, tp_metadata[0]]
|
|
49
|
+
elif num == 2:
|
|
50
|
+
interface = Annotated[tp_origin, tp_metadata[0], tp_metadata[1]]
|
|
51
|
+
elif num == 3:
|
|
52
|
+
interface = Annotated[
|
|
53
|
+
tp_origin,
|
|
54
|
+
tp_metadata[0],
|
|
55
|
+
tp_metadata[1],
|
|
56
|
+
tp_metadata[2],
|
|
57
|
+
]
|
|
58
|
+
else:
|
|
59
|
+
raise TypeError("Too many annotated arguments.") # pragma: no cover
|
|
60
|
+
return parameter.replace(annotation=interface, default=default)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def patch_call_parameter(
|
|
64
|
+
container: Container, call: Callable[..., Any], parameter: inspect.Parameter
|
|
65
|
+
) -> None:
|
|
66
|
+
"""Patch a parameter to inject dependencies using AnyDI."""
|
|
67
|
+
parameter = patch_annotated_parameter(parameter)
|
|
68
|
+
|
|
69
|
+
if not isinstance(parameter.default, HasInterface):
|
|
70
|
+
return None
|
|
50
71
|
|
|
51
|
-
if not container.strict and not container.is_registered(
|
|
72
|
+
if not container.strict and not container.is_registered(parameter.annotation):
|
|
52
73
|
logger.debug(
|
|
53
74
|
f"Callable `{get_full_qualname(call)}` injected parameter "
|
|
54
75
|
f"`{parameter.name}` with an annotation of "
|
|
55
|
-
f"`{get_full_qualname(
|
|
76
|
+
f"`{get_full_qualname(parameter.annotation)}` "
|
|
56
77
|
"is not registered. It will be registered at runtime with the "
|
|
57
78
|
"first call because it is running in non-strict mode."
|
|
58
79
|
)
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from typing import cast
|
|
2
|
+
|
|
3
|
+
from django.apps.registry import apps
|
|
4
|
+
from django.utils.functional import SimpleLazyObject
|
|
5
|
+
|
|
6
|
+
import anydi
|
|
7
|
+
|
|
8
|
+
from .apps import ContainerConfig
|
|
9
|
+
|
|
10
|
+
__all__ = ["container"]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _get_container() -> anydi.Container:
|
|
14
|
+
app_config = cast(ContainerConfig, apps.get_app_config(ContainerConfig.label))
|
|
15
|
+
return app_config.container
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
container = cast(anydi.Container, SimpleLazyObject(_get_container))
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
|
|
5
|
+
from django.conf import settings
|
|
6
|
+
from typing_extensions import TypedDict
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Settings(TypedDict):
|
|
10
|
+
CONTAINER_FACTORY: str | None
|
|
11
|
+
STRICT_MODE: bool
|
|
12
|
+
REGISTER_SETTINGS: bool
|
|
13
|
+
REGISTER_COMPONENTS: bool
|
|
14
|
+
INJECT_URLCONF: str | Sequence[str] | None
|
|
15
|
+
MODULES: Sequence[str]
|
|
16
|
+
SCAN_PACKAGES: Sequence[str]
|
|
17
|
+
PATCH_NINJA: bool
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
DEFAULTS = Settings(
|
|
21
|
+
CONTAINER_FACTORY=None,
|
|
22
|
+
STRICT_MODE=False,
|
|
23
|
+
REGISTER_SETTINGS=False,
|
|
24
|
+
REGISTER_COMPONENTS=False,
|
|
25
|
+
MODULES=[],
|
|
26
|
+
PATCH_NINJA=False,
|
|
27
|
+
INJECT_URLCONF=None,
|
|
28
|
+
SCAN_PACKAGES=[],
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def get_settings() -> Settings:
|
|
33
|
+
"""Get the AnyDI settings from the Django settings."""
|
|
34
|
+
return Settings(
|
|
35
|
+
**{
|
|
36
|
+
**DEFAULTS,
|
|
37
|
+
**getattr(settings, "ANYDI", {}),
|
|
38
|
+
}
|
|
39
|
+
)
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Iterator
|
|
4
|
+
from functools import wraps
|
|
5
|
+
from typing import Annotated, Any
|
|
6
|
+
|
|
7
|
+
from django.conf import settings
|
|
8
|
+
from django.core.cache import BaseCache, caches
|
|
9
|
+
from django.db import connections
|
|
10
|
+
from django.db.backends.base.base import BaseDatabaseWrapper
|
|
11
|
+
from django.urls import URLPattern, URLResolver, get_resolver
|
|
12
|
+
from typing_extensions import get_origin
|
|
13
|
+
|
|
14
|
+
from anydi import Container
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def register_settings(
|
|
18
|
+
container: Container, prefix: str = "django.conf.settings."
|
|
19
|
+
) -> None:
|
|
20
|
+
"""Register Django settings into the container."""
|
|
21
|
+
|
|
22
|
+
# Ensure prefix ends with a dot
|
|
23
|
+
if prefix[-1] != ".":
|
|
24
|
+
prefix += "."
|
|
25
|
+
|
|
26
|
+
for setting_name in dir(settings):
|
|
27
|
+
setting_value = getattr(settings, setting_name)
|
|
28
|
+
if not setting_name.isupper():
|
|
29
|
+
continue
|
|
30
|
+
|
|
31
|
+
container.register(
|
|
32
|
+
Annotated[Any, f"{prefix}{setting_name}"],
|
|
33
|
+
_get_setting_value(setting_value),
|
|
34
|
+
scope="singleton",
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
# Patch AnyDI to resolve Any types for annotated settings
|
|
38
|
+
_patch_any_typed_annotated(container, prefix=prefix)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def register_components(container: Container) -> None:
|
|
42
|
+
"""Register Django components into the container."""
|
|
43
|
+
|
|
44
|
+
# Register caches
|
|
45
|
+
def _get_cache(cache_name: str) -> Any:
|
|
46
|
+
return lambda: caches[cache_name]
|
|
47
|
+
|
|
48
|
+
for cache_name in caches:
|
|
49
|
+
container.register(
|
|
50
|
+
Annotated[BaseCache, cache_name],
|
|
51
|
+
_get_cache(cache_name),
|
|
52
|
+
scope="singleton",
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
# Register database connections
|
|
56
|
+
def _get_connection(alias: str) -> Any:
|
|
57
|
+
return lambda: connections[alias]
|
|
58
|
+
|
|
59
|
+
for alias in connections:
|
|
60
|
+
container.register(
|
|
61
|
+
Annotated[BaseDatabaseWrapper, alias],
|
|
62
|
+
_get_connection(alias),
|
|
63
|
+
scope="singleton",
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def inject_urlpatterns(container: Container, *, urlconf: str) -> None:
|
|
68
|
+
"""Auto-inject the container into views."""
|
|
69
|
+
resolver = get_resolver(urlconf)
|
|
70
|
+
for pattern in iter_urlpatterns(resolver.url_patterns):
|
|
71
|
+
# Skip already injected views
|
|
72
|
+
if hasattr(pattern.callback, "_injected"):
|
|
73
|
+
continue
|
|
74
|
+
# Skip django-ninja views
|
|
75
|
+
if pattern.lookup_str.startswith("ninja."):
|
|
76
|
+
continue # pragma: no cover
|
|
77
|
+
pattern.callback = container.inject(pattern.callback)
|
|
78
|
+
pattern.callback._injected = True # type: ignore[attr-defined]
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def iter_urlpatterns(
|
|
82
|
+
urlpatterns: list[URLPattern | URLResolver],
|
|
83
|
+
) -> Iterator[URLPattern]:
|
|
84
|
+
"""Iterate over all views in urlpatterns."""
|
|
85
|
+
for url_pattern in urlpatterns:
|
|
86
|
+
if isinstance(url_pattern, URLResolver):
|
|
87
|
+
yield from iter_urlpatterns(url_pattern.url_patterns)
|
|
88
|
+
else:
|
|
89
|
+
yield url_pattern
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _get_setting_value(value: Any) -> Any:
|
|
93
|
+
return lambda: value
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _any_typed_interface(interface: Any, prefix: str) -> Any:
|
|
97
|
+
origin = get_origin(interface)
|
|
98
|
+
if origin is not Annotated:
|
|
99
|
+
return interface # pragma: no cover
|
|
100
|
+
named = interface.__metadata__[-1]
|
|
101
|
+
|
|
102
|
+
if isinstance(named, str) and named.startswith(prefix):
|
|
103
|
+
_, setting_name = named.rsplit(prefix, maxsplit=1)
|
|
104
|
+
return Annotated[Any, f"{prefix}{setting_name}"]
|
|
105
|
+
return interface
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def _patch_any_typed_annotated(container: Container, *, prefix: str) -> None:
|
|
109
|
+
def _patch_resolve(resolve: Any) -> Any:
|
|
110
|
+
@wraps(resolve)
|
|
111
|
+
def wrapper(interface: Any) -> Any:
|
|
112
|
+
return resolve(_any_typed_interface(interface, prefix))
|
|
113
|
+
|
|
114
|
+
return wrapper
|
|
115
|
+
|
|
116
|
+
def _patch_aresolve(resolve: Any) -> Any:
|
|
117
|
+
@wraps(resolve)
|
|
118
|
+
async def wrapper(interface: Any) -> Any:
|
|
119
|
+
return await resolve(_any_typed_interface(interface, prefix))
|
|
120
|
+
|
|
121
|
+
return wrapper
|
|
122
|
+
|
|
123
|
+
container.resolve = _patch_resolve( # type: ignore[method-assign]
|
|
124
|
+
container.resolve
|
|
125
|
+
)
|
|
126
|
+
container.aresolve = _patch_aresolve( # type: ignore[method-assign]
|
|
127
|
+
container.aresolve
|
|
128
|
+
)
|