anydi 0.53.0__tar.gz → 0.54.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: anydi
3
- Version: 0.53.0
3
+ Version: 0.54.1
4
4
  Summary: Dependency Injection library
5
5
  Keywords: dependency injection,dependencies,di,async,asyncio,application
6
6
  Author: Anton Ruhlov
@@ -19,7 +19,7 @@ from ._async import run_sync
19
19
  from ._context import InstanceContext
20
20
  from ._decorators import is_provided
21
21
  from ._module import ModuleDef, ModuleRegistrar
22
- from ._provider import Provider, ProviderDef, ProviderKind
22
+ from ._provider import Provider, ProviderDef, ProviderKind, ProviderParameter
23
23
  from ._scan import PackageOrIterable, Scanner
24
24
  from ._scope import ALLOWED_SCOPES, Scope
25
25
  from ._typing import (
@@ -298,7 +298,7 @@ class Container:
298
298
 
299
299
  unresolved_parameter = None
300
300
  unresolved_exc: LookupError | None = None
301
- parameters: list[inspect.Parameter] = []
301
+ parameters: list[ProviderParameter] = []
302
302
  scopes: dict[Scope, Provider] = {}
303
303
 
304
304
  for parameter in signature.parameters.values():
@@ -326,7 +326,21 @@ class Container:
326
326
  if sub_provider.scope not in scopes:
327
327
  scopes[sub_provider.scope] = sub_provider
328
328
 
329
- parameters.append(parameter)
329
+ default = (
330
+ parameter.default
331
+ if parameter.default is not inspect.Parameter.empty
332
+ else NOT_SET
333
+ )
334
+ parameters.append(
335
+ ProviderParameter(
336
+ name=parameter.name,
337
+ annotation=parameter.annotation,
338
+ default=default,
339
+ has_default=default is not NOT_SET,
340
+ provider=sub_provider,
341
+ shared_scope=sub_provider.scope == scope and scope != "transient",
342
+ )
343
+ )
330
344
 
331
345
  # Check for unresolved parameters
332
346
  if unresolved_parameter:
@@ -356,7 +370,7 @@ class Container:
356
370
  interface=interface,
357
371
  name=name,
358
372
  kind=kind,
359
- parameters=parameters,
373
+ parameters=tuple(parameters),
360
374
  )
361
375
 
362
376
  self._set_provider(provider)
@@ -495,9 +509,36 @@ class Container:
495
509
  ) -> Any:
496
510
  """Internal method to handle instance resolution and creation."""
497
511
  provider = self._get_or_register_provider(interface, **defaults)
512
+ return self._resolve_with_provider(provider, create, **defaults)
513
+
514
+ async def _aresolve_or_create(
515
+ self, interface: Any, create: bool, /, **defaults: Any
516
+ ) -> Any:
517
+ """Internal method to handle instance resolution and creation asynchronously."""
518
+ provider = self._get_or_register_provider(interface, **defaults)
519
+ return await self._aresolve_with_provider(provider, create, **defaults)
520
+
521
+ def _resolve_with_provider(
522
+ self, provider: Provider, create: bool, /, **defaults: Any
523
+ ) -> Any:
498
524
  if provider.scope == "transient":
499
525
  return self._create_instance(provider, None, **defaults)
526
+
527
+ if provider.scope == "request":
528
+ context = self._get_request_context()
529
+ if not create:
530
+ cached = context.get(provider.interface, NOT_SET)
531
+ if cached is not NOT_SET:
532
+ return cached
533
+ if not create:
534
+ return self._get_or_create_instance(provider, context)
535
+ return self._create_instance(provider, context, **defaults)
536
+
500
537
  context = self._get_instance_context(provider.scope)
538
+ if not create:
539
+ cached = context.get(provider.interface, NOT_SET)
540
+ if cached is not NOT_SET:
541
+ return cached
501
542
  with context.lock():
502
543
  return (
503
544
  self._get_or_create_instance(provider, context)
@@ -505,14 +546,27 @@ class Container:
505
546
  else self._create_instance(provider, context, **defaults)
506
547
  )
507
548
 
508
- async def _aresolve_or_create(
509
- self, interface: Any, create: bool, /, **defaults: Any
549
+ async def _aresolve_with_provider(
550
+ self, provider: Provider, create: bool, /, **defaults: Any
510
551
  ) -> Any:
511
- """Internal method to handle instance resolution and creation asynchronously."""
512
- provider = self._get_or_register_provider(interface, **defaults)
513
552
  if provider.scope == "transient":
514
553
  return await self._acreate_instance(provider, None, **defaults)
554
+
555
+ if provider.scope == "request":
556
+ context = self._get_request_context()
557
+ if not create:
558
+ cached = context.get(provider.interface, NOT_SET)
559
+ if cached is not NOT_SET:
560
+ return cached
561
+ if not create:
562
+ return await self._aget_or_create_instance(provider, context)
563
+ return await self._acreate_instance(provider, context, **defaults)
564
+
515
565
  context = self._get_instance_context(provider.scope)
566
+ if not create:
567
+ cached = context.get(provider.interface, NOT_SET)
568
+ if cached is not NOT_SET:
569
+ return cached
516
570
  async with context.alock():
517
571
  return (
518
572
  await self._aget_or_create_instance(provider, context)
@@ -524,8 +578,8 @@ class Container:
524
578
  self, provider: Provider, context: InstanceContext
525
579
  ) -> Any:
526
580
  """Get an instance of a dependency from the scoped context."""
527
- instance = context.get(provider.interface)
528
- if instance is None:
581
+ instance = context.get(provider.interface, NOT_SET)
582
+ if instance is NOT_SET:
529
583
  instance = self._create_instance(provider, context)
530
584
  context.set(provider.interface, instance)
531
585
  return instance
@@ -535,8 +589,8 @@ class Container:
535
589
  self, provider: Provider, context: InstanceContext
536
590
  ) -> Any:
537
591
  """Get an async instance of a dependency from the scoped context."""
538
- instance = context.get(provider.interface)
539
- if instance is None:
592
+ instance = context.get(provider.interface, NOT_SET)
593
+ if instance is NOT_SET:
540
594
  instance = await self._acreate_instance(provider, context)
541
595
  context.set(provider.interface, instance)
542
596
  return instance
@@ -552,7 +606,9 @@ class Container:
552
606
  "synchronous mode."
553
607
  )
554
608
 
555
- provider_kwargs = self._get_provided_kwargs(provider, context, **defaults)
609
+ provider_kwargs = self._get_provided_kwargs(
610
+ provider, context, defaults=defaults if defaults else None
611
+ )
556
612
 
557
613
  if provider.is_generator:
558
614
  if context is None:
@@ -570,7 +626,7 @@ class Container:
570
626
  ) -> Any:
571
627
  """Create an instance asynchronously using the provider."""
572
628
  provider_kwargs = await self._aget_provided_kwargs(
573
- provider, context, **defaults
629
+ provider, context, defaults=defaults if defaults else None
574
630
  )
575
631
 
576
632
  if provider.is_coroutine:
@@ -604,97 +660,143 @@ class Container:
604
660
  return instance
605
661
 
606
662
  def _get_provided_kwargs(
607
- self, provider: Provider, context: InstanceContext | None, /, **defaults: Any
663
+ self,
664
+ provider: Provider,
665
+ context: InstanceContext | None,
666
+ /,
667
+ defaults: dict[str, Any] | None = None,
608
668
  ) -> dict[str, Any]:
609
669
  """Retrieve the arguments for a provider."""
610
- provided_kwargs = {}
670
+ if not provider.parameters:
671
+ return defaults if defaults else {}
672
+
673
+ provided_kwargs = dict(defaults) if defaults else {}
611
674
  for parameter in provider.parameters:
612
675
  provided_kwargs[parameter.name] = self._get_provider_instance(
613
- provider, parameter, context, **defaults
676
+ provider,
677
+ parameter,
678
+ context,
679
+ defaults=defaults,
614
680
  )
615
- return {**defaults, **provided_kwargs}
681
+ return provided_kwargs
616
682
 
617
- def _get_provider_instance(
683
+ def _get_provider_instance( # noqa: C901
618
684
  self,
619
685
  provider: Provider,
620
- parameter: inspect.Parameter,
686
+ parameter: ProviderParameter,
621
687
  context: InstanceContext | None,
622
688
  /,
623
- **defaults: Any,
689
+ *,
690
+ defaults: dict[str, Any] | None = None,
624
691
  ) -> Any:
625
692
  """Retrieve an instance of a dependency from the scoped context."""
626
693
 
627
- # Try to get instance from defaults
628
- if parameter.name in defaults:
694
+ if defaults and parameter.name in defaults:
629
695
  return defaults[parameter.name]
630
696
 
631
- # Try to get instance from context
632
- elif context and parameter.annotation in context:
633
- instance = context[parameter.annotation]
697
+ sub_provider = parameter.provider
634
698
 
635
- # Resolve new instance
636
- else:
637
- try:
638
- instance = self._resolve_parameter(provider, parameter)
639
- except LookupError:
640
- if parameter.default is inspect.Parameter.empty:
641
- raise
642
- return parameter.default
643
- return instance
699
+ if context and parameter.shared_scope and sub_provider is not None:
700
+ existing = context.get(sub_provider.interface, NOT_SET)
701
+ if existing is not NOT_SET:
702
+ return existing
703
+
704
+ if context:
705
+ cached = context.get(parameter.annotation, NOT_SET)
706
+ if cached is not NOT_SET:
707
+ return cached
708
+
709
+ sub_provider = parameter.provider
710
+
711
+ if sub_provider:
712
+ if sub_provider.scope == "transient":
713
+ return self._create_instance(sub_provider, None)
714
+ if sub_provider.scope == "singleton" and sub_provider is not provider:
715
+ return self._resolve_with_provider(sub_provider, False)
716
+
717
+ try:
718
+ return self._resolve_parameter(provider, parameter)
719
+ except LookupError:
720
+ if not parameter.has_default:
721
+ raise
722
+ return parameter.default
644
723
 
645
724
  async def _aget_provided_kwargs(
646
- self, provider: Provider, context: InstanceContext | None, /, **defaults: Any
725
+ self,
726
+ provider: Provider,
727
+ context: InstanceContext | None,
728
+ /,
729
+ defaults: dict[str, Any] | None = None,
647
730
  ) -> dict[str, Any]:
648
731
  """Asynchronously retrieve the arguments for a provider."""
649
- provided_kwargs = {}
732
+ if not provider.parameters:
733
+ return defaults if defaults else {}
734
+
735
+ provided_kwargs = dict(defaults) if defaults else {}
650
736
  for parameter in provider.parameters:
651
737
  provided_kwargs[parameter.name] = await self._aget_provider_instance(
652
- provider, parameter, context, **defaults
738
+ provider,
739
+ parameter,
740
+ context,
741
+ defaults=defaults,
653
742
  )
654
- return {**defaults, **provided_kwargs}
743
+ return provided_kwargs
655
744
 
656
- async def _aget_provider_instance(
745
+ async def _aget_provider_instance( # noqa: C901
657
746
  self,
658
747
  provider: Provider,
659
- parameter: inspect.Parameter,
748
+ parameter: ProviderParameter,
660
749
  context: InstanceContext | None,
661
750
  /,
662
- **defaults: Any,
751
+ *,
752
+ defaults: dict[str, Any] | None = None,
663
753
  ) -> Any:
664
754
  """Asynchronously retrieve an instance of a dependency from the context."""
665
755
 
666
- # Try to get instance from defaults
667
- if parameter.name in defaults:
756
+ if defaults and parameter.name in defaults:
668
757
  return defaults[parameter.name]
669
758
 
670
- # Try to get instance from context
671
- elif context and parameter.annotation in context:
672
- instance = context[parameter.annotation]
759
+ sub_provider = parameter.provider
673
760
 
674
- # Resolve new instance
675
- else:
676
- try:
677
- instance = await self._aresolve_parameter(provider, parameter)
678
- except LookupError:
679
- if parameter.default is inspect.Parameter.empty:
680
- raise
681
- return parameter.default
682
- return instance
761
+ if context and parameter.shared_scope and sub_provider is not None:
762
+ existing = context.get(sub_provider.interface, NOT_SET)
763
+ if existing is not NOT_SET:
764
+ return existing
765
+
766
+ if context:
767
+ cached = context.get(parameter.annotation, NOT_SET)
768
+ if cached is not NOT_SET:
769
+ return cached
770
+
771
+ sub_provider = parameter.provider
772
+
773
+ if sub_provider:
774
+ if sub_provider.scope == "transient":
775
+ return await self._acreate_instance(sub_provider, None)
776
+ if sub_provider.scope == "singleton" and sub_provider is not provider:
777
+ return await self._aresolve_with_provider(sub_provider, False)
778
+
779
+ try:
780
+ return await self._aresolve_parameter(provider, parameter)
781
+ except LookupError:
782
+ if not parameter.has_default:
783
+ raise
784
+ return parameter.default
683
785
 
684
786
  def _resolve_parameter(
685
- self, provider: Provider, parameter: inspect.Parameter
787
+ self, provider: Provider, parameter: ProviderParameter
686
788
  ) -> Any:
687
789
  self._validate_resolvable_parameter(provider, parameter)
688
790
  return self._resolve_or_create(parameter.annotation, False)
689
791
 
690
792
  async def _aresolve_parameter(
691
- self, provider: Provider, parameter: inspect.Parameter
793
+ self, provider: Provider, parameter: ProviderParameter
692
794
  ) -> Any:
693
795
  self._validate_resolvable_parameter(provider, parameter)
694
796
  return await self._aresolve_or_create(parameter.annotation, False)
695
797
 
696
798
  def _validate_resolvable_parameter(
697
- self, provider: Provider, parameter: inspect.Parameter
799
+ self, provider: Provider, parameter: ProviderParameter
698
800
  ) -> None:
699
801
  """Ensure that the specified interface is resolved."""
700
802
  if parameter.annotation in self._unresolved_interfaces:
@@ -8,6 +8,7 @@ from typing import Any
8
8
  from typing_extensions import Self
9
9
 
10
10
  from ._async import AsyncRLock, run_sync
11
+ from ._typing import NOT_SET
11
12
 
12
13
 
13
14
  class InstanceContext:
@@ -22,9 +23,9 @@ class InstanceContext:
22
23
  self._lock = threading.RLock()
23
24
  self._async_lock = AsyncRLock()
24
25
 
25
- def get(self, interface: Any) -> Any | None:
26
+ def get(self, interface: Any, default: Any = NOT_SET) -> Any:
26
27
  """Get an instance from the context."""
27
- return self._instances.get(interface)
28
+ return self._instances.get(interface, default)
28
29
 
29
30
  def set(self, interface: Any, value: Any) -> None:
30
31
  """Set an instance in the context."""
@@ -5,7 +5,7 @@ import inspect
5
5
  from collections.abc import Callable
6
6
  from dataclasses import dataclass
7
7
  from functools import cached_property
8
- from typing import Any, NamedTuple
8
+ from typing import Any
9
9
 
10
10
  from ._scope import Scope
11
11
  from ._typing import NOT_SET
@@ -39,13 +39,23 @@ class ProviderKind(enum.IntEnum):
39
39
  return kind in (cls.GENERATOR, cls.ASYNC_GENERATOR)
40
40
 
41
41
 
42
+ @dataclass(kw_only=True, frozen=True, slots=True)
43
+ class ProviderParameter:
44
+ name: str
45
+ annotation: Any
46
+ default: Any
47
+ has_default: bool
48
+ provider: Provider | None = None
49
+ shared_scope: bool = False
50
+
51
+
42
52
  @dataclass(kw_only=True, frozen=True)
43
53
  class Provider:
44
54
  call: Callable[..., Any]
45
55
  scope: Scope
46
56
  interface: Any
47
57
  name: str
48
- parameters: list[inspect.Parameter]
58
+ parameters: tuple[ProviderParameter, ...]
49
59
  kind: ProviderKind
50
60
 
51
61
  def __str__(self) -> str:
@@ -76,7 +86,8 @@ class Provider:
76
86
  return ProviderKind.is_resource(self.kind)
77
87
 
78
88
 
79
- class ProviderDef(NamedTuple):
89
+ @dataclass(frozen=True, slots=True)
90
+ class ProviderDef:
80
91
  call: Callable[..., Any]
81
92
  scope: Scope
82
93
  interface: Any = NOT_SET
@@ -7,6 +7,7 @@ from typing import Any, cast
7
7
 
8
8
  import pytest
9
9
  from anyio.pytest_plugin import extract_backend_and_options, get_runner
10
+ from typing_extensions import get_annotations
10
11
 
11
12
  from anydi import Container
12
13
 
@@ -60,16 +61,14 @@ def _anydi_injected_parameter_iterator(
60
61
  )
61
62
 
62
63
  def _iterator() -> Iterator[tuple[str, inspect.Parameter]]:
63
- for parameter in inspect.signature(
64
+ for name, annotation in get_annotations(
64
65
  request.function, eval_str=True
65
- ).parameters.values():
66
- interface = parameter.annotation
67
- if (
68
- interface is inspect.Parameter.empty
69
- or parameter.name not in fixturenames
70
- ):
66
+ ).items():
67
+ if name == "return":
71
68
  continue
72
- yield parameter.name, interface
69
+ if name not in fixturenames:
70
+ continue
71
+ yield name, annotation
73
72
 
74
73
  return _iterator
75
74
 
@@ -1,5 +1,4 @@
1
1
  import contextlib
2
- import inspect
3
2
  import logging
4
3
  from collections.abc import Iterable, Iterator, Sequence
5
4
  from typing import Any, TypeVar
@@ -10,7 +9,7 @@ from typing_extensions import Self, type_repr
10
9
  from ._container import Container
11
10
  from ._context import InstanceContext
12
11
  from ._module import ModuleDef
13
- from ._provider import Provider, ProviderDef
12
+ from ._provider import Provider, ProviderDef, ProviderParameter
14
13
  from ._scope import Scope
15
14
 
16
15
  T = TypeVar("T")
@@ -76,7 +75,7 @@ class TestContainer(Container):
76
75
  def _get_provider_instance(
77
76
  self,
78
77
  provider: Provider,
79
- parameter: inspect.Parameter,
78
+ parameter: ProviderParameter,
80
79
  context: InstanceContext | None,
81
80
  /,
82
81
  **defaults: Any,
@@ -90,7 +89,7 @@ class TestContainer(Container):
90
89
  async def _aget_provider_instance(
91
90
  self,
92
91
  provider: Provider,
93
- parameter: inspect.Parameter,
92
+ parameter: ProviderParameter,
94
93
  context: InstanceContext | None,
95
94
  /,
96
95
  **defaults: Any,
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "anydi"
3
- version = "0.53.0"
3
+ version = "0.54.1"
4
4
  description = "Dependency Injection library"
5
5
  authors = [{ name = "Anton Ruhlov", email = "antonruhlov@gmail.com" }]
6
6
  requires-python = ">=3.10.0, <3.15"
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes