anydi 0.55.1__py3-none-any.whl → 0.57.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- anydi/__init__.py +4 -2
- anydi/_container.py +67 -110
- anydi/_injector.py +132 -0
- anydi/_resolver.py +225 -122
- anydi/_scanner.py +52 -44
- anydi/_types.py +48 -7
- anydi/ext/fastapi.py +31 -33
- anydi/ext/faststream.py +25 -31
- anydi/ext/pydantic_settings.py +2 -1
- anydi/ext/pytest_plugin.py +380 -50
- anydi/testing.py +34 -113
- anydi-0.57.0.dist-info/METADATA +266 -0
- anydi-0.57.0.dist-info/RECORD +25 -0
- anydi-0.55.1.dist-info/METADATA +0 -193
- anydi-0.55.1.dist-info/RECORD +0 -24
- {anydi-0.55.1.dist-info → anydi-0.57.0.dist-info}/WHEEL +0 -0
- {anydi-0.55.1.dist-info → anydi-0.57.0.dist-info}/entry_points.txt +0 -0
anydi/_resolver.py
CHANGED
|
@@ -6,6 +6,7 @@ import contextlib
|
|
|
6
6
|
from typing import TYPE_CHECKING, Any, NamedTuple
|
|
7
7
|
|
|
8
8
|
import anyio.to_thread
|
|
9
|
+
import wrapt # type: ignore
|
|
9
10
|
from typing_extensions import type_repr
|
|
10
11
|
|
|
11
12
|
from ._provider import Provider
|
|
@@ -15,6 +16,18 @@ if TYPE_CHECKING:
|
|
|
15
16
|
from ._container import Container
|
|
16
17
|
|
|
17
18
|
|
|
19
|
+
class InstanceProxy(wrapt.ObjectProxy): # type: ignore
|
|
20
|
+
"""Proxy for dependency instances to enable override support."""
|
|
21
|
+
|
|
22
|
+
def __init__(self, wrapped: Any, *, interface: type[Any]) -> None:
|
|
23
|
+
super().__init__(wrapped) # type: ignore
|
|
24
|
+
self._self_interface = interface
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def interface(self) -> type[Any]:
|
|
28
|
+
return self._self_interface
|
|
29
|
+
|
|
30
|
+
|
|
18
31
|
class CompiledResolver(NamedTuple):
|
|
19
32
|
resolve: Any
|
|
20
33
|
create: Any
|
|
@@ -24,32 +37,46 @@ class Resolver:
|
|
|
24
37
|
def __init__(self, container: Container) -> None:
|
|
25
38
|
self._container = container
|
|
26
39
|
self._unresolved_interfaces: set[Any] = set()
|
|
40
|
+
# Normal caches (fast path, no override checks)
|
|
27
41
|
self._cache: dict[Any, CompiledResolver] = {}
|
|
28
42
|
self._async_cache: dict[Any, CompiledResolver] = {}
|
|
43
|
+
# Override caches (with override support)
|
|
44
|
+
self._override_cache: dict[Any, CompiledResolver] = {}
|
|
45
|
+
self._async_override_cache: dict[Any, CompiledResolver] = {}
|
|
46
|
+
# Override instances storage
|
|
47
|
+
self._override_instances: dict[Any, Any] = {}
|
|
29
48
|
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
)
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
self.
|
|
38
|
-
|
|
39
|
-
|
|
49
|
+
@property
|
|
50
|
+
def override_mode(self) -> bool:
|
|
51
|
+
"""Check if override mode is enabled."""
|
|
52
|
+
return bool(self._override_instances)
|
|
53
|
+
|
|
54
|
+
def add_override(self, interface: Any, instance: Any) -> None:
|
|
55
|
+
"""Add an override instance for an interface."""
|
|
56
|
+
self._override_instances[interface] = instance
|
|
57
|
+
|
|
58
|
+
def remove_override(self, interface: Any) -> None:
|
|
59
|
+
"""Remove an override instance for an interface."""
|
|
60
|
+
self._override_instances.pop(interface, None)
|
|
40
61
|
|
|
41
62
|
def add_unresolved(self, interface: Any) -> None:
|
|
42
63
|
self._unresolved_interfaces.add(interface)
|
|
43
64
|
|
|
44
65
|
def get_cached(self, interface: Any, *, is_async: bool) -> CompiledResolver | None:
|
|
45
66
|
"""Get cached resolver if it exists."""
|
|
46
|
-
|
|
67
|
+
if self.override_mode:
|
|
68
|
+
cache = self._async_override_cache if is_async else self._override_cache
|
|
69
|
+
else:
|
|
70
|
+
cache = self._async_cache if is_async else self._cache
|
|
47
71
|
return cache.get(interface)
|
|
48
72
|
|
|
49
73
|
def compile(self, provider: Provider, *, is_async: bool) -> CompiledResolver:
|
|
50
74
|
"""Compile an optimized resolver function for the given provider."""
|
|
51
|
-
# Select the appropriate cache based on sync/async mode
|
|
52
|
-
|
|
75
|
+
# Select the appropriate cache based on sync/async mode and override mode
|
|
76
|
+
if self.override_mode:
|
|
77
|
+
cache = self._async_override_cache if is_async else self._override_cache
|
|
78
|
+
else:
|
|
79
|
+
cache = self._async_cache if is_async else self._cache
|
|
53
80
|
|
|
54
81
|
# Check if already compiled in cache
|
|
55
82
|
if provider.interface in cache:
|
|
@@ -61,20 +88,58 @@ class Resolver:
|
|
|
61
88
|
self.compile(p.provider, is_async=is_async)
|
|
62
89
|
|
|
63
90
|
# Compile the resolver and creator functions
|
|
64
|
-
compiled = self._compile_resolver(
|
|
91
|
+
compiled = self._compile_resolver(
|
|
92
|
+
provider, is_async=is_async, with_override=self.override_mode
|
|
93
|
+
)
|
|
65
94
|
|
|
66
95
|
# Store the compiled functions in the cache
|
|
67
96
|
cache[provider.interface] = compiled
|
|
68
97
|
|
|
69
98
|
return compiled
|
|
70
99
|
|
|
100
|
+
def _add_override_check(
|
|
101
|
+
self, lines: list[str], *, include_not_set: bool = False
|
|
102
|
+
) -> None:
|
|
103
|
+
"""Add override checking code to generated resolver."""
|
|
104
|
+
lines.append(" override_mode = resolver.override_mode")
|
|
105
|
+
lines.append(" if override_mode:")
|
|
106
|
+
if include_not_set:
|
|
107
|
+
lines.append(" NOT_SET_ = _NOT_SET")
|
|
108
|
+
lines.append(" override = resolver._get_override_for(_interface)")
|
|
109
|
+
lines.append(" if override is not NOT_SET_:")
|
|
110
|
+
lines.append(" return override")
|
|
111
|
+
|
|
112
|
+
def _add_create_call(
|
|
113
|
+
self,
|
|
114
|
+
lines: list[str],
|
|
115
|
+
*,
|
|
116
|
+
is_async: bool,
|
|
117
|
+
with_override: bool,
|
|
118
|
+
context: str,
|
|
119
|
+
store: bool,
|
|
120
|
+
defaults: str = "None",
|
|
121
|
+
indent: str = " ",
|
|
122
|
+
) -> None:
|
|
123
|
+
"""Add _create_instance call to generated resolver."""
|
|
124
|
+
override_arg = "override_mode" if with_override else "False"
|
|
125
|
+
context_arg = context if context else "None"
|
|
126
|
+
store_arg = "True" if store else "False"
|
|
127
|
+
|
|
128
|
+
if is_async:
|
|
129
|
+
lines.append(
|
|
130
|
+
f"{indent}return await _create_instance("
|
|
131
|
+
f"container, {context_arg}, {store_arg}, {defaults}, {override_arg})"
|
|
132
|
+
)
|
|
133
|
+
else:
|
|
134
|
+
lines.append(
|
|
135
|
+
f"{indent}return _create_instance("
|
|
136
|
+
f"container, {context_arg}, {store_arg}, {defaults}, {override_arg})"
|
|
137
|
+
)
|
|
138
|
+
|
|
71
139
|
def _compile_resolver( # noqa: C901
|
|
72
|
-
self, provider: Provider, *, is_async: bool
|
|
140
|
+
self, provider: Provider, *, is_async: bool, with_override: bool = False
|
|
73
141
|
) -> CompiledResolver:
|
|
74
142
|
"""Compile optimized resolver functions for the given provider."""
|
|
75
|
-
has_override_support = self._has_override_support
|
|
76
|
-
wrap_dependencies = self._wrap_dependencies
|
|
77
|
-
wrap_instance = self._wrap_instance
|
|
78
143
|
num_params = len(provider.parameters)
|
|
79
144
|
param_resolvers: list[Any] = [None] * num_params
|
|
80
145
|
param_annotations: list[Any] = [None] * num_params
|
|
@@ -84,7 +149,11 @@ class Resolver:
|
|
|
84
149
|
param_shared_scopes: list[bool] = [False] * num_params
|
|
85
150
|
unresolved_messages: list[str] = [""] * num_params
|
|
86
151
|
|
|
87
|
-
cache =
|
|
152
|
+
cache = (
|
|
153
|
+
(self._async_override_cache if is_async else self._override_cache)
|
|
154
|
+
if with_override
|
|
155
|
+
else (self._async_cache if is_async else self._cache)
|
|
156
|
+
)
|
|
88
157
|
|
|
89
158
|
for idx, p in enumerate(provider.parameters):
|
|
90
159
|
param_annotations[idx] = p.annotation
|
|
@@ -100,13 +169,13 @@ class Resolver:
|
|
|
100
169
|
cache[p.provider.interface] = compiled
|
|
101
170
|
param_resolvers[idx] = compiled.resolve
|
|
102
171
|
|
|
103
|
-
|
|
172
|
+
unresolved_message = (
|
|
104
173
|
f"You are attempting to get the parameter `{p.name}` with the "
|
|
105
174
|
f"annotation `{type_repr(p.annotation)}` as a dependency into "
|
|
106
175
|
f"`{type_repr(provider.call)}` which is not registered or set in the "
|
|
107
176
|
"scoped context."
|
|
108
177
|
)
|
|
109
|
-
unresolved_messages[idx] =
|
|
178
|
+
unresolved_messages[idx] = unresolved_message
|
|
110
179
|
|
|
111
180
|
scope = provider.scope
|
|
112
181
|
is_generator = provider.is_generator
|
|
@@ -117,11 +186,13 @@ class Resolver:
|
|
|
117
186
|
create_lines: list[str] = []
|
|
118
187
|
if is_async:
|
|
119
188
|
create_lines.append(
|
|
120
|
-
"async def _create_instance(
|
|
189
|
+
"async def _create_instance("
|
|
190
|
+
"container, context, store, defaults, override_mode):"
|
|
121
191
|
)
|
|
122
192
|
else:
|
|
123
193
|
create_lines.append(
|
|
124
|
-
"def _create_instance(
|
|
194
|
+
"def _create_instance("
|
|
195
|
+
"container, context, store, defaults, override_mode):"
|
|
125
196
|
)
|
|
126
197
|
|
|
127
198
|
if no_params:
|
|
@@ -170,8 +241,10 @@ class Resolver:
|
|
|
170
241
|
create_lines.append(
|
|
171
242
|
f" raise LookupError(_unresolved_messages[{idx}])"
|
|
172
243
|
)
|
|
173
|
-
create_lines.append(
|
|
174
|
-
|
|
244
|
+
create_lines.append(
|
|
245
|
+
f" _dep_resolver = _param_resolvers[{idx}]"
|
|
246
|
+
)
|
|
247
|
+
create_lines.append(" if _dep_resolver is None:")
|
|
175
248
|
create_lines.append(" try:")
|
|
176
249
|
if is_async:
|
|
177
250
|
create_lines.append(
|
|
@@ -230,21 +303,23 @@ class Resolver:
|
|
|
230
303
|
create_lines.append(" else:")
|
|
231
304
|
if is_async:
|
|
232
305
|
create_lines.append(
|
|
233
|
-
f" arg_{idx} = await
|
|
306
|
+
f" arg_{idx} = await _dep_resolver("
|
|
234
307
|
f"container, "
|
|
235
308
|
f"context if _param_shared_scopes[{idx}] else None)"
|
|
236
309
|
)
|
|
237
310
|
else:
|
|
238
311
|
create_lines.append(
|
|
239
|
-
f" arg_{idx} =
|
|
312
|
+
f" arg_{idx} = _dep_resolver("
|
|
240
313
|
f"container, "
|
|
241
314
|
f"context if _param_shared_scopes[{idx}] else None)"
|
|
242
315
|
)
|
|
243
316
|
create_lines.append(" else:")
|
|
244
317
|
create_lines.append(f" arg_{idx} = cached")
|
|
245
|
-
if
|
|
318
|
+
# Wrap dependencies if in override mode (only for override version)
|
|
319
|
+
if with_override:
|
|
320
|
+
create_lines.append(" if override_mode:")
|
|
246
321
|
create_lines.append(
|
|
247
|
-
f"
|
|
322
|
+
f" arg_{idx} = resolver._wrap_for_override("
|
|
248
323
|
f"_param_annotations[{idx}], arg_{idx})"
|
|
249
324
|
)
|
|
250
325
|
|
|
@@ -380,9 +455,11 @@ class Resolver:
|
|
|
380
455
|
create_lines.append(" if context is not None and store:")
|
|
381
456
|
create_lines.append(" context.set(_interface, inst)")
|
|
382
457
|
|
|
383
|
-
if
|
|
458
|
+
# Wrap instance if in override mode (only for override version)
|
|
459
|
+
if with_override:
|
|
460
|
+
create_lines.append(" if override_mode:")
|
|
384
461
|
create_lines.append(
|
|
385
|
-
"
|
|
462
|
+
" inst = resolver._post_resolve_override(_interface, inst)"
|
|
386
463
|
)
|
|
387
464
|
create_lines.append(" return inst")
|
|
388
465
|
|
|
@@ -393,7 +470,7 @@ class Resolver:
|
|
|
393
470
|
resolver_lines.append("def _resolver(container, context=None):")
|
|
394
471
|
|
|
395
472
|
# Only define NOT_SET_ if we actually need it
|
|
396
|
-
needs_not_set =
|
|
473
|
+
needs_not_set = scope in ("singleton", "request")
|
|
397
474
|
if needs_not_set:
|
|
398
475
|
resolver_lines.append(" NOT_SET_ = _NOT_SET")
|
|
399
476
|
|
|
@@ -406,22 +483,14 @@ class Resolver:
|
|
|
406
483
|
else:
|
|
407
484
|
resolver_lines.append(" context = None")
|
|
408
485
|
|
|
409
|
-
if has_override_support:
|
|
410
|
-
resolver_lines.append(
|
|
411
|
-
" override = container._hook_override_for(_interface)"
|
|
412
|
-
)
|
|
413
|
-
resolver_lines.append(" if override is not NOT_SET_:")
|
|
414
|
-
resolver_lines.append(" return override")
|
|
415
|
-
|
|
416
486
|
if scope == "singleton":
|
|
487
|
+
if with_override:
|
|
488
|
+
self._add_override_check(resolver_lines)
|
|
489
|
+
|
|
490
|
+
# Fast path: check cached instance
|
|
417
491
|
resolver_lines.append(" inst = context.get(_interface)")
|
|
418
492
|
resolver_lines.append(" if inst is not NOT_SET_:")
|
|
419
|
-
|
|
420
|
-
resolver_lines.append(
|
|
421
|
-
" return container._hook_post_resolve(_provider, inst)"
|
|
422
|
-
)
|
|
423
|
-
else:
|
|
424
|
-
resolver_lines.append(" return inst")
|
|
493
|
+
resolver_lines.append(" return inst")
|
|
425
494
|
|
|
426
495
|
if is_async:
|
|
427
496
|
resolver_lines.append(" async with context.alock():")
|
|
@@ -429,47 +498,43 @@ class Resolver:
|
|
|
429
498
|
resolver_lines.append(" with context.lock():")
|
|
430
499
|
resolver_lines.append(" inst = context.get(_interface)")
|
|
431
500
|
resolver_lines.append(" if inst is not NOT_SET_:")
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
"_create_instance(container, context, True, None)"
|
|
442
|
-
)
|
|
443
|
-
else:
|
|
444
|
-
resolver_lines.append(
|
|
445
|
-
" return _create_instance(container, context, True, None)"
|
|
446
|
-
)
|
|
501
|
+
resolver_lines.append(" return inst")
|
|
502
|
+
self._add_create_call(
|
|
503
|
+
resolver_lines,
|
|
504
|
+
is_async=is_async,
|
|
505
|
+
with_override=with_override,
|
|
506
|
+
context="context",
|
|
507
|
+
store=True,
|
|
508
|
+
indent=" ",
|
|
509
|
+
)
|
|
447
510
|
elif scope == "request":
|
|
511
|
+
if with_override:
|
|
512
|
+
self._add_override_check(resolver_lines)
|
|
513
|
+
|
|
514
|
+
# Fast path: check cached instance
|
|
448
515
|
resolver_lines.append(" inst = context.get(_interface)")
|
|
449
516
|
resolver_lines.append(" if inst is not NOT_SET_:")
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
)
|
|
460
|
-
else:
|
|
461
|
-
resolver_lines.append(
|
|
462
|
-
" return _create_instance(container, context, True, None)"
|
|
463
|
-
)
|
|
517
|
+
resolver_lines.append(" return inst")
|
|
518
|
+
|
|
519
|
+
self._add_create_call(
|
|
520
|
+
resolver_lines,
|
|
521
|
+
is_async=is_async,
|
|
522
|
+
with_override=with_override,
|
|
523
|
+
context="context",
|
|
524
|
+
store=True,
|
|
525
|
+
)
|
|
464
526
|
else:
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
resolver_lines
|
|
471
|
-
|
|
472
|
-
|
|
527
|
+
# Transient scope
|
|
528
|
+
if with_override:
|
|
529
|
+
self._add_override_check(resolver_lines, include_not_set=True)
|
|
530
|
+
|
|
531
|
+
self._add_create_call(
|
|
532
|
+
resolver_lines,
|
|
533
|
+
is_async=is_async,
|
|
534
|
+
with_override=with_override,
|
|
535
|
+
context="",
|
|
536
|
+
store=False,
|
|
537
|
+
)
|
|
473
538
|
|
|
474
539
|
create_resolver_lines: list[str] = []
|
|
475
540
|
if is_async:
|
|
@@ -481,9 +546,9 @@ class Resolver:
|
|
|
481
546
|
"def _resolver_create(container, defaults=None):"
|
|
482
547
|
)
|
|
483
548
|
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
create_resolver_lines.append("
|
|
549
|
+
if with_override:
|
|
550
|
+
# Cache override mode check
|
|
551
|
+
create_resolver_lines.append(" override_mode = resolver.override_mode")
|
|
487
552
|
|
|
488
553
|
if scope == "singleton":
|
|
489
554
|
create_resolver_lines.append(" context = container._singleton_context")
|
|
@@ -494,43 +559,20 @@ class Resolver:
|
|
|
494
559
|
else:
|
|
495
560
|
create_resolver_lines.append(" context = None")
|
|
496
561
|
|
|
497
|
-
if
|
|
498
|
-
|
|
499
|
-
" override = container._hook_override_for(_interface)"
|
|
500
|
-
)
|
|
501
|
-
create_resolver_lines.append(" if override is not NOT_SET_:")
|
|
502
|
-
create_resolver_lines.append(" return override")
|
|
562
|
+
if with_override:
|
|
563
|
+
self._add_override_check(create_resolver_lines, include_not_set=True)
|
|
503
564
|
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
if is_async:
|
|
516
|
-
create_resolver_lines.append(
|
|
517
|
-
" return await "
|
|
518
|
-
"_create_instance(container, context, False, defaults)"
|
|
519
|
-
)
|
|
520
|
-
else:
|
|
521
|
-
create_resolver_lines.append(
|
|
522
|
-
" return _create_instance(container, context, False, defaults)"
|
|
523
|
-
)
|
|
524
|
-
else:
|
|
525
|
-
if is_async:
|
|
526
|
-
create_resolver_lines.append(
|
|
527
|
-
" return await "
|
|
528
|
-
"_create_instance(container, None, False, defaults)"
|
|
529
|
-
)
|
|
530
|
-
else:
|
|
531
|
-
create_resolver_lines.append(
|
|
532
|
-
" return _create_instance(container, None, False, defaults)"
|
|
533
|
-
)
|
|
565
|
+
# Determine context for create call
|
|
566
|
+
context_arg = "context" if scope in ("singleton", "request") else ""
|
|
567
|
+
|
|
568
|
+
self._add_create_call(
|
|
569
|
+
create_resolver_lines,
|
|
570
|
+
is_async=is_async,
|
|
571
|
+
with_override=with_override,
|
|
572
|
+
context=context_arg,
|
|
573
|
+
store=False,
|
|
574
|
+
defaults="defaults",
|
|
575
|
+
)
|
|
534
576
|
|
|
535
577
|
lines = create_lines + [""] + resolver_lines + [""] + create_resolver_lines
|
|
536
578
|
|
|
@@ -552,8 +594,13 @@ class Resolver:
|
|
|
552
594
|
"_NOT_SET": NOT_SET,
|
|
553
595
|
"_contextmanager": contextlib.contextmanager,
|
|
554
596
|
"_is_cm": is_context_manager,
|
|
555
|
-
"_cache":
|
|
597
|
+
"_cache": (
|
|
598
|
+
(self._async_override_cache if is_async else self._override_cache)
|
|
599
|
+
if with_override
|
|
600
|
+
else (self._async_cache if is_async else self._cache)
|
|
601
|
+
),
|
|
556
602
|
"_compile": self._compile_resolver,
|
|
603
|
+
"resolver": self,
|
|
557
604
|
}
|
|
558
605
|
|
|
559
606
|
# Add async-specific namespace entries
|
|
@@ -569,3 +616,59 @@ class Resolver:
|
|
|
569
616
|
creator = ns["_resolver_create"]
|
|
570
617
|
|
|
571
618
|
return CompiledResolver(resolver, creator)
|
|
619
|
+
|
|
620
|
+
def _get_override_for(self, interface: Any) -> Any:
|
|
621
|
+
"""Hook for checking if an interface has an override."""
|
|
622
|
+
return self._override_instances.get(interface, NOT_SET)
|
|
623
|
+
|
|
624
|
+
def _wrap_for_override(self, annotation: Any, value: Any) -> Any:
|
|
625
|
+
"""Hook for wrapping dependencies to enable override patching."""
|
|
626
|
+
return InstanceProxy(value, interface=annotation)
|
|
627
|
+
|
|
628
|
+
def _post_resolve_override(self, interface: Any, instance: Any) -> Any: # noqa: C901
|
|
629
|
+
"""Hook for patching resolved instances to support override."""
|
|
630
|
+
if interface in self._override_instances:
|
|
631
|
+
return self._override_instances[interface]
|
|
632
|
+
|
|
633
|
+
if not hasattr(instance, "__dict__") or hasattr(
|
|
634
|
+
instance, "__resolver_getter__"
|
|
635
|
+
):
|
|
636
|
+
return instance
|
|
637
|
+
|
|
638
|
+
wrapped = {
|
|
639
|
+
name: value.interface
|
|
640
|
+
for name, value in instance.__dict__.items()
|
|
641
|
+
if isinstance(value, InstanceProxy)
|
|
642
|
+
}
|
|
643
|
+
|
|
644
|
+
def __resolver_getter__(name: str) -> Any:
|
|
645
|
+
if name in wrapped:
|
|
646
|
+
_interface = wrapped[name]
|
|
647
|
+
# Resolve the dependency if it's wrapped
|
|
648
|
+
return self._container.resolve(_interface)
|
|
649
|
+
raise LookupError
|
|
650
|
+
|
|
651
|
+
# Attach the resolver getter to the instance
|
|
652
|
+
instance.__resolver_getter__ = __resolver_getter__
|
|
653
|
+
|
|
654
|
+
if not hasattr(instance.__class__, "__getattribute_patched__"):
|
|
655
|
+
|
|
656
|
+
def __getattribute__(_self: Any, name: str) -> Any:
|
|
657
|
+
# Skip the resolver getter
|
|
658
|
+
if name in {"__resolver_getter__", "__class__"}:
|
|
659
|
+
return object.__getattribute__(_self, name)
|
|
660
|
+
|
|
661
|
+
if hasattr(_self, "__resolver_getter__"):
|
|
662
|
+
try:
|
|
663
|
+
return _self.__resolver_getter__(name)
|
|
664
|
+
except LookupError:
|
|
665
|
+
pass
|
|
666
|
+
|
|
667
|
+
# Fall back to default behavior
|
|
668
|
+
return object.__getattribute__(_self, name)
|
|
669
|
+
|
|
670
|
+
# Apply the patched resolver if wrapped attributes exist
|
|
671
|
+
instance.__class__.__getattribute__ = __getattribute__
|
|
672
|
+
instance.__class__.__getattribute_patched__ = True
|
|
673
|
+
|
|
674
|
+
return instance
|
anydi/_scanner.py
CHANGED
|
@@ -3,13 +3,12 @@ from __future__ import annotations
|
|
|
3
3
|
import importlib
|
|
4
4
|
import inspect
|
|
5
5
|
import pkgutil
|
|
6
|
-
from collections.abc import Callable, Iterable
|
|
6
|
+
from collections.abc import Callable, Iterable, Iterator
|
|
7
7
|
from dataclasses import dataclass
|
|
8
8
|
from types import ModuleType
|
|
9
9
|
from typing import TYPE_CHECKING, Any
|
|
10
10
|
|
|
11
|
-
from ._decorators import is_injectable
|
|
12
|
-
from ._types import is_inject_marker
|
|
11
|
+
from ._decorators import Provided, is_injectable, is_provided
|
|
13
12
|
|
|
14
13
|
if TYPE_CHECKING:
|
|
15
14
|
from ._container import Container
|
|
@@ -38,45 +37,64 @@ class Scanner:
|
|
|
38
37
|
) -> None:
|
|
39
38
|
"""Scan packages or modules for decorated members and inject dependencies."""
|
|
40
39
|
if isinstance(packages, (ModuleType, str)):
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
40
|
+
packages = [packages]
|
|
41
|
+
|
|
42
|
+
tags_list = list(tags) if tags else []
|
|
43
|
+
provided_classes: list[type[Provided]] = []
|
|
44
|
+
injectable_dependencies: list[ScannedDependency] = []
|
|
45
|
+
|
|
46
|
+
# Single pass: collect both @provided classes and @injectable functions
|
|
47
|
+
for module in self._iter_modules(packages):
|
|
48
|
+
provided_classes.extend(self._scan_module_for_provided(module))
|
|
49
|
+
injectable_dependencies.extend(
|
|
50
|
+
self._scan_module_for_injectable(module, tags=tags_list)
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
# First: register @provided classes
|
|
54
|
+
for cls in provided_classes:
|
|
55
|
+
if not self._container.is_registered(cls):
|
|
56
|
+
scope = cls.__provided__["scope"]
|
|
57
|
+
self._container.register(cls, scope=scope)
|
|
58
|
+
|
|
59
|
+
# Second: inject @injectable functions
|
|
60
|
+
for dependency in injectable_dependencies:
|
|
52
61
|
decorated = self._container.inject()(dependency.member)
|
|
53
62
|
setattr(dependency.module, dependency.member.__name__, decorated)
|
|
54
63
|
|
|
55
|
-
def
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
64
|
+
def _iter_modules(self, packages: Iterable[Package]) -> Iterator[ModuleType]:
|
|
65
|
+
"""Iterate over all modules in the given packages."""
|
|
66
|
+
for package in packages:
|
|
67
|
+
if isinstance(package, str):
|
|
68
|
+
package = importlib.import_module(package)
|
|
60
69
|
|
|
61
|
-
|
|
62
|
-
|
|
70
|
+
# Single module (not a package)
|
|
71
|
+
if not hasattr(package, "__path__"):
|
|
72
|
+
yield package
|
|
73
|
+
continue
|
|
63
74
|
|
|
64
|
-
|
|
65
|
-
|
|
75
|
+
# Package - walk all submodules
|
|
76
|
+
for module_info in pkgutil.walk_packages(
|
|
77
|
+
package.__path__, prefix=package.__name__ + "."
|
|
78
|
+
):
|
|
79
|
+
yield importlib.import_module(module_info.name)
|
|
66
80
|
|
|
67
|
-
|
|
68
|
-
for
|
|
69
|
-
|
|
70
|
-
):
|
|
71
|
-
module = importlib.import_module(module_info.name)
|
|
72
|
-
dependencies.extend(self._scan_module(module, tags=tags))
|
|
81
|
+
def _scan_module_for_provided(self, module: ModuleType) -> list[type[Provided]]:
|
|
82
|
+
"""Scan a module for @provided classes."""
|
|
83
|
+
provided_classes: list[type[Provided]] = []
|
|
73
84
|
|
|
74
|
-
|
|
85
|
+
for _, member in inspect.getmembers(module, predicate=inspect.isclass):
|
|
86
|
+
if getattr(member, "__module__", None) != module.__name__:
|
|
87
|
+
continue
|
|
88
|
+
|
|
89
|
+
if is_provided(member):
|
|
90
|
+
provided_classes.append(member)
|
|
75
91
|
|
|
76
|
-
|
|
77
|
-
|
|
92
|
+
return provided_classes
|
|
93
|
+
|
|
94
|
+
def _scan_module_for_injectable(
|
|
95
|
+
self, module: ModuleType, *, tags: list[str]
|
|
78
96
|
) -> list[ScannedDependency]:
|
|
79
|
-
"""Scan a module for
|
|
97
|
+
"""Scan a module for @injectable functions."""
|
|
80
98
|
dependencies: list[ScannedDependency] = []
|
|
81
99
|
|
|
82
100
|
for _, member in inspect.getmembers(module, predicate=callable):
|
|
@@ -89,22 +107,12 @@ class Scanner:
|
|
|
89
107
|
return dependencies
|
|
90
108
|
|
|
91
109
|
@staticmethod
|
|
92
|
-
def _should_include_member(
|
|
93
|
-
member: Callable[..., Any], *, tags: Iterable[str]
|
|
94
|
-
) -> bool:
|
|
110
|
+
def _should_include_member(member: Callable[..., Any], *, tags: list[str]) -> bool:
|
|
95
111
|
"""Determine if a member should be included based on tags or marker defaults."""
|
|
96
|
-
|
|
97
112
|
if is_injectable(member):
|
|
98
113
|
member_tags = set(member.__injectable__["tags"] or [])
|
|
99
114
|
if tags:
|
|
100
115
|
return bool(set(tags) & member_tags)
|
|
101
116
|
return True # No tags passed → include all injectables
|
|
102
117
|
|
|
103
|
-
# If no tags are passed and not explicitly injectable,
|
|
104
|
-
# check for parameter markers
|
|
105
|
-
if not tags:
|
|
106
|
-
for parameter in inspect.signature(member).parameters.values():
|
|
107
|
-
if is_inject_marker(parameter.default):
|
|
108
|
-
return True
|
|
109
|
-
|
|
110
118
|
return False
|