anydi 0.54.1__py3-none-any.whl → 0.55.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
anydi/_resolver.py ADDED
@@ -0,0 +1,571 @@
1
+ """Resolver compilation module for AnyDI."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import contextlib
6
+ from typing import TYPE_CHECKING, Any, NamedTuple
7
+
8
+ import anyio.to_thread
9
+ from typing_extensions import type_repr
10
+
11
+ from ._provider import Provider
12
+ from ._types import NOT_SET, is_async_context_manager, is_context_manager
13
+
14
+ if TYPE_CHECKING:
15
+ from ._container import Container
16
+
17
+
18
+ class CompiledResolver(NamedTuple):
19
+ resolve: Any
20
+ create: Any
21
+
22
+
23
+ class Resolver:
24
+ def __init__(self, container: Container) -> None:
25
+ self._container = container
26
+ self._unresolved_interfaces: set[Any] = set()
27
+ self._cache: dict[Any, CompiledResolver] = {}
28
+ self._async_cache: dict[Any, CompiledResolver] = {}
29
+
30
+ # Determine compilation flags based on whether methods are overridden
31
+ self._has_override_support = callable(
32
+ getattr(self._container, "_hook_override_for", None)
33
+ )
34
+ self._wrap_dependencies = callable(
35
+ getattr(self._container, "_hook_wrap_dependency", None)
36
+ )
37
+ self._wrap_instance = callable(
38
+ getattr(self._container, "_hook_post_resolve", None)
39
+ )
40
+
41
+ def add_unresolved(self, interface: Any) -> None:
42
+ self._unresolved_interfaces.add(interface)
43
+
44
+ def get_cached(self, interface: Any, *, is_async: bool) -> CompiledResolver | None:
45
+ """Get cached resolver if it exists."""
46
+ cache = self._async_cache if is_async else self._cache
47
+ return cache.get(interface)
48
+
49
+ def compile(self, provider: Provider, *, is_async: bool) -> CompiledResolver:
50
+ """Compile an optimized resolver function for the given provider."""
51
+ # Select the appropriate cache based on sync/async mode
52
+ cache = self._async_cache if is_async else self._cache
53
+
54
+ # Check if already compiled in cache
55
+ if provider.interface in cache:
56
+ return cache[provider.interface]
57
+
58
+ # Recursively compile dependencies first
59
+ for p in provider.parameters:
60
+ if p.provider is not None:
61
+ self.compile(p.provider, is_async=is_async)
62
+
63
+ # Compile the resolver and creator functions
64
+ compiled = self._compile_resolver(provider, is_async=is_async)
65
+
66
+ # Store the compiled functions in the cache
67
+ cache[provider.interface] = compiled
68
+
69
+ return compiled
70
+
71
+ def _compile_resolver( # noqa: C901
72
+ self, provider: Provider, *, is_async: bool
73
+ ) -> CompiledResolver:
74
+ """Compile optimized resolver functions for the given provider."""
75
+ has_override_support = self._has_override_support
76
+ wrap_dependencies = self._wrap_dependencies
77
+ wrap_instance = self._wrap_instance
78
+ num_params = len(provider.parameters)
79
+ param_resolvers: list[Any] = [None] * num_params
80
+ param_annotations: list[Any] = [None] * num_params
81
+ param_defaults: list[Any] = [None] * num_params
82
+ param_has_default: list[bool] = [False] * num_params
83
+ param_names: list[str] = [""] * num_params
84
+ param_shared_scopes: list[bool] = [False] * num_params
85
+ unresolved_messages: list[str] = [""] * num_params
86
+
87
+ cache = self._async_cache if is_async else self._cache
88
+
89
+ for idx, p in enumerate(provider.parameters):
90
+ param_annotations[idx] = p.annotation
91
+ param_defaults[idx] = p.default
92
+ param_has_default[idx] = p.has_default
93
+ param_names[idx] = p.name
94
+ param_shared_scopes[idx] = p.shared_scope
95
+
96
+ if p.provider is not None:
97
+ compiled = cache.get(p.provider.interface)
98
+ if compiled is None:
99
+ compiled = self.compile(p.provider, is_async=is_async)
100
+ cache[p.provider.interface] = compiled
101
+ param_resolvers[idx] = compiled.resolve
102
+
103
+ msg = (
104
+ f"You are attempting to get the parameter `{p.name}` with the "
105
+ f"annotation `{type_repr(p.annotation)}` as a dependency into "
106
+ f"`{type_repr(provider.call)}` which is not registered or set in the "
107
+ "scoped context."
108
+ )
109
+ unresolved_messages[idx] = msg
110
+
111
+ scope = provider.scope
112
+ is_generator = provider.is_generator
113
+ is_async_generator = provider.is_async_generator if is_async else False
114
+ is_coroutine = provider.is_coroutine if is_async else False
115
+ no_params = len(param_names) == 0
116
+
117
+ create_lines: list[str] = []
118
+ if is_async:
119
+ create_lines.append(
120
+ "async def _create_instance(container, context, store, defaults):"
121
+ )
122
+ else:
123
+ create_lines.append(
124
+ "def _create_instance(container, context, store, defaults):"
125
+ )
126
+
127
+ if no_params:
128
+ # Fast path: no parameters to resolve, skip NOT_SET check
129
+ if not is_async:
130
+ create_lines.append(" if _is_async:")
131
+ create_lines.append(
132
+ " raise TypeError("
133
+ 'f"The instance for the provider `{_provider_name}` '
134
+ 'cannot be created in synchronous mode."'
135
+ ")"
136
+ )
137
+ else:
138
+ # Need NOT_SET for parameter resolution
139
+ create_lines.append(" NOT_SET_ = _NOT_SET")
140
+ if not is_async:
141
+ create_lines.append(" if _is_async:")
142
+ create_lines.append(
143
+ " raise TypeError("
144
+ 'f"The instance for the provider `{_provider_name}` '
145
+ 'cannot be created in synchronous mode."'
146
+ ")"
147
+ )
148
+ # Cache the resolver cache for faster repeated access
149
+ create_lines.append(" cache = _cache")
150
+
151
+ if not no_params:
152
+ # Only generate parameter resolution logic if there are parameters
153
+ for idx, name in enumerate(param_names):
154
+ create_lines.append(f" # resolve param `{name}`")
155
+ create_lines.append(
156
+ f" if defaults is not None and '{name}' in defaults:"
157
+ )
158
+ create_lines.append(f" arg_{idx} = defaults['{name}']")
159
+ create_lines.append(" else:")
160
+ create_lines.append(" cached = NOT_SET_")
161
+ create_lines.append(" if context is not None:")
162
+ create_lines.append(
163
+ f" cached = context.get(_param_annotations[{idx}])"
164
+ )
165
+ create_lines.append(" if cached is NOT_SET_:")
166
+ create_lines.append(
167
+ f" if _param_annotations[{idx}] in "
168
+ "_unresolved_interfaces:"
169
+ )
170
+ create_lines.append(
171
+ f" raise LookupError(_unresolved_messages[{idx}])"
172
+ )
173
+ create_lines.append(f" resolver = _param_resolvers[{idx}]")
174
+ create_lines.append(" if resolver is None:")
175
+ create_lines.append(" try:")
176
+ if is_async:
177
+ create_lines.append(
178
+ f" compiled = "
179
+ f"cache.get(_param_annotations[{idx}])"
180
+ )
181
+ create_lines.append(" if compiled is None:")
182
+ create_lines.append(
183
+ " provider = "
184
+ "container._get_or_register_provider("
185
+ f"_param_annotations[{idx}])"
186
+ )
187
+ create_lines.append(
188
+ " compiled = "
189
+ "_compile(provider, is_async=True)"
190
+ )
191
+ create_lines.append(
192
+ " cache[provider.interface] = compiled"
193
+ )
194
+ create_lines.append(
195
+ f" arg_{idx} = "
196
+ f"await compiled[0](container, "
197
+ f"context if _param_shared_scopes[{idx}] else None)"
198
+ )
199
+ else:
200
+ create_lines.append(
201
+ f" compiled = "
202
+ f"cache.get(_param_annotations[{idx}])"
203
+ )
204
+ create_lines.append(" if compiled is None:")
205
+ create_lines.append(
206
+ " provider = "
207
+ "container._get_or_register_provider(_param_annotations[{idx}])"
208
+ )
209
+ create_lines.append(
210
+ " compiled = "
211
+ "_compile(provider, is_async=False)"
212
+ )
213
+ create_lines.append(
214
+ " cache[provider.interface] = compiled"
215
+ )
216
+ create_lines.append(
217
+ f" arg_{idx} = "
218
+ f"compiled[0](container, "
219
+ f"context if _param_shared_scopes[{idx}] else None)"
220
+ )
221
+ create_lines.append(" except LookupError:")
222
+ create_lines.append(
223
+ f" if _param_has_default[{idx}]:"
224
+ )
225
+ create_lines.append(
226
+ f" arg_{idx} = _param_defaults[{idx}]"
227
+ )
228
+ create_lines.append(" else:")
229
+ create_lines.append(" raise")
230
+ create_lines.append(" else:")
231
+ if is_async:
232
+ create_lines.append(
233
+ f" arg_{idx} = await resolver("
234
+ f"container, "
235
+ f"context if _param_shared_scopes[{idx}] else None)"
236
+ )
237
+ else:
238
+ create_lines.append(
239
+ f" arg_{idx} = resolver("
240
+ f"container, "
241
+ f"context if _param_shared_scopes[{idx}] else None)"
242
+ )
243
+ create_lines.append(" else:")
244
+ create_lines.append(f" arg_{idx} = cached")
245
+ if wrap_dependencies:
246
+ create_lines.append(
247
+ f" arg_{idx} = container._hook_wrap_dependency("
248
+ f"_param_annotations[{idx}], arg_{idx})"
249
+ )
250
+
251
+ # Handle different provider types
252
+ if is_async and is_coroutine:
253
+ # Async function - call with await
254
+ if param_names:
255
+ call_args = ", ".join(
256
+ f"{name}=arg_{idx}" for idx, name in enumerate(param_names)
257
+ )
258
+ create_lines.append(" if defaults is not None:")
259
+ create_lines.append(" call_kwargs = {")
260
+ for idx, name in enumerate(param_names):
261
+ create_lines.append(f" '{name}': arg_{idx},")
262
+ create_lines.append(" }")
263
+ create_lines.append(" call_kwargs.update(defaults)")
264
+ create_lines.append(
265
+ " inst = await _provider_call(**call_kwargs)"
266
+ )
267
+ create_lines.append(" else:")
268
+ create_lines.append(f" inst = await _provider_call({call_args})")
269
+ else:
270
+ create_lines.append(" if defaults is not None:")
271
+ create_lines.append(" inst = await _provider_call(**defaults)")
272
+ create_lines.append(" else:")
273
+ create_lines.append(" inst = await _provider_call()")
274
+ elif is_async and is_async_generator:
275
+ # Async generator - use async context manager
276
+ create_lines.append(" if context is None:")
277
+ create_lines.append(
278
+ " raise ValueError("
279
+ '"The async stack is required for async generator providers.")'
280
+ )
281
+ if param_names:
282
+ call_args = ", ".join(
283
+ f"{name}=arg_{idx}" for idx, name in enumerate(param_names)
284
+ )
285
+ create_lines.append(" if defaults is not None:")
286
+ create_lines.append(" call_kwargs = {")
287
+ for idx, name in enumerate(param_names):
288
+ create_lines.append(f" '{name}': arg_{idx},")
289
+ create_lines.append(" }")
290
+ create_lines.append(" call_kwargs.update(defaults)")
291
+ create_lines.append(
292
+ " cm = _asynccontextmanager(_provider_call)(**call_kwargs)"
293
+ )
294
+ create_lines.append(" else:")
295
+ create_lines.append(
296
+ f" cm = _asynccontextmanager(_provider_call)({call_args})"
297
+ )
298
+ else:
299
+ create_lines.append(" if defaults is not None:")
300
+ create_lines.append(
301
+ " cm = _asynccontextmanager(_provider_call)(**defaults)"
302
+ )
303
+ create_lines.append(" else:")
304
+ create_lines.append(
305
+ " cm = _asynccontextmanager(_provider_call)()"
306
+ )
307
+ create_lines.append(" inst = await context.aenter(cm)")
308
+ elif is_generator:
309
+ # Sync generator - use sync context manager
310
+ create_lines.append(" if context is None:")
311
+ create_lines.append(
312
+ " raise ValueError("
313
+ '"The context is required for generator providers.")'
314
+ )
315
+ if param_names:
316
+ call_args = ", ".join(
317
+ f"{name}=arg_{idx}" for idx, name in enumerate(param_names)
318
+ )
319
+ create_lines.append(" if defaults is not None:")
320
+ create_lines.append(" call_kwargs = {")
321
+ for idx, name in enumerate(param_names):
322
+ create_lines.append(f" '{name}': arg_{idx},")
323
+ create_lines.append(" }")
324
+ create_lines.append(" call_kwargs.update(defaults)")
325
+ create_lines.append(
326
+ " cm = _contextmanager(_provider_call)(**call_kwargs)"
327
+ )
328
+ create_lines.append(" else:")
329
+ create_lines.append(
330
+ f" cm = _contextmanager(_provider_call)({call_args})"
331
+ )
332
+ else:
333
+ create_lines.append(" if defaults is not None:")
334
+ create_lines.append(
335
+ " cm = _contextmanager(_provider_call)(**defaults)"
336
+ )
337
+ create_lines.append(" else:")
338
+ create_lines.append(" cm = _contextmanager(_provider_call)()")
339
+ if is_async:
340
+ # In async mode, run sync context manager enter in thread
341
+ create_lines.append(" inst = await _run_sync(context.enter, cm)")
342
+ else:
343
+ create_lines.append(" inst = context.enter(cm)")
344
+ else:
345
+ if param_names:
346
+ call_args = ", ".join(
347
+ f"{name}=arg_{idx}" for idx, name in enumerate(param_names)
348
+ )
349
+ create_lines.append(" if defaults is not None:")
350
+ create_lines.append(" call_kwargs = {")
351
+ for idx, name in enumerate(param_names):
352
+ create_lines.append(f" '{name}': arg_{idx},")
353
+ create_lines.append(" }")
354
+ create_lines.append(" call_kwargs.update(defaults)")
355
+ create_lines.append(" inst = _provider_call(**call_kwargs)")
356
+ create_lines.append(" else:")
357
+ create_lines.append(f" inst = _provider_call({call_args})")
358
+ else:
359
+ create_lines.append(" if defaults is not None:")
360
+ create_lines.append(" inst = _provider_call(**defaults)")
361
+ create_lines.append(" else:")
362
+ create_lines.append(" inst = _provider_call()")
363
+
364
+ # Handle context managers
365
+ if is_async:
366
+ create_lines.append(
367
+ " if context is not None and _is_class and _is_acm(inst):"
368
+ )
369
+ create_lines.append(" await context.aenter(inst)")
370
+ create_lines.append(
371
+ " elif context is not None and _is_class and _is_cm(inst):"
372
+ )
373
+ create_lines.append(" await _run_sync(context.enter, inst)")
374
+ else:
375
+ create_lines.append(
376
+ " if context is not None and _is_class and _is_cm(inst):"
377
+ )
378
+ create_lines.append(" context.enter(inst)")
379
+
380
+ create_lines.append(" if context is not None and store:")
381
+ create_lines.append(" context.set(_interface, inst)")
382
+
383
+ if wrap_instance:
384
+ create_lines.append(
385
+ " inst = container._hook_post_resolve(_interface, inst)"
386
+ )
387
+ create_lines.append(" return inst")
388
+
389
+ resolver_lines: list[str] = []
390
+ if is_async:
391
+ resolver_lines.append("async def _resolver(container, context=None):")
392
+ else:
393
+ resolver_lines.append("def _resolver(container, context=None):")
394
+
395
+ # Only define NOT_SET_ if we actually need it
396
+ needs_not_set = has_override_support or scope in ("singleton", "request")
397
+ if needs_not_set:
398
+ resolver_lines.append(" NOT_SET_ = _NOT_SET")
399
+
400
+ if scope == "singleton":
401
+ resolver_lines.append(" if context is None:")
402
+ resolver_lines.append(" context = container._singleton_context")
403
+ elif scope == "request":
404
+ resolver_lines.append(" if context is None:")
405
+ resolver_lines.append(" context = container._get_request_context()")
406
+ else:
407
+ resolver_lines.append(" context = None")
408
+
409
+ if has_override_support:
410
+ resolver_lines.append(
411
+ " override = container._hook_override_for(_interface)"
412
+ )
413
+ resolver_lines.append(" if override is not NOT_SET_:")
414
+ resolver_lines.append(" return override")
415
+
416
+ if scope == "singleton":
417
+ resolver_lines.append(" inst = context.get(_interface)")
418
+ resolver_lines.append(" if inst is not NOT_SET_:")
419
+ if wrap_instance:
420
+ resolver_lines.append(
421
+ " return container._hook_post_resolve(_provider, inst)"
422
+ )
423
+ else:
424
+ resolver_lines.append(" return inst")
425
+
426
+ if is_async:
427
+ resolver_lines.append(" async with context.alock():")
428
+ else:
429
+ resolver_lines.append(" with context.lock():")
430
+ resolver_lines.append(" inst = context.get(_interface)")
431
+ resolver_lines.append(" if inst is not NOT_SET_:")
432
+ if wrap_instance:
433
+ resolver_lines.append(
434
+ " return container._hook_post_resolve(_provider, inst)"
435
+ )
436
+ else:
437
+ resolver_lines.append(" return inst")
438
+ if is_async:
439
+ resolver_lines.append(
440
+ " return await "
441
+ "_create_instance(container, context, True, None)"
442
+ )
443
+ else:
444
+ resolver_lines.append(
445
+ " return _create_instance(container, context, True, None)"
446
+ )
447
+ elif scope == "request":
448
+ resolver_lines.append(" inst = context.get(_interface)")
449
+ resolver_lines.append(" if inst is not NOT_SET_:")
450
+ if wrap_instance:
451
+ resolver_lines.append(
452
+ " return container._hook_post_resolve(_provider, inst)"
453
+ )
454
+ else:
455
+ resolver_lines.append(" return inst")
456
+ if is_async:
457
+ resolver_lines.append(
458
+ " return await _create_instance(container, context, True, None)"
459
+ )
460
+ else:
461
+ resolver_lines.append(
462
+ " return _create_instance(container, context, True, None)"
463
+ )
464
+ else:
465
+ if is_async:
466
+ resolver_lines.append(
467
+ " return await _create_instance(container, None, False, None)"
468
+ )
469
+ else:
470
+ resolver_lines.append(
471
+ " return _create_instance(container, None, False, None)"
472
+ )
473
+
474
+ create_resolver_lines: list[str] = []
475
+ if is_async:
476
+ create_resolver_lines.append(
477
+ "async def _resolver_create(container, defaults=None):"
478
+ )
479
+ else:
480
+ create_resolver_lines.append(
481
+ "def _resolver_create(container, defaults=None):"
482
+ )
483
+
484
+ # Only define NOT_SET_ if needed for override support
485
+ if has_override_support:
486
+ create_resolver_lines.append(" NOT_SET_ = _NOT_SET")
487
+
488
+ if scope == "singleton":
489
+ create_resolver_lines.append(" context = container._singleton_context")
490
+ elif scope == "request":
491
+ create_resolver_lines.append(
492
+ " context = container._get_request_context()"
493
+ )
494
+ else:
495
+ create_resolver_lines.append(" context = None")
496
+
497
+ if has_override_support:
498
+ create_resolver_lines.append(
499
+ " override = container._hook_override_for(_interface)"
500
+ )
501
+ create_resolver_lines.append(" if override is not NOT_SET_:")
502
+ create_resolver_lines.append(" return override")
503
+
504
+ if scope == "singleton":
505
+ if is_async:
506
+ create_resolver_lines.append(
507
+ " return await "
508
+ "_create_instance(container, context, False, defaults)"
509
+ )
510
+ else:
511
+ create_resolver_lines.append(
512
+ " return _create_instance(container, context, False, defaults)"
513
+ )
514
+ elif scope == "request":
515
+ if is_async:
516
+ create_resolver_lines.append(
517
+ " return await "
518
+ "_create_instance(container, context, False, defaults)"
519
+ )
520
+ else:
521
+ create_resolver_lines.append(
522
+ " return _create_instance(container, context, False, defaults)"
523
+ )
524
+ else:
525
+ if is_async:
526
+ create_resolver_lines.append(
527
+ " return await "
528
+ "_create_instance(container, None, False, defaults)"
529
+ )
530
+ else:
531
+ create_resolver_lines.append(
532
+ " return _create_instance(container, None, False, defaults)"
533
+ )
534
+
535
+ lines = create_lines + [""] + resolver_lines + [""] + create_resolver_lines
536
+
537
+ src = "\n".join(lines)
538
+
539
+ ns: dict[str, Any] = {
540
+ "_provider": provider,
541
+ "_interface": provider.interface,
542
+ "_provider_call": provider.call,
543
+ "_provider_name": provider.name,
544
+ "_is_class": provider.is_class,
545
+ "_param_annotations": param_annotations,
546
+ "_param_defaults": param_defaults,
547
+ "_param_has_default": param_has_default,
548
+ "_param_resolvers": param_resolvers,
549
+ "_param_shared_scopes": param_shared_scopes,
550
+ "_unresolved_messages": unresolved_messages,
551
+ "_unresolved_interfaces": self._unresolved_interfaces,
552
+ "_NOT_SET": NOT_SET,
553
+ "_contextmanager": contextlib.contextmanager,
554
+ "_is_cm": is_context_manager,
555
+ "_cache": self._async_cache if is_async else self._cache,
556
+ "_compile": self._compile_resolver,
557
+ }
558
+
559
+ # Add async-specific namespace entries
560
+ if is_async:
561
+ ns["_asynccontextmanager"] = contextlib.asynccontextmanager
562
+ ns["_is_acm"] = is_async_context_manager
563
+ ns["_run_sync"] = anyio.to_thread.run_sync
564
+ else:
565
+ ns["_is_async"] = provider.is_async
566
+
567
+ exec(src, ns)
568
+ resolver = ns["_resolver"]
569
+ creator = ns["_resolver_create"]
570
+
571
+ return CompiledResolver(resolver, creator)
@@ -9,7 +9,7 @@ from types import ModuleType
9
9
  from typing import TYPE_CHECKING, Any
10
10
 
11
11
  from ._decorators import is_injectable
12
- from ._typing import is_inject_marker
12
+ from ._types import is_inject_marker
13
13
 
14
14
  if TYPE_CHECKING:
15
15
  from ._container import Container
@@ -5,30 +5,11 @@ from __future__ import annotations
5
5
  import inspect
6
6
  from collections.abc import AsyncIterator, Iterator
7
7
  from types import NoneType
8
- from typing import Any
8
+ from typing import Any, Literal
9
9
 
10
10
  from typing_extensions import Sentinel
11
11
 
12
-
13
- def is_context_manager(obj: Any) -> bool:
14
- """Check if the given object is a context manager."""
15
- return hasattr(obj, "__enter__") and hasattr(obj, "__exit__")
16
-
17
-
18
- def is_async_context_manager(obj: Any) -> bool:
19
- """Check if the given object is an async context manager."""
20
- return hasattr(obj, "__aenter__") and hasattr(obj, "__aexit__")
21
-
22
-
23
- def is_none_type(tp: Any) -> bool:
24
- """Check if the given object is a None type."""
25
- return tp in (None, NoneType)
26
-
27
-
28
- def is_iterator_type(tp: Any) -> bool:
29
- """Check if the given object is an iterator type."""
30
- return tp in (Iterator, AsyncIterator)
31
-
12
+ Scope = Literal["transient", "singleton", "request"]
32
13
 
33
14
  NOT_SET = Sentinel("NOT_SET")
34
15
 
@@ -69,3 +50,23 @@ class Event:
69
50
  def is_event_type(obj: Any) -> bool:
70
51
  """Checks if an object is an event type."""
71
52
  return inspect.isclass(obj) and issubclass(obj, Event)
53
+
54
+
55
+ def is_context_manager(obj: Any) -> bool:
56
+ """Check if the given object is a context manager."""
57
+ return hasattr(obj, "__enter__") and hasattr(obj, "__exit__")
58
+
59
+
60
+ def is_async_context_manager(obj: Any) -> bool:
61
+ """Check if the given object is an async context manager."""
62
+ return hasattr(obj, "__aenter__") and hasattr(obj, "__aexit__")
63
+
64
+
65
+ def is_none_type(tp: Any) -> bool:
66
+ """Check if the given object is a None type."""
67
+ return tp in (None, NoneType)
68
+
69
+
70
+ def is_iterator_type(tp: Any) -> bool:
71
+ """Check if the given object is an iterator type."""
72
+ return tp in (Iterator, AsyncIterator)
anydi/ext/fastapi.py CHANGED
@@ -12,7 +12,7 @@ from fastapi.routing import APIRoute
12
12
  from starlette.requests import Request
13
13
 
14
14
  from anydi._container import Container
15
- from anydi._typing import InjectMarker
15
+ from anydi._types import InjectMarker
16
16
 
17
17
  from .starlette.middleware import RequestScopedMiddleware
18
18
 
anydi/ext/faststream.py CHANGED
@@ -10,7 +10,7 @@ from faststream import ContextRepo
10
10
  from faststream.broker.core.usecase import BrokerUsecase
11
11
 
12
12
  from anydi import Container
13
- from anydi._typing import InjectMarker
13
+ from anydi._types import InjectMarker
14
14
 
15
15
 
16
16
  def install(broker: BrokerUsecase[Any, Any], container: Container) -> None: