anydi 0.38.1__py3-none-any.whl → 0.38.2rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- anydi/__init__.py +1 -2
- anydi/_container.py +411 -271
- anydi/_types.py +80 -2
- anydi/ext/pytest_plugin.py +2 -6
- {anydi-0.38.1.dist-info → anydi-0.38.2rc0.dist-info}/METADATA +1 -1
- {anydi-0.38.1.dist-info → anydi-0.38.2rc0.dist-info}/RECORD +9 -10
- anydi/_provider.py +0 -232
- {anydi-0.38.1.dist-info → anydi-0.38.2rc0.dist-info}/WHEEL +0 -0
- {anydi-0.38.1.dist-info → anydi-0.38.2rc0.dist-info}/entry_points.txt +0 -0
- {anydi-0.38.1.dist-info → anydi-0.38.2rc0.dist-info}/licenses/LICENSE +0 -0
anydi/__init__.py
CHANGED
|
@@ -11,8 +11,7 @@ from ._container import (
|
|
|
11
11
|
singleton,
|
|
12
12
|
transient,
|
|
13
13
|
)
|
|
14
|
-
from .
|
|
15
|
-
from ._types import Marker, Scope
|
|
14
|
+
from ._types import Marker, ProviderArgs as Provider, Scope
|
|
16
15
|
|
|
17
16
|
# Alias for dependency auto marker
|
|
18
17
|
auto = cast(Any, Marker())
|
anydi/_container.py
CHANGED
|
@@ -10,23 +10,27 @@ import logging
|
|
|
10
10
|
import pkgutil
|
|
11
11
|
import threading
|
|
12
12
|
import types
|
|
13
|
+
import uuid
|
|
13
14
|
from collections import defaultdict
|
|
14
15
|
from collections.abc import AsyncIterator, Iterable, Iterator, Sequence
|
|
15
16
|
from contextvars import ContextVar
|
|
16
17
|
from types import ModuleType
|
|
17
|
-
from typing import Any, Callable, TypeVar, Union, cast, overload
|
|
18
|
-
from weakref import WeakKeyDictionary
|
|
18
|
+
from typing import Any, Callable, TypeVar, Union, cast, get_origin, overload
|
|
19
19
|
|
|
20
|
-
from typing_extensions import Concatenate, ParamSpec, Self, final
|
|
20
|
+
from typing_extensions import Concatenate, ParamSpec, Self, final, get_args
|
|
21
21
|
|
|
22
22
|
from ._context import InstanceContext
|
|
23
|
-
from ._provider import Provider
|
|
24
23
|
from ._types import (
|
|
24
|
+
NOT_SET,
|
|
25
25
|
AnyInterface,
|
|
26
|
-
|
|
26
|
+
Event,
|
|
27
27
|
InjectableDecoratorArgs,
|
|
28
28
|
InstanceProxy,
|
|
29
|
+
Provider,
|
|
30
|
+
ProviderArgs,
|
|
29
31
|
ProviderDecoratorArgs,
|
|
32
|
+
ProviderKind,
|
|
33
|
+
ScannedDependency,
|
|
30
34
|
Scope,
|
|
31
35
|
is_event_type,
|
|
32
36
|
is_marker,
|
|
@@ -34,6 +38,7 @@ from ._types import (
|
|
|
34
38
|
from ._utils import (
|
|
35
39
|
AsyncRLock,
|
|
36
40
|
get_full_qualname,
|
|
41
|
+
get_typed_annotation,
|
|
37
42
|
get_typed_parameters,
|
|
38
43
|
import_string,
|
|
39
44
|
is_async_context_manager,
|
|
@@ -42,6 +47,12 @@ from ._utils import (
|
|
|
42
47
|
run_async,
|
|
43
48
|
)
|
|
44
49
|
|
|
50
|
+
try:
|
|
51
|
+
from types import NoneType
|
|
52
|
+
except ImportError:
|
|
53
|
+
NoneType = type(None) # type: ignore[misc]
|
|
54
|
+
|
|
55
|
+
|
|
45
56
|
T = TypeVar("T", bound=Any)
|
|
46
57
|
M = TypeVar("M", bound="Module")
|
|
47
58
|
P = ParamSpec("P")
|
|
@@ -74,7 +85,6 @@ class Module(metaclass=ModuleMeta):
|
|
|
74
85
|
"""Configure the AnyDI container with providers and their dependencies."""
|
|
75
86
|
|
|
76
87
|
|
|
77
|
-
# noinspection PyShadowingNames
|
|
78
88
|
@final
|
|
79
89
|
class Container:
|
|
80
90
|
"""AnyDI is a dependency injection container."""
|
|
@@ -82,7 +92,7 @@ class Container:
|
|
|
82
92
|
def __init__(
|
|
83
93
|
self,
|
|
84
94
|
*,
|
|
85
|
-
providers: Sequence[
|
|
95
|
+
providers: Sequence[ProviderArgs] | None = None,
|
|
86
96
|
modules: Sequence[Module | type[Module] | Callable[[Container], None] | str]
|
|
87
97
|
| None = None,
|
|
88
98
|
strict: bool = False,
|
|
@@ -104,20 +114,27 @@ class Container:
|
|
|
104
114
|
)
|
|
105
115
|
self._override_instances: dict[type[Any], Any] = {}
|
|
106
116
|
self._unresolved_interfaces: set[type[Any]] = set()
|
|
107
|
-
self._inject_cache:
|
|
108
|
-
Callable[..., Any], Callable[..., Any]
|
|
109
|
-
] = WeakKeyDictionary()
|
|
117
|
+
self._inject_cache: dict[Callable[..., Any], Callable[..., Any]] = {}
|
|
110
118
|
|
|
111
119
|
# Register providers
|
|
112
120
|
providers = providers or []
|
|
113
121
|
for provider in providers:
|
|
114
|
-
self.
|
|
122
|
+
_provider = self._create_provider(
|
|
123
|
+
call=provider.call,
|
|
124
|
+
scope=provider.scope,
|
|
125
|
+
interface=provider.interface,
|
|
126
|
+
)
|
|
127
|
+
self._register_provider(_provider, False)
|
|
115
128
|
|
|
116
129
|
# Register modules
|
|
117
130
|
modules = modules or []
|
|
118
131
|
for module in modules:
|
|
119
132
|
self.register_module(module)
|
|
120
133
|
|
|
134
|
+
############################
|
|
135
|
+
# Properties
|
|
136
|
+
############################
|
|
137
|
+
|
|
121
138
|
@property
|
|
122
139
|
def strict(self) -> bool:
|
|
123
140
|
"""Check if strict mode is enabled."""
|
|
@@ -143,9 +160,110 @@ class Container:
|
|
|
143
160
|
"""Get the logger instance."""
|
|
144
161
|
return self._logger
|
|
145
162
|
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
163
|
+
############################
|
|
164
|
+
# Lifespan/Context Methods
|
|
165
|
+
############################
|
|
166
|
+
|
|
167
|
+
def __enter__(self) -> Self:
|
|
168
|
+
"""Enter the singleton context."""
|
|
169
|
+
self.start()
|
|
170
|
+
return self
|
|
171
|
+
|
|
172
|
+
def __exit__(
|
|
173
|
+
self,
|
|
174
|
+
exc_type: type[BaseException] | None,
|
|
175
|
+
exc_val: BaseException | None,
|
|
176
|
+
exc_tb: types.TracebackType | None,
|
|
177
|
+
) -> Any:
|
|
178
|
+
"""Exit the singleton context."""
|
|
179
|
+
return self._singleton_context.__exit__(exc_type, exc_val, exc_tb)
|
|
180
|
+
|
|
181
|
+
def start(self) -> None:
|
|
182
|
+
"""Start the singleton context."""
|
|
183
|
+
# Resolve all singleton resources
|
|
184
|
+
for interface in self._resources.get("singleton", []):
|
|
185
|
+
self.resolve(interface)
|
|
186
|
+
|
|
187
|
+
def close(self) -> None:
|
|
188
|
+
"""Close the singleton context."""
|
|
189
|
+
self._singleton_context.close()
|
|
190
|
+
|
|
191
|
+
async def __aenter__(self) -> Self:
|
|
192
|
+
"""Enter the singleton context."""
|
|
193
|
+
await self.astart()
|
|
194
|
+
return self
|
|
195
|
+
|
|
196
|
+
async def __aexit__(
|
|
197
|
+
self,
|
|
198
|
+
exc_type: type[BaseException] | None,
|
|
199
|
+
exc_val: BaseException | None,
|
|
200
|
+
exc_tb: types.TracebackType | None,
|
|
201
|
+
) -> bool:
|
|
202
|
+
"""Exit the singleton context."""
|
|
203
|
+
return await self._singleton_context.__aexit__(exc_type, exc_val, exc_tb)
|
|
204
|
+
|
|
205
|
+
async def astart(self) -> None:
|
|
206
|
+
"""Start the singleton context asynchronously."""
|
|
207
|
+
for interface in self._resources.get("singleton", []):
|
|
208
|
+
await self.aresolve(interface)
|
|
209
|
+
|
|
210
|
+
async def aclose(self) -> None:
|
|
211
|
+
"""Close the singleton context asynchronously."""
|
|
212
|
+
await self._singleton_context.aclose()
|
|
213
|
+
|
|
214
|
+
@contextlib.contextmanager
|
|
215
|
+
def request_context(self) -> Iterator[InstanceContext]:
|
|
216
|
+
"""Obtain a context manager for the request-scoped context."""
|
|
217
|
+
context = InstanceContext()
|
|
218
|
+
|
|
219
|
+
token = self._request_context_var.set(context)
|
|
220
|
+
|
|
221
|
+
# Resolve all request resources
|
|
222
|
+
for interface in self._resources.get("request", []):
|
|
223
|
+
if not is_event_type(interface):
|
|
224
|
+
continue
|
|
225
|
+
self.resolve(interface)
|
|
226
|
+
|
|
227
|
+
with context:
|
|
228
|
+
yield context
|
|
229
|
+
self._request_context_var.reset(token)
|
|
230
|
+
|
|
231
|
+
@contextlib.asynccontextmanager
|
|
232
|
+
async def arequest_context(self) -> AsyncIterator[InstanceContext]:
|
|
233
|
+
"""Obtain an async context manager for the request-scoped context."""
|
|
234
|
+
context = InstanceContext()
|
|
235
|
+
|
|
236
|
+
token = self._request_context_var.set(context)
|
|
237
|
+
|
|
238
|
+
for interface in self._resources.get("request", []):
|
|
239
|
+
if not is_event_type(interface):
|
|
240
|
+
continue
|
|
241
|
+
await self.aresolve(interface)
|
|
242
|
+
|
|
243
|
+
async with context:
|
|
244
|
+
yield context
|
|
245
|
+
self._request_context_var.reset(token)
|
|
246
|
+
|
|
247
|
+
def _get_request_context(self) -> InstanceContext:
|
|
248
|
+
"""Get the current request context."""
|
|
249
|
+
request_context = self._request_context_var.get()
|
|
250
|
+
if request_context is None:
|
|
251
|
+
raise LookupError(
|
|
252
|
+
"The request context has not been started. Please ensure that "
|
|
253
|
+
"the request context is properly initialized before attempting "
|
|
254
|
+
"to use it."
|
|
255
|
+
)
|
|
256
|
+
return request_context
|
|
257
|
+
|
|
258
|
+
def _get_scoped_context(self, scope: Scope) -> InstanceContext:
|
|
259
|
+
"""Get the instance context for the specified scope."""
|
|
260
|
+
if scope == "singleton":
|
|
261
|
+
return self._singleton_context
|
|
262
|
+
return self._get_request_context()
|
|
263
|
+
|
|
264
|
+
############################
|
|
265
|
+
# Provider Methods
|
|
266
|
+
############################
|
|
149
267
|
|
|
150
268
|
def register(
|
|
151
269
|
self,
|
|
@@ -156,26 +274,12 @@ class Container:
|
|
|
156
274
|
override: bool = False,
|
|
157
275
|
) -> Provider:
|
|
158
276
|
"""Register a provider for the specified interface."""
|
|
159
|
-
provider =
|
|
277
|
+
provider = self._create_provider(call=call, scope=scope, interface=interface)
|
|
160
278
|
return self._register_provider(provider, override)
|
|
161
279
|
|
|
162
|
-
def
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
"""Register a provider."""
|
|
166
|
-
if provider.interface in self._providers:
|
|
167
|
-
if override:
|
|
168
|
-
self._set_provider(provider)
|
|
169
|
-
return provider
|
|
170
|
-
|
|
171
|
-
raise LookupError(
|
|
172
|
-
f"The provider interface `{get_full_qualname(provider.interface)}` "
|
|
173
|
-
"already registered."
|
|
174
|
-
)
|
|
175
|
-
|
|
176
|
-
self._validate_sub_providers(provider, **defaults)
|
|
177
|
-
self._set_provider(provider)
|
|
178
|
-
return provider
|
|
280
|
+
def is_registered(self, interface: AnyInterface) -> bool:
|
|
281
|
+
"""Check if a provider is registered for the specified interface."""
|
|
282
|
+
return interface in self._providers
|
|
179
283
|
|
|
180
284
|
def unregister(self, interface: AnyInterface) -> None:
|
|
181
285
|
"""Unregister a provider by interface."""
|
|
@@ -199,6 +303,120 @@ class Container:
|
|
|
199
303
|
# Cleanup provider references
|
|
200
304
|
self._delete_provider(provider)
|
|
201
305
|
|
|
306
|
+
def provider(
|
|
307
|
+
self, *, scope: Scope, override: bool = False
|
|
308
|
+
) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
|
309
|
+
"""Decorator to register a provider function with the specified scope."""
|
|
310
|
+
|
|
311
|
+
def decorator(call: Callable[P, T]) -> Callable[P, T]:
|
|
312
|
+
provider = self._create_provider(call=call, scope=scope)
|
|
313
|
+
self._register_provider(provider, override)
|
|
314
|
+
return call
|
|
315
|
+
|
|
316
|
+
return decorator
|
|
317
|
+
|
|
318
|
+
def _create_provider( # noqa: C901
|
|
319
|
+
self, call: Callable[..., Any], *, scope: Scope, interface: Any = NOT_SET
|
|
320
|
+
) -> Provider:
|
|
321
|
+
name = get_full_qualname(call)
|
|
322
|
+
|
|
323
|
+
# Detect the kind of callable provider
|
|
324
|
+
kind = ProviderKind.from_call(call)
|
|
325
|
+
|
|
326
|
+
# Validate the scope of the provider
|
|
327
|
+
if scope not in get_args(Scope):
|
|
328
|
+
raise ValueError(
|
|
329
|
+
f"The provider `{name}` scope is invalid. Only the following "
|
|
330
|
+
f"scopes are supported: {', '.join(get_args(Scope))}. "
|
|
331
|
+
"Please use one of the supported scopes when registering a provider."
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
# Validate the scope of the provider
|
|
335
|
+
if (
|
|
336
|
+
kind in {ProviderKind.GENERATOR, ProviderKind.ASYNC_GENERATOR}
|
|
337
|
+
and scope == "transient"
|
|
338
|
+
):
|
|
339
|
+
raise TypeError(
|
|
340
|
+
f"The resource provider `{name}` is attempting to register "
|
|
341
|
+
"with a transient scope, which is not allowed."
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
# Get the signature
|
|
345
|
+
globalns = getattr(call, "__globals__", {})
|
|
346
|
+
signature = inspect.signature(call, globals=globalns)
|
|
347
|
+
|
|
348
|
+
# Detect the interface
|
|
349
|
+
if kind == ProviderKind.CLASS:
|
|
350
|
+
interface = call
|
|
351
|
+
else:
|
|
352
|
+
if interface is NOT_SET:
|
|
353
|
+
interface = signature.return_annotation
|
|
354
|
+
if interface is inspect.Signature.empty:
|
|
355
|
+
interface = None
|
|
356
|
+
else:
|
|
357
|
+
interface = get_typed_annotation(interface, globalns)
|
|
358
|
+
|
|
359
|
+
# If the callable is an iterator, return the actual type
|
|
360
|
+
iterator_types = {Iterator, AsyncIterator}
|
|
361
|
+
if interface in iterator_types or get_origin(interface) in iterator_types:
|
|
362
|
+
if args := get_args(interface):
|
|
363
|
+
interface = args[0]
|
|
364
|
+
# If the callable is a generator, return the resource type
|
|
365
|
+
if interface in {None, NoneType}:
|
|
366
|
+
interface = type(f"Event_{uuid.uuid4().hex}", (Event,), {})
|
|
367
|
+
else:
|
|
368
|
+
raise TypeError(
|
|
369
|
+
f"Cannot use `{name}` resource type annotation "
|
|
370
|
+
"without actual type argument."
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
# None interface is not allowed
|
|
374
|
+
if interface in {None, NoneType}:
|
|
375
|
+
raise TypeError(f"Missing `{name}` provider return annotation.")
|
|
376
|
+
|
|
377
|
+
# Detect the parameters
|
|
378
|
+
parameters = []
|
|
379
|
+
for parameter in signature.parameters.values():
|
|
380
|
+
if parameter.annotation is inspect.Parameter.empty:
|
|
381
|
+
raise TypeError(
|
|
382
|
+
f"Missing provider `{name}` "
|
|
383
|
+
f"dependency `{parameter.name}` annotation."
|
|
384
|
+
)
|
|
385
|
+
if parameter.kind == inspect.Parameter.POSITIONAL_ONLY:
|
|
386
|
+
raise TypeError(
|
|
387
|
+
"Positional-only parameters "
|
|
388
|
+
f"are not allowed in the provider `{name}`."
|
|
389
|
+
)
|
|
390
|
+
annotation = get_typed_annotation(parameter.annotation, globalns)
|
|
391
|
+
parameters.append(parameter.replace(annotation=annotation))
|
|
392
|
+
|
|
393
|
+
return Provider(
|
|
394
|
+
call=call,
|
|
395
|
+
scope=scope,
|
|
396
|
+
interface=interface,
|
|
397
|
+
name=name,
|
|
398
|
+
kind=kind,
|
|
399
|
+
parameters=parameters,
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
def _register_provider(
|
|
403
|
+
self, provider: Provider, override: bool, /, **defaults: Any
|
|
404
|
+
) -> Provider:
|
|
405
|
+
"""Register a provider."""
|
|
406
|
+
if provider.interface in self._providers:
|
|
407
|
+
if override:
|
|
408
|
+
self._set_provider(provider)
|
|
409
|
+
return provider
|
|
410
|
+
|
|
411
|
+
raise LookupError(
|
|
412
|
+
f"The provider interface `{get_full_qualname(provider.interface)}` "
|
|
413
|
+
"already registered."
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
self._validate_sub_providers(provider, **defaults)
|
|
417
|
+
self._set_provider(provider)
|
|
418
|
+
return provider
|
|
419
|
+
|
|
202
420
|
def _get_provider(self, interface: AnyInterface) -> Provider:
|
|
203
421
|
"""Get provider by interface."""
|
|
204
422
|
try:
|
|
@@ -227,9 +445,11 @@ class Container:
|
|
|
227
445
|
scope = getattr(interface, "__scope__", parent_scope)
|
|
228
446
|
# Try to detect scope
|
|
229
447
|
if scope is None:
|
|
230
|
-
scope = self.
|
|
448
|
+
scope = self._detect_provider_scope(interface, **defaults)
|
|
231
449
|
scope = scope or self.default_scope
|
|
232
|
-
provider =
|
|
450
|
+
provider = self._create_provider(
|
|
451
|
+
call=interface, scope=scope, interface=interface
|
|
452
|
+
)
|
|
233
453
|
return self._register_provider(provider, False, **defaults)
|
|
234
454
|
raise
|
|
235
455
|
|
|
@@ -250,12 +470,6 @@ class Container:
|
|
|
250
470
|
"""Validate the sub-providers of a provider."""
|
|
251
471
|
|
|
252
472
|
for parameter in provider.parameters:
|
|
253
|
-
if parameter.annotation is inspect.Parameter.empty:
|
|
254
|
-
raise TypeError(
|
|
255
|
-
f"Missing provider `{provider}` "
|
|
256
|
-
f"dependency `{parameter.name}` annotation."
|
|
257
|
-
)
|
|
258
|
-
|
|
259
473
|
try:
|
|
260
474
|
sub_provider = self._get_or_register_provider(
|
|
261
475
|
parameter.annotation, provider.scope
|
|
@@ -282,7 +496,9 @@ class Container:
|
|
|
282
496
|
"Please ensure all providers are registered with matching scopes."
|
|
283
497
|
)
|
|
284
498
|
|
|
285
|
-
def
|
|
499
|
+
def _detect_provider_scope(
|
|
500
|
+
self, call: Callable[..., Any], /, **defaults: Any
|
|
501
|
+
) -> Scope | None:
|
|
286
502
|
"""Detect the scope for a callable."""
|
|
287
503
|
scopes = set()
|
|
288
504
|
|
|
@@ -320,122 +536,56 @@ class Container:
|
|
|
320
536
|
not self.strict and parameter.default is not inspect.Parameter.empty
|
|
321
537
|
)
|
|
322
538
|
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
"""Register a module as a callable, module type, or module instance."""
|
|
327
|
-
# Callable Module
|
|
328
|
-
if inspect.isfunction(module):
|
|
329
|
-
module(self)
|
|
330
|
-
return
|
|
331
|
-
|
|
332
|
-
# Module path
|
|
333
|
-
if isinstance(module, str):
|
|
334
|
-
module = import_string(module)
|
|
335
|
-
|
|
336
|
-
# Class based Module or Module type
|
|
337
|
-
if inspect.isclass(module) and issubclass(module, Module):
|
|
338
|
-
module = module()
|
|
339
|
-
|
|
340
|
-
if isinstance(module, Module):
|
|
341
|
-
module.configure(self)
|
|
342
|
-
for provider_name, decorator_args in module.providers:
|
|
343
|
-
obj = getattr(module, provider_name)
|
|
344
|
-
self.provider(
|
|
345
|
-
scope=decorator_args.scope,
|
|
346
|
-
override=decorator_args.override,
|
|
347
|
-
)(obj)
|
|
348
|
-
|
|
349
|
-
def __enter__(self) -> Self:
|
|
350
|
-
"""Enter the singleton context."""
|
|
351
|
-
self.start()
|
|
352
|
-
return self
|
|
353
|
-
|
|
354
|
-
def __exit__(
|
|
355
|
-
self,
|
|
356
|
-
exc_type: type[BaseException] | None,
|
|
357
|
-
exc_val: BaseException | None,
|
|
358
|
-
exc_tb: types.TracebackType | None,
|
|
359
|
-
) -> Any:
|
|
360
|
-
"""Exit the singleton context."""
|
|
361
|
-
return self._singleton_context.__exit__(exc_type, exc_val, exc_tb)
|
|
362
|
-
|
|
363
|
-
def start(self) -> None:
|
|
364
|
-
"""Start the singleton context."""
|
|
365
|
-
# Resolve all singleton resources
|
|
366
|
-
for interface in self._resources.get("singleton", []):
|
|
367
|
-
self.resolve(interface)
|
|
368
|
-
|
|
369
|
-
def close(self) -> None:
|
|
370
|
-
"""Close the singleton context."""
|
|
371
|
-
self._singleton_context.close()
|
|
372
|
-
|
|
373
|
-
@contextlib.contextmanager
|
|
374
|
-
def request_context(self) -> Iterator[InstanceContext]:
|
|
375
|
-
"""Obtain a context manager for the request-scoped context."""
|
|
376
|
-
context = InstanceContext()
|
|
377
|
-
|
|
378
|
-
token = self._request_context_var.set(context)
|
|
539
|
+
############################
|
|
540
|
+
# Instance Methods
|
|
541
|
+
############################
|
|
379
542
|
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
if not is_event_type(interface):
|
|
383
|
-
continue
|
|
384
|
-
self.resolve(interface)
|
|
385
|
-
|
|
386
|
-
with context:
|
|
387
|
-
yield context
|
|
388
|
-
self._request_context_var.reset(token)
|
|
543
|
+
@overload
|
|
544
|
+
def resolve(self, interface: type[T]) -> T: ...
|
|
389
545
|
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
await self.astart()
|
|
393
|
-
return self
|
|
546
|
+
@overload
|
|
547
|
+
def resolve(self, interface: T) -> T: ...
|
|
394
548
|
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
exc_val: BaseException | None,
|
|
399
|
-
exc_tb: types.TracebackType | None,
|
|
400
|
-
) -> bool:
|
|
401
|
-
"""Exit the singleton context."""
|
|
402
|
-
return await self._singleton_context.__aexit__(exc_type, exc_val, exc_tb)
|
|
549
|
+
def resolve(self, interface: type[T]) -> T:
|
|
550
|
+
"""Resolve an instance by interface."""
|
|
551
|
+
return self._resolve_or_create(interface, False)
|
|
403
552
|
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
for interface in self._resources.get("singleton", []):
|
|
407
|
-
await self.aresolve(interface)
|
|
553
|
+
@overload
|
|
554
|
+
async def aresolve(self, interface: type[T]) -> T: ...
|
|
408
555
|
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
await self._singleton_context.aclose()
|
|
556
|
+
@overload
|
|
557
|
+
async def aresolve(self, interface: T) -> T: ...
|
|
412
558
|
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
context = InstanceContext()
|
|
559
|
+
async def aresolve(self, interface: type[T]) -> T:
|
|
560
|
+
"""Resolve an instance by interface asynchronously."""
|
|
561
|
+
return await self._aresolve_or_acreate(interface, False)
|
|
417
562
|
|
|
418
|
-
|
|
563
|
+
def create(self, interface: type[T], /, **defaults: Any) -> T:
|
|
564
|
+
"""Create an instance by interface."""
|
|
565
|
+
return self._resolve_or_create(interface, True, **defaults)
|
|
419
566
|
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
await self.aresolve(interface)
|
|
567
|
+
async def acreate(self, interface: type[T], /, **defaults: Any) -> T:
|
|
568
|
+
"""Create an instance by interface asynchronously."""
|
|
569
|
+
return await self._aresolve_or_acreate(interface, True, **defaults)
|
|
424
570
|
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
571
|
+
def is_resolved(self, interface: AnyInterface) -> bool:
|
|
572
|
+
"""Check if an instance by interface exists."""
|
|
573
|
+
try:
|
|
574
|
+
provider = self._get_provider(interface)
|
|
575
|
+
except LookupError:
|
|
576
|
+
return False
|
|
577
|
+
if provider.scope == "transient":
|
|
578
|
+
return False
|
|
579
|
+
context = self._get_scoped_context(provider.scope)
|
|
580
|
+
return interface in context
|
|
428
581
|
|
|
429
|
-
def
|
|
430
|
-
"""
|
|
431
|
-
|
|
432
|
-
if
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
"to use it."
|
|
437
|
-
)
|
|
438
|
-
return request_context
|
|
582
|
+
def release(self, interface: AnyInterface) -> None:
|
|
583
|
+
"""Release an instance by interface."""
|
|
584
|
+
provider = self._get_provider(interface)
|
|
585
|
+
if provider.scope == "transient":
|
|
586
|
+
return None
|
|
587
|
+
context = self._get_scoped_context(provider.scope)
|
|
588
|
+
del context[interface]
|
|
439
589
|
|
|
440
590
|
def reset(self) -> None:
|
|
441
591
|
"""Reset resolved instances."""
|
|
@@ -448,52 +598,10 @@ class Container:
|
|
|
448
598
|
continue
|
|
449
599
|
del context[interface]
|
|
450
600
|
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
def resolve(self, interface: T) -> T: ...
|
|
456
|
-
|
|
457
|
-
def resolve(self, interface: type[T]) -> T:
|
|
458
|
-
"""Resolve an instance by interface."""
|
|
459
|
-
provider = self._get_or_register_provider(interface, None)
|
|
460
|
-
if provider.scope == "transient":
|
|
461
|
-
instance = self._create_instance(provider, None)
|
|
462
|
-
else:
|
|
463
|
-
context = self._get_scoped_context(provider.scope)
|
|
464
|
-
if provider.scope == "singleton":
|
|
465
|
-
with self._singleton_lock:
|
|
466
|
-
instance = self._get_or_create_instance(provider, context)
|
|
467
|
-
else:
|
|
468
|
-
instance = self._get_or_create_instance(provider, context)
|
|
469
|
-
if self.testing:
|
|
470
|
-
instance = self._patch_test_resolver(provider.interface, instance)
|
|
471
|
-
return cast(T, instance)
|
|
472
|
-
|
|
473
|
-
@overload
|
|
474
|
-
async def aresolve(self, interface: type[T]) -> T: ...
|
|
475
|
-
|
|
476
|
-
@overload
|
|
477
|
-
async def aresolve(self, interface: T) -> T: ...
|
|
478
|
-
|
|
479
|
-
async def aresolve(self, interface: type[T]) -> T:
|
|
480
|
-
"""Resolve an instance by interface asynchronously."""
|
|
481
|
-
provider = self._get_or_register_provider(interface, None)
|
|
482
|
-
if provider.scope == "transient":
|
|
483
|
-
instance = await self._acreate_instance(provider, None)
|
|
484
|
-
else:
|
|
485
|
-
context = self._get_scoped_context(provider.scope)
|
|
486
|
-
if provider.scope == "singleton":
|
|
487
|
-
async with self._singleton_async_lock:
|
|
488
|
-
instance = await self._aget_or_create_instance(provider, context)
|
|
489
|
-
else:
|
|
490
|
-
instance = await self._aget_or_create_instance(provider, context)
|
|
491
|
-
if self.testing:
|
|
492
|
-
instance = self._patch_test_resolver(interface, instance)
|
|
493
|
-
return cast(T, instance)
|
|
494
|
-
|
|
495
|
-
def create(self, interface: type[T], **defaults: Any) -> T:
|
|
496
|
-
"""Create an instance by interface."""
|
|
601
|
+
def _resolve_or_create(
|
|
602
|
+
self, interface: type[T], create: bool, /, **defaults: Any
|
|
603
|
+
) -> T:
|
|
604
|
+
"""Internal method to handle instance resolution and creation."""
|
|
497
605
|
provider = self._get_or_register_provider(interface, None, **defaults)
|
|
498
606
|
if provider.scope == "transient":
|
|
499
607
|
instance = self._create_instance(provider, None, **defaults)
|
|
@@ -501,15 +609,27 @@ class Container:
|
|
|
501
609
|
context = self._get_scoped_context(provider.scope)
|
|
502
610
|
if provider.scope == "singleton":
|
|
503
611
|
with self._singleton_lock:
|
|
504
|
-
instance =
|
|
612
|
+
instance = (
|
|
613
|
+
self._get_or_create_instance(provider, context)
|
|
614
|
+
if not create
|
|
615
|
+
else self._create_instance(provider, context, **defaults)
|
|
616
|
+
)
|
|
505
617
|
else:
|
|
506
|
-
instance =
|
|
618
|
+
instance = (
|
|
619
|
+
self._get_or_create_instance(provider, context)
|
|
620
|
+
if not create
|
|
621
|
+
else self._create_instance(provider, context, **defaults)
|
|
622
|
+
)
|
|
623
|
+
|
|
507
624
|
if self.testing:
|
|
508
625
|
instance = self._patch_test_resolver(provider.interface, instance)
|
|
626
|
+
|
|
509
627
|
return cast(T, instance)
|
|
510
628
|
|
|
511
|
-
async def
|
|
512
|
-
|
|
629
|
+
async def _aresolve_or_acreate(
|
|
630
|
+
self, interface: type[T], create: bool, /, **defaults: Any
|
|
631
|
+
) -> T:
|
|
632
|
+
"""Internal method to handle instance resolution and creation asynchronously."""
|
|
513
633
|
provider = self._get_or_register_provider(interface, None, **defaults)
|
|
514
634
|
if provider.scope == "transient":
|
|
515
635
|
instance = await self._acreate_instance(provider, None, **defaults)
|
|
@@ -517,13 +637,21 @@ class Container:
|
|
|
517
637
|
context = self._get_scoped_context(provider.scope)
|
|
518
638
|
if provider.scope == "singleton":
|
|
519
639
|
async with self._singleton_async_lock:
|
|
520
|
-
instance =
|
|
521
|
-
provider, context
|
|
640
|
+
instance = (
|
|
641
|
+
await self._aget_or_create_instance(provider, context)
|
|
642
|
+
if not create
|
|
643
|
+
else await self._acreate_instance(provider, context, **defaults)
|
|
522
644
|
)
|
|
523
645
|
else:
|
|
524
|
-
instance =
|
|
646
|
+
instance = (
|
|
647
|
+
await self._aget_or_create_instance(provider, context)
|
|
648
|
+
if not create
|
|
649
|
+
else await self._acreate_instance(provider, context, **defaults)
|
|
650
|
+
)
|
|
651
|
+
|
|
525
652
|
if self.testing:
|
|
526
653
|
instance = self._patch_test_resolver(provider.interface, instance)
|
|
654
|
+
|
|
527
655
|
return cast(T, instance)
|
|
528
656
|
|
|
529
657
|
def _get_or_create_instance(
|
|
@@ -700,26 +828,48 @@ class Container:
|
|
|
700
828
|
def _resolve_parameter(
|
|
701
829
|
self, provider: Provider, parameter: inspect.Parameter
|
|
702
830
|
) -> Any:
|
|
703
|
-
self._validate_resolvable_parameter(
|
|
831
|
+
self._validate_resolvable_parameter(provider, parameter)
|
|
704
832
|
return self.resolve(parameter.annotation)
|
|
705
833
|
|
|
706
834
|
async def _aresolve_parameter(
|
|
707
835
|
self, provider: Provider, parameter: inspect.Parameter
|
|
708
836
|
) -> Any:
|
|
709
|
-
self._validate_resolvable_parameter(
|
|
837
|
+
self._validate_resolvable_parameter(provider, parameter)
|
|
710
838
|
return await self.aresolve(parameter.annotation)
|
|
711
839
|
|
|
712
840
|
def _validate_resolvable_parameter(
|
|
713
|
-
self, parameter: inspect.Parameter
|
|
841
|
+
self, provider: Provider, parameter: inspect.Parameter
|
|
714
842
|
) -> None:
|
|
715
843
|
"""Ensure that the specified interface is resolved."""
|
|
716
844
|
if parameter.annotation in self._unresolved_interfaces:
|
|
717
845
|
raise LookupError(
|
|
718
846
|
f"You are attempting to get the parameter `{parameter.name}` with the "
|
|
719
847
|
f"annotation `{get_full_qualname(parameter.annotation)}` as a "
|
|
720
|
-
f"dependency into `{get_full_qualname(call)}` which is
|
|
721
|
-
"or set in the scoped context."
|
|
848
|
+
f"dependency into `{get_full_qualname(provider.call)}` which is "
|
|
849
|
+
"not registered or set in the scoped context."
|
|
850
|
+
)
|
|
851
|
+
|
|
852
|
+
@contextlib.contextmanager
|
|
853
|
+
def override(self, interface: AnyInterface, instance: Any) -> Iterator[None]:
|
|
854
|
+
"""
|
|
855
|
+
Override the provider for the specified interface with a specific instance.
|
|
856
|
+
"""
|
|
857
|
+
if not self.testing:
|
|
858
|
+
raise RuntimeError(
|
|
859
|
+
"The `override` method can only be used in testing mode."
|
|
722
860
|
)
|
|
861
|
+
if not self.is_registered(interface) and self.strict:
|
|
862
|
+
raise LookupError(
|
|
863
|
+
f"The provider interface `{get_full_qualname(interface)}` "
|
|
864
|
+
"not registered."
|
|
865
|
+
)
|
|
866
|
+
self._override_instances[interface] = instance
|
|
867
|
+
yield
|
|
868
|
+
del self._override_instances[interface]
|
|
869
|
+
|
|
870
|
+
############################
|
|
871
|
+
# Testing Methods
|
|
872
|
+
############################
|
|
723
873
|
|
|
724
874
|
def _patch_test_resolver(self, interface: type[Any], instance: Any) -> Any:
|
|
725
875
|
"""Patch the test resolver for the instance."""
|
|
@@ -769,60 +919,9 @@ class Container:
|
|
|
769
919
|
|
|
770
920
|
return instance
|
|
771
921
|
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
provider = self._get_provider(interface)
|
|
776
|
-
except LookupError:
|
|
777
|
-
return False
|
|
778
|
-
if provider.scope == "transient":
|
|
779
|
-
return False
|
|
780
|
-
context = self._get_scoped_context(provider.scope)
|
|
781
|
-
return interface in context
|
|
782
|
-
|
|
783
|
-
def release(self, interface: AnyInterface) -> None:
|
|
784
|
-
"""Release an instance by interface."""
|
|
785
|
-
provider = self._get_provider(interface)
|
|
786
|
-
if provider.scope == "transient":
|
|
787
|
-
return None
|
|
788
|
-
context = self._get_scoped_context(provider.scope)
|
|
789
|
-
del context[interface]
|
|
790
|
-
|
|
791
|
-
def _get_scoped_context(self, scope: Scope) -> InstanceContext:
|
|
792
|
-
"""Get the instance context for the specified scope."""
|
|
793
|
-
if scope == "singleton":
|
|
794
|
-
return self._singleton_context
|
|
795
|
-
return self._get_request_context()
|
|
796
|
-
|
|
797
|
-
@contextlib.contextmanager
|
|
798
|
-
def override(self, interface: AnyInterface, instance: Any) -> Iterator[None]:
|
|
799
|
-
"""
|
|
800
|
-
Override the provider for the specified interface with a specific instance.
|
|
801
|
-
"""
|
|
802
|
-
if not self.testing:
|
|
803
|
-
raise RuntimeError(
|
|
804
|
-
"The `override` method can only be used in testing mode."
|
|
805
|
-
)
|
|
806
|
-
if not self.is_registered(interface) and self.strict:
|
|
807
|
-
raise LookupError(
|
|
808
|
-
f"The provider interface `{get_full_qualname(interface)}` "
|
|
809
|
-
"not registered."
|
|
810
|
-
)
|
|
811
|
-
self._override_instances[interface] = instance
|
|
812
|
-
yield
|
|
813
|
-
del self._override_instances[interface]
|
|
814
|
-
|
|
815
|
-
def provider(
|
|
816
|
-
self, *, scope: Scope, override: bool = False
|
|
817
|
-
) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
|
818
|
-
"""Decorator to register a provider function with the specified scope."""
|
|
819
|
-
|
|
820
|
-
def decorator(call: Callable[P, T]) -> Callable[P, T]:
|
|
821
|
-
provider = Provider(call=call, scope=scope)
|
|
822
|
-
self._register_provider(provider, override)
|
|
823
|
-
return call
|
|
824
|
-
|
|
825
|
-
return decorator
|
|
922
|
+
############################
|
|
923
|
+
# Injector Methods
|
|
924
|
+
############################
|
|
826
925
|
|
|
827
926
|
@overload
|
|
828
927
|
def inject(self, func: Callable[P, T]) -> Callable[P, T]: ...
|
|
@@ -842,6 +941,10 @@ class Container:
|
|
|
842
941
|
return decorator
|
|
843
942
|
return decorator(func)
|
|
844
943
|
|
|
944
|
+
def run(self, func: Callable[P, T], /, *args: P.args, **kwargs: P.kwargs) -> T:
|
|
945
|
+
"""Run the given function with injected dependencies."""
|
|
946
|
+
return self._inject(func)(*args, **kwargs)
|
|
947
|
+
|
|
845
948
|
def _inject(self, call: Callable[P, T]) -> Callable[P, T]:
|
|
846
949
|
"""Inject dependencies into a callable."""
|
|
847
950
|
if call in self._inject_cache:
|
|
@@ -909,9 +1012,39 @@ class Container:
|
|
|
909
1012
|
f"`{get_full_qualname(parameter.annotation)}`."
|
|
910
1013
|
)
|
|
911
1014
|
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
1015
|
+
############################
|
|
1016
|
+
# Module Methods
|
|
1017
|
+
############################
|
|
1018
|
+
|
|
1019
|
+
def register_module(
|
|
1020
|
+
self, module: Module | type[Module] | Callable[[Container], None] | str
|
|
1021
|
+
) -> None:
|
|
1022
|
+
"""Register a module as a callable, module type, or module instance."""
|
|
1023
|
+
# Callable Module
|
|
1024
|
+
if inspect.isfunction(module):
|
|
1025
|
+
module(self)
|
|
1026
|
+
return
|
|
1027
|
+
|
|
1028
|
+
# Module path
|
|
1029
|
+
if isinstance(module, str):
|
|
1030
|
+
module = import_string(module)
|
|
1031
|
+
|
|
1032
|
+
# Class based Module or Module type
|
|
1033
|
+
if inspect.isclass(module) and issubclass(module, Module):
|
|
1034
|
+
module = module()
|
|
1035
|
+
|
|
1036
|
+
if isinstance(module, Module):
|
|
1037
|
+
module.configure(self)
|
|
1038
|
+
for provider_name, decorator_args in module.providers:
|
|
1039
|
+
obj = getattr(module, provider_name)
|
|
1040
|
+
self.provider(
|
|
1041
|
+
scope=decorator_args.scope,
|
|
1042
|
+
override=decorator_args.override,
|
|
1043
|
+
)(obj)
|
|
1044
|
+
|
|
1045
|
+
############################
|
|
1046
|
+
# Scanner Methods
|
|
1047
|
+
############################
|
|
915
1048
|
|
|
916
1049
|
def scan(
|
|
917
1050
|
self,
|
|
@@ -921,7 +1054,7 @@ class Container:
|
|
|
921
1054
|
tags: Iterable[str] | None = None,
|
|
922
1055
|
) -> None:
|
|
923
1056
|
"""Scan packages or modules for decorated members and inject dependencies."""
|
|
924
|
-
dependencies: list[
|
|
1057
|
+
dependencies: list[ScannedDependency] = []
|
|
925
1058
|
|
|
926
1059
|
if isinstance(packages, Iterable) and not isinstance(packages, str):
|
|
927
1060
|
scan_packages: Iterable[ModuleType | str] = packages
|
|
@@ -940,7 +1073,7 @@ class Container:
|
|
|
940
1073
|
package: ModuleType | str,
|
|
941
1074
|
*,
|
|
942
1075
|
tags: Iterable[str] | None = None,
|
|
943
|
-
) -> list[
|
|
1076
|
+
) -> list[ScannedDependency]:
|
|
944
1077
|
"""Scan a package or module for decorated members."""
|
|
945
1078
|
tags = tags or []
|
|
946
1079
|
if isinstance(package, str):
|
|
@@ -951,7 +1084,7 @@ class Container:
|
|
|
951
1084
|
if not package_path:
|
|
952
1085
|
return self._scan_module(package, tags=tags)
|
|
953
1086
|
|
|
954
|
-
dependencies: list[
|
|
1087
|
+
dependencies: list[ScannedDependency] = []
|
|
955
1088
|
|
|
956
1089
|
for module_info in pkgutil.walk_packages(
|
|
957
1090
|
path=package_path, prefix=package.__name__ + "."
|
|
@@ -963,9 +1096,9 @@ class Container:
|
|
|
963
1096
|
|
|
964
1097
|
def _scan_module(
|
|
965
1098
|
self, module: ModuleType, *, tags: Iterable[str]
|
|
966
|
-
) -> list[
|
|
1099
|
+
) -> list[ScannedDependency]:
|
|
967
1100
|
"""Scan a module for decorated members."""
|
|
968
|
-
dependencies: list[
|
|
1101
|
+
dependencies: list[ScannedDependency] = []
|
|
969
1102
|
|
|
970
1103
|
for _, member in inspect.getmembers(module):
|
|
971
1104
|
if getattr(member, "__module__", None) != module.__name__ or not callable(
|
|
@@ -988,7 +1121,7 @@ class Container:
|
|
|
988
1121
|
|
|
989
1122
|
if decorator_args.wrapped:
|
|
990
1123
|
dependencies.append(
|
|
991
|
-
self.
|
|
1124
|
+
self._create_scanned_dependency(member=member, module=module)
|
|
992
1125
|
)
|
|
993
1126
|
continue
|
|
994
1127
|
|
|
@@ -996,17 +1129,24 @@ class Container:
|
|
|
996
1129
|
for parameter in get_typed_parameters(member):
|
|
997
1130
|
if is_marker(parameter.default):
|
|
998
1131
|
dependencies.append(
|
|
999
|
-
self.
|
|
1132
|
+
self._create_scanned_dependency(member=member, module=module)
|
|
1000
1133
|
)
|
|
1001
1134
|
continue
|
|
1002
1135
|
|
|
1003
1136
|
return dependencies
|
|
1004
1137
|
|
|
1005
|
-
def
|
|
1138
|
+
def _create_scanned_dependency(
|
|
1139
|
+
self, member: Any, module: ModuleType
|
|
1140
|
+
) -> ScannedDependency:
|
|
1006
1141
|
"""Create a `Dependency` object from the scanned member and module."""
|
|
1007
1142
|
if hasattr(member, "__wrapped__"):
|
|
1008
1143
|
member = member.__wrapped__
|
|
1009
|
-
return
|
|
1144
|
+
return ScannedDependency(member=member, module=module)
|
|
1145
|
+
|
|
1146
|
+
|
|
1147
|
+
############################
|
|
1148
|
+
# Decorators
|
|
1149
|
+
############################
|
|
1010
1150
|
|
|
1011
1151
|
|
|
1012
1152
|
def transient(target: T) -> T:
|
anydi/_types.py
CHANGED
|
@@ -1,9 +1,12 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import enum
|
|
3
4
|
import inspect
|
|
4
5
|
from collections.abc import Iterable
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from functools import cached_property
|
|
5
8
|
from types import ModuleType
|
|
6
|
-
from typing import Annotated, Any, NamedTuple, Union
|
|
9
|
+
from typing import Annotated, Any, Callable, NamedTuple, Union
|
|
7
10
|
|
|
8
11
|
import wrapt
|
|
9
12
|
from typing_extensions import Literal, Self, TypeAlias
|
|
@@ -12,6 +15,8 @@ Scope = Literal["transient", "singleton", "request"]
|
|
|
12
15
|
|
|
13
16
|
AnyInterface: TypeAlias = Union[type[Any], Annotated[Any, ...]]
|
|
14
17
|
|
|
18
|
+
NOT_SET = object()
|
|
19
|
+
|
|
15
20
|
|
|
16
21
|
class Marker:
|
|
17
22
|
"""A marker class for marking dependencies."""
|
|
@@ -53,12 +58,85 @@ class InstanceProxy(wrapt.ObjectProxy): # type: ignore[misc]
|
|
|
53
58
|
return object.__getattribute__(self, item)
|
|
54
59
|
|
|
55
60
|
|
|
61
|
+
class ProviderKind(enum.IntEnum):
|
|
62
|
+
CLASS = 1
|
|
63
|
+
FUNCTION = 2
|
|
64
|
+
COROUTINE = 3
|
|
65
|
+
GENERATOR = 4
|
|
66
|
+
ASYNC_GENERATOR = 5
|
|
67
|
+
|
|
68
|
+
@classmethod
|
|
69
|
+
def from_call(cls, call: Callable[..., Any]) -> ProviderKind:
|
|
70
|
+
if inspect.isclass(call):
|
|
71
|
+
return cls.CLASS
|
|
72
|
+
elif inspect.iscoroutinefunction(call):
|
|
73
|
+
return cls.COROUTINE
|
|
74
|
+
elif inspect.isasyncgenfunction(call):
|
|
75
|
+
return cls.ASYNC_GENERATOR
|
|
76
|
+
elif inspect.isgeneratorfunction(call):
|
|
77
|
+
return cls.GENERATOR
|
|
78
|
+
elif inspect.isfunction(call) or inspect.ismethod(call):
|
|
79
|
+
return cls.FUNCTION
|
|
80
|
+
raise TypeError(
|
|
81
|
+
f"The provider `{call}` is invalid because it is not a callable "
|
|
82
|
+
"object. Only callable providers are allowed."
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@dataclass(kw_only=True, frozen=True)
|
|
87
|
+
class ProviderParameter:
|
|
88
|
+
pass
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@dataclass(kw_only=True, frozen=True)
|
|
92
|
+
class Provider:
|
|
93
|
+
call: Callable[..., Any]
|
|
94
|
+
scope: Scope
|
|
95
|
+
interface: Any
|
|
96
|
+
name: str
|
|
97
|
+
parameters: list[inspect.Parameter]
|
|
98
|
+
kind: ProviderKind
|
|
99
|
+
|
|
100
|
+
def __str__(self) -> str:
|
|
101
|
+
return self.name
|
|
102
|
+
|
|
103
|
+
@cached_property
|
|
104
|
+
def is_class(self) -> bool:
|
|
105
|
+
return self.kind == ProviderKind.CLASS
|
|
106
|
+
|
|
107
|
+
@cached_property
|
|
108
|
+
def is_coroutine(self) -> bool:
|
|
109
|
+
return self.kind == ProviderKind.COROUTINE
|
|
110
|
+
|
|
111
|
+
@cached_property
|
|
112
|
+
def is_generator(self) -> bool:
|
|
113
|
+
return self.kind == ProviderKind.GENERATOR
|
|
114
|
+
|
|
115
|
+
@cached_property
|
|
116
|
+
def is_async_generator(self) -> bool:
|
|
117
|
+
return self.kind == ProviderKind.ASYNC_GENERATOR
|
|
118
|
+
|
|
119
|
+
@cached_property
|
|
120
|
+
def is_async(self) -> bool:
|
|
121
|
+
return self.is_coroutine or self.is_async_generator
|
|
122
|
+
|
|
123
|
+
@cached_property
|
|
124
|
+
def is_resource(self) -> bool:
|
|
125
|
+
return self.is_generator or self.is_async_generator
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class ProviderArgs(NamedTuple):
|
|
129
|
+
call: Callable[..., Any]
|
|
130
|
+
scope: Scope
|
|
131
|
+
interface: Any | None = None
|
|
132
|
+
|
|
133
|
+
|
|
56
134
|
class ProviderDecoratorArgs(NamedTuple):
|
|
57
135
|
scope: Scope
|
|
58
136
|
override: bool
|
|
59
137
|
|
|
60
138
|
|
|
61
|
-
class
|
|
139
|
+
class ScannedDependency(NamedTuple):
|
|
62
140
|
member: Any
|
|
63
141
|
module: ModuleType
|
|
64
142
|
|
anydi/ext/pytest_plugin.py
CHANGED
|
@@ -35,11 +35,9 @@ CONTAINER_FIXTURE_NAME = "container"
|
|
|
35
35
|
|
|
36
36
|
|
|
37
37
|
@pytest.fixture
|
|
38
|
-
def anydi_setup_container(
|
|
39
|
-
request: pytest.FixtureRequest,
|
|
40
|
-
) -> Iterator[Container]:
|
|
38
|
+
def anydi_setup_container(request: pytest.FixtureRequest) -> Container:
|
|
41
39
|
try:
|
|
42
|
-
|
|
40
|
+
return cast(Container, request.getfixturevalue(CONTAINER_FIXTURE_NAME))
|
|
43
41
|
except pytest.FixtureLookupError as exc:
|
|
44
42
|
exc.msg = (
|
|
45
43
|
"`container` fixture is not found. Make sure to define it in your test "
|
|
@@ -47,8 +45,6 @@ def anydi_setup_container(
|
|
|
47
45
|
)
|
|
48
46
|
raise exc
|
|
49
47
|
|
|
50
|
-
yield container
|
|
51
|
-
|
|
52
48
|
|
|
53
49
|
@pytest.fixture
|
|
54
50
|
def _anydi_should_inject(request: pytest.FixtureRequest) -> bool:
|
|
@@ -1,8 +1,7 @@
|
|
|
1
|
-
anydi/__init__.py,sha256=
|
|
2
|
-
anydi/_container.py,sha256=
|
|
1
|
+
anydi/__init__.py,sha256=aAq10a1V_zQ3_Me3p_pll5d1O77PIgqotkOm3pshORI,495
|
|
2
|
+
anydi/_container.py,sha256=BqpvUPeYt6OW7TLIDm-OvMGCbcxnvXA6KyOF9XBmi7M,43072
|
|
3
3
|
anydi/_context.py,sha256=7LV_SL4QWkJeiG7_4D9PZ5lmU-MPzhofxC95zCgY9Gc,2651
|
|
4
|
-
anydi/
|
|
5
|
-
anydi/_types.py,sha256=fdO4xNXtGMxVArmlfDkFYbyR895ixkBTW6V8lMceN7Q,1562
|
|
4
|
+
anydi/_types.py,sha256=oFyx6jxkEsz5FZk6tdRjUmBBQ2tX7eA_bLaa2elq7Mg,3586
|
|
6
5
|
anydi/_utils.py,sha256=INI0jNIXrJ6LS4zqJymMO2yUEobpxmBGASf4G_vR6AU,4378
|
|
7
6
|
anydi/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
7
|
anydi/ext/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -10,7 +9,7 @@ anydi/ext/_utils.py,sha256=U6sRqWzccWUu7eMhbXX1NrwcaxitQF9cO1KxnKF37gw,2566
|
|
|
10
9
|
anydi/ext/fastapi.py,sha256=AEL3ubu-LxUPHMMt1YIn3En_JZC7nyBKmKxmhka3O3c,2436
|
|
11
10
|
anydi/ext/faststream.py,sha256=qXnNGvAqWWp9kbhbQUE6EF_OPUiYQmtOH211_O7BI_0,1898
|
|
12
11
|
anydi/ext/pydantic_settings.py,sha256=8IXXLuG_OvKbvKlBkBRQUHcXgbTpgQUxeWyoMcRIUQM,1488
|
|
13
|
-
anydi/ext/pytest_plugin.py,sha256=
|
|
12
|
+
anydi/ext/pytest_plugin.py,sha256=yR_vos8qt8uFS9uy_G_HJjNgudzJIXawrtpWnn3Pu_s,4613
|
|
14
13
|
anydi/ext/django/__init__.py,sha256=QI1IABCVgSDTUoh7M9WMECKXwB3xvh04HfQ9TOWw1Mk,223
|
|
15
14
|
anydi/ext/django/_container.py,sha256=cxVoYQG16WP0S_Yv4TnLwuaaT7NVEOhLWO-YdALJUb4,418
|
|
16
15
|
anydi/ext/django/_settings.py,sha256=Z0RlAuXoO73oahWeMkK10w8c-4uCBde-DBpeKTV5USY,853
|
|
@@ -22,8 +21,8 @@ anydi/ext/django/ninja/_operation.py,sha256=wSWa7D73XTVlOibmOciv2l6JHPe1ERZcXrqI
|
|
|
22
21
|
anydi/ext/django/ninja/_signature.py,sha256=2cSzKxBIxXLqtwNuH6GSlmjVJFftoGmleWfyk_NVEWw,2207
|
|
23
22
|
anydi/ext/starlette/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
24
23
|
anydi/ext/starlette/middleware.py,sha256=9CQtGg5ZzUz2gFSzJr8U4BWzwNjK8XMctm3n52M77Z0,792
|
|
25
|
-
anydi-0.38.
|
|
26
|
-
anydi-0.38.
|
|
27
|
-
anydi-0.38.
|
|
28
|
-
anydi-0.38.
|
|
29
|
-
anydi-0.38.
|
|
24
|
+
anydi-0.38.2rc0.dist-info/METADATA,sha256=t8FL1ILCUK37XUFQ_gqgAlwX77UytT5FeP4QueIDX50,4920
|
|
25
|
+
anydi-0.38.2rc0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
26
|
+
anydi-0.38.2rc0.dist-info/entry_points.txt,sha256=Nklo9f3Oe4AkNsEgC4g43nCJ-23QDngZSVDNRMdaILI,43
|
|
27
|
+
anydi-0.38.2rc0.dist-info/licenses/LICENSE,sha256=V6rU8a8fv6o2jQ-7ODHs0XfDFimot8Q6Km6CylRIDTo,1069
|
|
28
|
+
anydi-0.38.2rc0.dist-info/RECORD,,
|
anydi/_provider.py
DELETED
|
@@ -1,232 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
import inspect
|
|
4
|
-
import uuid
|
|
5
|
-
from collections.abc import AsyncIterator, Iterator
|
|
6
|
-
from enum import IntEnum
|
|
7
|
-
from typing import Any, Callable
|
|
8
|
-
|
|
9
|
-
from typing_extensions import get_args, get_origin
|
|
10
|
-
|
|
11
|
-
try:
|
|
12
|
-
from types import NoneType
|
|
13
|
-
except ImportError:
|
|
14
|
-
NoneType = type(None) # type: ignore[misc]
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
from ._types import Event, Scope
|
|
18
|
-
from ._utils import get_full_qualname, get_typed_annotation
|
|
19
|
-
|
|
20
|
-
_sentinel = object()
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
class CallableKind(IntEnum):
|
|
24
|
-
CLASS = 1
|
|
25
|
-
FUNCTION = 2
|
|
26
|
-
COROUTINE = 3
|
|
27
|
-
GENERATOR = 4
|
|
28
|
-
ASYNC_GENERATOR = 5
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
class Provider:
|
|
32
|
-
__slots__ = (
|
|
33
|
-
"_call",
|
|
34
|
-
"_call_module",
|
|
35
|
-
"_call_globals",
|
|
36
|
-
"_scope",
|
|
37
|
-
"_qualname",
|
|
38
|
-
"_kind",
|
|
39
|
-
"_interface",
|
|
40
|
-
"_parameters",
|
|
41
|
-
"_is_class",
|
|
42
|
-
"_is_coroutine",
|
|
43
|
-
"_is_generator",
|
|
44
|
-
"_is_async_generator",
|
|
45
|
-
"_is_async",
|
|
46
|
-
"_is_resource",
|
|
47
|
-
)
|
|
48
|
-
|
|
49
|
-
def __init__(
|
|
50
|
-
self, call: Callable[..., Any], *, scope: Scope, interface: Any = _sentinel
|
|
51
|
-
) -> None:
|
|
52
|
-
self._call = call
|
|
53
|
-
self._call_module = getattr(call, "__module__", None)
|
|
54
|
-
self._call_globals = getattr(call, "__globals__", {})
|
|
55
|
-
self._scope = scope
|
|
56
|
-
self._qualname = get_full_qualname(call)
|
|
57
|
-
|
|
58
|
-
# Detect the kind of callable provider
|
|
59
|
-
self._detect_kind()
|
|
60
|
-
|
|
61
|
-
self._is_class = self._kind == CallableKind.CLASS
|
|
62
|
-
self._is_coroutine = self._kind == CallableKind.COROUTINE
|
|
63
|
-
self._is_generator = self._kind == CallableKind.GENERATOR
|
|
64
|
-
self._is_async_generator = self._kind == CallableKind.ASYNC_GENERATOR
|
|
65
|
-
self._is_async = self._is_coroutine or self._is_async_generator
|
|
66
|
-
self._is_resource = self._is_generator or self._is_async_generator
|
|
67
|
-
|
|
68
|
-
# Validate the scope of the provider
|
|
69
|
-
self._validate_scope()
|
|
70
|
-
|
|
71
|
-
# Get the signature
|
|
72
|
-
signature = inspect.signature(call)
|
|
73
|
-
|
|
74
|
-
# Detect the interface
|
|
75
|
-
self._detect_interface(interface, signature)
|
|
76
|
-
|
|
77
|
-
# Detect the parameters
|
|
78
|
-
self._detect_parameters(signature)
|
|
79
|
-
|
|
80
|
-
def __str__(self) -> str:
|
|
81
|
-
return self._qualname
|
|
82
|
-
|
|
83
|
-
def __eq__(self, other: object) -> bool:
|
|
84
|
-
if not isinstance(other, Provider):
|
|
85
|
-
return NotImplemented # pragma: no cover
|
|
86
|
-
return (
|
|
87
|
-
self._call == other._call
|
|
88
|
-
and self._scope == other._scope
|
|
89
|
-
and self._interface == other._interface
|
|
90
|
-
)
|
|
91
|
-
|
|
92
|
-
@property
|
|
93
|
-
def call(self) -> Callable[..., Any]:
|
|
94
|
-
return self._call
|
|
95
|
-
|
|
96
|
-
@property
|
|
97
|
-
def kind(self) -> CallableKind:
|
|
98
|
-
return self._kind
|
|
99
|
-
|
|
100
|
-
@property
|
|
101
|
-
def scope(self) -> Scope:
|
|
102
|
-
return self._scope
|
|
103
|
-
|
|
104
|
-
@property
|
|
105
|
-
def interface(self) -> Any:
|
|
106
|
-
return self._interface
|
|
107
|
-
|
|
108
|
-
@property
|
|
109
|
-
def parameters(self) -> list[inspect.Parameter]:
|
|
110
|
-
return self._parameters
|
|
111
|
-
|
|
112
|
-
@property
|
|
113
|
-
def is_class(self) -> bool:
|
|
114
|
-
"""Check if the provider is a class."""
|
|
115
|
-
return self._is_class
|
|
116
|
-
|
|
117
|
-
@property
|
|
118
|
-
def is_coroutine(self) -> bool:
|
|
119
|
-
"""Check if the provider is a coroutine."""
|
|
120
|
-
return self._is_coroutine
|
|
121
|
-
|
|
122
|
-
@property
|
|
123
|
-
def is_generator(self) -> bool:
|
|
124
|
-
"""Check if the provider is a generator."""
|
|
125
|
-
return self._is_generator
|
|
126
|
-
|
|
127
|
-
@property
|
|
128
|
-
def is_async_generator(self) -> bool:
|
|
129
|
-
"""Check if the provider is an async generator."""
|
|
130
|
-
return self._is_async_generator
|
|
131
|
-
|
|
132
|
-
@property
|
|
133
|
-
def is_async(self) -> bool:
|
|
134
|
-
"""Check if the provider is an async callable."""
|
|
135
|
-
return self._is_async
|
|
136
|
-
|
|
137
|
-
@property
|
|
138
|
-
def is_resource(self) -> bool:
|
|
139
|
-
"""Check if the provider is a resource."""
|
|
140
|
-
return self._is_resource
|
|
141
|
-
|
|
142
|
-
def _validate_scope(self) -> None:
|
|
143
|
-
"""Validate the scope of the provider."""
|
|
144
|
-
if self.scope not in get_args(Scope):
|
|
145
|
-
raise ValueError(
|
|
146
|
-
"The scope provided is invalid. Only the following scopes are "
|
|
147
|
-
f"supported: {', '.join(get_args(Scope))}. Please use one of the "
|
|
148
|
-
"supported scopes when registering a provider."
|
|
149
|
-
)
|
|
150
|
-
if self.is_resource and self.scope == "transient":
|
|
151
|
-
raise TypeError(
|
|
152
|
-
f"The resource provider `{self}` is attempting to register "
|
|
153
|
-
"with a transient scope, which is not allowed."
|
|
154
|
-
)
|
|
155
|
-
|
|
156
|
-
def _detect_kind(self) -> None:
|
|
157
|
-
"""Detect the kind of callable provider."""
|
|
158
|
-
if inspect.isclass(self.call):
|
|
159
|
-
self._kind = CallableKind.CLASS
|
|
160
|
-
elif inspect.iscoroutinefunction(self.call):
|
|
161
|
-
self._kind = CallableKind.COROUTINE
|
|
162
|
-
elif inspect.isasyncgenfunction(self.call):
|
|
163
|
-
self._kind = CallableKind.ASYNC_GENERATOR
|
|
164
|
-
elif inspect.isgeneratorfunction(self.call):
|
|
165
|
-
self._kind = CallableKind.GENERATOR
|
|
166
|
-
elif inspect.isfunction(self.call) or inspect.ismethod(self.call):
|
|
167
|
-
self._kind = CallableKind.FUNCTION
|
|
168
|
-
else:
|
|
169
|
-
raise TypeError(
|
|
170
|
-
f"The provider `{self.call}` is invalid because it is not a callable "
|
|
171
|
-
"object. Only callable providers are allowed."
|
|
172
|
-
)
|
|
173
|
-
|
|
174
|
-
def _detect_interface(self, interface: Any, signature: inspect.Signature) -> None:
|
|
175
|
-
"""Detect the interface of callable provider."""
|
|
176
|
-
# If the callable is a class, return the class itself
|
|
177
|
-
if self._kind == CallableKind.CLASS:
|
|
178
|
-
self._interface = self._call
|
|
179
|
-
return
|
|
180
|
-
|
|
181
|
-
if interface is _sentinel:
|
|
182
|
-
interface = self._resolve_interface(interface, signature)
|
|
183
|
-
|
|
184
|
-
# If the callable is an iterator, return the actual type
|
|
185
|
-
iterator_types = {Iterator, AsyncIterator}
|
|
186
|
-
if interface in iterator_types or get_origin(interface) in iterator_types:
|
|
187
|
-
if args := get_args(interface):
|
|
188
|
-
interface = args[0]
|
|
189
|
-
# If the callable is a generator, return the resource type
|
|
190
|
-
if interface is NoneType or interface is None:
|
|
191
|
-
self._interface = type(f"Event_{uuid.uuid4().hex}", (Event,), {})
|
|
192
|
-
return
|
|
193
|
-
else:
|
|
194
|
-
raise TypeError(
|
|
195
|
-
f"Cannot use `{self}` resource type annotation "
|
|
196
|
-
"without actual type argument."
|
|
197
|
-
)
|
|
198
|
-
|
|
199
|
-
# None interface is not allowed
|
|
200
|
-
if interface in {None, NoneType}:
|
|
201
|
-
raise TypeError(f"Missing `{self}` provider return annotation.")
|
|
202
|
-
|
|
203
|
-
# Set the interface
|
|
204
|
-
self._interface = interface
|
|
205
|
-
|
|
206
|
-
def _resolve_interface(self, interface: Any, signature: inspect.Signature) -> Any:
|
|
207
|
-
"""Resolve the interface of the callable provider."""
|
|
208
|
-
interface = signature.return_annotation
|
|
209
|
-
if interface is inspect.Signature.empty:
|
|
210
|
-
return None
|
|
211
|
-
return get_typed_annotation(
|
|
212
|
-
interface,
|
|
213
|
-
self._call_globals,
|
|
214
|
-
module=self._call_module,
|
|
215
|
-
)
|
|
216
|
-
|
|
217
|
-
def _detect_parameters(self, signature: inspect.Signature) -> None:
|
|
218
|
-
"""Detect the parameters of the callable provider."""
|
|
219
|
-
parameters = []
|
|
220
|
-
for parameter in signature.parameters.values():
|
|
221
|
-
if parameter.kind == inspect.Parameter.POSITIONAL_ONLY:
|
|
222
|
-
raise TypeError(
|
|
223
|
-
f"Positional-only parameter `{parameter.name}` is not allowed "
|
|
224
|
-
f"in the provider `{self}`."
|
|
225
|
-
)
|
|
226
|
-
annotation = get_typed_annotation(
|
|
227
|
-
parameter.annotation,
|
|
228
|
-
self._call_globals,
|
|
229
|
-
module=self._call_module,
|
|
230
|
-
)
|
|
231
|
-
parameters.append(parameter.replace(annotation=annotation))
|
|
232
|
-
self._parameters = parameters
|
|
File without changes
|
|
File without changes
|
|
File without changes
|