anydi 0.38.1__py3-none-any.whl → 0.38.2rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
anydi/__init__.py CHANGED
@@ -11,8 +11,7 @@ from ._container import (
11
11
  singleton,
12
12
  transient,
13
13
  )
14
- from ._provider import Provider
15
- from ._types import Marker, Scope
14
+ from ._types import Marker, ProviderArgs as Provider, Scope
16
15
 
17
16
  # Alias for dependency auto marker
18
17
  auto = cast(Any, Marker())
anydi/_container.py CHANGED
@@ -10,23 +10,27 @@ import logging
10
10
  import pkgutil
11
11
  import threading
12
12
  import types
13
+ import uuid
13
14
  from collections import defaultdict
14
15
  from collections.abc import AsyncIterator, Iterable, Iterator, Sequence
15
16
  from contextvars import ContextVar
16
17
  from types import ModuleType
17
- from typing import Any, Callable, TypeVar, Union, cast, overload
18
- from weakref import WeakKeyDictionary
18
+ from typing import Any, Callable, TypeVar, Union, cast, get_origin, overload
19
19
 
20
- from typing_extensions import Concatenate, ParamSpec, Self, final
20
+ from typing_extensions import Concatenate, ParamSpec, Self, final, get_args
21
21
 
22
22
  from ._context import InstanceContext
23
- from ._provider import Provider
24
23
  from ._types import (
24
+ NOT_SET,
25
25
  AnyInterface,
26
- Dependency,
26
+ Event,
27
27
  InjectableDecoratorArgs,
28
28
  InstanceProxy,
29
+ Provider,
30
+ ProviderArgs,
29
31
  ProviderDecoratorArgs,
32
+ ProviderKind,
33
+ ScannedDependency,
30
34
  Scope,
31
35
  is_event_type,
32
36
  is_marker,
@@ -34,6 +38,7 @@ from ._types import (
34
38
  from ._utils import (
35
39
  AsyncRLock,
36
40
  get_full_qualname,
41
+ get_typed_annotation,
37
42
  get_typed_parameters,
38
43
  import_string,
39
44
  is_async_context_manager,
@@ -42,6 +47,12 @@ from ._utils import (
42
47
  run_async,
43
48
  )
44
49
 
50
+ try:
51
+ from types import NoneType
52
+ except ImportError:
53
+ NoneType = type(None) # type: ignore[misc]
54
+
55
+
45
56
  T = TypeVar("T", bound=Any)
46
57
  M = TypeVar("M", bound="Module")
47
58
  P = ParamSpec("P")
@@ -74,7 +85,6 @@ class Module(metaclass=ModuleMeta):
74
85
  """Configure the AnyDI container with providers and their dependencies."""
75
86
 
76
87
 
77
- # noinspection PyShadowingNames
78
88
  @final
79
89
  class Container:
80
90
  """AnyDI is a dependency injection container."""
@@ -82,7 +92,7 @@ class Container:
82
92
  def __init__(
83
93
  self,
84
94
  *,
85
- providers: Sequence[Provider] | None = None,
95
+ providers: Sequence[ProviderArgs] | None = None,
86
96
  modules: Sequence[Module | type[Module] | Callable[[Container], None] | str]
87
97
  | None = None,
88
98
  strict: bool = False,
@@ -104,20 +114,27 @@ class Container:
104
114
  )
105
115
  self._override_instances: dict[type[Any], Any] = {}
106
116
  self._unresolved_interfaces: set[type[Any]] = set()
107
- self._inject_cache: WeakKeyDictionary[
108
- Callable[..., Any], Callable[..., Any]
109
- ] = WeakKeyDictionary()
117
+ self._inject_cache: dict[Callable[..., Any], Callable[..., Any]] = {}
110
118
 
111
119
  # Register providers
112
120
  providers = providers or []
113
121
  for provider in providers:
114
- self._register_provider(provider, False)
122
+ _provider = self._create_provider(
123
+ call=provider.call,
124
+ scope=provider.scope,
125
+ interface=provider.interface,
126
+ )
127
+ self._register_provider(_provider, False)
115
128
 
116
129
  # Register modules
117
130
  modules = modules or []
118
131
  for module in modules:
119
132
  self.register_module(module)
120
133
 
134
+ ############################
135
+ # Properties
136
+ ############################
137
+
121
138
  @property
122
139
  def strict(self) -> bool:
123
140
  """Check if strict mode is enabled."""
@@ -143,9 +160,110 @@ class Container:
143
160
  """Get the logger instance."""
144
161
  return self._logger
145
162
 
146
- def is_registered(self, interface: AnyInterface) -> bool:
147
- """Check if a provider is registered for the specified interface."""
148
- return interface in self._providers
163
+ ############################
164
+ # Lifespan/Context Methods
165
+ ############################
166
+
167
+ def __enter__(self) -> Self:
168
+ """Enter the singleton context."""
169
+ self.start()
170
+ return self
171
+
172
+ def __exit__(
173
+ self,
174
+ exc_type: type[BaseException] | None,
175
+ exc_val: BaseException | None,
176
+ exc_tb: types.TracebackType | None,
177
+ ) -> Any:
178
+ """Exit the singleton context."""
179
+ return self._singleton_context.__exit__(exc_type, exc_val, exc_tb)
180
+
181
+ def start(self) -> None:
182
+ """Start the singleton context."""
183
+ # Resolve all singleton resources
184
+ for interface in self._resources.get("singleton", []):
185
+ self.resolve(interface)
186
+
187
+ def close(self) -> None:
188
+ """Close the singleton context."""
189
+ self._singleton_context.close()
190
+
191
+ async def __aenter__(self) -> Self:
192
+ """Enter the singleton context."""
193
+ await self.astart()
194
+ return self
195
+
196
+ async def __aexit__(
197
+ self,
198
+ exc_type: type[BaseException] | None,
199
+ exc_val: BaseException | None,
200
+ exc_tb: types.TracebackType | None,
201
+ ) -> bool:
202
+ """Exit the singleton context."""
203
+ return await self._singleton_context.__aexit__(exc_type, exc_val, exc_tb)
204
+
205
+ async def astart(self) -> None:
206
+ """Start the singleton context asynchronously."""
207
+ for interface in self._resources.get("singleton", []):
208
+ await self.aresolve(interface)
209
+
210
+ async def aclose(self) -> None:
211
+ """Close the singleton context asynchronously."""
212
+ await self._singleton_context.aclose()
213
+
214
+ @contextlib.contextmanager
215
+ def request_context(self) -> Iterator[InstanceContext]:
216
+ """Obtain a context manager for the request-scoped context."""
217
+ context = InstanceContext()
218
+
219
+ token = self._request_context_var.set(context)
220
+
221
+ # Resolve all request resources
222
+ for interface in self._resources.get("request", []):
223
+ if not is_event_type(interface):
224
+ continue
225
+ self.resolve(interface)
226
+
227
+ with context:
228
+ yield context
229
+ self._request_context_var.reset(token)
230
+
231
+ @contextlib.asynccontextmanager
232
+ async def arequest_context(self) -> AsyncIterator[InstanceContext]:
233
+ """Obtain an async context manager for the request-scoped context."""
234
+ context = InstanceContext()
235
+
236
+ token = self._request_context_var.set(context)
237
+
238
+ for interface in self._resources.get("request", []):
239
+ if not is_event_type(interface):
240
+ continue
241
+ await self.aresolve(interface)
242
+
243
+ async with context:
244
+ yield context
245
+ self._request_context_var.reset(token)
246
+
247
+ def _get_request_context(self) -> InstanceContext:
248
+ """Get the current request context."""
249
+ request_context = self._request_context_var.get()
250
+ if request_context is None:
251
+ raise LookupError(
252
+ "The request context has not been started. Please ensure that "
253
+ "the request context is properly initialized before attempting "
254
+ "to use it."
255
+ )
256
+ return request_context
257
+
258
+ def _get_scoped_context(self, scope: Scope) -> InstanceContext:
259
+ """Get the instance context for the specified scope."""
260
+ if scope == "singleton":
261
+ return self._singleton_context
262
+ return self._get_request_context()
263
+
264
+ ############################
265
+ # Provider Methods
266
+ ############################
149
267
 
150
268
  def register(
151
269
  self,
@@ -156,26 +274,12 @@ class Container:
156
274
  override: bool = False,
157
275
  ) -> Provider:
158
276
  """Register a provider for the specified interface."""
159
- provider = Provider(call=call, scope=scope, interface=interface)
277
+ provider = self._create_provider(call=call, scope=scope, interface=interface)
160
278
  return self._register_provider(provider, override)
161
279
 
162
- def _register_provider(
163
- self, provider: Provider, override: bool, /, **defaults: Any
164
- ) -> Provider:
165
- """Register a provider."""
166
- if provider.interface in self._providers:
167
- if override:
168
- self._set_provider(provider)
169
- return provider
170
-
171
- raise LookupError(
172
- f"The provider interface `{get_full_qualname(provider.interface)}` "
173
- "already registered."
174
- )
175
-
176
- self._validate_sub_providers(provider, **defaults)
177
- self._set_provider(provider)
178
- return provider
280
+ def is_registered(self, interface: AnyInterface) -> bool:
281
+ """Check if a provider is registered for the specified interface."""
282
+ return interface in self._providers
179
283
 
180
284
  def unregister(self, interface: AnyInterface) -> None:
181
285
  """Unregister a provider by interface."""
@@ -199,6 +303,120 @@ class Container:
199
303
  # Cleanup provider references
200
304
  self._delete_provider(provider)
201
305
 
306
+ def provider(
307
+ self, *, scope: Scope, override: bool = False
308
+ ) -> Callable[[Callable[P, T]], Callable[P, T]]:
309
+ """Decorator to register a provider function with the specified scope."""
310
+
311
+ def decorator(call: Callable[P, T]) -> Callable[P, T]:
312
+ provider = self._create_provider(call=call, scope=scope)
313
+ self._register_provider(provider, override)
314
+ return call
315
+
316
+ return decorator
317
+
318
+ def _create_provider( # noqa: C901
319
+ self, call: Callable[..., Any], *, scope: Scope, interface: Any = NOT_SET
320
+ ) -> Provider:
321
+ name = get_full_qualname(call)
322
+
323
+ # Detect the kind of callable provider
324
+ kind = ProviderKind.from_call(call)
325
+
326
+ # Validate the scope of the provider
327
+ if scope not in get_args(Scope):
328
+ raise ValueError(
329
+ f"The provider `{name}` scope is invalid. Only the following "
330
+ f"scopes are supported: {', '.join(get_args(Scope))}. "
331
+ "Please use one of the supported scopes when registering a provider."
332
+ )
333
+
334
+ # Validate the scope of the provider
335
+ if (
336
+ kind in {ProviderKind.GENERATOR, ProviderKind.ASYNC_GENERATOR}
337
+ and scope == "transient"
338
+ ):
339
+ raise TypeError(
340
+ f"The resource provider `{name}` is attempting to register "
341
+ "with a transient scope, which is not allowed."
342
+ )
343
+
344
+ # Get the signature
345
+ globalns = getattr(call, "__globals__", {})
346
+ signature = inspect.signature(call, globals=globalns)
347
+
348
+ # Detect the interface
349
+ if kind == ProviderKind.CLASS:
350
+ interface = call
351
+ else:
352
+ if interface is NOT_SET:
353
+ interface = signature.return_annotation
354
+ if interface is inspect.Signature.empty:
355
+ interface = None
356
+ else:
357
+ interface = get_typed_annotation(interface, globalns)
358
+
359
+ # If the callable is an iterator, return the actual type
360
+ iterator_types = {Iterator, AsyncIterator}
361
+ if interface in iterator_types or get_origin(interface) in iterator_types:
362
+ if args := get_args(interface):
363
+ interface = args[0]
364
+ # If the callable is a generator, return the resource type
365
+ if interface in {None, NoneType}:
366
+ interface = type(f"Event_{uuid.uuid4().hex}", (Event,), {})
367
+ else:
368
+ raise TypeError(
369
+ f"Cannot use `{name}` resource type annotation "
370
+ "without actual type argument."
371
+ )
372
+
373
+ # None interface is not allowed
374
+ if interface in {None, NoneType}:
375
+ raise TypeError(f"Missing `{name}` provider return annotation.")
376
+
377
+ # Detect the parameters
378
+ parameters = []
379
+ for parameter in signature.parameters.values():
380
+ if parameter.annotation is inspect.Parameter.empty:
381
+ raise TypeError(
382
+ f"Missing provider `{name}` "
383
+ f"dependency `{parameter.name}` annotation."
384
+ )
385
+ if parameter.kind == inspect.Parameter.POSITIONAL_ONLY:
386
+ raise TypeError(
387
+ "Positional-only parameters "
388
+ f"are not allowed in the provider `{name}`."
389
+ )
390
+ annotation = get_typed_annotation(parameter.annotation, globalns)
391
+ parameters.append(parameter.replace(annotation=annotation))
392
+
393
+ return Provider(
394
+ call=call,
395
+ scope=scope,
396
+ interface=interface,
397
+ name=name,
398
+ kind=kind,
399
+ parameters=parameters,
400
+ )
401
+
402
+ def _register_provider(
403
+ self, provider: Provider, override: bool, /, **defaults: Any
404
+ ) -> Provider:
405
+ """Register a provider."""
406
+ if provider.interface in self._providers:
407
+ if override:
408
+ self._set_provider(provider)
409
+ return provider
410
+
411
+ raise LookupError(
412
+ f"The provider interface `{get_full_qualname(provider.interface)}` "
413
+ "already registered."
414
+ )
415
+
416
+ self._validate_sub_providers(provider, **defaults)
417
+ self._set_provider(provider)
418
+ return provider
419
+
202
420
  def _get_provider(self, interface: AnyInterface) -> Provider:
203
421
  """Get provider by interface."""
204
422
  try:
@@ -227,9 +445,11 @@ class Container:
227
445
  scope = getattr(interface, "__scope__", parent_scope)
228
446
  # Try to detect scope
229
447
  if scope is None:
230
- scope = self._detect_scope(interface, **defaults)
448
+ scope = self._detect_provider_scope(interface, **defaults)
231
449
  scope = scope or self.default_scope
232
- provider = Provider(call=interface, scope=scope, interface=interface)
450
+ provider = self._create_provider(
451
+ call=interface, scope=scope, interface=interface
452
+ )
233
453
  return self._register_provider(provider, False, **defaults)
234
454
  raise
235
455
 
@@ -250,12 +470,6 @@ class Container:
250
470
  """Validate the sub-providers of a provider."""
251
471
 
252
472
  for parameter in provider.parameters:
253
- if parameter.annotation is inspect.Parameter.empty:
254
- raise TypeError(
255
- f"Missing provider `{provider}` "
256
- f"dependency `{parameter.name}` annotation."
257
- )
258
-
259
473
  try:
260
474
  sub_provider = self._get_or_register_provider(
261
475
  parameter.annotation, provider.scope
@@ -282,7 +496,9 @@ class Container:
282
496
  "Please ensure all providers are registered with matching scopes."
283
497
  )
284
498
 
285
- def _detect_scope(self, call: Callable[..., Any], **defaults: Any) -> Scope | None:
499
+ def _detect_provider_scope(
500
+ self, call: Callable[..., Any], /, **defaults: Any
501
+ ) -> Scope | None:
286
502
  """Detect the scope for a callable."""
287
503
  scopes = set()
288
504
 
@@ -320,122 +536,56 @@ class Container:
320
536
  not self.strict and parameter.default is not inspect.Parameter.empty
321
537
  )
322
538
 
323
- def register_module(
324
- self, module: Module | type[Module] | Callable[[Container], None] | str
325
- ) -> None:
326
- """Register a module as a callable, module type, or module instance."""
327
- # Callable Module
328
- if inspect.isfunction(module):
329
- module(self)
330
- return
331
-
332
- # Module path
333
- if isinstance(module, str):
334
- module = import_string(module)
335
-
336
- # Class based Module or Module type
337
- if inspect.isclass(module) and issubclass(module, Module):
338
- module = module()
339
-
340
- if isinstance(module, Module):
341
- module.configure(self)
342
- for provider_name, decorator_args in module.providers:
343
- obj = getattr(module, provider_name)
344
- self.provider(
345
- scope=decorator_args.scope,
346
- override=decorator_args.override,
347
- )(obj)
348
-
349
- def __enter__(self) -> Self:
350
- """Enter the singleton context."""
351
- self.start()
352
- return self
353
-
354
- def __exit__(
355
- self,
356
- exc_type: type[BaseException] | None,
357
- exc_val: BaseException | None,
358
- exc_tb: types.TracebackType | None,
359
- ) -> Any:
360
- """Exit the singleton context."""
361
- return self._singleton_context.__exit__(exc_type, exc_val, exc_tb)
362
-
363
- def start(self) -> None:
364
- """Start the singleton context."""
365
- # Resolve all singleton resources
366
- for interface in self._resources.get("singleton", []):
367
- self.resolve(interface)
368
-
369
- def close(self) -> None:
370
- """Close the singleton context."""
371
- self._singleton_context.close()
372
-
373
- @contextlib.contextmanager
374
- def request_context(self) -> Iterator[InstanceContext]:
375
- """Obtain a context manager for the request-scoped context."""
376
- context = InstanceContext()
377
-
378
- token = self._request_context_var.set(context)
539
+ ############################
540
+ # Instance Methods
541
+ ############################
379
542
 
380
- # Resolve all request resources
381
- for interface in self._resources.get("request", []):
382
- if not is_event_type(interface):
383
- continue
384
- self.resolve(interface)
385
-
386
- with context:
387
- yield context
388
- self._request_context_var.reset(token)
543
+ @overload
544
+ def resolve(self, interface: type[T]) -> T: ...
389
545
 
390
- async def __aenter__(self) -> Self:
391
- """Enter the singleton context."""
392
- await self.astart()
393
- return self
546
+ @overload
547
+ def resolve(self, interface: T) -> T: ...
394
548
 
395
- async def __aexit__(
396
- self,
397
- exc_type: type[BaseException] | None,
398
- exc_val: BaseException | None,
399
- exc_tb: types.TracebackType | None,
400
- ) -> bool:
401
- """Exit the singleton context."""
402
- return await self._singleton_context.__aexit__(exc_type, exc_val, exc_tb)
549
+ def resolve(self, interface: type[T]) -> T:
550
+ """Resolve an instance by interface."""
551
+ return self._resolve_or_create(interface, False)
403
552
 
404
- async def astart(self) -> None:
405
- """Start the singleton context asynchronously."""
406
- for interface in self._resources.get("singleton", []):
407
- await self.aresolve(interface)
553
+ @overload
554
+ async def aresolve(self, interface: type[T]) -> T: ...
408
555
 
409
- async def aclose(self) -> None:
410
- """Close the singleton context asynchronously."""
411
- await self._singleton_context.aclose()
556
+ @overload
557
+ async def aresolve(self, interface: T) -> T: ...
412
558
 
413
- @contextlib.asynccontextmanager
414
- async def arequest_context(self) -> AsyncIterator[InstanceContext]:
415
- """Obtain an async context manager for the request-scoped context."""
416
- context = InstanceContext()
559
+ async def aresolve(self, interface: type[T]) -> T:
560
+ """Resolve an instance by interface asynchronously."""
561
+ return await self._aresolve_or_acreate(interface, False)
417
562
 
418
- token = self._request_context_var.set(context)
563
+ def create(self, interface: type[T], /, **defaults: Any) -> T:
564
+ """Create an instance by interface."""
565
+ return self._resolve_or_create(interface, True, **defaults)
419
566
 
420
- for interface in self._resources.get("request", []):
421
- if not is_event_type(interface):
422
- continue
423
- await self.aresolve(interface)
567
+ async def acreate(self, interface: type[T], /, **defaults: Any) -> T:
568
+ """Create an instance by interface asynchronously."""
569
+ return await self._aresolve_or_acreate(interface, True, **defaults)
424
570
 
425
- async with context:
426
- yield context
427
- self._request_context_var.reset(token)
571
+ def is_resolved(self, interface: AnyInterface) -> bool:
572
+ """Check if an instance by interface exists."""
573
+ try:
574
+ provider = self._get_provider(interface)
575
+ except LookupError:
576
+ return False
577
+ if provider.scope == "transient":
578
+ return False
579
+ context = self._get_scoped_context(provider.scope)
580
+ return interface in context
428
581
 
429
- def _get_request_context(self) -> InstanceContext:
430
- """Get the current request context."""
431
- request_context = self._request_context_var.get()
432
- if request_context is None:
433
- raise LookupError(
434
- "The request context has not been started. Please ensure that "
435
- "the request context is properly initialized before attempting "
436
- "to use it."
437
- )
438
- return request_context
582
+ def release(self, interface: AnyInterface) -> None:
583
+ """Release an instance by interface."""
584
+ provider = self._get_provider(interface)
585
+ if provider.scope == "transient":
586
+ return None
587
+ context = self._get_scoped_context(provider.scope)
588
+ del context[interface]
439
589
 
440
590
  def reset(self) -> None:
441
591
  """Reset resolved instances."""
@@ -448,52 +598,10 @@ class Container:
448
598
  continue
449
599
  del context[interface]
450
600
 
451
- @overload
452
- def resolve(self, interface: type[T]) -> T: ...
453
-
454
- @overload
455
- def resolve(self, interface: T) -> T: ...
456
-
457
- def resolve(self, interface: type[T]) -> T:
458
- """Resolve an instance by interface."""
459
- provider = self._get_or_register_provider(interface, None)
460
- if provider.scope == "transient":
461
- instance = self._create_instance(provider, None)
462
- else:
463
- context = self._get_scoped_context(provider.scope)
464
- if provider.scope == "singleton":
465
- with self._singleton_lock:
466
- instance = self._get_or_create_instance(provider, context)
467
- else:
468
- instance = self._get_or_create_instance(provider, context)
469
- if self.testing:
470
- instance = self._patch_test_resolver(provider.interface, instance)
471
- return cast(T, instance)
472
-
473
- @overload
474
- async def aresolve(self, interface: type[T]) -> T: ...
475
-
476
- @overload
477
- async def aresolve(self, interface: T) -> T: ...
478
-
479
- async def aresolve(self, interface: type[T]) -> T:
480
- """Resolve an instance by interface asynchronously."""
481
- provider = self._get_or_register_provider(interface, None)
482
- if provider.scope == "transient":
483
- instance = await self._acreate_instance(provider, None)
484
- else:
485
- context = self._get_scoped_context(provider.scope)
486
- if provider.scope == "singleton":
487
- async with self._singleton_async_lock:
488
- instance = await self._aget_or_create_instance(provider, context)
489
- else:
490
- instance = await self._aget_or_create_instance(provider, context)
491
- if self.testing:
492
- instance = self._patch_test_resolver(interface, instance)
493
- return cast(T, instance)
494
-
495
- def create(self, interface: type[T], **defaults: Any) -> T:
496
- """Create an instance by interface."""
601
+ def _resolve_or_create(
602
+ self, interface: type[T], create: bool, /, **defaults: Any
603
+ ) -> T:
604
+ """Internal method to handle instance resolution and creation."""
497
605
  provider = self._get_or_register_provider(interface, None, **defaults)
498
606
  if provider.scope == "transient":
499
607
  instance = self._create_instance(provider, None, **defaults)
@@ -501,15 +609,27 @@ class Container:
501
609
  context = self._get_scoped_context(provider.scope)
502
610
  if provider.scope == "singleton":
503
611
  with self._singleton_lock:
504
- instance = self._create_instance(provider, context, **defaults)
612
+ instance = (
613
+ self._get_or_create_instance(provider, context)
614
+ if not create
615
+ else self._create_instance(provider, context, **defaults)
616
+ )
505
617
  else:
506
- instance = self._create_instance(provider, context, **defaults)
618
+ instance = (
619
+ self._get_or_create_instance(provider, context)
620
+ if not create
621
+ else self._create_instance(provider, context, **defaults)
622
+ )
623
+
507
624
  if self.testing:
508
625
  instance = self._patch_test_resolver(provider.interface, instance)
626
+
509
627
  return cast(T, instance)
510
628
 
511
- async def acreate(self, interface: type[T], **defaults: Any) -> T:
512
- """Create an instance by interface."""
629
+ async def _aresolve_or_acreate(
630
+ self, interface: type[T], create: bool, /, **defaults: Any
631
+ ) -> T:
632
+ """Internal method to handle instance resolution and creation asynchronously."""
513
633
  provider = self._get_or_register_provider(interface, None, **defaults)
514
634
  if provider.scope == "transient":
515
635
  instance = await self._acreate_instance(provider, None, **defaults)
@@ -517,13 +637,21 @@ class Container:
517
637
  context = self._get_scoped_context(provider.scope)
518
638
  if provider.scope == "singleton":
519
639
  async with self._singleton_async_lock:
520
- instance = await self._acreate_instance(
521
- provider, context, **defaults
640
+ instance = (
641
+ await self._aget_or_create_instance(provider, context)
642
+ if not create
643
+ else await self._acreate_instance(provider, context, **defaults)
522
644
  )
523
645
  else:
524
- instance = await self._acreate_instance(provider, context, **defaults)
646
+ instance = (
647
+ await self._aget_or_create_instance(provider, context)
648
+ if not create
649
+ else await self._acreate_instance(provider, context, **defaults)
650
+ )
651
+
525
652
  if self.testing:
526
653
  instance = self._patch_test_resolver(provider.interface, instance)
654
+
527
655
  return cast(T, instance)
528
656
 
529
657
  def _get_or_create_instance(
@@ -700,26 +828,48 @@ class Container:
700
828
  def _resolve_parameter(
701
829
  self, provider: Provider, parameter: inspect.Parameter
702
830
  ) -> Any:
703
- self._validate_resolvable_parameter(parameter, call=provider.call)
831
+ self._validate_resolvable_parameter(provider, parameter)
704
832
  return self.resolve(parameter.annotation)
705
833
 
706
834
  async def _aresolve_parameter(
707
835
  self, provider: Provider, parameter: inspect.Parameter
708
836
  ) -> Any:
709
- self._validate_resolvable_parameter(parameter, call=provider.call)
837
+ self._validate_resolvable_parameter(provider, parameter)
710
838
  return await self.aresolve(parameter.annotation)
711
839
 
712
840
  def _validate_resolvable_parameter(
713
- self, parameter: inspect.Parameter, call: Callable[..., Any]
841
+ self, provider: Provider, parameter: inspect.Parameter
714
842
  ) -> None:
715
843
  """Ensure that the specified interface is resolved."""
716
844
  if parameter.annotation in self._unresolved_interfaces:
717
845
  raise LookupError(
718
846
  f"You are attempting to get the parameter `{parameter.name}` with the "
719
847
  f"annotation `{get_full_qualname(parameter.annotation)}` as a "
720
- f"dependency into `{get_full_qualname(call)}` which is not registered "
721
- "or set in the scoped context."
848
+ f"dependency into `{get_full_qualname(provider.call)}` which is "
849
+ "not registered or set in the scoped context."
850
+ )
851
+
852
+ @contextlib.contextmanager
853
+ def override(self, interface: AnyInterface, instance: Any) -> Iterator[None]:
854
+ """
855
+ Override the provider for the specified interface with a specific instance.
856
+ """
857
+ if not self.testing:
858
+ raise RuntimeError(
859
+ "The `override` method can only be used in testing mode."
722
860
  )
861
+ if not self.is_registered(interface) and self.strict:
862
+ raise LookupError(
863
+ f"The provider interface `{get_full_qualname(interface)}` "
864
+ "not registered."
865
+ )
866
+ self._override_instances[interface] = instance
867
+ yield
868
+ del self._override_instances[interface]
869
+
870
+ ############################
871
+ # Testing Methods
872
+ ############################
723
873
 
724
874
  def _patch_test_resolver(self, interface: type[Any], instance: Any) -> Any:
725
875
  """Patch the test resolver for the instance."""
@@ -769,60 +919,9 @@ class Container:
769
919
 
770
920
  return instance
771
921
 
772
- def is_resolved(self, interface: AnyInterface) -> bool:
773
- """Check if an instance by interface exists."""
774
- try:
775
- provider = self._get_provider(interface)
776
- except LookupError:
777
- return False
778
- if provider.scope == "transient":
779
- return False
780
- context = self._get_scoped_context(provider.scope)
781
- return interface in context
782
-
783
- def release(self, interface: AnyInterface) -> None:
784
- """Release an instance by interface."""
785
- provider = self._get_provider(interface)
786
- if provider.scope == "transient":
787
- return None
788
- context = self._get_scoped_context(provider.scope)
789
- del context[interface]
790
-
791
- def _get_scoped_context(self, scope: Scope) -> InstanceContext:
792
- """Get the instance context for the specified scope."""
793
- if scope == "singleton":
794
- return self._singleton_context
795
- return self._get_request_context()
796
-
797
- @contextlib.contextmanager
798
- def override(self, interface: AnyInterface, instance: Any) -> Iterator[None]:
799
- """
800
- Override the provider for the specified interface with a specific instance.
801
- """
802
- if not self.testing:
803
- raise RuntimeError(
804
- "The `override` method can only be used in testing mode."
805
- )
806
- if not self.is_registered(interface) and self.strict:
807
- raise LookupError(
808
- f"The provider interface `{get_full_qualname(interface)}` "
809
- "not registered."
810
- )
811
- self._override_instances[interface] = instance
812
- yield
813
- del self._override_instances[interface]
814
-
815
- def provider(
816
- self, *, scope: Scope, override: bool = False
817
- ) -> Callable[[Callable[P, T]], Callable[P, T]]:
818
- """Decorator to register a provider function with the specified scope."""
819
-
820
- def decorator(call: Callable[P, T]) -> Callable[P, T]:
821
- provider = Provider(call=call, scope=scope)
822
- self._register_provider(provider, override)
823
- return call
824
-
825
- return decorator
922
+ ############################
923
+ # Injector Methods
924
+ ############################
826
925
 
827
926
  @overload
828
927
  def inject(self, func: Callable[P, T]) -> Callable[P, T]: ...
@@ -842,6 +941,10 @@ class Container:
842
941
  return decorator
843
942
  return decorator(func)
844
943
 
944
+ def run(self, func: Callable[P, T], /, *args: P.args, **kwargs: P.kwargs) -> T:
945
+ """Run the given function with injected dependencies."""
946
+ return self._inject(func)(*args, **kwargs)
947
+
845
948
  def _inject(self, call: Callable[P, T]) -> Callable[P, T]:
846
949
  """Inject dependencies into a callable."""
847
950
  if call in self._inject_cache:
@@ -909,9 +1012,39 @@ class Container:
909
1012
  f"`{get_full_qualname(parameter.annotation)}`."
910
1013
  )
911
1014
 
912
- def run(self, func: Callable[P, T], /, *args: P.args, **kwargs: P.kwargs) -> T:
913
- """Run the given function with injected dependencies."""
914
- return self._inject(func)(*args, **kwargs)
1015
+ ############################
1016
+ # Module Methods
1017
+ ############################
1018
+
1019
+ def register_module(
1020
+ self, module: Module | type[Module] | Callable[[Container], None] | str
1021
+ ) -> None:
1022
+ """Register a module as a callable, module type, or module instance."""
1023
+ # Callable Module
1024
+ if inspect.isfunction(module):
1025
+ module(self)
1026
+ return
1027
+
1028
+ # Module path
1029
+ if isinstance(module, str):
1030
+ module = import_string(module)
1031
+
1032
+ # Class based Module or Module type
1033
+ if inspect.isclass(module) and issubclass(module, Module):
1034
+ module = module()
1035
+
1036
+ if isinstance(module, Module):
1037
+ module.configure(self)
1038
+ for provider_name, decorator_args in module.providers:
1039
+ obj = getattr(module, provider_name)
1040
+ self.provider(
1041
+ scope=decorator_args.scope,
1042
+ override=decorator_args.override,
1043
+ )(obj)
1044
+
1045
+ ############################
1046
+ # Scanner Methods
1047
+ ############################
915
1048
 
916
1049
  def scan(
917
1050
  self,
@@ -921,7 +1054,7 @@ class Container:
921
1054
  tags: Iterable[str] | None = None,
922
1055
  ) -> None:
923
1056
  """Scan packages or modules for decorated members and inject dependencies."""
924
- dependencies: list[Dependency] = []
1057
+ dependencies: list[ScannedDependency] = []
925
1058
 
926
1059
  if isinstance(packages, Iterable) and not isinstance(packages, str):
927
1060
  scan_packages: Iterable[ModuleType | str] = packages
@@ -940,7 +1073,7 @@ class Container:
940
1073
  package: ModuleType | str,
941
1074
  *,
942
1075
  tags: Iterable[str] | None = None,
943
- ) -> list[Dependency]:
1076
+ ) -> list[ScannedDependency]:
944
1077
  """Scan a package or module for decorated members."""
945
1078
  tags = tags or []
946
1079
  if isinstance(package, str):
@@ -951,7 +1084,7 @@ class Container:
951
1084
  if not package_path:
952
1085
  return self._scan_module(package, tags=tags)
953
1086
 
954
- dependencies: list[Dependency] = []
1087
+ dependencies: list[ScannedDependency] = []
955
1088
 
956
1089
  for module_info in pkgutil.walk_packages(
957
1090
  path=package_path, prefix=package.__name__ + "."
@@ -963,9 +1096,9 @@ class Container:
963
1096
 
964
1097
  def _scan_module(
965
1098
  self, module: ModuleType, *, tags: Iterable[str]
966
- ) -> list[Dependency]:
1099
+ ) -> list[ScannedDependency]:
967
1100
  """Scan a module for decorated members."""
968
- dependencies: list[Dependency] = []
1101
+ dependencies: list[ScannedDependency] = []
969
1102
 
970
1103
  for _, member in inspect.getmembers(module):
971
1104
  if getattr(member, "__module__", None) != module.__name__ or not callable(
@@ -988,7 +1121,7 @@ class Container:
988
1121
 
989
1122
  if decorator_args.wrapped:
990
1123
  dependencies.append(
991
- self._create_dependency(member=member, module=module)
1124
+ self._create_scanned_dependency(member=member, module=module)
992
1125
  )
993
1126
  continue
994
1127
 
@@ -996,17 +1129,24 @@ class Container:
996
1129
  for parameter in get_typed_parameters(member):
997
1130
  if is_marker(parameter.default):
998
1131
  dependencies.append(
999
- self._create_dependency(member=member, module=module)
1132
+ self._create_scanned_dependency(member=member, module=module)
1000
1133
  )
1001
1134
  continue
1002
1135
 
1003
1136
  return dependencies
1004
1137
 
1005
- def _create_dependency(self, member: Any, module: ModuleType) -> Dependency:
1138
+ def _create_scanned_dependency(
1139
+ self, member: Any, module: ModuleType
1140
+ ) -> ScannedDependency:
1006
1141
  """Create a `Dependency` object from the scanned member and module."""
1007
1142
  if hasattr(member, "__wrapped__"):
1008
1143
  member = member.__wrapped__
1009
- return Dependency(member=member, module=module)
1144
+ return ScannedDependency(member=member, module=module)
1145
+
1146
+
1147
+ ############################
1148
+ # Decorators
1149
+ ############################
1010
1150
 
1011
1151
 
1012
1152
  def transient(target: T) -> T:
anydi/_types.py CHANGED
@@ -1,9 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import enum
3
4
  import inspect
4
5
  from collections.abc import Iterable
6
+ from dataclasses import dataclass
7
+ from functools import cached_property
5
8
  from types import ModuleType
6
- from typing import Annotated, Any, NamedTuple, Union
9
+ from typing import Annotated, Any, Callable, NamedTuple, Union
7
10
 
8
11
  import wrapt
9
12
  from typing_extensions import Literal, Self, TypeAlias
@@ -12,6 +15,8 @@ Scope = Literal["transient", "singleton", "request"]
12
15
 
13
16
  AnyInterface: TypeAlias = Union[type[Any], Annotated[Any, ...]]
14
17
 
18
+ NOT_SET = object()
19
+
15
20
 
16
21
  class Marker:
17
22
  """A marker class for marking dependencies."""
@@ -53,12 +58,85 @@ class InstanceProxy(wrapt.ObjectProxy): # type: ignore[misc]
53
58
  return object.__getattribute__(self, item)
54
59
 
55
60
 
61
+ class ProviderKind(enum.IntEnum):
62
+ CLASS = 1
63
+ FUNCTION = 2
64
+ COROUTINE = 3
65
+ GENERATOR = 4
66
+ ASYNC_GENERATOR = 5
67
+
68
+ @classmethod
69
+ def from_call(cls, call: Callable[..., Any]) -> ProviderKind:
70
+ if inspect.isclass(call):
71
+ return cls.CLASS
72
+ elif inspect.iscoroutinefunction(call):
73
+ return cls.COROUTINE
74
+ elif inspect.isasyncgenfunction(call):
75
+ return cls.ASYNC_GENERATOR
76
+ elif inspect.isgeneratorfunction(call):
77
+ return cls.GENERATOR
78
+ elif inspect.isfunction(call) or inspect.ismethod(call):
79
+ return cls.FUNCTION
80
+ raise TypeError(
81
+ f"The provider `{call}` is invalid because it is not a callable "
82
+ "object. Only callable providers are allowed."
83
+ )
84
+
85
+
86
+ @dataclass(kw_only=True, frozen=True)
87
+ class ProviderParameter:
88
+ pass
89
+
90
+
91
+ @dataclass(kw_only=True, frozen=True)
92
+ class Provider:
93
+ call: Callable[..., Any]
94
+ scope: Scope
95
+ interface: Any
96
+ name: str
97
+ parameters: list[inspect.Parameter]
98
+ kind: ProviderKind
99
+
100
+ def __str__(self) -> str:
101
+ return self.name
102
+
103
+ @cached_property
104
+ def is_class(self) -> bool:
105
+ return self.kind == ProviderKind.CLASS
106
+
107
+ @cached_property
108
+ def is_coroutine(self) -> bool:
109
+ return self.kind == ProviderKind.COROUTINE
110
+
111
+ @cached_property
112
+ def is_generator(self) -> bool:
113
+ return self.kind == ProviderKind.GENERATOR
114
+
115
+ @cached_property
116
+ def is_async_generator(self) -> bool:
117
+ return self.kind == ProviderKind.ASYNC_GENERATOR
118
+
119
+ @cached_property
120
+ def is_async(self) -> bool:
121
+ return self.is_coroutine or self.is_async_generator
122
+
123
+ @cached_property
124
+ def is_resource(self) -> bool:
125
+ return self.is_generator or self.is_async_generator
126
+
127
+
128
+ class ProviderArgs(NamedTuple):
129
+ call: Callable[..., Any]
130
+ scope: Scope
131
+ interface: Any | None = None
132
+
133
+
56
134
  class ProviderDecoratorArgs(NamedTuple):
57
135
  scope: Scope
58
136
  override: bool
59
137
 
60
138
 
61
- class Dependency(NamedTuple):
139
+ class ScannedDependency(NamedTuple):
62
140
  member: Any
63
141
  module: ModuleType
64
142
 
@@ -35,11 +35,9 @@ CONTAINER_FIXTURE_NAME = "container"
35
35
 
36
36
 
37
37
  @pytest.fixture
38
- def anydi_setup_container(
39
- request: pytest.FixtureRequest,
40
- ) -> Iterator[Container]:
38
+ def anydi_setup_container(request: pytest.FixtureRequest) -> Container:
41
39
  try:
42
- container = request.getfixturevalue(CONTAINER_FIXTURE_NAME)
40
+ return cast(Container, request.getfixturevalue(CONTAINER_FIXTURE_NAME))
43
41
  except pytest.FixtureLookupError as exc:
44
42
  exc.msg = (
45
43
  "`container` fixture is not found. Make sure to define it in your test "
@@ -47,8 +45,6 @@ def anydi_setup_container(
47
45
  )
48
46
  raise exc
49
47
 
50
- yield container
51
-
52
48
 
53
49
  @pytest.fixture
54
50
  def _anydi_should_inject(request: pytest.FixtureRequest) -> bool:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: anydi
3
- Version: 0.38.1
3
+ Version: 0.38.2rc0
4
4
  Summary: Dependency Injection library
5
5
  Project-URL: Repository, https://github.com/antonrh/anydi
6
6
  Author-email: Anton Ruhlov <antonruhlov@gmail.com>
@@ -1,8 +1,7 @@
1
- anydi/__init__.py,sha256=OfRg2EfXD65pHTGQKhfkABMwUhw5LvsuTQV_Tv4V4wk,501
2
- anydi/_container.py,sha256=b2WKdDLwLj50u0DUOdAku2oYOKpw3UD7rCuytu2x_R4,38732
1
+ anydi/__init__.py,sha256=aAq10a1V_zQ3_Me3p_pll5d1O77PIgqotkOm3pshORI,495
2
+ anydi/_container.py,sha256=BqpvUPeYt6OW7TLIDm-OvMGCbcxnvXA6KyOF9XBmi7M,43072
3
3
  anydi/_context.py,sha256=7LV_SL4QWkJeiG7_4D9PZ5lmU-MPzhofxC95zCgY9Gc,2651
4
- anydi/_provider.py,sha256=1IyxHO83NHjsPDHLDIZtW1pJ7i8VpWD3EM4T6duw9zA,7661
5
- anydi/_types.py,sha256=fdO4xNXtGMxVArmlfDkFYbyR895ixkBTW6V8lMceN7Q,1562
4
+ anydi/_types.py,sha256=oFyx6jxkEsz5FZk6tdRjUmBBQ2tX7eA_bLaa2elq7Mg,3586
6
5
  anydi/_utils.py,sha256=INI0jNIXrJ6LS4zqJymMO2yUEobpxmBGASf4G_vR6AU,4378
7
6
  anydi/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
7
  anydi/ext/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -10,7 +9,7 @@ anydi/ext/_utils.py,sha256=U6sRqWzccWUu7eMhbXX1NrwcaxitQF9cO1KxnKF37gw,2566
10
9
  anydi/ext/fastapi.py,sha256=AEL3ubu-LxUPHMMt1YIn3En_JZC7nyBKmKxmhka3O3c,2436
11
10
  anydi/ext/faststream.py,sha256=qXnNGvAqWWp9kbhbQUE6EF_OPUiYQmtOH211_O7BI_0,1898
12
11
  anydi/ext/pydantic_settings.py,sha256=8IXXLuG_OvKbvKlBkBRQUHcXgbTpgQUxeWyoMcRIUQM,1488
13
- anydi/ext/pytest_plugin.py,sha256=25A93Yon0cr4y8CszQB7zOJRiAXW-ZwtO7_Ji8mqcXs,4639
12
+ anydi/ext/pytest_plugin.py,sha256=yR_vos8qt8uFS9uy_G_HJjNgudzJIXawrtpWnn3Pu_s,4613
14
13
  anydi/ext/django/__init__.py,sha256=QI1IABCVgSDTUoh7M9WMECKXwB3xvh04HfQ9TOWw1Mk,223
15
14
  anydi/ext/django/_container.py,sha256=cxVoYQG16WP0S_Yv4TnLwuaaT7NVEOhLWO-YdALJUb4,418
16
15
  anydi/ext/django/_settings.py,sha256=Z0RlAuXoO73oahWeMkK10w8c-4uCBde-DBpeKTV5USY,853
@@ -22,8 +21,8 @@ anydi/ext/django/ninja/_operation.py,sha256=wSWa7D73XTVlOibmOciv2l6JHPe1ERZcXrqI
22
21
  anydi/ext/django/ninja/_signature.py,sha256=2cSzKxBIxXLqtwNuH6GSlmjVJFftoGmleWfyk_NVEWw,2207
23
22
  anydi/ext/starlette/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
24
23
  anydi/ext/starlette/middleware.py,sha256=9CQtGg5ZzUz2gFSzJr8U4BWzwNjK8XMctm3n52M77Z0,792
25
- anydi-0.38.1.dist-info/METADATA,sha256=_DWptDh85lW1UwzL1AnIaCiB9THnpeaCI4AlupFYjG0,4917
26
- anydi-0.38.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
27
- anydi-0.38.1.dist-info/entry_points.txt,sha256=Nklo9f3Oe4AkNsEgC4g43nCJ-23QDngZSVDNRMdaILI,43
28
- anydi-0.38.1.dist-info/licenses/LICENSE,sha256=V6rU8a8fv6o2jQ-7ODHs0XfDFimot8Q6Km6CylRIDTo,1069
29
- anydi-0.38.1.dist-info/RECORD,,
24
+ anydi-0.38.2rc0.dist-info/METADATA,sha256=t8FL1ILCUK37XUFQ_gqgAlwX77UytT5FeP4QueIDX50,4920
25
+ anydi-0.38.2rc0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
26
+ anydi-0.38.2rc0.dist-info/entry_points.txt,sha256=Nklo9f3Oe4AkNsEgC4g43nCJ-23QDngZSVDNRMdaILI,43
27
+ anydi-0.38.2rc0.dist-info/licenses/LICENSE,sha256=V6rU8a8fv6o2jQ-7ODHs0XfDFimot8Q6Km6CylRIDTo,1069
28
+ anydi-0.38.2rc0.dist-info/RECORD,,
anydi/_provider.py DELETED
@@ -1,232 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import inspect
4
- import uuid
5
- from collections.abc import AsyncIterator, Iterator
6
- from enum import IntEnum
7
- from typing import Any, Callable
8
-
9
- from typing_extensions import get_args, get_origin
10
-
11
- try:
12
- from types import NoneType
13
- except ImportError:
14
- NoneType = type(None) # type: ignore[misc]
15
-
16
-
17
- from ._types import Event, Scope
18
- from ._utils import get_full_qualname, get_typed_annotation
19
-
20
- _sentinel = object()
21
-
22
-
23
- class CallableKind(IntEnum):
24
- CLASS = 1
25
- FUNCTION = 2
26
- COROUTINE = 3
27
- GENERATOR = 4
28
- ASYNC_GENERATOR = 5
29
-
30
-
31
- class Provider:
32
- __slots__ = (
33
- "_call",
34
- "_call_module",
35
- "_call_globals",
36
- "_scope",
37
- "_qualname",
38
- "_kind",
39
- "_interface",
40
- "_parameters",
41
- "_is_class",
42
- "_is_coroutine",
43
- "_is_generator",
44
- "_is_async_generator",
45
- "_is_async",
46
- "_is_resource",
47
- )
48
-
49
- def __init__(
50
- self, call: Callable[..., Any], *, scope: Scope, interface: Any = _sentinel
51
- ) -> None:
52
- self._call = call
53
- self._call_module = getattr(call, "__module__", None)
54
- self._call_globals = getattr(call, "__globals__", {})
55
- self._scope = scope
56
- self._qualname = get_full_qualname(call)
57
-
58
- # Detect the kind of callable provider
59
- self._detect_kind()
60
-
61
- self._is_class = self._kind == CallableKind.CLASS
62
- self._is_coroutine = self._kind == CallableKind.COROUTINE
63
- self._is_generator = self._kind == CallableKind.GENERATOR
64
- self._is_async_generator = self._kind == CallableKind.ASYNC_GENERATOR
65
- self._is_async = self._is_coroutine or self._is_async_generator
66
- self._is_resource = self._is_generator or self._is_async_generator
67
-
68
- # Validate the scope of the provider
69
- self._validate_scope()
70
-
71
- # Get the signature
72
- signature = inspect.signature(call)
73
-
74
- # Detect the interface
75
- self._detect_interface(interface, signature)
76
-
77
- # Detect the parameters
78
- self._detect_parameters(signature)
79
-
80
- def __str__(self) -> str:
81
- return self._qualname
82
-
83
- def __eq__(self, other: object) -> bool:
84
- if not isinstance(other, Provider):
85
- return NotImplemented # pragma: no cover
86
- return (
87
- self._call == other._call
88
- and self._scope == other._scope
89
- and self._interface == other._interface
90
- )
91
-
92
- @property
93
- def call(self) -> Callable[..., Any]:
94
- return self._call
95
-
96
- @property
97
- def kind(self) -> CallableKind:
98
- return self._kind
99
-
100
- @property
101
- def scope(self) -> Scope:
102
- return self._scope
103
-
104
- @property
105
- def interface(self) -> Any:
106
- return self._interface
107
-
108
- @property
109
- def parameters(self) -> list[inspect.Parameter]:
110
- return self._parameters
111
-
112
- @property
113
- def is_class(self) -> bool:
114
- """Check if the provider is a class."""
115
- return self._is_class
116
-
117
- @property
118
- def is_coroutine(self) -> bool:
119
- """Check if the provider is a coroutine."""
120
- return self._is_coroutine
121
-
122
- @property
123
- def is_generator(self) -> bool:
124
- """Check if the provider is a generator."""
125
- return self._is_generator
126
-
127
- @property
128
- def is_async_generator(self) -> bool:
129
- """Check if the provider is an async generator."""
130
- return self._is_async_generator
131
-
132
- @property
133
- def is_async(self) -> bool:
134
- """Check if the provider is an async callable."""
135
- return self._is_async
136
-
137
- @property
138
- def is_resource(self) -> bool:
139
- """Check if the provider is a resource."""
140
- return self._is_resource
141
-
142
- def _validate_scope(self) -> None:
143
- """Validate the scope of the provider."""
144
- if self.scope not in get_args(Scope):
145
- raise ValueError(
146
- "The scope provided is invalid. Only the following scopes are "
147
- f"supported: {', '.join(get_args(Scope))}. Please use one of the "
148
- "supported scopes when registering a provider."
149
- )
150
- if self.is_resource and self.scope == "transient":
151
- raise TypeError(
152
- f"The resource provider `{self}` is attempting to register "
153
- "with a transient scope, which is not allowed."
154
- )
155
-
156
- def _detect_kind(self) -> None:
157
- """Detect the kind of callable provider."""
158
- if inspect.isclass(self.call):
159
- self._kind = CallableKind.CLASS
160
- elif inspect.iscoroutinefunction(self.call):
161
- self._kind = CallableKind.COROUTINE
162
- elif inspect.isasyncgenfunction(self.call):
163
- self._kind = CallableKind.ASYNC_GENERATOR
164
- elif inspect.isgeneratorfunction(self.call):
165
- self._kind = CallableKind.GENERATOR
166
- elif inspect.isfunction(self.call) or inspect.ismethod(self.call):
167
- self._kind = CallableKind.FUNCTION
168
- else:
169
- raise TypeError(
170
- f"The provider `{self.call}` is invalid because it is not a callable "
171
- "object. Only callable providers are allowed."
172
- )
173
-
174
- def _detect_interface(self, interface: Any, signature: inspect.Signature) -> None:
175
- """Detect the interface of callable provider."""
176
- # If the callable is a class, return the class itself
177
- if self._kind == CallableKind.CLASS:
178
- self._interface = self._call
179
- return
180
-
181
- if interface is _sentinel:
182
- interface = self._resolve_interface(interface, signature)
183
-
184
- # If the callable is an iterator, return the actual type
185
- iterator_types = {Iterator, AsyncIterator}
186
- if interface in iterator_types or get_origin(interface) in iterator_types:
187
- if args := get_args(interface):
188
- interface = args[0]
189
- # If the callable is a generator, return the resource type
190
- if interface is NoneType or interface is None:
191
- self._interface = type(f"Event_{uuid.uuid4().hex}", (Event,), {})
192
- return
193
- else:
194
- raise TypeError(
195
- f"Cannot use `{self}` resource type annotation "
196
- "without actual type argument."
197
- )
198
-
199
- # None interface is not allowed
200
- if interface in {None, NoneType}:
201
- raise TypeError(f"Missing `{self}` provider return annotation.")
202
-
203
- # Set the interface
204
- self._interface = interface
205
-
206
- def _resolve_interface(self, interface: Any, signature: inspect.Signature) -> Any:
207
- """Resolve the interface of the callable provider."""
208
- interface = signature.return_annotation
209
- if interface is inspect.Signature.empty:
210
- return None
211
- return get_typed_annotation(
212
- interface,
213
- self._call_globals,
214
- module=self._call_module,
215
- )
216
-
217
- def _detect_parameters(self, signature: inspect.Signature) -> None:
218
- """Detect the parameters of the callable provider."""
219
- parameters = []
220
- for parameter in signature.parameters.values():
221
- if parameter.kind == inspect.Parameter.POSITIONAL_ONLY:
222
- raise TypeError(
223
- f"Positional-only parameter `{parameter.name}` is not allowed "
224
- f"in the provider `{self}`."
225
- )
226
- annotation = get_typed_annotation(
227
- parameter.annotation,
228
- self._call_globals,
229
- module=self._call_module,
230
- )
231
- parameters.append(parameter.replace(annotation=annotation))
232
- self._parameters = parameters