python-injection 0.10.12.post0__py3-none-any.whl → 0.12.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
injection/_core/module.py CHANGED
@@ -1,21 +1,23 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import asyncio
3
4
  import inspect
4
5
  from abc import ABC, abstractmethod
5
6
  from collections import OrderedDict
6
7
  from collections.abc import (
8
+ AsyncIterator,
9
+ Awaitable,
7
10
  Callable,
8
11
  Collection,
9
12
  Iterable,
10
13
  Iterator,
11
14
  Mapping,
12
- MutableMapping,
13
15
  )
14
16
  from contextlib import contextmanager, suppress
15
17
  from dataclasses import dataclass, field
16
18
  from enum import StrEnum
17
19
  from functools import partialmethod, singledispatchmethod, update_wrapper
18
- from inspect import Signature, isclass
20
+ from inspect import Signature, isclass, iscoroutinefunction
19
21
  from logging import Logger, getLogger
20
22
  from queue import Empty, Queue
21
23
  from types import MethodType
@@ -25,22 +27,34 @@ from typing import (
25
27
  ContextManager,
26
28
  Literal,
27
29
  NamedTuple,
28
- NoReturn,
29
30
  Protocol,
30
31
  Self,
32
+ TypeGuard,
33
+ overload,
31
34
  override,
32
35
  runtime_checkable,
33
36
  )
34
37
  from uuid import uuid4
35
38
 
39
+ from injection._core.common.asynchronous import (
40
+ AsyncCaller,
41
+ Caller,
42
+ SimpleAwaitable,
43
+ SyncCaller,
44
+ )
36
45
  from injection._core.common.event import Event, EventChannel, EventListener
37
46
  from injection._core.common.invertible import Invertible, SimpleInvertible
38
47
  from injection._core.common.lazy import Lazy, LazyMapping
39
48
  from injection._core.common.threading import synchronized
40
49
  from injection._core.common.type import InputType, TypeInfo, get_return_types
41
50
  from injection._core.hook import Hook, apply_hooks
51
+ from injection._core.injectables import (
52
+ Injectable,
53
+ ShouldBeInjectable,
54
+ SimpleInjectable,
55
+ SingletonInjectable,
56
+ )
42
57
  from injection.exceptions import (
43
- InjectionError,
44
58
  ModuleError,
45
59
  ModuleLockError,
46
60
  ModuleNotUsedError,
@@ -129,86 +143,6 @@ class ModulePriorityUpdated(ModuleEvent):
129
143
  )
130
144
 
131
145
 
132
- """
133
- Injectables
134
- """
135
-
136
-
137
- @runtime_checkable
138
- class Injectable[T](Protocol):
139
- __slots__ = ()
140
-
141
- @property
142
- def is_locked(self) -> bool:
143
- return False
144
-
145
- def unlock(self) -> None:
146
- return
147
-
148
- @abstractmethod
149
- def get_instance(self) -> T:
150
- raise NotImplementedError
151
-
152
-
153
- @dataclass(repr=False, frozen=True, slots=True)
154
- class BaseInjectable[T](Injectable[T], ABC):
155
- factory: Callable[..., T]
156
-
157
-
158
- class SimpleInjectable[T](BaseInjectable[T]):
159
- __slots__ = ()
160
-
161
- @override
162
- def get_instance(self) -> T:
163
- return self.factory()
164
-
165
-
166
- class SingletonInjectable[T](BaseInjectable[T]):
167
- __slots__ = ("__dict__",)
168
-
169
- __key: ClassVar[str] = "$instance"
170
-
171
- @property
172
- def cache(self) -> MutableMapping[str, Any]:
173
- return self.__dict__
174
-
175
- @property
176
- @override
177
- def is_locked(self) -> bool:
178
- return self.__key in self.cache
179
-
180
- @override
181
- def unlock(self) -> None:
182
- self.cache.clear()
183
-
184
- @override
185
- def get_instance(self) -> T:
186
- with suppress(KeyError):
187
- return self.cache[self.__key]
188
-
189
- with synchronized():
190
- instance = self.factory()
191
- self.cache[self.__key] = instance
192
-
193
- return instance
194
-
195
-
196
- @dataclass(repr=False, frozen=True, slots=True)
197
- class ShouldBeInjectable[T](Injectable[T]):
198
- cls: type[T]
199
-
200
- @override
201
- def get_instance(self) -> NoReturn:
202
- raise InjectionError(f"`{self.cls}` should be an injectable.")
203
-
204
- @classmethod
205
- def from_callable(cls, callable: Callable[..., T]) -> Self:
206
- if not isclass(callable):
207
- raise TypeError(f"`{callable}` should be a class.")
208
-
209
- return cls(callable)
210
-
211
-
212
146
  """
213
147
  Broker
214
148
  """
@@ -235,6 +169,10 @@ class Broker(Protocol):
235
169
  def unlock(self) -> Self:
236
170
  raise NotImplementedError
237
171
 
172
+ @abstractmethod
173
+ async def all_ready(self) -> None:
174
+ raise NotImplementedError
175
+
238
176
 
239
177
  """
240
178
  Locator
@@ -257,7 +195,7 @@ class Mode(StrEnum):
257
195
 
258
196
  type ModeStr = Literal["fallback", "normal", "override"]
259
197
 
260
- type InjectableFactory[T] = Callable[[Callable[..., T]], Injectable[T]]
198
+ type InjectableFactory[T] = Callable[[Caller[..., T]], Injectable[T]]
261
199
 
262
200
 
263
201
  class Record[T](NamedTuple):
@@ -267,7 +205,7 @@ class Record[T](NamedTuple):
267
205
 
268
206
  @dataclass(repr=False, eq=False, kw_only=True, slots=True)
269
207
  class Updater[T]:
270
- factory: Callable[..., T]
208
+ factory: Caller[..., T]
271
209
  classes: Iterable[InputType[T]]
272
210
  injectable_factory: InjectableFactory[T]
273
211
  mode: Mode
@@ -354,6 +292,11 @@ class Locator(Broker):
354
292
 
355
293
  return self
356
294
 
295
+ @override
296
+ async def all_ready(self) -> None:
297
+ for injectable in self.__injectables:
298
+ await injectable.aget_instance()
299
+
357
300
  def add_listener(self, listener: EventListener) -> Self:
358
301
  self.__channel.add_listener(listener)
359
302
  return self
@@ -466,18 +409,20 @@ class Module(Broker, EventListener):
466
409
  yield from tuple(self.__modules)
467
410
  yield self.__locator
468
411
 
469
- def injectable[**P, T]( # type: ignore[no-untyped-def]
412
+ def injectable[**P, T](
470
413
  self,
471
- wrapped: Callable[P, T] | None = None,
414
+ wrapped: Callable[P, T] | Callable[P, Awaitable[T]] | None = None,
472
415
  /,
473
416
  *,
474
417
  cls: InjectableFactory[T] = SimpleInjectable,
475
418
  inject: bool = True,
476
419
  on: TypeInfo[T] = (),
477
420
  mode: Mode | ModeStr = Mode.get_default(),
478
- ):
479
- def decorator(wp): # type: ignore[no-untyped-def]
480
- factory = self.make_injected_function(wp) if inject else wp
421
+ ) -> Any:
422
+ def decorator(
423
+ wp: Callable[P, T] | Callable[P, Awaitable[T]],
424
+ ) -> Callable[P, T] | Callable[P, Awaitable[T]]:
425
+ factory = _get_caller(self.make_injected_function(wp) if inject else wp)
481
426
  classes = get_return_types(wp, on)
482
427
  updater = Updater(
483
428
  factory=factory,
@@ -492,12 +437,12 @@ class Module(Broker, EventListener):
492
437
 
493
438
  singleton = partialmethod(injectable, cls=SingletonInjectable)
494
439
 
495
- def should_be_injectable[T](self, wrapped: type[T] | None = None, /): # type: ignore[no-untyped-def]
496
- def decorator(wp): # type: ignore[no-untyped-def]
440
+ def should_be_injectable[T](self, wrapped: type[T] | None = None, /) -> Any:
441
+ def decorator(wp: type[T]) -> type[T]:
497
442
  updater = Updater(
498
- factory=wp,
443
+ factory=SyncCaller(wp),
499
444
  classes=(wp,),
500
- injectable_factory=ShouldBeInjectable.from_callable,
445
+ injectable_factory=lambda _: ShouldBeInjectable(wp),
501
446
  mode=Mode.FALLBACK,
502
447
  )
503
448
  self.update(updater)
@@ -505,15 +450,15 @@ class Module(Broker, EventListener):
505
450
 
506
451
  return decorator(wrapped) if wrapped else decorator
507
452
 
508
- def constant[T]( # type: ignore[no-untyped-def]
453
+ def constant[T](
509
454
  self,
510
455
  wrapped: type[T] | None = None,
511
456
  /,
512
457
  *,
513
458
  on: TypeInfo[T] = (),
514
459
  mode: Mode | ModeStr = Mode.get_default(),
515
- ):
516
- def decorator(wp): # type: ignore[no-untyped-def]
460
+ ) -> Any:
461
+ def decorator(wp: type[T]) -> type[T]:
517
462
  lazy_instance = Lazy(wp)
518
463
  self.injectable(
519
464
  lambda: ~lazy_instance,
@@ -545,8 +490,8 @@ class Module(Broker, EventListener):
545
490
  )
546
491
  return self
547
492
 
548
- def inject[**P, T](self, wrapped: Callable[P, T] | None = None, /): # type: ignore[no-untyped-def]
549
- def decorator(wp): # type: ignore[no-untyped-def]
493
+ def inject[**P, T](self, wrapped: Callable[P, T] | None = None, /) -> Any:
494
+ def decorator(wp: Callable[P, T]) -> Callable[P, T]:
550
495
  if isclass(wp):
551
496
  wp.__init__ = self.inject(wp.__init__)
552
497
  return wp
@@ -555,47 +500,135 @@ class Module(Broker, EventListener):
555
500
 
556
501
  return decorator(wrapped) if wrapped else decorator
557
502
 
503
+ @overload
558
504
  def make_injected_function[**P, T](
559
505
  self,
560
506
  wrapped: Callable[P, T],
561
507
  /,
562
- ) -> InjectedFunction[P, T]:
563
- injected = Injected(wrapped)
508
+ ) -> SyncInjectedFunction[P, T]: ...
509
+
510
+ @overload
511
+ def make_injected_function[**P, T](
512
+ self,
513
+ wrapped: Callable[P, Awaitable[T]],
514
+ /,
515
+ ) -> AsyncInjectedFunction[P, T]: ...
516
+
517
+ def make_injected_function(self, wrapped, /): # type: ignore[no-untyped-def]
518
+ metadata = InjectMetadata(wrapped)
564
519
 
565
- @injected.on_setup
520
+ @metadata.on_setup
566
521
  def listen() -> None:
567
- injected.update(self)
568
- self.add_listener(injected)
522
+ metadata.update(self)
523
+ self.add_listener(metadata)
524
+
525
+ if iscoroutinefunction(wrapped):
526
+ return AsyncInjectedFunction(metadata)
569
527
 
570
- return InjectedFunction(injected)
528
+ return SyncInjectedFunction(metadata)
529
+
530
+ async def afind_instance[T](self, cls: InputType[T]) -> T:
531
+ injectable = self[cls]
532
+ return await injectable.aget_instance()
571
533
 
572
534
  def find_instance[T](self, cls: InputType[T]) -> T:
573
535
  injectable = self[cls]
574
536
  return injectable.get_instance()
575
537
 
538
+ @overload
539
+ async def aget_instance[T, Default](
540
+ self,
541
+ cls: InputType[T],
542
+ default: Default,
543
+ ) -> T | Default: ...
544
+
545
+ @overload
546
+ async def aget_instance[T](
547
+ self,
548
+ cls: InputType[T],
549
+ default: None = ...,
550
+ ) -> T | None: ...
551
+
552
+ async def aget_instance(self, cls, default=None): # type: ignore[no-untyped-def]
553
+ try:
554
+ return await self.afind_instance(cls)
555
+ except KeyError:
556
+ return default
557
+
558
+ @overload
576
559
  def get_instance[T, Default](
577
560
  self,
578
561
  cls: InputType[T],
579
- default: Default | None = None,
580
- ) -> T | Default | None:
562
+ default: Default,
563
+ ) -> T | Default: ...
564
+
565
+ @overload
566
+ def get_instance[T](
567
+ self,
568
+ cls: InputType[T],
569
+ default: None = ...,
570
+ ) -> T | None: ...
571
+
572
+ def get_instance(self, cls, default=None): # type: ignore[no-untyped-def]
581
573
  try:
582
574
  return self.find_instance(cls)
583
575
  except KeyError:
584
576
  return default
585
577
 
578
+ @overload
579
+ def aget_lazy_instance[T, Default](
580
+ self,
581
+ cls: InputType[T],
582
+ default: Default,
583
+ *,
584
+ cache: bool = ...,
585
+ ) -> Awaitable[T | Default]: ...
586
+
587
+ @overload
588
+ def aget_lazy_instance[T](
589
+ self,
590
+ cls: InputType[T],
591
+ default: None = ...,
592
+ *,
593
+ cache: bool = ...,
594
+ ) -> Awaitable[T | None]: ...
595
+
596
+ def aget_lazy_instance(self, cls, default=None, *, cache=False): # type: ignore[no-untyped-def]
597
+ if cache:
598
+ coroutine = self.aget_instance(cls, default)
599
+ return asyncio.ensure_future(coroutine)
600
+
601
+ function = self.make_injected_function(lambda instance=default: instance)
602
+ metadata = function.__inject_metadata__
603
+ metadata.set_owner(cls)
604
+ return SimpleAwaitable(metadata.acall)
605
+
606
+ @overload
586
607
  def get_lazy_instance[T, Default](
587
608
  self,
588
609
  cls: InputType[T],
589
- default: Default | None = None,
610
+ default: Default,
590
611
  *,
591
- cache: bool = False,
592
- ) -> Invertible[T | Default | None]:
612
+ cache: bool = ...,
613
+ ) -> Invertible[T | Default]: ...
614
+
615
+ @overload
616
+ def get_lazy_instance[T](
617
+ self,
618
+ cls: InputType[T],
619
+ default: None = ...,
620
+ *,
621
+ cache: bool = ...,
622
+ ) -> Invertible[T | None]: ...
623
+
624
+ def get_lazy_instance(self, cls, default=None, *, cache=False): # type: ignore[no-untyped-def]
593
625
  if cache:
594
626
  return Lazy(lambda: self.get_instance(cls, default))
595
627
 
596
- function = self.inject(lambda instance=default: instance)
597
- function.__injected__.set_owner(cls)
598
- return SimpleInvertible(function)
628
+ function = self.make_injected_function(lambda instance=default: instance)
629
+ metadata = function.__inject_metadata__
630
+ metadata.set_owner(cls)
631
+ return SimpleInvertible(metadata.call)
599
632
 
600
633
  def update[T](self, updater: Updater[T]) -> Self:
601
634
  self.__locator.update(updater)
@@ -670,6 +703,11 @@ class Module(Broker, EventListener):
670
703
 
671
704
  return self
672
705
 
706
+ @override
707
+ async def all_ready(self) -> None:
708
+ for broker in self.__brokers:
709
+ await broker.all_ready()
710
+
673
711
  def add_logger(self, logger: Logger) -> Self:
674
712
  self.__loggers.append(logger)
675
713
  return self
@@ -730,6 +768,13 @@ class Module(Broker, EventListener):
730
768
  return cls.from_name("__default__")
731
769
 
732
770
 
771
+ def mod(name: str | None = None, /) -> Module:
772
+ if name is None:
773
+ return Module.default()
774
+
775
+ return Module.from_name(name)
776
+
777
+
733
778
  """
734
779
  InjectedFunction
735
780
  """
@@ -744,7 +789,13 @@ class Dependencies:
744
789
 
745
790
  def __iter__(self) -> Iterator[tuple[str, Any]]:
746
791
  for name, injectable in self.mapping.items():
747
- yield name, injectable.get_instance()
792
+ instance = injectable.get_instance()
793
+ yield name, instance
794
+
795
+ async def __aiter__(self) -> AsyncIterator[tuple[str, Any]]:
796
+ for name, injectable in self.mapping.items():
797
+ instance = await injectable.aget_instance()
798
+ yield name, instance
748
799
 
749
800
  @property
750
801
  def are_resolved(self) -> bool:
@@ -753,9 +804,11 @@ class Dependencies:
753
804
 
754
805
  return bool(self)
755
806
 
756
- @property
757
- def arguments(self) -> OrderedDict[str, Any]:
758
- return OrderedDict(self)
807
+ async def aget_arguments(self) -> dict[str, Any]:
808
+ return {key: value async for key, value in self}
809
+
810
+ def get_arguments(self) -> dict[str, Any]:
811
+ return dict(self)
759
812
 
760
813
  @classmethod
761
814
  def from_mapping(cls, mapping: Mapping[str, Injectable[Any]]) -> Self:
@@ -810,7 +863,7 @@ class Arguments(NamedTuple):
810
863
  kwargs: Mapping[str, Any]
811
864
 
812
865
 
813
- class Injected[**P, T](EventListener):
866
+ class InjectMetadata[**P, T](Caller[P, T], EventListener):
814
867
  __slots__ = (
815
868
  "__dependencies",
816
869
  "__owner",
@@ -831,11 +884,6 @@ class Injected[**P, T](EventListener):
831
884
  self.__setup_queue = Queue()
832
885
  self.__wrapped = wrapped
833
886
 
834
- def __call__(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
835
- self.__setup()
836
- arguments = self.bind(args, kwargs)
837
- return self.wrapped(*arguments.args, **arguments.kwargs)
838
-
839
887
  @property
840
888
  def signature(self) -> Signature:
841
889
  with suppress(AttributeError):
@@ -851,22 +899,33 @@ class Injected[**P, T](EventListener):
851
899
  def wrapped(self) -> Callable[P, T]:
852
900
  return self.__wrapped
853
901
 
902
+ async def abind(
903
+ self,
904
+ args: Iterable[Any] = (),
905
+ kwargs: Mapping[str, Any] | None = None,
906
+ ) -> Arguments:
907
+ additional_arguments = await self.__dependencies.aget_arguments()
908
+ return self.__bind(args, kwargs, additional_arguments)
909
+
854
910
  def bind(
855
911
  self,
856
912
  args: Iterable[Any] = (),
857
913
  kwargs: Mapping[str, Any] | None = None,
858
914
  ) -> Arguments:
859
- if kwargs is None:
860
- kwargs = {}
915
+ additional_arguments = self.__dependencies.get_arguments()
916
+ return self.__bind(args, kwargs, additional_arguments)
861
917
 
862
- if not self.__dependencies:
863
- return Arguments(args, kwargs)
918
+ @override
919
+ async def acall(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
920
+ self.__setup()
921
+ arguments = await self.abind(args, kwargs)
922
+ return self.wrapped(*arguments.args, **arguments.kwargs)
864
923
 
865
- bound = self.signature.bind_partial(*args, **kwargs)
866
- bound.arguments = (
867
- bound.arguments | self.__dependencies.arguments | bound.arguments
868
- )
869
- return Arguments(bound.args, bound.kwargs)
924
+ @override
925
+ def call(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
926
+ self.__setup()
927
+ arguments = self.bind(args, kwargs)
928
+ return self.wrapped(*arguments.args, **arguments.kwargs)
870
929
 
871
930
  def set_owner(self, owner: type) -> Self:
872
931
  if self.__dependencies.are_resolved:
@@ -885,8 +944,8 @@ class Injected[**P, T](EventListener):
885
944
  self.__dependencies = Dependencies.resolve(self.signature, module, self.__owner)
886
945
  return self
887
946
 
888
- def on_setup[**_P, _T](self, wrapped: Callable[_P, _T] | None = None, /): # type: ignore[no-untyped-def]
889
- def decorator(wp): # type: ignore[no-untyped-def]
947
+ def on_setup[**_P, _T](self, wrapped: Callable[_P, _T] | None = None, /) -> Any:
948
+ def decorator(wp: Callable[_P, _T]) -> Callable[_P, _T]:
890
949
  queue = self.__setup_queue
891
950
 
892
951
  if queue is None:
@@ -908,6 +967,22 @@ class Injected[**P, T](EventListener):
908
967
  yield
909
968
  self.update(event.module)
910
969
 
970
+ def __bind(
971
+ self,
972
+ args: Iterable[Any],
973
+ kwargs: Mapping[str, Any] | None,
974
+ additional_arguments: dict[str, Any] | None,
975
+ ) -> Arguments:
976
+ if kwargs is None:
977
+ kwargs = {}
978
+
979
+ if not additional_arguments:
980
+ return Arguments(args, kwargs)
981
+
982
+ bound = self.signature.bind_partial(*args, **kwargs)
983
+ bound.arguments = bound.arguments | additional_arguments | bound.arguments
984
+ return Arguments(bound.args, bound.kwargs)
985
+
911
986
  def __close_setup_queue(self) -> None:
912
987
  self.__setup_queue = None
913
988
 
@@ -930,25 +1005,26 @@ class Injected[**P, T](EventListener):
930
1005
  self.__close_setup_queue()
931
1006
 
932
1007
 
933
- class InjectedFunction[**P, T]:
934
- __slots__ = ("__dict__", "__injected__")
1008
+ class InjectedFunction[**P, T](ABC):
1009
+ __slots__ = ("__dict__", "__inject_metadata__")
935
1010
 
936
- __injected__: Injected[P, T]
1011
+ __inject_metadata__: InjectMetadata[P, T]
937
1012
 
938
- def __init__(self, injected: Injected[P, T]) -> None:
939
- update_wrapper(self, injected.wrapped)
940
- self.__injected__ = injected
1013
+ def __init__(self, metadata: InjectMetadata[P, T]) -> None:
1014
+ update_wrapper(self, metadata.wrapped)
1015
+ self.__inject_metadata__ = metadata
941
1016
 
942
1017
  @override
943
1018
  def __repr__(self) -> str: # pragma: no cover
944
- return repr(self.__injected__.wrapped)
1019
+ return repr(self.__inject_metadata__.wrapped)
945
1020
 
946
1021
  @override
947
1022
  def __str__(self) -> str: # pragma: no cover
948
- return str(self.__injected__.wrapped)
1023
+ return str(self.__inject_metadata__.wrapped)
949
1024
 
1025
+ @abstractmethod
950
1026
  def __call__(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
951
- return self.__injected__(*args, **kwargs)
1027
+ raise NotImplementedError
952
1028
 
953
1029
  def __get__(
954
1030
  self,
@@ -961,4 +1037,43 @@ class InjectedFunction[**P, T]:
961
1037
  return MethodType(self, instance)
962
1038
 
963
1039
  def __set_name__(self, owner: type, name: str) -> None:
964
- self.__injected__.set_owner(owner)
1040
+ self.__inject_metadata__.set_owner(owner)
1041
+
1042
+
1043
+ class AsyncInjectedFunction[**P, T](InjectedFunction[P, Awaitable[T]]):
1044
+ __slots__ = ()
1045
+
1046
+ @override
1047
+ async def __call__(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
1048
+ return await (await self.__inject_metadata__.acall(*args, **kwargs))
1049
+
1050
+
1051
+ class SyncInjectedFunction[**P, T](InjectedFunction[P, T]):
1052
+ __slots__ = ()
1053
+
1054
+ @override
1055
+ def __call__(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
1056
+ return self.__inject_metadata__.call(*args, **kwargs)
1057
+
1058
+
1059
+ def _is_coroutine_function[**P, T](
1060
+ function: Callable[P, T] | Callable[P, Awaitable[T]],
1061
+ ) -> TypeGuard[Callable[P, Awaitable[T]]]:
1062
+ if iscoroutinefunction(function):
1063
+ return True
1064
+
1065
+ elif isclass(function):
1066
+ return False
1067
+
1068
+ call = getattr(function, "__call__", None)
1069
+ return iscoroutinefunction(call)
1070
+
1071
+
1072
+ def _get_caller[**P, T](function: Callable[P, T]) -> Caller[P, T]:
1073
+ if _is_coroutine_function(function):
1074
+ return AsyncCaller(function)
1075
+
1076
+ elif isinstance(function, InjectedFunction):
1077
+ return function.__inject_metadata__
1078
+
1079
+ return SyncCaller(function)
@@ -1,6 +1,6 @@
1
- from collections.abc import Callable
1
+ from collections.abc import Awaitable
2
2
  from types import GenericAlias
3
- from typing import Any, ClassVar, TypeAliasType
3
+ from typing import Any, TypeAliasType
4
4
 
5
5
  from injection import Module, mod
6
6
  from injection.exceptions import InjectionError
@@ -28,20 +28,29 @@ def Inject[T]( # noqa: N802
28
28
 
29
29
 
30
30
  class InjectionDependency[T]:
31
- __slots__ = ("__call__", "__class", "__module")
31
+ __slots__ = ("__class", "__lazy_instance", "__module")
32
32
 
33
- __call__: Callable[[], T]
34
33
  __class: type[T] | TypeAliasType | GenericAlias
34
+ __lazy_instance: Awaitable[T]
35
35
  __module: Module
36
36
 
37
- __sentinel: ClassVar[object] = object()
38
-
39
- def __init__(self, cls: type[T] | TypeAliasType | GenericAlias, module: Module):
40
- lazy_instance = module.get_lazy_instance(cls, default=self.__sentinel)
41
- self.__call__ = lambda: self.__ensure(~lazy_instance)
37
+ def __init__(
38
+ self,
39
+ cls: type[T] | TypeAliasType | GenericAlias,
40
+ module: Module,
41
+ ) -> None:
42
42
  self.__class = cls
43
+ self.__lazy_instance = module.aget_lazy_instance(cls, default=NotImplemented)
43
44
  self.__module = module
44
45
 
46
+ async def __call__(self) -> T:
47
+ instance = await self.__lazy_instance
48
+
49
+ if instance is NotImplemented:
50
+ raise InjectionError(f"`{self.__class}` is an unknown dependency.")
51
+
52
+ return instance
53
+
45
54
  def __eq__(self, other: Any) -> bool:
46
55
  if isinstance(other, type(self)):
47
56
  return self.__key == other.__key
@@ -54,9 +63,3 @@ class InjectionDependency[T]:
54
63
  @property
55
64
  def __key(self) -> tuple[type[T] | TypeAliasType | GenericAlias, Module]:
56
65
  return self.__class, self.__module
57
-
58
- def __ensure(self, instance: T | Any) -> T:
59
- if instance is self.__sentinel:
60
- raise InjectionError(f"`{self.__class}` is an unknown dependency.")
61
-
62
- return instance
@@ -21,5 +21,5 @@ test_injectable = mod(_TEST_PROFILE_NAME).injectable
21
21
  test_singleton = mod(_TEST_PROFILE_NAME).singleton
22
22
 
23
23
 
24
- def load_test_profile(*other_profile_names: str) -> ContextManager[None]:
25
- return load_profile(_TEST_PROFILE_NAME, *other_profile_names)
24
+ def load_test_profile(*names: str) -> ContextManager[None]:
25
+ return load_profile(_TEST_PROFILE_NAME, *names)