anydi 0.56.0__tar.gz → 0.57.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {anydi-0.56.0 → anydi-0.57.0}/PKG-INFO +9 -10
- {anydi-0.56.0 → anydi-0.57.0}/README.md +8 -9
- {anydi-0.56.0 → anydi-0.57.0}/anydi/__init__.py +4 -2
- {anydi-0.56.0 → anydi-0.57.0}/anydi/_container.py +57 -104
- anydi-0.57.0/anydi/_injector.py +132 -0
- anydi-0.57.0/anydi/_scanner.py +118 -0
- {anydi-0.56.0 → anydi-0.57.0}/anydi/_types.py +48 -7
- {anydi-0.56.0 → anydi-0.57.0}/anydi/ext/fastapi.py +31 -33
- {anydi-0.56.0 → anydi-0.57.0}/anydi/ext/faststream.py +25 -31
- {anydi-0.56.0 → anydi-0.57.0}/anydi/ext/pydantic_settings.py +2 -1
- anydi-0.57.0/anydi/ext/pytest_plugin.py +477 -0
- {anydi-0.56.0 → anydi-0.57.0}/pyproject.toml +6 -1
- anydi-0.56.0/anydi/_scanner.py +0 -110
- anydi-0.56.0/anydi/ext/pytest_plugin.py +0 -147
- {anydi-0.56.0 → anydi-0.57.0}/anydi/_async_lock.py +0 -0
- {anydi-0.56.0 → anydi-0.57.0}/anydi/_context.py +0 -0
- {anydi-0.56.0 → anydi-0.57.0}/anydi/_decorators.py +0 -0
- {anydi-0.56.0 → anydi-0.57.0}/anydi/_module.py +0 -0
- {anydi-0.56.0 → anydi-0.57.0}/anydi/_provider.py +0 -0
- {anydi-0.56.0 → anydi-0.57.0}/anydi/_resolver.py +0 -0
- {anydi-0.56.0 → anydi-0.57.0}/anydi/ext/__init__.py +0 -0
- {anydi-0.56.0 → anydi-0.57.0}/anydi/ext/django/__init__.py +0 -0
- {anydi-0.56.0 → anydi-0.57.0}/anydi/ext/starlette/__init__.py +0 -0
- {anydi-0.56.0 → anydi-0.57.0}/anydi/ext/starlette/middleware.py +0 -0
- {anydi-0.56.0 → anydi-0.57.0}/anydi/py.typed +0 -0
- {anydi-0.56.0 → anydi-0.57.0}/anydi/testing.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: anydi
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.57.0
|
|
4
4
|
Summary: Dependency Injection library
|
|
5
5
|
Keywords: dependency injection,dependencies,di,async,asyncio,application
|
|
6
6
|
Author: Anton Ruhlov
|
|
@@ -116,19 +116,18 @@ if __name__ == "__main__":
|
|
|
116
116
|
### Inject Into Functions (`app/main.py`)
|
|
117
117
|
|
|
118
118
|
```python
|
|
119
|
-
from anydi import
|
|
119
|
+
from anydi import Provide
|
|
120
120
|
|
|
121
121
|
from app.container import container
|
|
122
122
|
from app.services import GreetingService
|
|
123
123
|
|
|
124
124
|
|
|
125
|
-
|
|
126
|
-
def greet(service: GreetingService = Inject()) -> str:
|
|
125
|
+
def greet(service: Provide[GreetingService]) -> str:
|
|
127
126
|
return service.greet("World")
|
|
128
127
|
|
|
129
128
|
|
|
130
129
|
if __name__ == "__main__":
|
|
131
|
-
print(greet
|
|
130
|
+
print(container.run(greet))
|
|
132
131
|
```
|
|
133
132
|
|
|
134
133
|
### Test with Overrides (`tests/test_app.py`)
|
|
@@ -146,7 +145,7 @@ def test_greet() -> None:
|
|
|
146
145
|
service_mock.greet.return_value = "Mocked"
|
|
147
146
|
|
|
148
147
|
with container.override(GreetingService, service_mock):
|
|
149
|
-
result = greet
|
|
148
|
+
result = container.run(greet)
|
|
150
149
|
|
|
151
150
|
assert result == "Mocked"
|
|
152
151
|
```
|
|
@@ -158,8 +157,8 @@ from typing import Annotated
|
|
|
158
157
|
|
|
159
158
|
import anydi.ext.fastapi
|
|
160
159
|
from fastapi import FastAPI
|
|
161
|
-
from anydi.ext.fastapi import Inject
|
|
162
160
|
|
|
161
|
+
from anydi import Provide
|
|
163
162
|
from app.container import container
|
|
164
163
|
from app.services import GreetingService
|
|
165
164
|
|
|
@@ -169,7 +168,7 @@ app = FastAPI()
|
|
|
169
168
|
|
|
170
169
|
@app.get("/greeting")
|
|
171
170
|
async def greet(
|
|
172
|
-
service:
|
|
171
|
+
service: Provide[GreetingService]
|
|
173
172
|
) -> dict[str, str]:
|
|
174
173
|
return {"greeting": service.greet("World")}
|
|
175
174
|
|
|
@@ -245,7 +244,7 @@ Wire Django Ninja (`urls.py`):
|
|
|
245
244
|
```python
|
|
246
245
|
from typing import Annotated, Any
|
|
247
246
|
|
|
248
|
-
from anydi import
|
|
247
|
+
from anydi import Provide
|
|
249
248
|
from django.http import HttpRequest
|
|
250
249
|
from django.urls import path
|
|
251
250
|
from ninja import NinjaAPI
|
|
@@ -257,7 +256,7 @@ api = NinjaAPI()
|
|
|
257
256
|
|
|
258
257
|
|
|
259
258
|
@api.get("/greeting")
|
|
260
|
-
def greet(request: HttpRequest, service:
|
|
259
|
+
def greet(request: HttpRequest, service: Provide[GreetingService]) -> Any:
|
|
261
260
|
return {"greeting": service.greet("World")}
|
|
262
261
|
|
|
263
262
|
|
|
@@ -81,19 +81,18 @@ if __name__ == "__main__":
|
|
|
81
81
|
### Inject Into Functions (`app/main.py`)
|
|
82
82
|
|
|
83
83
|
```python
|
|
84
|
-
from anydi import
|
|
84
|
+
from anydi import Provide
|
|
85
85
|
|
|
86
86
|
from app.container import container
|
|
87
87
|
from app.services import GreetingService
|
|
88
88
|
|
|
89
89
|
|
|
90
|
-
|
|
91
|
-
def greet(service: GreetingService = Inject()) -> str:
|
|
90
|
+
def greet(service: Provide[GreetingService]) -> str:
|
|
92
91
|
return service.greet("World")
|
|
93
92
|
|
|
94
93
|
|
|
95
94
|
if __name__ == "__main__":
|
|
96
|
-
print(greet
|
|
95
|
+
print(container.run(greet))
|
|
97
96
|
```
|
|
98
97
|
|
|
99
98
|
### Test with Overrides (`tests/test_app.py`)
|
|
@@ -111,7 +110,7 @@ def test_greet() -> None:
|
|
|
111
110
|
service_mock.greet.return_value = "Mocked"
|
|
112
111
|
|
|
113
112
|
with container.override(GreetingService, service_mock):
|
|
114
|
-
result = greet
|
|
113
|
+
result = container.run(greet)
|
|
115
114
|
|
|
116
115
|
assert result == "Mocked"
|
|
117
116
|
```
|
|
@@ -123,8 +122,8 @@ from typing import Annotated
|
|
|
123
122
|
|
|
124
123
|
import anydi.ext.fastapi
|
|
125
124
|
from fastapi import FastAPI
|
|
126
|
-
from anydi.ext.fastapi import Inject
|
|
127
125
|
|
|
126
|
+
from anydi import Provide
|
|
128
127
|
from app.container import container
|
|
129
128
|
from app.services import GreetingService
|
|
130
129
|
|
|
@@ -134,7 +133,7 @@ app = FastAPI()
|
|
|
134
133
|
|
|
135
134
|
@app.get("/greeting")
|
|
136
135
|
async def greet(
|
|
137
|
-
service:
|
|
136
|
+
service: Provide[GreetingService]
|
|
138
137
|
) -> dict[str, str]:
|
|
139
138
|
return {"greeting": service.greet("World")}
|
|
140
139
|
|
|
@@ -210,7 +209,7 @@ Wire Django Ninja (`urls.py`):
|
|
|
210
209
|
```python
|
|
211
210
|
from typing import Annotated, Any
|
|
212
211
|
|
|
213
|
-
from anydi import
|
|
212
|
+
from anydi import Provide
|
|
214
213
|
from django.http import HttpRequest
|
|
215
214
|
from django.urls import path
|
|
216
215
|
from ninja import NinjaAPI
|
|
@@ -222,7 +221,7 @@ api = NinjaAPI()
|
|
|
222
221
|
|
|
223
222
|
|
|
224
223
|
@api.get("/greeting")
|
|
225
|
-
def greet(request: HttpRequest, service:
|
|
224
|
+
def greet(request: HttpRequest, service: Provide[GreetingService]) -> Any:
|
|
226
225
|
return {"greeting": service.greet("World")}
|
|
227
226
|
|
|
228
227
|
|
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
"""AnyDI public objects and functions."""
|
|
2
2
|
|
|
3
|
-
from ._container import Container
|
|
3
|
+
from ._container import Container, import_container
|
|
4
4
|
from ._decorators import injectable, provided, provider, request, singleton, transient
|
|
5
5
|
from ._module import Module
|
|
6
6
|
from ._provider import ProviderDef as Provider
|
|
7
|
-
from ._types import Inject, Scope
|
|
7
|
+
from ._types import Inject, Provide, Scope
|
|
8
8
|
|
|
9
9
|
# Alias for dependency auto marker
|
|
10
10
|
# TODO: deprecate it
|
|
@@ -15,9 +15,11 @@ __all__ = [
|
|
|
15
15
|
"Container",
|
|
16
16
|
"Inject",
|
|
17
17
|
"Module",
|
|
18
|
+
"Provide",
|
|
18
19
|
"Provider",
|
|
19
20
|
"Scope",
|
|
20
21
|
"auto",
|
|
22
|
+
"import_container",
|
|
21
23
|
"injectable",
|
|
22
24
|
"provided",
|
|
23
25
|
"provider",
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
import contextlib
|
|
6
|
-
import
|
|
6
|
+
import importlib
|
|
7
7
|
import inspect
|
|
8
8
|
import logging
|
|
9
9
|
import types
|
|
@@ -11,12 +11,13 @@ import uuid
|
|
|
11
11
|
from collections import defaultdict
|
|
12
12
|
from collections.abc import AsyncIterator, Callable, Iterable, Iterator
|
|
13
13
|
from contextvars import ContextVar
|
|
14
|
-
from typing import
|
|
14
|
+
from typing import Any, TypeVar, get_args, get_origin, overload
|
|
15
15
|
|
|
16
16
|
from typing_extensions import ParamSpec, Self, type_repr
|
|
17
17
|
|
|
18
18
|
from ._context import InstanceContext
|
|
19
19
|
from ._decorators import is_provided
|
|
20
|
+
from ._injector import Injector
|
|
20
21
|
from ._module import ModuleDef, ModuleRegistrar
|
|
21
22
|
from ._provider import Provider, ProviderDef, ProviderKind, ProviderParameter
|
|
22
23
|
from ._resolver import Resolver
|
|
@@ -26,7 +27,6 @@ from ._types import (
|
|
|
26
27
|
Event,
|
|
27
28
|
Scope,
|
|
28
29
|
is_event_type,
|
|
29
|
-
is_inject_marker,
|
|
30
30
|
is_iterator_type,
|
|
31
31
|
is_none_type,
|
|
32
32
|
)
|
|
@@ -58,10 +58,10 @@ class Container:
|
|
|
58
58
|
self._request_context_var: ContextVar[InstanceContext | None] = ContextVar(
|
|
59
59
|
"request_context", default=None
|
|
60
60
|
)
|
|
61
|
-
self._inject_cache: dict[Callable[..., Any], Callable[..., Any]] = {}
|
|
62
61
|
|
|
63
62
|
# Components
|
|
64
63
|
self._resolver = Resolver(self)
|
|
64
|
+
self._injector = Injector(self)
|
|
65
65
|
self._modules = ModuleRegistrar(self)
|
|
66
66
|
self._scanner = Scanner(self)
|
|
67
67
|
|
|
@@ -195,12 +195,14 @@ class Container:
|
|
|
195
195
|
def register(
|
|
196
196
|
self,
|
|
197
197
|
interface: Any,
|
|
198
|
-
call: Callable[..., Any],
|
|
198
|
+
call: Callable[..., Any] = NOT_SET,
|
|
199
199
|
*,
|
|
200
|
-
scope: Scope,
|
|
200
|
+
scope: Scope = "singleton",
|
|
201
201
|
override: bool = False,
|
|
202
202
|
) -> Provider:
|
|
203
203
|
"""Register a provider for the specified interface."""
|
|
204
|
+
if call is NOT_SET:
|
|
205
|
+
call = interface
|
|
204
206
|
return self._register_provider(call, scope, interface, override)
|
|
205
207
|
|
|
206
208
|
def is_registered(self, interface: Any) -> bool:
|
|
@@ -548,115 +550,21 @@ class Container:
|
|
|
548
550
|
"""Decorator to inject dependencies into a callable."""
|
|
549
551
|
|
|
550
552
|
def decorator(call: Callable[P, T]) -> Callable[P, T]:
|
|
551
|
-
return self.
|
|
553
|
+
return self._injector.inject(call)
|
|
552
554
|
|
|
553
555
|
if func is None:
|
|
554
556
|
return decorator
|
|
555
557
|
return decorator(func)
|
|
556
558
|
|
|
557
|
-
def run(self, func: Callable[
|
|
559
|
+
def run(self, func: Callable[..., T], /, *args: Any, **kwargs: Any) -> T:
|
|
558
560
|
"""Run the given function with injected dependencies."""
|
|
559
|
-
return self.
|
|
560
|
-
|
|
561
|
-
def _inject(self, call: Callable[P, T]) -> Callable[P, T]:
|
|
562
|
-
"""Inject dependencies into a callable."""
|
|
563
|
-
if call in self._inject_cache:
|
|
564
|
-
return cast(Callable[P, T], self._inject_cache[call])
|
|
565
|
-
|
|
566
|
-
injected_params = self._get_injected_params(call)
|
|
567
|
-
if not injected_params:
|
|
568
|
-
self._inject_cache[call] = call
|
|
569
|
-
return call
|
|
570
|
-
|
|
571
|
-
if inspect.iscoroutinefunction(call):
|
|
572
|
-
|
|
573
|
-
@functools.wraps(call)
|
|
574
|
-
async def awrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
575
|
-
for name, annotation in injected_params.items():
|
|
576
|
-
kwargs[name] = await self.aresolve(annotation)
|
|
577
|
-
return cast(T, await call(*args, **kwargs))
|
|
578
|
-
|
|
579
|
-
self._inject_cache[call] = awrapper
|
|
580
|
-
|
|
581
|
-
return awrapper # type: ignore
|
|
582
|
-
|
|
583
|
-
@functools.wraps(call)
|
|
584
|
-
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
585
|
-
for name, annotation in injected_params.items():
|
|
586
|
-
kwargs[name] = self.resolve(annotation)
|
|
587
|
-
return call(*args, **kwargs)
|
|
588
|
-
|
|
589
|
-
self._inject_cache[call] = wrapper
|
|
590
|
-
|
|
591
|
-
return wrapper
|
|
592
|
-
|
|
593
|
-
def _get_injected_params(self, call: Callable[..., Any]) -> dict[str, Any]:
|
|
594
|
-
"""Get the injected parameters of a callable object."""
|
|
595
|
-
injected_params: dict[str, Any] = {}
|
|
596
|
-
for parameter in inspect.signature(call, eval_str=True).parameters.values():
|
|
597
|
-
interface, should_inject = self.validate_injected_parameter(
|
|
598
|
-
parameter, call=call
|
|
599
|
-
)
|
|
600
|
-
if should_inject:
|
|
601
|
-
injected_params[parameter.name] = interface
|
|
602
|
-
return injected_params
|
|
603
|
-
|
|
604
|
-
@staticmethod
|
|
605
|
-
def _unwrap_injected_parameter(parameter: inspect.Parameter) -> inspect.Parameter:
|
|
606
|
-
if get_origin(parameter.annotation) is not Annotated:
|
|
607
|
-
return parameter
|
|
608
|
-
|
|
609
|
-
origin, *metadata = get_args(parameter.annotation)
|
|
610
|
-
|
|
611
|
-
if not metadata or not is_inject_marker(metadata[-1]):
|
|
612
|
-
return parameter
|
|
613
|
-
|
|
614
|
-
if is_inject_marker(parameter.default):
|
|
615
|
-
raise TypeError(
|
|
616
|
-
"Cannot specify `Inject` in `Annotated` and "
|
|
617
|
-
f"default value together for '{parameter.name}'"
|
|
618
|
-
)
|
|
619
|
-
|
|
620
|
-
if parameter.default is not inspect.Parameter.empty:
|
|
621
|
-
return parameter
|
|
622
|
-
|
|
623
|
-
marker = metadata[-1]
|
|
624
|
-
new_metadata = metadata[:-1]
|
|
625
|
-
if new_metadata:
|
|
626
|
-
if hasattr(Annotated, "__getitem__"):
|
|
627
|
-
new_annotation = Annotated.__getitem__((origin, *new_metadata)) # type: ignore
|
|
628
|
-
else:
|
|
629
|
-
new_annotation = Annotated.__class_getitem__((origin, *new_metadata)) # type: ignore
|
|
630
|
-
else:
|
|
631
|
-
new_annotation = origin
|
|
632
|
-
return parameter.replace(annotation=new_annotation, default=marker)
|
|
561
|
+
return self._injector.inject(func)(*args, **kwargs)
|
|
633
562
|
|
|
634
563
|
def validate_injected_parameter(
|
|
635
564
|
self, parameter: inspect.Parameter, *, call: Callable[..., Any]
|
|
636
565
|
) -> tuple[Any, bool]:
|
|
637
566
|
"""Validate an injected parameter."""
|
|
638
|
-
|
|
639
|
-
interface = parameter.annotation
|
|
640
|
-
|
|
641
|
-
if not is_inject_marker(parameter.default):
|
|
642
|
-
return interface, False
|
|
643
|
-
|
|
644
|
-
if interface is inspect.Parameter.empty:
|
|
645
|
-
raise TypeError(
|
|
646
|
-
f"Missing `{type_repr(call)}` parameter `{parameter.name}` annotation."
|
|
647
|
-
)
|
|
648
|
-
|
|
649
|
-
# Set inject marker interface
|
|
650
|
-
parameter.default.interface = interface
|
|
651
|
-
|
|
652
|
-
if not self.has_provider_for(interface):
|
|
653
|
-
raise LookupError(
|
|
654
|
-
f"`{type_repr(call)}` has an unknown dependency parameter "
|
|
655
|
-
f"`{parameter.name}` with an annotation of "
|
|
656
|
-
f"`{type_repr(interface)}`."
|
|
657
|
-
)
|
|
658
|
-
|
|
659
|
-
return interface, True
|
|
567
|
+
return self._injector.validate_parameter(parameter, call=call)
|
|
660
568
|
|
|
661
569
|
# == Module Registration ==
|
|
662
570
|
|
|
@@ -685,3 +593,48 @@ class Container:
|
|
|
685
593
|
yield
|
|
686
594
|
finally:
|
|
687
595
|
self._resolver.remove_override(interface)
|
|
596
|
+
|
|
597
|
+
|
|
598
|
+
def import_container(container_path: str) -> Container:
|
|
599
|
+
"""Import container from a string path."""
|
|
600
|
+
# Replace colon with dot for unified processing
|
|
601
|
+
container_path = container_path.replace(":", ".")
|
|
602
|
+
|
|
603
|
+
try:
|
|
604
|
+
module_path, attr_name = container_path.rsplit(".", 1)
|
|
605
|
+
except ValueError as exc:
|
|
606
|
+
raise ImportError(
|
|
607
|
+
f"Invalid container path '{container_path}'. "
|
|
608
|
+
"Expected format: 'module.path:attribute' or 'module.path.attribute'"
|
|
609
|
+
) from exc
|
|
610
|
+
|
|
611
|
+
try:
|
|
612
|
+
module = importlib.import_module(module_path)
|
|
613
|
+
except ImportError as exc:
|
|
614
|
+
raise ImportError(
|
|
615
|
+
f"Failed to import module '{module_path}' "
|
|
616
|
+
f"from container path '{container_path}'"
|
|
617
|
+
) from exc
|
|
618
|
+
|
|
619
|
+
try:
|
|
620
|
+
container_or_factory = getattr(module, attr_name)
|
|
621
|
+
except AttributeError as exc:
|
|
622
|
+
raise ImportError(
|
|
623
|
+
f"Module '{module_path}' has no attribute '{attr_name}'"
|
|
624
|
+
) from exc
|
|
625
|
+
|
|
626
|
+
# If it's a callable (factory), call it
|
|
627
|
+
if callable(container_or_factory) and not isinstance(
|
|
628
|
+
container_or_factory, Container
|
|
629
|
+
):
|
|
630
|
+
container = container_or_factory()
|
|
631
|
+
else:
|
|
632
|
+
container = container_or_factory
|
|
633
|
+
|
|
634
|
+
if not isinstance(container, Container):
|
|
635
|
+
raise ImportError(
|
|
636
|
+
f"Expected Container instance, got {type(container).__name__} "
|
|
637
|
+
f"from '{container_path}'"
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
return container
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
"""Dependency injection utilities."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import functools
|
|
6
|
+
import inspect
|
|
7
|
+
from collections.abc import Callable
|
|
8
|
+
from typing import (
|
|
9
|
+
TYPE_CHECKING,
|
|
10
|
+
Annotated,
|
|
11
|
+
Any,
|
|
12
|
+
TypeVar,
|
|
13
|
+
cast,
|
|
14
|
+
get_args,
|
|
15
|
+
get_origin,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
from typing_extensions import ParamSpec, type_repr
|
|
19
|
+
|
|
20
|
+
from ._types import is_provide_marker
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from ._container import Container
|
|
24
|
+
|
|
25
|
+
T = TypeVar("T", bound=Any)
|
|
26
|
+
P = ParamSpec("P")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class Injector:
|
|
30
|
+
"""Handles dependency injection for callables."""
|
|
31
|
+
|
|
32
|
+
def __init__(self, container: Container) -> None:
|
|
33
|
+
self.container = container
|
|
34
|
+
self._cache: dict[Callable[..., Any], Callable[..., Any]] = {}
|
|
35
|
+
|
|
36
|
+
def inject(self, call: Callable[P, T]) -> Callable[P, T]:
|
|
37
|
+
"""Inject dependencies into a callable."""
|
|
38
|
+
if call in self._cache:
|
|
39
|
+
return cast(Callable[P, T], self._cache[call])
|
|
40
|
+
|
|
41
|
+
injected_params = self._get_injected_params(call)
|
|
42
|
+
if not injected_params:
|
|
43
|
+
self._cache[call] = call
|
|
44
|
+
return call
|
|
45
|
+
|
|
46
|
+
if inspect.iscoroutinefunction(call):
|
|
47
|
+
|
|
48
|
+
@functools.wraps(call)
|
|
49
|
+
async def awrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
50
|
+
for name, annotation in injected_params.items():
|
|
51
|
+
kwargs[name] = await self.container.aresolve(annotation)
|
|
52
|
+
return cast(T, await call(*args, **kwargs))
|
|
53
|
+
|
|
54
|
+
self._cache[call] = awrapper
|
|
55
|
+
|
|
56
|
+
return awrapper # type: ignore
|
|
57
|
+
|
|
58
|
+
@functools.wraps(call)
|
|
59
|
+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
60
|
+
for name, annotation in injected_params.items():
|
|
61
|
+
kwargs[name] = self.container.resolve(annotation)
|
|
62
|
+
return call(*args, **kwargs)
|
|
63
|
+
|
|
64
|
+
self._cache[call] = wrapper
|
|
65
|
+
|
|
66
|
+
return wrapper
|
|
67
|
+
|
|
68
|
+
def _get_injected_params(self, call: Callable[..., Any]) -> dict[str, Any]:
|
|
69
|
+
"""Get the injected parameters of a callable object."""
|
|
70
|
+
injected_params: dict[str, Any] = {}
|
|
71
|
+
for parameter in inspect.signature(call, eval_str=True).parameters.values():
|
|
72
|
+
interface, should_inject = self.validate_parameter(parameter, call=call)
|
|
73
|
+
if should_inject:
|
|
74
|
+
injected_params[parameter.name] = interface
|
|
75
|
+
return injected_params
|
|
76
|
+
|
|
77
|
+
def validate_parameter(
|
|
78
|
+
self, parameter: inspect.Parameter, *, call: Callable[..., Any]
|
|
79
|
+
) -> tuple[Any, bool]:
|
|
80
|
+
"""Validate an injected parameter."""
|
|
81
|
+
parameter = self.unwrap_parameter(parameter)
|
|
82
|
+
interface = parameter.annotation
|
|
83
|
+
|
|
84
|
+
if not is_provide_marker(parameter.default):
|
|
85
|
+
return interface, False
|
|
86
|
+
|
|
87
|
+
if interface is inspect.Parameter.empty:
|
|
88
|
+
raise TypeError(
|
|
89
|
+
f"Missing `{type_repr(call)}` parameter `{parameter.name}` annotation."
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
# Set inject marker interface
|
|
93
|
+
parameter.default.interface = interface
|
|
94
|
+
|
|
95
|
+
if not self.container.has_provider_for(interface):
|
|
96
|
+
raise LookupError(
|
|
97
|
+
f"`{type_repr(call)}` has an unknown dependency parameter "
|
|
98
|
+
f"`{parameter.name}` with an annotation of "
|
|
99
|
+
f"`{type_repr(interface)}`."
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
return interface, True
|
|
103
|
+
|
|
104
|
+
@staticmethod
|
|
105
|
+
def unwrap_parameter(parameter: inspect.Parameter) -> inspect.Parameter:
|
|
106
|
+
if get_origin(parameter.annotation) is not Annotated:
|
|
107
|
+
return parameter
|
|
108
|
+
|
|
109
|
+
origin, *metadata = get_args(parameter.annotation)
|
|
110
|
+
|
|
111
|
+
if not metadata or not is_provide_marker(metadata[-1]):
|
|
112
|
+
return parameter
|
|
113
|
+
|
|
114
|
+
if is_provide_marker(parameter.default):
|
|
115
|
+
raise TypeError(
|
|
116
|
+
"Cannot specify `Inject` in `Annotated` and "
|
|
117
|
+
f"default value together for '{parameter.name}'"
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
if parameter.default is not inspect.Parameter.empty:
|
|
121
|
+
return parameter
|
|
122
|
+
|
|
123
|
+
marker = metadata[-1]
|
|
124
|
+
new_metadata = metadata[:-1]
|
|
125
|
+
if new_metadata:
|
|
126
|
+
if hasattr(Annotated, "__getitem__"):
|
|
127
|
+
new_annotation = Annotated.__getitem__((origin, *new_metadata)) # type: ignore
|
|
128
|
+
else:
|
|
129
|
+
new_annotation = Annotated.__class_getitem__((origin, *new_metadata)) # type: ignore
|
|
130
|
+
else:
|
|
131
|
+
new_annotation = origin
|
|
132
|
+
return parameter.replace(annotation=new_annotation, default=marker)
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import importlib
|
|
4
|
+
import inspect
|
|
5
|
+
import pkgutil
|
|
6
|
+
from collections.abc import Callable, Iterable, Iterator
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from types import ModuleType
|
|
9
|
+
from typing import TYPE_CHECKING, Any
|
|
10
|
+
|
|
11
|
+
from ._decorators import Provided, is_injectable, is_provided
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from ._container import Container
|
|
15
|
+
|
|
16
|
+
Package = ModuleType | str
|
|
17
|
+
PackageOrIterable = Package | Iterable[Package]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass(kw_only=True)
|
|
21
|
+
class ScannedDependency:
|
|
22
|
+
member: Any
|
|
23
|
+
module: ModuleType
|
|
24
|
+
|
|
25
|
+
def __post_init__(self) -> None:
|
|
26
|
+
# Unwrap decorated functions if necessary
|
|
27
|
+
if hasattr(self.member, "__wrapped__"):
|
|
28
|
+
self.member = self.member.__wrapped__
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class Scanner:
|
|
32
|
+
def __init__(self, container: Container) -> None:
|
|
33
|
+
self._container = container
|
|
34
|
+
|
|
35
|
+
def scan(
|
|
36
|
+
self, /, packages: PackageOrIterable, *, tags: Iterable[str] | None = None
|
|
37
|
+
) -> None:
|
|
38
|
+
"""Scan packages or modules for decorated members and inject dependencies."""
|
|
39
|
+
if isinstance(packages, (ModuleType, str)):
|
|
40
|
+
packages = [packages]
|
|
41
|
+
|
|
42
|
+
tags_list = list(tags) if tags else []
|
|
43
|
+
provided_classes: list[type[Provided]] = []
|
|
44
|
+
injectable_dependencies: list[ScannedDependency] = []
|
|
45
|
+
|
|
46
|
+
# Single pass: collect both @provided classes and @injectable functions
|
|
47
|
+
for module in self._iter_modules(packages):
|
|
48
|
+
provided_classes.extend(self._scan_module_for_provided(module))
|
|
49
|
+
injectable_dependencies.extend(
|
|
50
|
+
self._scan_module_for_injectable(module, tags=tags_list)
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
# First: register @provided classes
|
|
54
|
+
for cls in provided_classes:
|
|
55
|
+
if not self._container.is_registered(cls):
|
|
56
|
+
scope = cls.__provided__["scope"]
|
|
57
|
+
self._container.register(cls, scope=scope)
|
|
58
|
+
|
|
59
|
+
# Second: inject @injectable functions
|
|
60
|
+
for dependency in injectable_dependencies:
|
|
61
|
+
decorated = self._container.inject()(dependency.member)
|
|
62
|
+
setattr(dependency.module, dependency.member.__name__, decorated)
|
|
63
|
+
|
|
64
|
+
def _iter_modules(self, packages: Iterable[Package]) -> Iterator[ModuleType]:
|
|
65
|
+
"""Iterate over all modules in the given packages."""
|
|
66
|
+
for package in packages:
|
|
67
|
+
if isinstance(package, str):
|
|
68
|
+
package = importlib.import_module(package)
|
|
69
|
+
|
|
70
|
+
# Single module (not a package)
|
|
71
|
+
if not hasattr(package, "__path__"):
|
|
72
|
+
yield package
|
|
73
|
+
continue
|
|
74
|
+
|
|
75
|
+
# Package - walk all submodules
|
|
76
|
+
for module_info in pkgutil.walk_packages(
|
|
77
|
+
package.__path__, prefix=package.__name__ + "."
|
|
78
|
+
):
|
|
79
|
+
yield importlib.import_module(module_info.name)
|
|
80
|
+
|
|
81
|
+
def _scan_module_for_provided(self, module: ModuleType) -> list[type[Provided]]:
|
|
82
|
+
"""Scan a module for @provided classes."""
|
|
83
|
+
provided_classes: list[type[Provided]] = []
|
|
84
|
+
|
|
85
|
+
for _, member in inspect.getmembers(module, predicate=inspect.isclass):
|
|
86
|
+
if getattr(member, "__module__", None) != module.__name__:
|
|
87
|
+
continue
|
|
88
|
+
|
|
89
|
+
if is_provided(member):
|
|
90
|
+
provided_classes.append(member)
|
|
91
|
+
|
|
92
|
+
return provided_classes
|
|
93
|
+
|
|
94
|
+
def _scan_module_for_injectable(
|
|
95
|
+
self, module: ModuleType, *, tags: list[str]
|
|
96
|
+
) -> list[ScannedDependency]:
|
|
97
|
+
"""Scan a module for @injectable functions."""
|
|
98
|
+
dependencies: list[ScannedDependency] = []
|
|
99
|
+
|
|
100
|
+
for _, member in inspect.getmembers(module, predicate=callable):
|
|
101
|
+
if getattr(member, "__module__", None) != module.__name__:
|
|
102
|
+
continue
|
|
103
|
+
|
|
104
|
+
if self._should_include_member(member, tags=tags):
|
|
105
|
+
dependencies.append(ScannedDependency(member=member, module=module))
|
|
106
|
+
|
|
107
|
+
return dependencies
|
|
108
|
+
|
|
109
|
+
@staticmethod
|
|
110
|
+
def _should_include_member(member: Callable[..., Any], *, tags: list[str]) -> bool:
|
|
111
|
+
"""Determine if a member should be included based on tags or marker defaults."""
|
|
112
|
+
if is_injectable(member):
|
|
113
|
+
member_tags = set(member.__injectable__["tags"] or [])
|
|
114
|
+
if tags:
|
|
115
|
+
return bool(set(tags) & member_tags)
|
|
116
|
+
return True # No tags passed → include all injectables
|
|
117
|
+
|
|
118
|
+
return False
|