anydi 0.38.1__py3-none-any.whl → 0.38.2rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- anydi/__init__.py +1 -2
- anydi/_container.py +483 -354
- anydi/_types.py +80 -2
- anydi/_utils.py +3 -0
- anydi/ext/pytest_plugin.py +2 -6
- {anydi-0.38.1.dist-info → anydi-0.38.2rc1.dist-info}/METADATA +1 -1
- {anydi-0.38.1.dist-info → anydi-0.38.2rc1.dist-info}/RECORD +10 -11
- anydi/_provider.py +0 -232
- {anydi-0.38.1.dist-info → anydi-0.38.2rc1.dist-info}/WHEEL +0 -0
- {anydi-0.38.1.dist-info → anydi-0.38.2rc1.dist-info}/entry_points.txt +0 -0
- {anydi-0.38.1.dist-info → anydi-0.38.2rc1.dist-info}/licenses/LICENSE +0 -0
anydi/_container.py
CHANGED
|
@@ -10,23 +10,27 @@ import logging
|
|
|
10
10
|
import pkgutil
|
|
11
11
|
import threading
|
|
12
12
|
import types
|
|
13
|
+
import uuid
|
|
13
14
|
from collections import defaultdict
|
|
14
15
|
from collections.abc import AsyncIterator, Iterable, Iterator, Sequence
|
|
15
16
|
from contextvars import ContextVar
|
|
16
17
|
from types import ModuleType
|
|
17
|
-
from typing import Any, Callable, TypeVar, Union, cast, overload
|
|
18
|
-
from weakref import WeakKeyDictionary
|
|
18
|
+
from typing import Annotated, Any, Callable, TypeVar, Union, cast, overload
|
|
19
19
|
|
|
20
|
-
from typing_extensions import Concatenate, ParamSpec, Self, final
|
|
20
|
+
from typing_extensions import Concatenate, ParamSpec, Self, final, get_args, get_origin
|
|
21
21
|
|
|
22
22
|
from ._context import InstanceContext
|
|
23
|
-
from ._provider import Provider
|
|
24
23
|
from ._types import (
|
|
24
|
+
NOT_SET,
|
|
25
25
|
AnyInterface,
|
|
26
|
-
|
|
26
|
+
Event,
|
|
27
27
|
InjectableDecoratorArgs,
|
|
28
28
|
InstanceProxy,
|
|
29
|
+
Provider,
|
|
30
|
+
ProviderArgs,
|
|
29
31
|
ProviderDecoratorArgs,
|
|
32
|
+
ProviderKind,
|
|
33
|
+
ScannedDependency,
|
|
30
34
|
Scope,
|
|
31
35
|
is_event_type,
|
|
32
36
|
is_marker,
|
|
@@ -34,6 +38,7 @@ from ._types import (
|
|
|
34
38
|
from ._utils import (
|
|
35
39
|
AsyncRLock,
|
|
36
40
|
get_full_qualname,
|
|
41
|
+
get_typed_annotation,
|
|
37
42
|
get_typed_parameters,
|
|
38
43
|
import_string,
|
|
39
44
|
is_async_context_manager,
|
|
@@ -42,6 +47,12 @@ from ._utils import (
|
|
|
42
47
|
run_async,
|
|
43
48
|
)
|
|
44
49
|
|
|
50
|
+
try:
|
|
51
|
+
from types import NoneType
|
|
52
|
+
except ImportError:
|
|
53
|
+
NoneType = type(None) # type: ignore[misc]
|
|
54
|
+
|
|
55
|
+
|
|
45
56
|
T = TypeVar("T", bound=Any)
|
|
46
57
|
M = TypeVar("M", bound="Module")
|
|
47
58
|
P = ParamSpec("P")
|
|
@@ -74,7 +85,6 @@ class Module(metaclass=ModuleMeta):
|
|
|
74
85
|
"""Configure the AnyDI container with providers and their dependencies."""
|
|
75
86
|
|
|
76
87
|
|
|
77
|
-
# noinspection PyShadowingNames
|
|
78
88
|
@final
|
|
79
89
|
class Container:
|
|
80
90
|
"""AnyDI is a dependency injection container."""
|
|
@@ -82,7 +92,7 @@ class Container:
|
|
|
82
92
|
def __init__(
|
|
83
93
|
self,
|
|
84
94
|
*,
|
|
85
|
-
providers: Sequence[
|
|
95
|
+
providers: Sequence[ProviderArgs] | None = None,
|
|
86
96
|
modules: Sequence[Module | type[Module] | Callable[[Container], None] | str]
|
|
87
97
|
| None = None,
|
|
88
98
|
strict: bool = False,
|
|
@@ -104,20 +114,26 @@ class Container:
|
|
|
104
114
|
)
|
|
105
115
|
self._override_instances: dict[type[Any], Any] = {}
|
|
106
116
|
self._unresolved_interfaces: set[type[Any]] = set()
|
|
107
|
-
self._inject_cache:
|
|
108
|
-
Callable[..., Any], Callable[..., Any]
|
|
109
|
-
] = WeakKeyDictionary()
|
|
117
|
+
self._inject_cache: dict[Callable[..., Any], Callable[..., Any]] = {}
|
|
110
118
|
|
|
111
119
|
# Register providers
|
|
112
120
|
providers = providers or []
|
|
113
121
|
for provider in providers:
|
|
114
|
-
self._register_provider(
|
|
122
|
+
self._register_provider(
|
|
123
|
+
provider.call,
|
|
124
|
+
provider.scope,
|
|
125
|
+
provider.interface,
|
|
126
|
+
)
|
|
115
127
|
|
|
116
128
|
# Register modules
|
|
117
129
|
modules = modules or []
|
|
118
130
|
for module in modules:
|
|
119
131
|
self.register_module(module)
|
|
120
132
|
|
|
133
|
+
############################
|
|
134
|
+
# Properties
|
|
135
|
+
############################
|
|
136
|
+
|
|
121
137
|
@property
|
|
122
138
|
def strict(self) -> bool:
|
|
123
139
|
"""Check if strict mode is enabled."""
|
|
@@ -143,9 +159,110 @@ class Container:
|
|
|
143
159
|
"""Get the logger instance."""
|
|
144
160
|
return self._logger
|
|
145
161
|
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
162
|
+
############################
|
|
163
|
+
# Lifespan/Context Methods
|
|
164
|
+
############################
|
|
165
|
+
|
|
166
|
+
def __enter__(self) -> Self:
|
|
167
|
+
"""Enter the singleton context."""
|
|
168
|
+
self.start()
|
|
169
|
+
return self
|
|
170
|
+
|
|
171
|
+
def __exit__(
|
|
172
|
+
self,
|
|
173
|
+
exc_type: type[BaseException] | None,
|
|
174
|
+
exc_val: BaseException | None,
|
|
175
|
+
exc_tb: types.TracebackType | None,
|
|
176
|
+
) -> Any:
|
|
177
|
+
"""Exit the singleton context."""
|
|
178
|
+
return self._singleton_context.__exit__(exc_type, exc_val, exc_tb)
|
|
179
|
+
|
|
180
|
+
def start(self) -> None:
|
|
181
|
+
"""Start the singleton context."""
|
|
182
|
+
# Resolve all singleton resources
|
|
183
|
+
for interface in self._resources.get("singleton", []):
|
|
184
|
+
self.resolve(interface)
|
|
185
|
+
|
|
186
|
+
def close(self) -> None:
|
|
187
|
+
"""Close the singleton context."""
|
|
188
|
+
self._singleton_context.close()
|
|
189
|
+
|
|
190
|
+
async def __aenter__(self) -> Self:
|
|
191
|
+
"""Enter the singleton context."""
|
|
192
|
+
await self.astart()
|
|
193
|
+
return self
|
|
194
|
+
|
|
195
|
+
async def __aexit__(
|
|
196
|
+
self,
|
|
197
|
+
exc_type: type[BaseException] | None,
|
|
198
|
+
exc_val: BaseException | None,
|
|
199
|
+
exc_tb: types.TracebackType | None,
|
|
200
|
+
) -> bool:
|
|
201
|
+
"""Exit the singleton context."""
|
|
202
|
+
return await self._singleton_context.__aexit__(exc_type, exc_val, exc_tb)
|
|
203
|
+
|
|
204
|
+
async def astart(self) -> None:
|
|
205
|
+
"""Start the singleton context asynchronously."""
|
|
206
|
+
for interface in self._resources.get("singleton", []):
|
|
207
|
+
await self.aresolve(interface)
|
|
208
|
+
|
|
209
|
+
async def aclose(self) -> None:
|
|
210
|
+
"""Close the singleton context asynchronously."""
|
|
211
|
+
await self._singleton_context.aclose()
|
|
212
|
+
|
|
213
|
+
@contextlib.contextmanager
|
|
214
|
+
def request_context(self) -> Iterator[InstanceContext]:
|
|
215
|
+
"""Obtain a context manager for the request-scoped context."""
|
|
216
|
+
context = InstanceContext()
|
|
217
|
+
|
|
218
|
+
token = self._request_context_var.set(context)
|
|
219
|
+
|
|
220
|
+
# Resolve all request resources
|
|
221
|
+
for interface in self._resources.get("request", []):
|
|
222
|
+
if not is_event_type(interface):
|
|
223
|
+
continue
|
|
224
|
+
self.resolve(interface)
|
|
225
|
+
|
|
226
|
+
with context:
|
|
227
|
+
yield context
|
|
228
|
+
self._request_context_var.reset(token)
|
|
229
|
+
|
|
230
|
+
@contextlib.asynccontextmanager
|
|
231
|
+
async def arequest_context(self) -> AsyncIterator[InstanceContext]:
|
|
232
|
+
"""Obtain an async context manager for the request-scoped context."""
|
|
233
|
+
context = InstanceContext()
|
|
234
|
+
|
|
235
|
+
token = self._request_context_var.set(context)
|
|
236
|
+
|
|
237
|
+
for interface in self._resources.get("request", []):
|
|
238
|
+
if not is_event_type(interface):
|
|
239
|
+
continue
|
|
240
|
+
await self.aresolve(interface)
|
|
241
|
+
|
|
242
|
+
async with context:
|
|
243
|
+
yield context
|
|
244
|
+
self._request_context_var.reset(token)
|
|
245
|
+
|
|
246
|
+
def _get_request_context(self) -> InstanceContext:
|
|
247
|
+
"""Get the current request context."""
|
|
248
|
+
request_context = self._request_context_var.get()
|
|
249
|
+
if request_context is None:
|
|
250
|
+
raise LookupError(
|
|
251
|
+
"The request context has not been started. Please ensure that "
|
|
252
|
+
"the request context is properly initialized before attempting "
|
|
253
|
+
"to use it."
|
|
254
|
+
)
|
|
255
|
+
return request_context
|
|
256
|
+
|
|
257
|
+
def _get_scoped_context(self, scope: Scope) -> InstanceContext:
|
|
258
|
+
"""Get the instance context for the specified scope."""
|
|
259
|
+
if scope == "singleton":
|
|
260
|
+
return self._singleton_context
|
|
261
|
+
return self._get_request_context()
|
|
262
|
+
|
|
263
|
+
############################
|
|
264
|
+
# Provider Methods
|
|
265
|
+
############################
|
|
149
266
|
|
|
150
267
|
def register(
|
|
151
268
|
self,
|
|
@@ -156,26 +273,11 @@ class Container:
|
|
|
156
273
|
override: bool = False,
|
|
157
274
|
) -> Provider:
|
|
158
275
|
"""Register a provider for the specified interface."""
|
|
159
|
-
|
|
160
|
-
return self._register_provider(provider, override)
|
|
276
|
+
return self._register_provider(call, scope, interface, override)
|
|
161
277
|
|
|
162
|
-
def
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
"""Register a provider."""
|
|
166
|
-
if provider.interface in self._providers:
|
|
167
|
-
if override:
|
|
168
|
-
self._set_provider(provider)
|
|
169
|
-
return provider
|
|
170
|
-
|
|
171
|
-
raise LookupError(
|
|
172
|
-
f"The provider interface `{get_full_qualname(provider.interface)}` "
|
|
173
|
-
"already registered."
|
|
174
|
-
)
|
|
175
|
-
|
|
176
|
-
self._validate_sub_providers(provider, **defaults)
|
|
177
|
-
self._set_provider(provider)
|
|
178
|
-
return provider
|
|
278
|
+
def is_registered(self, interface: AnyInterface) -> bool:
|
|
279
|
+
"""Check if a provider is registered for the specified interface."""
|
|
280
|
+
return interface in self._providers
|
|
179
281
|
|
|
180
282
|
def unregister(self, interface: AnyInterface) -> None:
|
|
181
283
|
"""Unregister a provider by interface."""
|
|
@@ -199,254 +301,233 @@ class Container:
|
|
|
199
301
|
# Cleanup provider references
|
|
200
302
|
self._delete_provider(provider)
|
|
201
303
|
|
|
202
|
-
def
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
except KeyError as exc:
|
|
207
|
-
raise LookupError(
|
|
208
|
-
f"The provider interface for `{get_full_qualname(interface)}` has "
|
|
209
|
-
"not been registered. Please ensure that the provider interface is "
|
|
210
|
-
"properly registered before attempting to use it."
|
|
211
|
-
) from exc
|
|
304
|
+
def provider(
|
|
305
|
+
self, *, scope: Scope, override: bool = False
|
|
306
|
+
) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
|
307
|
+
"""Decorator to register a provider function with the specified scope."""
|
|
212
308
|
|
|
213
|
-
|
|
214
|
-
|
|
309
|
+
def decorator(call: Callable[P, T]) -> Callable[P, T]:
|
|
310
|
+
self._register_provider(call, scope, NOT_SET, override)
|
|
311
|
+
return call
|
|
312
|
+
|
|
313
|
+
return decorator
|
|
314
|
+
|
|
315
|
+
def _register_provider( # noqa: C901
|
|
316
|
+
self,
|
|
317
|
+
call: Callable[..., Any],
|
|
318
|
+
scope: Scope | None,
|
|
319
|
+
interface: Any = NOT_SET,
|
|
320
|
+
override: bool = False,
|
|
321
|
+
/,
|
|
322
|
+
**defaults: Any,
|
|
215
323
|
) -> Provider:
|
|
216
|
-
"""
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
324
|
+
"""Register a provider with the specified scope."""
|
|
325
|
+
name = get_full_qualname(call)
|
|
326
|
+
kind = ProviderKind.from_call(call)
|
|
327
|
+
detected_scope = scope
|
|
328
|
+
|
|
329
|
+
# Validate scope if it provided
|
|
330
|
+
if scope:
|
|
331
|
+
self._validate_provider_scope(scope, name, kind)
|
|
332
|
+
|
|
333
|
+
# Get the signature
|
|
334
|
+
globalns = getattr(call, "__globals__", {})
|
|
335
|
+
module = getattr(call, "__module__", None)
|
|
336
|
+
signature = inspect.signature(call, globals=globalns)
|
|
337
|
+
|
|
338
|
+
# Detect the interface
|
|
339
|
+
if interface is NOT_SET:
|
|
340
|
+
if kind == ProviderKind.CLASS:
|
|
341
|
+
interface = call
|
|
342
|
+
else:
|
|
343
|
+
interface = signature.return_annotation
|
|
344
|
+
if interface is inspect.Signature.empty:
|
|
345
|
+
interface = None
|
|
346
|
+
else:
|
|
347
|
+
interface = get_typed_annotation(interface, globalns, module)
|
|
348
|
+
|
|
349
|
+
# If the callable is an iterator, return the actual type
|
|
350
|
+
iterator_types = {Iterator, AsyncIterator}
|
|
351
|
+
if interface in iterator_types or get_origin(interface) in iterator_types:
|
|
352
|
+
if args := get_args(interface):
|
|
353
|
+
interface = args[0]
|
|
354
|
+
# If the callable is a generator, return the resource type
|
|
355
|
+
if interface in {None, NoneType}:
|
|
356
|
+
interface = type(f"Event_{uuid.uuid4().hex}", (Event,), {})
|
|
357
|
+
else:
|
|
358
|
+
raise TypeError(
|
|
359
|
+
f"Cannot use `{name}` resource type annotation "
|
|
360
|
+
"without actual type argument."
|
|
361
|
+
)
|
|
235
362
|
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
if provider.is_resource:
|
|
240
|
-
self._resources[provider.scope].append(provider.interface)
|
|
363
|
+
# None interface is not allowed
|
|
364
|
+
if interface in {None, NoneType}:
|
|
365
|
+
raise TypeError(f"Missing `{name}` provider return annotation.")
|
|
241
366
|
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
367
|
+
# Check for existing provider
|
|
368
|
+
if interface in self._providers and not override:
|
|
369
|
+
raise LookupError(
|
|
370
|
+
f"The provider interface `{get_full_qualname(interface)}` "
|
|
371
|
+
"already registered."
|
|
372
|
+
)
|
|
248
373
|
|
|
249
|
-
|
|
250
|
-
|
|
374
|
+
unresolved_parameter = None
|
|
375
|
+
parameters = []
|
|
376
|
+
scopes = {}
|
|
251
377
|
|
|
252
|
-
for parameter in
|
|
378
|
+
for parameter in signature.parameters.values():
|
|
253
379
|
if parameter.annotation is inspect.Parameter.empty:
|
|
254
380
|
raise TypeError(
|
|
255
|
-
f"Missing provider `{
|
|
381
|
+
f"Missing provider `{name}` "
|
|
256
382
|
f"dependency `{parameter.name}` annotation."
|
|
257
383
|
)
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
)
|
|
263
|
-
except LookupError:
|
|
264
|
-
if self._parameter_has_default(parameter, **defaults):
|
|
265
|
-
continue
|
|
266
|
-
|
|
267
|
-
if provider.scope not in {"singleton", "transient"}:
|
|
268
|
-
self._unresolved_interfaces.add(provider.interface)
|
|
269
|
-
continue
|
|
270
|
-
raise ValueError(
|
|
271
|
-
f"The provider `{provider}` depends on `{parameter.name}` of type "
|
|
272
|
-
f"`{get_full_qualname(parameter.annotation)}`, which "
|
|
273
|
-
"has not been registered or set. To resolve this, ensure that "
|
|
274
|
-
f"`{parameter.name}` is registered before attempting to use it."
|
|
275
|
-
) from None
|
|
276
|
-
|
|
277
|
-
# Check scope compatibility
|
|
278
|
-
if sub_provider.scope not in ALLOWED_SCOPES.get(provider.scope, []):
|
|
279
|
-
raise ValueError(
|
|
280
|
-
f"The provider `{provider}` with a `{provider.scope}` scope cannot "
|
|
281
|
-
f"depend on `{sub_provider}` with a `{sub_provider.scope}` scope. "
|
|
282
|
-
"Please ensure all providers are registered with matching scopes."
|
|
384
|
+
if parameter.kind == inspect.Parameter.POSITIONAL_ONLY:
|
|
385
|
+
raise TypeError(
|
|
386
|
+
"Positional-only parameters "
|
|
387
|
+
f"are not allowed in the provider `{name}`."
|
|
283
388
|
)
|
|
284
389
|
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
390
|
+
parameter = parameter.replace(
|
|
391
|
+
annotation=get_typed_annotation(parameter.annotation, globalns, module)
|
|
392
|
+
)
|
|
288
393
|
|
|
289
|
-
for parameter in get_typed_parameters(call):
|
|
290
394
|
try:
|
|
291
395
|
sub_provider = self._get_or_register_provider(
|
|
292
|
-
parameter.annotation,
|
|
396
|
+
parameter.annotation, scope
|
|
293
397
|
)
|
|
294
398
|
except LookupError:
|
|
295
399
|
if self._parameter_has_default(parameter, **defaults):
|
|
296
400
|
continue
|
|
297
|
-
|
|
298
|
-
scope = sub_provider.scope
|
|
299
|
-
|
|
300
|
-
if scope == "transient":
|
|
301
|
-
return "transient"
|
|
302
|
-
scopes.add(scope)
|
|
303
|
-
|
|
304
|
-
# If all scopes are found, we can return based on priority order
|
|
305
|
-
if {"transient", "request", "singleton"}.issubset(scopes):
|
|
306
|
-
break # pragma: no cover
|
|
307
|
-
|
|
308
|
-
# Determine scope based on priority
|
|
309
|
-
if "request" in scopes:
|
|
310
|
-
return "request"
|
|
311
|
-
if "singleton" in scopes:
|
|
312
|
-
return "singleton"
|
|
313
|
-
|
|
314
|
-
return None
|
|
315
|
-
|
|
316
|
-
def _parameter_has_default(
|
|
317
|
-
self, parameter: inspect.Parameter, /, **defaults: Any
|
|
318
|
-
) -> bool:
|
|
319
|
-
return (defaults and parameter.name in defaults) or (
|
|
320
|
-
not self.strict and parameter.default is not inspect.Parameter.empty
|
|
321
|
-
)
|
|
322
|
-
|
|
323
|
-
def register_module(
|
|
324
|
-
self, module: Module | type[Module] | Callable[[Container], None] | str
|
|
325
|
-
) -> None:
|
|
326
|
-
"""Register a module as a callable, module type, or module instance."""
|
|
327
|
-
# Callable Module
|
|
328
|
-
if inspect.isfunction(module):
|
|
329
|
-
module(self)
|
|
330
|
-
return
|
|
331
|
-
|
|
332
|
-
# Module path
|
|
333
|
-
if isinstance(module, str):
|
|
334
|
-
module = import_string(module)
|
|
335
|
-
|
|
336
|
-
# Class based Module or Module type
|
|
337
|
-
if inspect.isclass(module) and issubclass(module, Module):
|
|
338
|
-
module = module()
|
|
339
|
-
|
|
340
|
-
if isinstance(module, Module):
|
|
341
|
-
module.configure(self)
|
|
342
|
-
for provider_name, decorator_args in module.providers:
|
|
343
|
-
obj = getattr(module, provider_name)
|
|
344
|
-
self.provider(
|
|
345
|
-
scope=decorator_args.scope,
|
|
346
|
-
override=decorator_args.override,
|
|
347
|
-
)(obj)
|
|
348
|
-
|
|
349
|
-
def __enter__(self) -> Self:
|
|
350
|
-
"""Enter the singleton context."""
|
|
351
|
-
self.start()
|
|
352
|
-
return self
|
|
353
|
-
|
|
354
|
-
def __exit__(
|
|
355
|
-
self,
|
|
356
|
-
exc_type: type[BaseException] | None,
|
|
357
|
-
exc_val: BaseException | None,
|
|
358
|
-
exc_tb: types.TracebackType | None,
|
|
359
|
-
) -> Any:
|
|
360
|
-
"""Exit the singleton context."""
|
|
361
|
-
return self._singleton_context.__exit__(exc_type, exc_val, exc_tb)
|
|
362
|
-
|
|
363
|
-
def start(self) -> None:
|
|
364
|
-
"""Start the singleton context."""
|
|
365
|
-
# Resolve all singleton resources
|
|
366
|
-
for interface in self._resources.get("singleton", []):
|
|
367
|
-
self.resolve(interface)
|
|
368
|
-
|
|
369
|
-
def close(self) -> None:
|
|
370
|
-
"""Close the singleton context."""
|
|
371
|
-
self._singleton_context.close()
|
|
372
|
-
|
|
373
|
-
@contextlib.contextmanager
|
|
374
|
-
def request_context(self) -> Iterator[InstanceContext]:
|
|
375
|
-
"""Obtain a context manager for the request-scoped context."""
|
|
376
|
-
context = InstanceContext()
|
|
377
|
-
|
|
378
|
-
token = self._request_context_var.set(context)
|
|
379
|
-
|
|
380
|
-
# Resolve all request resources
|
|
381
|
-
for interface in self._resources.get("request", []):
|
|
382
|
-
if not is_event_type(interface):
|
|
401
|
+
unresolved_parameter = parameter
|
|
383
402
|
continue
|
|
384
|
-
self.resolve(interface)
|
|
385
403
|
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
404
|
+
# Store first provider for each scope
|
|
405
|
+
if sub_provider.scope not in scopes:
|
|
406
|
+
scopes[sub_provider.scope] = sub_provider
|
|
389
407
|
|
|
390
|
-
|
|
391
|
-
"""Enter the singleton context."""
|
|
392
|
-
await self.astart()
|
|
393
|
-
return self
|
|
408
|
+
parameters.append(parameter)
|
|
394
409
|
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
410
|
+
# Set detected scope
|
|
411
|
+
if detected_scope is None:
|
|
412
|
+
if "transient" in scopes:
|
|
413
|
+
detected_scope = "transient"
|
|
414
|
+
elif "request" in scopes:
|
|
415
|
+
detected_scope = "request"
|
|
416
|
+
elif "singleton" in scopes:
|
|
417
|
+
detected_scope = "singleton"
|
|
418
|
+
else:
|
|
419
|
+
detected_scope = self.default_scope
|
|
403
420
|
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
await self.aresolve(interface)
|
|
421
|
+
# Validate the provider scope after detection
|
|
422
|
+
if scope is None:
|
|
423
|
+
self._validate_provider_scope(detected_scope, name, kind)
|
|
408
424
|
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
425
|
+
# Check for unresolved parameters
|
|
426
|
+
if unresolved_parameter:
|
|
427
|
+
if detected_scope not in {"singleton", "transient"}:
|
|
428
|
+
self._unresolved_interfaces.add(interface)
|
|
429
|
+
else:
|
|
430
|
+
raise LookupError(
|
|
431
|
+
f"The provider `{name}` depends on `{unresolved_parameter.name}` "
|
|
432
|
+
f"of type `{get_full_qualname(unresolved_parameter.annotation)}`, "
|
|
433
|
+
"which has not been registered or set. To resolve this, ensure "
|
|
434
|
+
f"that `{unresolved_parameter.name}` is registered before "
|
|
435
|
+
f"attempting to use it."
|
|
436
|
+
) from None
|
|
412
437
|
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
438
|
+
# Check scope compatibility
|
|
439
|
+
for sub_provider in scopes.values():
|
|
440
|
+
if sub_provider.scope not in ALLOWED_SCOPES.get(detected_scope, []):
|
|
441
|
+
raise ValueError(
|
|
442
|
+
f"The provider `{name}` with a `{detected_scope}` scope cannot "
|
|
443
|
+
f"depend on `{sub_provider}` with a `{sub_provider.scope}` scope. "
|
|
444
|
+
"Please ensure all providers are registered with matching scopes."
|
|
445
|
+
)
|
|
417
446
|
|
|
418
|
-
|
|
447
|
+
provider = Provider(
|
|
448
|
+
call=call,
|
|
449
|
+
scope=detected_scope,
|
|
450
|
+
interface=interface,
|
|
451
|
+
name=name,
|
|
452
|
+
kind=kind,
|
|
453
|
+
parameters=parameters,
|
|
454
|
+
)
|
|
419
455
|
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
continue
|
|
423
|
-
await self.aresolve(interface)
|
|
456
|
+
self._set_provider(provider)
|
|
457
|
+
return provider
|
|
424
458
|
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
459
|
+
def _validate_provider_scope(
|
|
460
|
+
self, scope: Scope, name: str, kind: ProviderKind
|
|
461
|
+
) -> None:
|
|
462
|
+
"""Validate the provider scope."""
|
|
463
|
+
if scope not in (allowed_scopes := get_args(Scope)):
|
|
464
|
+
raise ValueError(
|
|
465
|
+
f"The provider `{name}` scope is invalid. Only the following "
|
|
466
|
+
f"scopes are supported: {', '.join(allowed_scopes)}. "
|
|
467
|
+
"Please use one of the supported scopes when registering a provider."
|
|
468
|
+
)
|
|
469
|
+
if (
|
|
470
|
+
kind in {ProviderKind.GENERATOR, ProviderKind.ASYNC_GENERATOR}
|
|
471
|
+
and scope == "transient"
|
|
472
|
+
):
|
|
473
|
+
raise TypeError(
|
|
474
|
+
f"The resource provider `{name}` is attempting to register "
|
|
475
|
+
"with a transient scope, which is not allowed."
|
|
476
|
+
)
|
|
428
477
|
|
|
429
|
-
def
|
|
430
|
-
"""Get
|
|
431
|
-
|
|
432
|
-
|
|
478
|
+
def _get_provider(self, interface: AnyInterface) -> Provider:
|
|
479
|
+
"""Get provider by interface."""
|
|
480
|
+
try:
|
|
481
|
+
return self._providers[interface]
|
|
482
|
+
except KeyError as exc:
|
|
433
483
|
raise LookupError(
|
|
434
|
-
"The
|
|
435
|
-
"
|
|
436
|
-
"to use it."
|
|
437
|
-
)
|
|
438
|
-
return request_context
|
|
484
|
+
f"The provider interface for `{get_full_qualname(interface)}` has "
|
|
485
|
+
"not been registered. Please ensure that the provider interface is "
|
|
486
|
+
"properly registered before attempting to use it."
|
|
487
|
+
) from exc
|
|
439
488
|
|
|
440
|
-
def
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
489
|
+
def _get_or_register_provider(
|
|
490
|
+
self, interface: AnyInterface, parent_scope: Scope | None, /, **defaults: Any
|
|
491
|
+
) -> Provider:
|
|
492
|
+
"""Get or register a provider by interface."""
|
|
493
|
+
try:
|
|
494
|
+
return self._get_provider(interface)
|
|
495
|
+
except LookupError:
|
|
496
|
+
if self.strict or interface is inspect.Parameter.empty:
|
|
497
|
+
raise
|
|
498
|
+
if get_origin(interface) is Annotated and (args := get_args(interface)):
|
|
499
|
+
call = args[0]
|
|
500
|
+
else:
|
|
501
|
+
call = interface
|
|
502
|
+
if inspect.isclass(call) and not is_builtin_type(call):
|
|
503
|
+
# Try to get defined scope
|
|
504
|
+
scope = getattr(interface, "__scope__", parent_scope)
|
|
505
|
+
return self._register_provider(call, scope, interface, **defaults)
|
|
506
|
+
raise
|
|
507
|
+
|
|
508
|
+
def _set_provider(self, provider: Provider) -> None:
|
|
509
|
+
"""Set a provider by interface."""
|
|
510
|
+
self._providers[provider.interface] = provider
|
|
511
|
+
if provider.is_resource:
|
|
512
|
+
self._resources[provider.scope].append(provider.interface)
|
|
513
|
+
|
|
514
|
+
def _delete_provider(self, provider: Provider) -> None:
|
|
515
|
+
"""Delete a provider."""
|
|
516
|
+
if provider.interface in self._providers:
|
|
517
|
+
del self._providers[provider.interface]
|
|
518
|
+
if provider.is_resource:
|
|
519
|
+
self._resources[provider.scope].remove(provider.interface)
|
|
520
|
+
|
|
521
|
+
def _parameter_has_default(
|
|
522
|
+
self, parameter: inspect.Parameter, /, **defaults: Any
|
|
523
|
+
) -> bool:
|
|
524
|
+
return (defaults and parameter.name in defaults) or (
|
|
525
|
+
not self.strict and parameter.default is not inspect.Parameter.empty
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
############################
|
|
529
|
+
# Instance Methods
|
|
530
|
+
############################
|
|
450
531
|
|
|
451
532
|
@overload
|
|
452
533
|
def resolve(self, interface: type[T]) -> T: ...
|
|
@@ -456,19 +537,7 @@ class Container:
|
|
|
456
537
|
|
|
457
538
|
def resolve(self, interface: type[T]) -> T:
|
|
458
539
|
"""Resolve an instance by interface."""
|
|
459
|
-
|
|
460
|
-
if provider.scope == "transient":
|
|
461
|
-
instance = self._create_instance(provider, None)
|
|
462
|
-
else:
|
|
463
|
-
context = self._get_scoped_context(provider.scope)
|
|
464
|
-
if provider.scope == "singleton":
|
|
465
|
-
with self._singleton_lock:
|
|
466
|
-
instance = self._get_or_create_instance(provider, context)
|
|
467
|
-
else:
|
|
468
|
-
instance = self._get_or_create_instance(provider, context)
|
|
469
|
-
if self.testing:
|
|
470
|
-
instance = self._patch_test_resolver(provider.interface, instance)
|
|
471
|
-
return cast(T, instance)
|
|
540
|
+
return self._resolve_or_create(interface, False)
|
|
472
541
|
|
|
473
542
|
@overload
|
|
474
543
|
async def aresolve(self, interface: type[T]) -> T: ...
|
|
@@ -478,22 +547,50 @@ class Container:
|
|
|
478
547
|
|
|
479
548
|
async def aresolve(self, interface: type[T]) -> T:
|
|
480
549
|
"""Resolve an instance by interface asynchronously."""
|
|
481
|
-
|
|
482
|
-
if provider.scope == "transient":
|
|
483
|
-
instance = await self._acreate_instance(provider, None)
|
|
484
|
-
else:
|
|
485
|
-
context = self._get_scoped_context(provider.scope)
|
|
486
|
-
if provider.scope == "singleton":
|
|
487
|
-
async with self._singleton_async_lock:
|
|
488
|
-
instance = await self._aget_or_create_instance(provider, context)
|
|
489
|
-
else:
|
|
490
|
-
instance = await self._aget_or_create_instance(provider, context)
|
|
491
|
-
if self.testing:
|
|
492
|
-
instance = self._patch_test_resolver(interface, instance)
|
|
493
|
-
return cast(T, instance)
|
|
550
|
+
return await self._aresolve_or_acreate(interface, False)
|
|
494
551
|
|
|
495
|
-
def create(self, interface: type[T], **defaults: Any) -> T:
|
|
552
|
+
def create(self, interface: type[T], /, **defaults: Any) -> T:
|
|
496
553
|
"""Create an instance by interface."""
|
|
554
|
+
return self._resolve_or_create(interface, True, **defaults)
|
|
555
|
+
|
|
556
|
+
async def acreate(self, interface: type[T], /, **defaults: Any) -> T:
|
|
557
|
+
"""Create an instance by interface asynchronously."""
|
|
558
|
+
return await self._aresolve_or_acreate(interface, True, **defaults)
|
|
559
|
+
|
|
560
|
+
def is_resolved(self, interface: AnyInterface) -> bool:
|
|
561
|
+
"""Check if an instance by interface exists."""
|
|
562
|
+
try:
|
|
563
|
+
provider = self._get_provider(interface)
|
|
564
|
+
except LookupError:
|
|
565
|
+
return False
|
|
566
|
+
if provider.scope == "transient":
|
|
567
|
+
return False
|
|
568
|
+
context = self._get_scoped_context(provider.scope)
|
|
569
|
+
return interface in context
|
|
570
|
+
|
|
571
|
+
def release(self, interface: AnyInterface) -> None:
|
|
572
|
+
"""Release an instance by interface."""
|
|
573
|
+
provider = self._get_provider(interface)
|
|
574
|
+
if provider.scope == "transient":
|
|
575
|
+
return None
|
|
576
|
+
context = self._get_scoped_context(provider.scope)
|
|
577
|
+
del context[interface]
|
|
578
|
+
|
|
579
|
+
def reset(self) -> None:
|
|
580
|
+
"""Reset resolved instances."""
|
|
581
|
+
for interface, provider in self._providers.items():
|
|
582
|
+
if provider.scope == "transient":
|
|
583
|
+
continue
|
|
584
|
+
try:
|
|
585
|
+
context = self._get_scoped_context(provider.scope)
|
|
586
|
+
except LookupError:
|
|
587
|
+
continue
|
|
588
|
+
del context[interface]
|
|
589
|
+
|
|
590
|
+
def _resolve_or_create(
|
|
591
|
+
self, interface: type[T], create: bool, /, **defaults: Any
|
|
592
|
+
) -> T:
|
|
593
|
+
"""Internal method to handle instance resolution and creation."""
|
|
497
594
|
provider = self._get_or_register_provider(interface, None, **defaults)
|
|
498
595
|
if provider.scope == "transient":
|
|
499
596
|
instance = self._create_instance(provider, None, **defaults)
|
|
@@ -501,15 +598,27 @@ class Container:
|
|
|
501
598
|
context = self._get_scoped_context(provider.scope)
|
|
502
599
|
if provider.scope == "singleton":
|
|
503
600
|
with self._singleton_lock:
|
|
504
|
-
instance =
|
|
601
|
+
instance = (
|
|
602
|
+
self._get_or_create_instance(provider, context)
|
|
603
|
+
if not create
|
|
604
|
+
else self._create_instance(provider, context, **defaults)
|
|
605
|
+
)
|
|
505
606
|
else:
|
|
506
|
-
instance =
|
|
607
|
+
instance = (
|
|
608
|
+
self._get_or_create_instance(provider, context)
|
|
609
|
+
if not create
|
|
610
|
+
else self._create_instance(provider, context, **defaults)
|
|
611
|
+
)
|
|
612
|
+
|
|
507
613
|
if self.testing:
|
|
508
614
|
instance = self._patch_test_resolver(provider.interface, instance)
|
|
615
|
+
|
|
509
616
|
return cast(T, instance)
|
|
510
617
|
|
|
511
|
-
async def
|
|
512
|
-
|
|
618
|
+
async def _aresolve_or_acreate(
|
|
619
|
+
self, interface: type[T], create: bool, /, **defaults: Any
|
|
620
|
+
) -> T:
|
|
621
|
+
"""Internal method to handle instance resolution and creation asynchronously."""
|
|
513
622
|
provider = self._get_or_register_provider(interface, None, **defaults)
|
|
514
623
|
if provider.scope == "transient":
|
|
515
624
|
instance = await self._acreate_instance(provider, None, **defaults)
|
|
@@ -517,13 +626,21 @@ class Container:
|
|
|
517
626
|
context = self._get_scoped_context(provider.scope)
|
|
518
627
|
if provider.scope == "singleton":
|
|
519
628
|
async with self._singleton_async_lock:
|
|
520
|
-
instance =
|
|
521
|
-
provider, context
|
|
629
|
+
instance = (
|
|
630
|
+
await self._aget_or_create_instance(provider, context)
|
|
631
|
+
if not create
|
|
632
|
+
else await self._acreate_instance(provider, context, **defaults)
|
|
522
633
|
)
|
|
523
634
|
else:
|
|
524
|
-
instance =
|
|
635
|
+
instance = (
|
|
636
|
+
await self._aget_or_create_instance(provider, context)
|
|
637
|
+
if not create
|
|
638
|
+
else await self._acreate_instance(provider, context, **defaults)
|
|
639
|
+
)
|
|
640
|
+
|
|
525
641
|
if self.testing:
|
|
526
642
|
instance = self._patch_test_resolver(provider.interface, instance)
|
|
643
|
+
|
|
527
644
|
return cast(T, instance)
|
|
528
645
|
|
|
529
646
|
def _get_or_create_instance(
|
|
@@ -700,26 +817,48 @@ class Container:
|
|
|
700
817
|
def _resolve_parameter(
|
|
701
818
|
self, provider: Provider, parameter: inspect.Parameter
|
|
702
819
|
) -> Any:
|
|
703
|
-
self._validate_resolvable_parameter(
|
|
820
|
+
self._validate_resolvable_parameter(provider, parameter)
|
|
704
821
|
return self.resolve(parameter.annotation)
|
|
705
822
|
|
|
706
823
|
async def _aresolve_parameter(
|
|
707
824
|
self, provider: Provider, parameter: inspect.Parameter
|
|
708
825
|
) -> Any:
|
|
709
|
-
self._validate_resolvable_parameter(
|
|
826
|
+
self._validate_resolvable_parameter(provider, parameter)
|
|
710
827
|
return await self.aresolve(parameter.annotation)
|
|
711
828
|
|
|
712
829
|
def _validate_resolvable_parameter(
|
|
713
|
-
self, parameter: inspect.Parameter
|
|
830
|
+
self, provider: Provider, parameter: inspect.Parameter
|
|
714
831
|
) -> None:
|
|
715
832
|
"""Ensure that the specified interface is resolved."""
|
|
716
833
|
if parameter.annotation in self._unresolved_interfaces:
|
|
717
834
|
raise LookupError(
|
|
718
835
|
f"You are attempting to get the parameter `{parameter.name}` with the "
|
|
719
836
|
f"annotation `{get_full_qualname(parameter.annotation)}` as a "
|
|
720
|
-
f"dependency into `{get_full_qualname(call)}` which is
|
|
721
|
-
"or set in the scoped context."
|
|
837
|
+
f"dependency into `{get_full_qualname(provider.call)}` which is "
|
|
838
|
+
"not registered or set in the scoped context."
|
|
839
|
+
)
|
|
840
|
+
|
|
841
|
+
@contextlib.contextmanager
|
|
842
|
+
def override(self, interface: AnyInterface, instance: Any) -> Iterator[None]:
|
|
843
|
+
"""
|
|
844
|
+
Override the provider for the specified interface with a specific instance.
|
|
845
|
+
"""
|
|
846
|
+
if not self.testing:
|
|
847
|
+
raise RuntimeError(
|
|
848
|
+
"The `override` method can only be used in testing mode."
|
|
849
|
+
)
|
|
850
|
+
if not self.is_registered(interface) and self.strict:
|
|
851
|
+
raise LookupError(
|
|
852
|
+
f"The provider interface `{get_full_qualname(interface)}` "
|
|
853
|
+
"not registered."
|
|
722
854
|
)
|
|
855
|
+
self._override_instances[interface] = instance
|
|
856
|
+
yield
|
|
857
|
+
del self._override_instances[interface]
|
|
858
|
+
|
|
859
|
+
############################
|
|
860
|
+
# Testing Methods
|
|
861
|
+
############################
|
|
723
862
|
|
|
724
863
|
def _patch_test_resolver(self, interface: type[Any], instance: Any) -> Any:
|
|
725
864
|
"""Patch the test resolver for the instance."""
|
|
@@ -769,60 +908,9 @@ class Container:
|
|
|
769
908
|
|
|
770
909
|
return instance
|
|
771
910
|
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
provider = self._get_provider(interface)
|
|
776
|
-
except LookupError:
|
|
777
|
-
return False
|
|
778
|
-
if provider.scope == "transient":
|
|
779
|
-
return False
|
|
780
|
-
context = self._get_scoped_context(provider.scope)
|
|
781
|
-
return interface in context
|
|
782
|
-
|
|
783
|
-
def release(self, interface: AnyInterface) -> None:
|
|
784
|
-
"""Release an instance by interface."""
|
|
785
|
-
provider = self._get_provider(interface)
|
|
786
|
-
if provider.scope == "transient":
|
|
787
|
-
return None
|
|
788
|
-
context = self._get_scoped_context(provider.scope)
|
|
789
|
-
del context[interface]
|
|
790
|
-
|
|
791
|
-
def _get_scoped_context(self, scope: Scope) -> InstanceContext:
|
|
792
|
-
"""Get the instance context for the specified scope."""
|
|
793
|
-
if scope == "singleton":
|
|
794
|
-
return self._singleton_context
|
|
795
|
-
return self._get_request_context()
|
|
796
|
-
|
|
797
|
-
@contextlib.contextmanager
|
|
798
|
-
def override(self, interface: AnyInterface, instance: Any) -> Iterator[None]:
|
|
799
|
-
"""
|
|
800
|
-
Override the provider for the specified interface with a specific instance.
|
|
801
|
-
"""
|
|
802
|
-
if not self.testing:
|
|
803
|
-
raise RuntimeError(
|
|
804
|
-
"The `override` method can only be used in testing mode."
|
|
805
|
-
)
|
|
806
|
-
if not self.is_registered(interface) and self.strict:
|
|
807
|
-
raise LookupError(
|
|
808
|
-
f"The provider interface `{get_full_qualname(interface)}` "
|
|
809
|
-
"not registered."
|
|
810
|
-
)
|
|
811
|
-
self._override_instances[interface] = instance
|
|
812
|
-
yield
|
|
813
|
-
del self._override_instances[interface]
|
|
814
|
-
|
|
815
|
-
def provider(
|
|
816
|
-
self, *, scope: Scope, override: bool = False
|
|
817
|
-
) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
|
818
|
-
"""Decorator to register a provider function with the specified scope."""
|
|
819
|
-
|
|
820
|
-
def decorator(call: Callable[P, T]) -> Callable[P, T]:
|
|
821
|
-
provider = Provider(call=call, scope=scope)
|
|
822
|
-
self._register_provider(provider, override)
|
|
823
|
-
return call
|
|
824
|
-
|
|
825
|
-
return decorator
|
|
911
|
+
############################
|
|
912
|
+
# Injector Methods
|
|
913
|
+
############################
|
|
826
914
|
|
|
827
915
|
@overload
|
|
828
916
|
def inject(self, func: Callable[P, T]) -> Callable[P, T]: ...
|
|
@@ -842,6 +930,10 @@ class Container:
|
|
|
842
930
|
return decorator
|
|
843
931
|
return decorator(func)
|
|
844
932
|
|
|
933
|
+
def run(self, func: Callable[P, T], /, *args: P.args, **kwargs: P.kwargs) -> T:
|
|
934
|
+
"""Run the given function with injected dependencies."""
|
|
935
|
+
return self._inject(func)(*args, **kwargs)
|
|
936
|
+
|
|
845
937
|
def _inject(self, call: Callable[P, T]) -> Callable[P, T]:
|
|
846
938
|
"""Inject dependencies into a callable."""
|
|
847
939
|
if call in self._inject_cache:
|
|
@@ -909,9 +1001,39 @@ class Container:
|
|
|
909
1001
|
f"`{get_full_qualname(parameter.annotation)}`."
|
|
910
1002
|
)
|
|
911
1003
|
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
1004
|
+
############################
|
|
1005
|
+
# Module Methods
|
|
1006
|
+
############################
|
|
1007
|
+
|
|
1008
|
+
def register_module(
|
|
1009
|
+
self, module: Module | type[Module] | Callable[[Container], None] | str
|
|
1010
|
+
) -> None:
|
|
1011
|
+
"""Register a module as a callable, module type, or module instance."""
|
|
1012
|
+
# Callable Module
|
|
1013
|
+
if inspect.isfunction(module):
|
|
1014
|
+
module(self)
|
|
1015
|
+
return
|
|
1016
|
+
|
|
1017
|
+
# Module path
|
|
1018
|
+
if isinstance(module, str):
|
|
1019
|
+
module = import_string(module)
|
|
1020
|
+
|
|
1021
|
+
# Class based Module or Module type
|
|
1022
|
+
if inspect.isclass(module) and issubclass(module, Module):
|
|
1023
|
+
module = module()
|
|
1024
|
+
|
|
1025
|
+
if isinstance(module, Module):
|
|
1026
|
+
module.configure(self)
|
|
1027
|
+
for provider_name, decorator_args in module.providers:
|
|
1028
|
+
obj = getattr(module, provider_name)
|
|
1029
|
+
self.provider(
|
|
1030
|
+
scope=decorator_args.scope,
|
|
1031
|
+
override=decorator_args.override,
|
|
1032
|
+
)(obj)
|
|
1033
|
+
|
|
1034
|
+
############################
|
|
1035
|
+
# Scanner Methods
|
|
1036
|
+
############################
|
|
915
1037
|
|
|
916
1038
|
def scan(
|
|
917
1039
|
self,
|
|
@@ -921,7 +1043,7 @@ class Container:
|
|
|
921
1043
|
tags: Iterable[str] | None = None,
|
|
922
1044
|
) -> None:
|
|
923
1045
|
"""Scan packages or modules for decorated members and inject dependencies."""
|
|
924
|
-
dependencies: list[
|
|
1046
|
+
dependencies: list[ScannedDependency] = []
|
|
925
1047
|
|
|
926
1048
|
if isinstance(packages, Iterable) and not isinstance(packages, str):
|
|
927
1049
|
scan_packages: Iterable[ModuleType | str] = packages
|
|
@@ -940,7 +1062,7 @@ class Container:
|
|
|
940
1062
|
package: ModuleType | str,
|
|
941
1063
|
*,
|
|
942
1064
|
tags: Iterable[str] | None = None,
|
|
943
|
-
) -> list[
|
|
1065
|
+
) -> list[ScannedDependency]:
|
|
944
1066
|
"""Scan a package or module for decorated members."""
|
|
945
1067
|
tags = tags or []
|
|
946
1068
|
if isinstance(package, str):
|
|
@@ -951,7 +1073,7 @@ class Container:
|
|
|
951
1073
|
if not package_path:
|
|
952
1074
|
return self._scan_module(package, tags=tags)
|
|
953
1075
|
|
|
954
|
-
dependencies: list[
|
|
1076
|
+
dependencies: list[ScannedDependency] = []
|
|
955
1077
|
|
|
956
1078
|
for module_info in pkgutil.walk_packages(
|
|
957
1079
|
path=package_path, prefix=package.__name__ + "."
|
|
@@ -963,9 +1085,9 @@ class Container:
|
|
|
963
1085
|
|
|
964
1086
|
def _scan_module(
|
|
965
1087
|
self, module: ModuleType, *, tags: Iterable[str]
|
|
966
|
-
) -> list[
|
|
1088
|
+
) -> list[ScannedDependency]:
|
|
967
1089
|
"""Scan a module for decorated members."""
|
|
968
|
-
dependencies: list[
|
|
1090
|
+
dependencies: list[ScannedDependency] = []
|
|
969
1091
|
|
|
970
1092
|
for _, member in inspect.getmembers(module):
|
|
971
1093
|
if getattr(member, "__module__", None) != module.__name__ or not callable(
|
|
@@ -988,7 +1110,7 @@ class Container:
|
|
|
988
1110
|
|
|
989
1111
|
if decorator_args.wrapped:
|
|
990
1112
|
dependencies.append(
|
|
991
|
-
self.
|
|
1113
|
+
self._create_scanned_dependency(member=member, module=module)
|
|
992
1114
|
)
|
|
993
1115
|
continue
|
|
994
1116
|
|
|
@@ -996,17 +1118,24 @@ class Container:
|
|
|
996
1118
|
for parameter in get_typed_parameters(member):
|
|
997
1119
|
if is_marker(parameter.default):
|
|
998
1120
|
dependencies.append(
|
|
999
|
-
self.
|
|
1121
|
+
self._create_scanned_dependency(member=member, module=module)
|
|
1000
1122
|
)
|
|
1001
1123
|
continue
|
|
1002
1124
|
|
|
1003
1125
|
return dependencies
|
|
1004
1126
|
|
|
1005
|
-
def
|
|
1127
|
+
def _create_scanned_dependency(
|
|
1128
|
+
self, member: Any, module: ModuleType
|
|
1129
|
+
) -> ScannedDependency:
|
|
1006
1130
|
"""Create a `Dependency` object from the scanned member and module."""
|
|
1007
1131
|
if hasattr(member, "__wrapped__"):
|
|
1008
1132
|
member = member.__wrapped__
|
|
1009
|
-
return
|
|
1133
|
+
return ScannedDependency(member=member, module=module)
|
|
1134
|
+
|
|
1135
|
+
|
|
1136
|
+
############################
|
|
1137
|
+
# Decorators
|
|
1138
|
+
############################
|
|
1010
1139
|
|
|
1011
1140
|
|
|
1012
1141
|
def transient(target: T) -> T:
|