anydi 0.67.2__py3-none-any.whl → 0.69.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- anydi/_cli.py +80 -0
- anydi/_container.py +697 -263
- anydi/_context.py +14 -14
- anydi/_decorators.py +115 -8
- anydi/_graph.py +217 -0
- anydi/_injector.py +12 -10
- anydi/_marker.py +11 -13
- anydi/_provider.py +46 -8
- anydi/_resolver.py +205 -159
- anydi/_scanner.py +46 -7
- anydi/ext/fastapi.py +1 -1
- anydi/ext/faststream.py +1 -1
- anydi/ext/pydantic_settings.py +3 -3
- anydi/ext/pytest_plugin.py +125 -380
- anydi/ext/typer.py +4 -4
- {anydi-0.67.2.dist-info → anydi-0.69.0.dist-info}/METADATA +1 -1
- anydi-0.69.0.dist-info/RECORD +29 -0
- {anydi-0.67.2.dist-info → anydi-0.69.0.dist-info}/entry_points.txt +3 -0
- anydi-0.67.2.dist-info/RECORD +0 -27
- {anydi-0.67.2.dist-info → anydi-0.69.0.dist-info}/WHEEL +0 -0
anydi/_resolver.py
CHANGED
|
@@ -19,13 +19,13 @@ if TYPE_CHECKING:
|
|
|
19
19
|
class InstanceProxy(wrapt.ObjectProxy): # type: ignore
|
|
20
20
|
"""Proxy for dependency instances to enable override support."""
|
|
21
21
|
|
|
22
|
-
def __init__(self, wrapped: Any, *,
|
|
22
|
+
def __init__(self, wrapped: Any, *, dependency_type: Any) -> None:
|
|
23
23
|
super().__init__(wrapped) # type: ignore
|
|
24
|
-
self.
|
|
24
|
+
self._self_dependency_type = dependency_type
|
|
25
25
|
|
|
26
26
|
@property
|
|
27
|
-
def
|
|
28
|
-
return self.
|
|
27
|
+
def dependency_type(self) -> Any:
|
|
28
|
+
return self._self_dependency_type
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
class CompiledResolver(NamedTuple):
|
|
@@ -36,7 +36,6 @@ class CompiledResolver(NamedTuple):
|
|
|
36
36
|
class Resolver:
|
|
37
37
|
def __init__(self, container: Container) -> None:
|
|
38
38
|
self._container = container
|
|
39
|
-
self._unresolved_interfaces: set[Any] = set()
|
|
40
39
|
# Normal caches (fast path, no override checks)
|
|
41
40
|
self._cache: dict[Any, CompiledResolver] = {}
|
|
42
41
|
self._async_cache: dict[Any, CompiledResolver] = {}
|
|
@@ -44,23 +43,20 @@ class Resolver:
|
|
|
44
43
|
self._override_cache: dict[Any, CompiledResolver] = {}
|
|
45
44
|
self._async_override_cache: dict[Any, CompiledResolver] = {}
|
|
46
45
|
# Override instances storage
|
|
47
|
-
self.
|
|
46
|
+
self._overrides: dict[Any, Any] = {}
|
|
48
47
|
|
|
49
48
|
@property
|
|
50
49
|
def override_mode(self) -> bool:
|
|
51
50
|
"""Check if override mode is enabled."""
|
|
52
|
-
return bool(self.
|
|
51
|
+
return bool(self._overrides) or getattr(self._container, "_test_mode", False)
|
|
53
52
|
|
|
54
|
-
def add_override(self,
|
|
55
|
-
"""Add an override instance for
|
|
56
|
-
self.
|
|
53
|
+
def add_override(self, dependency_type: Any, instance: Any) -> None:
|
|
54
|
+
"""Add an override instance for a dependency type."""
|
|
55
|
+
self._overrides[dependency_type] = instance
|
|
57
56
|
|
|
58
|
-
def remove_override(self,
|
|
59
|
-
"""Remove an override instance for
|
|
60
|
-
self.
|
|
61
|
-
|
|
62
|
-
def add_unresolved(self, interface: Any) -> None:
|
|
63
|
-
self._unresolved_interfaces.add(interface)
|
|
57
|
+
def remove_override(self, dependency_type: Any) -> None:
|
|
58
|
+
"""Remove an override instance for a dependency type."""
|
|
59
|
+
self._overrides.pop(dependency_type, None)
|
|
64
60
|
|
|
65
61
|
def clear_caches(self) -> None:
|
|
66
62
|
"""Clear all cached resolvers."""
|
|
@@ -69,13 +65,15 @@ class Resolver:
|
|
|
69
65
|
self._override_cache.clear()
|
|
70
66
|
self._async_override_cache.clear()
|
|
71
67
|
|
|
72
|
-
def get_cached(
|
|
68
|
+
def get_cached(
|
|
69
|
+
self, dependency_type: Any, *, is_async: bool
|
|
70
|
+
) -> CompiledResolver | None:
|
|
73
71
|
"""Get cached resolver if it exists."""
|
|
74
72
|
if self.override_mode:
|
|
75
73
|
cache = self._async_override_cache if is_async else self._override_cache
|
|
76
74
|
else:
|
|
77
75
|
cache = self._async_cache if is_async else self._cache
|
|
78
|
-
return cache.get(
|
|
76
|
+
return cache.get(dependency_type)
|
|
79
77
|
|
|
80
78
|
def compile(self, provider: Provider, *, is_async: bool) -> CompiledResolver:
|
|
81
79
|
"""Compile an optimized resolver function for the given provider."""
|
|
@@ -86,14 +84,14 @@ class Resolver:
|
|
|
86
84
|
cache = self._async_cache if is_async else self._cache
|
|
87
85
|
|
|
88
86
|
# Check if already compiled in cache
|
|
89
|
-
if provider.
|
|
90
|
-
return cache[provider.
|
|
87
|
+
if provider.dependency_type in cache:
|
|
88
|
+
return cache[provider.dependency_type]
|
|
91
89
|
|
|
92
90
|
# Recursively compile dependencies first
|
|
93
91
|
for param in provider.parameters:
|
|
94
92
|
if param.provider is not None:
|
|
95
93
|
# Look up the current provider to handle overrides
|
|
96
|
-
current_provider = self._container.providers.get(param.
|
|
94
|
+
current_provider = self._container.providers.get(param.dependency_type)
|
|
97
95
|
if current_provider is not None:
|
|
98
96
|
self.compile(current_provider, is_async=is_async)
|
|
99
97
|
else:
|
|
@@ -105,7 +103,7 @@ class Resolver:
|
|
|
105
103
|
)
|
|
106
104
|
|
|
107
105
|
# Store the compiled functions in the cache
|
|
108
|
-
cache[provider.
|
|
106
|
+
cache[provider.dependency_type] = compiled
|
|
109
107
|
|
|
110
108
|
return compiled
|
|
111
109
|
|
|
@@ -117,7 +115,7 @@ class Resolver:
|
|
|
117
115
|
lines.append(" if override_mode:")
|
|
118
116
|
if include_not_set:
|
|
119
117
|
lines.append(" NOT_SET_ = _NOT_SET")
|
|
120
|
-
lines.append(" override = resolver._get_override_for(
|
|
118
|
+
lines.append(" override = resolver._get_override_for(_dependency_type)")
|
|
121
119
|
lines.append(" if override is not NOT_SET_:")
|
|
122
120
|
lines.append(" return override")
|
|
123
121
|
|
|
@@ -152,14 +150,21 @@ class Resolver:
|
|
|
152
150
|
self, provider: Provider, *, is_async: bool, with_override: bool = False
|
|
153
151
|
) -> CompiledResolver:
|
|
154
152
|
"""Compile optimized resolver functions for the given provider."""
|
|
153
|
+
# Handle from_context providers with simplified code generation
|
|
154
|
+
if provider.from_context:
|
|
155
|
+
return self._compile_from_context_resolver(
|
|
156
|
+
provider, is_async=is_async, with_override=with_override
|
|
157
|
+
)
|
|
158
|
+
|
|
155
159
|
num_params = len(provider.parameters)
|
|
156
160
|
param_resolvers: list[Any] = [None] * num_params
|
|
157
|
-
|
|
161
|
+
param_types: list[Any] = [None] * num_params
|
|
158
162
|
param_defaults: list[Any] = [None] * num_params
|
|
159
163
|
param_has_default: list[bool] = [False] * num_params
|
|
160
164
|
param_names: list[str] = [""] * num_params
|
|
161
165
|
param_shared_scopes: list[bool] = [False] * num_params
|
|
162
|
-
|
|
166
|
+
# Track unresolved messages for params with provider=None
|
|
167
|
+
unresolved_messages: dict[int, str] = {}
|
|
163
168
|
|
|
164
169
|
cache = (
|
|
165
170
|
(self._async_override_cache if is_async else self._override_cache)
|
|
@@ -168,7 +173,7 @@ class Resolver:
|
|
|
168
173
|
)
|
|
169
174
|
|
|
170
175
|
for idx, param in enumerate(provider.parameters):
|
|
171
|
-
|
|
176
|
+
param_types[idx] = param.dependency_type
|
|
172
177
|
param_defaults[idx] = param.default
|
|
173
178
|
param_has_default[idx] = param.has_default
|
|
174
179
|
param_names[idx] = param.name
|
|
@@ -176,24 +181,24 @@ class Resolver:
|
|
|
176
181
|
|
|
177
182
|
if param.provider is not None:
|
|
178
183
|
# Look up the current provider from the container to handle overrides
|
|
179
|
-
current_provider = self._container.providers.get(param.
|
|
184
|
+
current_provider = self._container.providers.get(param.dependency_type)
|
|
180
185
|
if current_provider is not None:
|
|
181
|
-
compiled = cache.get(current_provider.
|
|
186
|
+
compiled = cache.get(current_provider.dependency_type)
|
|
182
187
|
else:
|
|
183
188
|
# Fallback to the original provider if not in container
|
|
184
|
-
compiled = cache.get(param.provider.
|
|
189
|
+
compiled = cache.get(param.provider.dependency_type)
|
|
185
190
|
if compiled is None:
|
|
186
191
|
compiled = self.compile(param.provider, is_async=is_async)
|
|
187
|
-
cache[param.provider.
|
|
192
|
+
cache[param.provider.dependency_type] = compiled
|
|
188
193
|
param_resolvers[idx] = compiled.resolve
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
194
|
+
else:
|
|
195
|
+
# Generate unresolved message for params without a provider
|
|
196
|
+
unresolved_messages[idx] = (
|
|
197
|
+
f"You are attempting to get the parameter `{param.name}` with the "
|
|
198
|
+
f"annotation `{type_repr(param.dependency_type)}` as a dependency "
|
|
199
|
+
f"into `{provider}` which is not registered or set in the "
|
|
200
|
+
f"scoped context."
|
|
201
|
+
)
|
|
197
202
|
|
|
198
203
|
scope = provider.scope
|
|
199
204
|
is_generator = provider.is_generator
|
|
@@ -219,7 +224,7 @@ class Resolver:
|
|
|
219
224
|
create_lines.append(" if _is_async:")
|
|
220
225
|
create_lines.append(
|
|
221
226
|
" raise TypeError("
|
|
222
|
-
'f"The instance for the provider `{
|
|
227
|
+
'f"The instance for the provider `{_dependency_repr}` '
|
|
223
228
|
'cannot be created in synchronous mode."'
|
|
224
229
|
")"
|
|
225
230
|
)
|
|
@@ -230,7 +235,7 @@ class Resolver:
|
|
|
230
235
|
create_lines.append(" if _is_async:")
|
|
231
236
|
create_lines.append(
|
|
232
237
|
" raise TypeError("
|
|
233
|
-
'f"The instance for the provider `{
|
|
238
|
+
'f"The instance for the provider `{_dependency_repr}` '
|
|
234
239
|
'cannot be created in synchronous mode."'
|
|
235
240
|
")"
|
|
236
241
|
)
|
|
@@ -240,6 +245,8 @@ class Resolver:
|
|
|
240
245
|
if not no_params:
|
|
241
246
|
# Only generate parameter resolution logic if there are parameters
|
|
242
247
|
for idx, name in enumerate(param_names):
|
|
248
|
+
is_from_context = idx in unresolved_messages
|
|
249
|
+
|
|
243
250
|
create_lines.append(f" # resolve param `{name}`")
|
|
244
251
|
create_lines.append(
|
|
245
252
|
f" if defaults is not None and '{name}' in defaults:"
|
|
@@ -249,92 +256,45 @@ class Resolver:
|
|
|
249
256
|
# Direct dict access for shared scope params (avoids method call)
|
|
250
257
|
if param_shared_scopes[idx]:
|
|
251
258
|
create_lines.append(
|
|
252
|
-
f" cached = (context.
|
|
253
|
-
f"
|
|
259
|
+
f" cached = (context._items.get("
|
|
260
|
+
f"_param_types[{idx}], NOT_SET_) "
|
|
254
261
|
f"if context is not None else NOT_SET_)"
|
|
255
262
|
)
|
|
256
263
|
else:
|
|
257
264
|
create_lines.append(" cached = NOT_SET_")
|
|
258
265
|
create_lines.append(" if cached is NOT_SET_:")
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
f" compiled = "
|
|
274
|
-
f"cache.get(_param_annotations[{idx}])"
|
|
275
|
-
)
|
|
276
|
-
create_lines.append(" if compiled is None:")
|
|
277
|
-
create_lines.append(
|
|
278
|
-
" provider = "
|
|
279
|
-
"container._get_or_register_provider("
|
|
280
|
-
f"_param_annotations[{idx}])"
|
|
281
|
-
)
|
|
282
|
-
create_lines.append(
|
|
283
|
-
" compiled = "
|
|
284
|
-
"_compile(provider, is_async=True)"
|
|
285
|
-
)
|
|
286
|
-
create_lines.append(
|
|
287
|
-
" cache[provider.interface] = compiled"
|
|
288
|
-
)
|
|
289
|
-
create_lines.append(
|
|
290
|
-
f" arg_{idx} = "
|
|
291
|
-
f"await compiled[0](container, "
|
|
292
|
-
f"context if _param_shared_scopes[{idx}] else None)"
|
|
293
|
-
)
|
|
294
|
-
else:
|
|
295
|
-
create_lines.append(
|
|
296
|
-
f" compiled = "
|
|
297
|
-
f"cache.get(_param_annotations[{idx}])"
|
|
298
|
-
)
|
|
299
|
-
create_lines.append(" if compiled is None:")
|
|
300
|
-
create_lines.append(
|
|
301
|
-
" provider = "
|
|
302
|
-
f"container._get_or_register_provider(_param_annotations[{idx}])"
|
|
303
|
-
)
|
|
304
|
-
create_lines.append(
|
|
305
|
-
" compiled = "
|
|
306
|
-
"_compile(provider, is_async=False)"
|
|
307
|
-
)
|
|
308
|
-
create_lines.append(
|
|
309
|
-
" cache[provider.interface] = compiled"
|
|
310
|
-
)
|
|
311
|
-
create_lines.append(
|
|
312
|
-
f" arg_{idx} = "
|
|
313
|
-
f"compiled[0](container, "
|
|
314
|
-
f"context if _param_shared_scopes[{idx}] else None)"
|
|
315
|
-
)
|
|
316
|
-
create_lines.append(" except LookupError:")
|
|
317
|
-
create_lines.append(
|
|
318
|
-
f" if _param_has_default[{idx}]:"
|
|
319
|
-
)
|
|
320
|
-
create_lines.append(
|
|
321
|
-
f" arg_{idx} = _param_defaults[{idx}]"
|
|
322
|
-
)
|
|
323
|
-
create_lines.append(" else:")
|
|
324
|
-
create_lines.append(" raise")
|
|
325
|
-
create_lines.append(" else:")
|
|
326
|
-
if is_async:
|
|
327
|
-
create_lines.append(
|
|
328
|
-
f" arg_{idx} = await _dep_resolver("
|
|
329
|
-
f"container, "
|
|
330
|
-
f"context if _param_shared_scopes[{idx}] else None)"
|
|
331
|
-
)
|
|
266
|
+
|
|
267
|
+
if is_from_context:
|
|
268
|
+
# Unresolved param without provider
|
|
269
|
+
if param_has_default[idx]:
|
|
270
|
+
# Has default, use it
|
|
271
|
+
create_lines.append(
|
|
272
|
+
f" arg_{idx} = _param_defaults[{idx}]"
|
|
273
|
+
)
|
|
274
|
+
else:
|
|
275
|
+
# No default, raise
|
|
276
|
+
create_lines.append(
|
|
277
|
+
" raise LookupError("
|
|
278
|
+
f"_unresolved_messages[{idx}])"
|
|
279
|
+
)
|
|
332
280
|
else:
|
|
281
|
+
# Has a pre-compiled resolver, use it directly
|
|
333
282
|
create_lines.append(
|
|
334
|
-
f"
|
|
335
|
-
f"container, "
|
|
336
|
-
f"context if _param_shared_scopes[{idx}] else None)"
|
|
283
|
+
f" _dep_resolver = _param_resolvers[{idx}]"
|
|
337
284
|
)
|
|
285
|
+
if is_async:
|
|
286
|
+
create_lines.append(
|
|
287
|
+
f" arg_{idx} = await _dep_resolver("
|
|
288
|
+
f"container, context if _param_shared_scopes[{idx}] "
|
|
289
|
+
"else None)"
|
|
290
|
+
)
|
|
291
|
+
else:
|
|
292
|
+
create_lines.append(
|
|
293
|
+
f" arg_{idx} = _dep_resolver("
|
|
294
|
+
f"container, context if _param_shared_scopes[{idx}] "
|
|
295
|
+
"else None)"
|
|
296
|
+
)
|
|
297
|
+
|
|
338
298
|
create_lines.append(" else:")
|
|
339
299
|
create_lines.append(f" arg_{idx} = cached")
|
|
340
300
|
# Wrap dependencies if in override mode (only for override version)
|
|
@@ -342,7 +302,7 @@ class Resolver:
|
|
|
342
302
|
create_lines.append(" if override_mode:")
|
|
343
303
|
create_lines.append(
|
|
344
304
|
f" arg_{idx} = resolver._wrap_for_override("
|
|
345
|
-
f"
|
|
305
|
+
f"_param_types[{idx}], arg_{idx})"
|
|
346
306
|
)
|
|
347
307
|
|
|
348
308
|
# Handle different provider types
|
|
@@ -359,15 +319,19 @@ class Resolver:
|
|
|
359
319
|
create_lines.append(" }")
|
|
360
320
|
create_lines.append(" call_kwargs.update(defaults)")
|
|
361
321
|
create_lines.append(
|
|
362
|
-
" inst = await
|
|
322
|
+
" inst = await _provider_factory(**call_kwargs)"
|
|
363
323
|
)
|
|
364
324
|
create_lines.append(" else:")
|
|
365
|
-
create_lines.append(
|
|
325
|
+
create_lines.append(
|
|
326
|
+
f" inst = await _provider_factory({call_args})"
|
|
327
|
+
)
|
|
366
328
|
else:
|
|
367
329
|
create_lines.append(" if defaults is not None:")
|
|
368
|
-
create_lines.append(
|
|
330
|
+
create_lines.append(
|
|
331
|
+
" inst = await _provider_factory(**defaults)"
|
|
332
|
+
)
|
|
369
333
|
create_lines.append(" else:")
|
|
370
|
-
create_lines.append(" inst = await
|
|
334
|
+
create_lines.append(" inst = await _provider_factory()")
|
|
371
335
|
elif is_async and is_async_generator:
|
|
372
336
|
# Async generator - use async context manager
|
|
373
337
|
create_lines.append(" if context is None:")
|
|
@@ -386,20 +350,21 @@ class Resolver:
|
|
|
386
350
|
create_lines.append(" }")
|
|
387
351
|
create_lines.append(" call_kwargs.update(defaults)")
|
|
388
352
|
create_lines.append(
|
|
389
|
-
" cm =
|
|
353
|
+
" cm = "
|
|
354
|
+
"_asynccontextmanager(_provider_factory)(**call_kwargs)"
|
|
390
355
|
)
|
|
391
356
|
create_lines.append(" else:")
|
|
392
357
|
create_lines.append(
|
|
393
|
-
f" cm = _asynccontextmanager(
|
|
358
|
+
f" cm = _asynccontextmanager(_provider_factory)({call_args})"
|
|
394
359
|
)
|
|
395
360
|
else:
|
|
396
361
|
create_lines.append(" if defaults is not None:")
|
|
397
362
|
create_lines.append(
|
|
398
|
-
" cm = _asynccontextmanager(
|
|
363
|
+
" cm = _asynccontextmanager(_provider_factory)(**defaults)"
|
|
399
364
|
)
|
|
400
365
|
create_lines.append(" else:")
|
|
401
366
|
create_lines.append(
|
|
402
|
-
" cm = _asynccontextmanager(
|
|
367
|
+
" cm = _asynccontextmanager(_provider_factory)()"
|
|
403
368
|
)
|
|
404
369
|
create_lines.append(" inst = await context.aenter(cm)")
|
|
405
370
|
elif is_generator:
|
|
@@ -420,19 +385,19 @@ class Resolver:
|
|
|
420
385
|
create_lines.append(" }")
|
|
421
386
|
create_lines.append(" call_kwargs.update(defaults)")
|
|
422
387
|
create_lines.append(
|
|
423
|
-
" cm = _contextmanager(
|
|
388
|
+
" cm = _contextmanager(_provider_factory)(**call_kwargs)"
|
|
424
389
|
)
|
|
425
390
|
create_lines.append(" else:")
|
|
426
391
|
create_lines.append(
|
|
427
|
-
f" cm = _contextmanager(
|
|
392
|
+
f" cm = _contextmanager(_provider_factory)({call_args})"
|
|
428
393
|
)
|
|
429
394
|
else:
|
|
430
395
|
create_lines.append(" if defaults is not None:")
|
|
431
396
|
create_lines.append(
|
|
432
|
-
" cm = _contextmanager(
|
|
397
|
+
" cm = _contextmanager(_provider_factory)(**defaults)"
|
|
433
398
|
)
|
|
434
399
|
create_lines.append(" else:")
|
|
435
|
-
create_lines.append(" cm = _contextmanager(
|
|
400
|
+
create_lines.append(" cm = _contextmanager(_provider_factory)()")
|
|
436
401
|
if is_async:
|
|
437
402
|
# In async mode, run sync context manager enter in thread
|
|
438
403
|
create_lines.append(" inst = await _run_sync(context.enter, cm)")
|
|
@@ -449,14 +414,14 @@ class Resolver:
|
|
|
449
414
|
create_lines.append(f" '{name}': arg_{idx},")
|
|
450
415
|
create_lines.append(" }")
|
|
451
416
|
create_lines.append(" call_kwargs.update(defaults)")
|
|
452
|
-
create_lines.append(" inst =
|
|
417
|
+
create_lines.append(" inst = _provider_factory(**call_kwargs)")
|
|
453
418
|
create_lines.append(" else:")
|
|
454
|
-
create_lines.append(f" inst =
|
|
419
|
+
create_lines.append(f" inst = _provider_factory({call_args})")
|
|
455
420
|
else:
|
|
456
421
|
create_lines.append(" if defaults is not None:")
|
|
457
|
-
create_lines.append(" inst =
|
|
422
|
+
create_lines.append(" inst = _provider_factory(**defaults)")
|
|
458
423
|
create_lines.append(" else:")
|
|
459
|
-
create_lines.append(" inst =
|
|
424
|
+
create_lines.append(" inst = _provider_factory()")
|
|
460
425
|
|
|
461
426
|
# Handle context managers
|
|
462
427
|
if is_async:
|
|
@@ -475,13 +440,13 @@ class Resolver:
|
|
|
475
440
|
create_lines.append(" context.enter(inst)")
|
|
476
441
|
|
|
477
442
|
create_lines.append(" if context is not None and store:")
|
|
478
|
-
create_lines.append(" context.
|
|
443
|
+
create_lines.append(" context._items[_dependency_type] = inst")
|
|
479
444
|
|
|
480
445
|
# Wrap instance if in override mode (only for override version)
|
|
481
446
|
if with_override:
|
|
482
447
|
create_lines.append(" if override_mode:")
|
|
483
448
|
create_lines.append(
|
|
484
|
-
" inst = resolver._post_resolve_override(
|
|
449
|
+
" inst = resolver._post_resolve_override(_dependency_type, inst)"
|
|
485
450
|
)
|
|
486
451
|
create_lines.append(" return inst")
|
|
487
452
|
|
|
@@ -520,7 +485,7 @@ class Resolver:
|
|
|
520
485
|
self._add_override_check(resolver_lines)
|
|
521
486
|
|
|
522
487
|
# Fast path: check cached instance
|
|
523
|
-
resolver_lines.append(" inst = context.get(
|
|
488
|
+
resolver_lines.append(" inst = context.get(_dependency_type)")
|
|
524
489
|
resolver_lines.append(" if inst is not NOT_SET_:")
|
|
525
490
|
resolver_lines.append(" return inst")
|
|
526
491
|
|
|
@@ -528,7 +493,7 @@ class Resolver:
|
|
|
528
493
|
resolver_lines.append(" async with context.alock():")
|
|
529
494
|
else:
|
|
530
495
|
resolver_lines.append(" with context.lock():")
|
|
531
|
-
resolver_lines.append(" inst = context.get(
|
|
496
|
+
resolver_lines.append(" inst = context.get(_dependency_type)")
|
|
532
497
|
resolver_lines.append(" if inst is not NOT_SET_:")
|
|
533
498
|
resolver_lines.append(" return inst")
|
|
534
499
|
self._add_create_call(
|
|
@@ -558,7 +523,7 @@ class Resolver:
|
|
|
558
523
|
|
|
559
524
|
# Fast path: check cached instance (inline dict access for speed)
|
|
560
525
|
resolver_lines.append(
|
|
561
|
-
" inst = context.
|
|
526
|
+
" inst = context._items.get(_dependency_type, NOT_SET_)"
|
|
562
527
|
)
|
|
563
528
|
resolver_lines.append(" if inst is not NOT_SET_:")
|
|
564
529
|
resolver_lines.append(" return inst")
|
|
@@ -622,18 +587,16 @@ class Resolver:
|
|
|
622
587
|
src = "\n".join(lines)
|
|
623
588
|
|
|
624
589
|
ns: dict[str, Any] = {
|
|
625
|
-
"
|
|
626
|
-
"
|
|
627
|
-
"
|
|
628
|
-
"_provider_name": provider.name,
|
|
590
|
+
"_dependency_type": provider.dependency_type,
|
|
591
|
+
"_dependency_repr": type_repr(provider.dependency_type),
|
|
592
|
+
"_provider_factory": provider.factory,
|
|
629
593
|
"_is_class": provider.is_class,
|
|
630
|
-
"
|
|
594
|
+
"_param_types": param_types,
|
|
631
595
|
"_param_defaults": param_defaults,
|
|
632
596
|
"_param_has_default": param_has_default,
|
|
633
597
|
"_param_resolvers": param_resolvers,
|
|
634
598
|
"_param_shared_scopes": param_shared_scopes,
|
|
635
599
|
"_unresolved_messages": unresolved_messages,
|
|
636
|
-
"_unresolved_interfaces": self._unresolved_interfaces,
|
|
637
600
|
"_NOT_SET": NOT_SET,
|
|
638
601
|
"_contextmanager": contextlib.contextmanager,
|
|
639
602
|
"_is_cm": is_context_manager,
|
|
@@ -666,20 +629,103 @@ class Resolver:
|
|
|
666
629
|
|
|
667
630
|
return CompiledResolver(resolver, creator)
|
|
668
631
|
|
|
669
|
-
def
|
|
670
|
-
|
|
671
|
-
|
|
632
|
+
def _compile_from_context_resolver(
|
|
633
|
+
self, provider: Provider, *, is_async: bool, with_override: bool = False
|
|
634
|
+
) -> CompiledResolver:
|
|
635
|
+
"""Compile a resolver for from_context providers.
|
|
636
|
+
|
|
637
|
+
from_context providers get their instances from the scoped context
|
|
638
|
+
via context.set(), not from a factory call.
|
|
639
|
+
"""
|
|
640
|
+
scope = provider.scope
|
|
641
|
+
dependency_repr = type_repr(provider.dependency_type)
|
|
642
|
+
|
|
643
|
+
# Build resolver function
|
|
644
|
+
resolver_lines: list[str] = []
|
|
645
|
+
if is_async:
|
|
646
|
+
resolver_lines.append("async def _resolver(container, context=None):")
|
|
647
|
+
else:
|
|
648
|
+
resolver_lines.append("def _resolver(container, context=None):")
|
|
649
|
+
|
|
650
|
+
resolver_lines.append(" NOT_SET_ = _NOT_SET")
|
|
651
|
+
|
|
652
|
+
# Get context from context variable
|
|
653
|
+
resolver_lines.append(" if context is None:")
|
|
654
|
+
resolver_lines.append(" try:")
|
|
655
|
+
resolver_lines.append(" context = _scoped_context_var.get()")
|
|
656
|
+
resolver_lines.append(" except LookupError:")
|
|
657
|
+
resolver_lines.append(
|
|
658
|
+
f" raise LookupError("
|
|
659
|
+
f"'The {scope} context has not been started. "
|
|
660
|
+
f"Please ensure that the {scope} context is properly initialized "
|
|
661
|
+
f"before attempting to use it.')"
|
|
662
|
+
)
|
|
663
|
+
|
|
664
|
+
if with_override:
|
|
665
|
+
self._add_override_check(resolver_lines)
|
|
666
|
+
|
|
667
|
+
# Check if instance is set in context
|
|
668
|
+
resolver_lines.append(
|
|
669
|
+
" inst = context._items.get(_dependency_type, NOT_SET_)"
|
|
670
|
+
)
|
|
671
|
+
resolver_lines.append(" if inst is NOT_SET_:")
|
|
672
|
+
resolver_lines.append(
|
|
673
|
+
f" raise LookupError("
|
|
674
|
+
f"'The provider `{dependency_repr}` is registered with from_context=True "
|
|
675
|
+
f"but has not been set in the {scope} context. "
|
|
676
|
+
f"Please call context.set({dependency_repr}, instance) before "
|
|
677
|
+
f"attempting to resolve it.')"
|
|
678
|
+
)
|
|
679
|
+
resolver_lines.append(" return inst")
|
|
680
|
+
|
|
681
|
+
# Build creator function (not typically used for from_context, but needed)
|
|
682
|
+
create_resolver_lines: list[str] = []
|
|
683
|
+
if is_async:
|
|
684
|
+
create_resolver_lines.append(
|
|
685
|
+
"async def _resolver_create(container, defaults=None):"
|
|
686
|
+
)
|
|
687
|
+
else:
|
|
688
|
+
create_resolver_lines.append(
|
|
689
|
+
"def _resolver_create(container, defaults=None):"
|
|
690
|
+
)
|
|
691
|
+
create_resolver_lines.append(
|
|
692
|
+
f" raise TypeError("
|
|
693
|
+
f"'Cannot create instance for from_context provider `{dependency_repr}`. "
|
|
694
|
+
f"Use context.set() instead.')"
|
|
695
|
+
)
|
|
696
|
+
|
|
697
|
+
lines = resolver_lines + [""] + create_resolver_lines
|
|
698
|
+
src = "\n".join(lines)
|
|
699
|
+
|
|
700
|
+
ns: dict[str, Any] = {
|
|
701
|
+
"_dependency_type": provider.dependency_type,
|
|
702
|
+
"_NOT_SET": NOT_SET,
|
|
703
|
+
"_scoped_context_var": self._container._get_scoped_context_var( # type: ignore[reportPrivateUsage]
|
|
704
|
+
scope
|
|
705
|
+
),
|
|
706
|
+
"resolver": self,
|
|
707
|
+
}
|
|
708
|
+
|
|
709
|
+
exec(src, ns)
|
|
710
|
+
resolver = ns["_resolver"]
|
|
711
|
+
creator = ns["_resolver_create"]
|
|
712
|
+
|
|
713
|
+
return CompiledResolver(resolver, creator)
|
|
714
|
+
|
|
715
|
+
def _get_override_for(self, dependency_type: Any) -> Any:
|
|
716
|
+
"""Hook for checking if a dependency type has an override."""
|
|
717
|
+
return self._overrides.get(dependency_type, NOT_SET)
|
|
672
718
|
|
|
673
|
-
def _wrap_for_override(self,
|
|
719
|
+
def _wrap_for_override(self, dependency_type: Any, instance: Any) -> Any:
|
|
674
720
|
"""Hook for wrapping dependencies to enable override patching."""
|
|
675
|
-
if isinstance(
|
|
676
|
-
return
|
|
677
|
-
return InstanceProxy(
|
|
721
|
+
if isinstance(instance, InstanceProxy):
|
|
722
|
+
return instance
|
|
723
|
+
return InstanceProxy(instance, dependency_type=dependency_type)
|
|
678
724
|
|
|
679
|
-
def _post_resolve_override(self,
|
|
725
|
+
def _post_resolve_override(self, dependency_type: Any, instance: Any) -> Any: # noqa: C901
|
|
680
726
|
"""Hook for patching resolved instances to support override."""
|
|
681
|
-
if
|
|
682
|
-
return self.
|
|
727
|
+
if dependency_type in self._overrides:
|
|
728
|
+
return self._overrides[dependency_type]
|
|
683
729
|
|
|
684
730
|
if not hasattr(instance, "__dict__") or hasattr(
|
|
685
731
|
instance, "__resolver_getter__"
|
|
@@ -687,16 +733,16 @@ class Resolver:
|
|
|
687
733
|
return instance
|
|
688
734
|
|
|
689
735
|
wrapped = {
|
|
690
|
-
name: value.
|
|
736
|
+
name: value.dependency_type
|
|
691
737
|
for name, value in instance.__dict__.items()
|
|
692
738
|
if isinstance(value, InstanceProxy)
|
|
693
739
|
}
|
|
694
740
|
|
|
695
741
|
def __resolver_getter__(name: str) -> Any:
|
|
696
742
|
if name in wrapped:
|
|
697
|
-
|
|
743
|
+
_dependency_type = wrapped[name]
|
|
698
744
|
# Resolve the dependency if it's wrapped
|
|
699
|
-
return self._container.resolve(
|
|
745
|
+
return self._container.resolve(_dependency_type)
|
|
700
746
|
raise LookupError
|
|
701
747
|
|
|
702
748
|
# Attach the resolver getter to the instance
|