python-injection 0.10.12.post0__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
injection/_core/module.py CHANGED
@@ -1,21 +1,23 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import asyncio
3
4
  import inspect
4
5
  from abc import ABC, abstractmethod
5
6
  from collections import OrderedDict
6
7
  from collections.abc import (
8
+ AsyncIterator,
9
+ Awaitable,
7
10
  Callable,
8
11
  Collection,
9
12
  Iterable,
10
13
  Iterator,
11
14
  Mapping,
12
- MutableMapping,
13
15
  )
14
16
  from contextlib import contextmanager, suppress
15
17
  from dataclasses import dataclass, field
16
18
  from enum import StrEnum
17
19
  from functools import partialmethod, singledispatchmethod, update_wrapper
18
- from inspect import Signature, isclass
20
+ from inspect import Signature, isclass, iscoroutinefunction
19
21
  from logging import Logger, getLogger
20
22
  from queue import Empty, Queue
21
23
  from types import MethodType
@@ -25,22 +27,34 @@ from typing import (
25
27
  ContextManager,
26
28
  Literal,
27
29
  NamedTuple,
28
- NoReturn,
29
30
  Protocol,
30
31
  Self,
32
+ TypeGuard,
33
+ overload,
31
34
  override,
32
35
  runtime_checkable,
33
36
  )
34
37
  from uuid import uuid4
35
38
 
39
+ from injection._core.common.asynchronous import (
40
+ AsyncCaller,
41
+ Caller,
42
+ SimpleAwaitable,
43
+ SyncCaller,
44
+ )
36
45
  from injection._core.common.event import Event, EventChannel, EventListener
37
46
  from injection._core.common.invertible import Invertible, SimpleInvertible
38
47
  from injection._core.common.lazy import Lazy, LazyMapping
39
48
  from injection._core.common.threading import synchronized
40
49
  from injection._core.common.type import InputType, TypeInfo, get_return_types
41
50
  from injection._core.hook import Hook, apply_hooks
51
+ from injection._core.injectables import (
52
+ Injectable,
53
+ ShouldBeInjectable,
54
+ SimpleInjectable,
55
+ SingletonInjectable,
56
+ )
42
57
  from injection.exceptions import (
43
- InjectionError,
44
58
  ModuleError,
45
59
  ModuleLockError,
46
60
  ModuleNotUsedError,
@@ -129,86 +143,6 @@ class ModulePriorityUpdated(ModuleEvent):
129
143
  )
130
144
 
131
145
 
132
- """
133
- Injectables
134
- """
135
-
136
-
137
- @runtime_checkable
138
- class Injectable[T](Protocol):
139
- __slots__ = ()
140
-
141
- @property
142
- def is_locked(self) -> bool:
143
- return False
144
-
145
- def unlock(self) -> None:
146
- return
147
-
148
- @abstractmethod
149
- def get_instance(self) -> T:
150
- raise NotImplementedError
151
-
152
-
153
- @dataclass(repr=False, frozen=True, slots=True)
154
- class BaseInjectable[T](Injectable[T], ABC):
155
- factory: Callable[..., T]
156
-
157
-
158
- class SimpleInjectable[T](BaseInjectable[T]):
159
- __slots__ = ()
160
-
161
- @override
162
- def get_instance(self) -> T:
163
- return self.factory()
164
-
165
-
166
- class SingletonInjectable[T](BaseInjectable[T]):
167
- __slots__ = ("__dict__",)
168
-
169
- __key: ClassVar[str] = "$instance"
170
-
171
- @property
172
- def cache(self) -> MutableMapping[str, Any]:
173
- return self.__dict__
174
-
175
- @property
176
- @override
177
- def is_locked(self) -> bool:
178
- return self.__key in self.cache
179
-
180
- @override
181
- def unlock(self) -> None:
182
- self.cache.clear()
183
-
184
- @override
185
- def get_instance(self) -> T:
186
- with suppress(KeyError):
187
- return self.cache[self.__key]
188
-
189
- with synchronized():
190
- instance = self.factory()
191
- self.cache[self.__key] = instance
192
-
193
- return instance
194
-
195
-
196
- @dataclass(repr=False, frozen=True, slots=True)
197
- class ShouldBeInjectable[T](Injectable[T]):
198
- cls: type[T]
199
-
200
- @override
201
- def get_instance(self) -> NoReturn:
202
- raise InjectionError(f"`{self.cls}` should be an injectable.")
203
-
204
- @classmethod
205
- def from_callable(cls, callable: Callable[..., T]) -> Self:
206
- if not isclass(callable):
207
- raise TypeError(f"`{callable}` should be a class.")
208
-
209
- return cls(callable)
210
-
211
-
212
146
  """
213
147
  Broker
214
148
  """
@@ -235,6 +169,10 @@ class Broker(Protocol):
235
169
  def unlock(self) -> Self:
236
170
  raise NotImplementedError
237
171
 
172
+ @abstractmethod
173
+ async def all_ready(self) -> None:
174
+ raise NotImplementedError
175
+
238
176
 
239
177
  """
240
178
  Locator
@@ -257,7 +195,7 @@ class Mode(StrEnum):
257
195
 
258
196
  type ModeStr = Literal["fallback", "normal", "override"]
259
197
 
260
- type InjectableFactory[T] = Callable[[Callable[..., T]], Injectable[T]]
198
+ type InjectableFactory[T] = Callable[[Caller[..., T]], Injectable[T]]
261
199
 
262
200
 
263
201
  class Record[T](NamedTuple):
@@ -267,7 +205,7 @@ class Record[T](NamedTuple):
267
205
 
268
206
  @dataclass(repr=False, eq=False, kw_only=True, slots=True)
269
207
  class Updater[T]:
270
- factory: Callable[..., T]
208
+ factory: Caller[..., T]
271
209
  classes: Iterable[InputType[T]]
272
210
  injectable_factory: InjectableFactory[T]
273
211
  mode: Mode
@@ -354,6 +292,11 @@ class Locator(Broker):
354
292
 
355
293
  return self
356
294
 
295
+ @override
296
+ async def all_ready(self) -> None:
297
+ for injectable in self.__injectables:
298
+ await injectable.aget_instance()
299
+
357
300
  def add_listener(self, listener: EventListener) -> Self:
358
301
  self.__channel.add_listener(listener)
359
302
  return self
@@ -466,18 +409,20 @@ class Module(Broker, EventListener):
466
409
  yield from tuple(self.__modules)
467
410
  yield self.__locator
468
411
 
469
- def injectable[**P, T]( # type: ignore[no-untyped-def]
412
+ def injectable[**P, T](
470
413
  self,
471
- wrapped: Callable[P, T] | None = None,
414
+ wrapped: Callable[P, T] | Callable[P, Awaitable[T]] | None = None,
472
415
  /,
473
416
  *,
474
417
  cls: InjectableFactory[T] = SimpleInjectable,
475
418
  inject: bool = True,
476
419
  on: TypeInfo[T] = (),
477
420
  mode: Mode | ModeStr = Mode.get_default(),
478
- ):
479
- def decorator(wp): # type: ignore[no-untyped-def]
480
- factory = self.make_injected_function(wp) if inject else wp
421
+ ) -> Any:
422
+ def decorator(
423
+ wp: Callable[P, T] | Callable[P, Awaitable[T]],
424
+ ) -> Callable[P, T] | Callable[P, Awaitable[T]]:
425
+ factory = _get_caller(self.make_injected_function(wp) if inject else wp)
481
426
  classes = get_return_types(wp, on)
482
427
  updater = Updater(
483
428
  factory=factory,
@@ -492,12 +437,12 @@ class Module(Broker, EventListener):
492
437
 
493
438
  singleton = partialmethod(injectable, cls=SingletonInjectable)
494
439
 
495
- def should_be_injectable[T](self, wrapped: type[T] | None = None, /): # type: ignore[no-untyped-def]
496
- def decorator(wp): # type: ignore[no-untyped-def]
440
+ def should_be_injectable[T](self, wrapped: type[T] | None = None, /) -> Any:
441
+ def decorator(wp: type[T]) -> type[T]:
497
442
  updater = Updater(
498
- factory=wp,
443
+ factory=SyncCaller(wp),
499
444
  classes=(wp,),
500
- injectable_factory=ShouldBeInjectable.from_callable,
445
+ injectable_factory=lambda _: ShouldBeInjectable(wp),
501
446
  mode=Mode.FALLBACK,
502
447
  )
503
448
  self.update(updater)
@@ -505,15 +450,15 @@ class Module(Broker, EventListener):
505
450
 
506
451
  return decorator(wrapped) if wrapped else decorator
507
452
 
508
- def constant[T]( # type: ignore[no-untyped-def]
453
+ def constant[T](
509
454
  self,
510
455
  wrapped: type[T] | None = None,
511
456
  /,
512
457
  *,
513
458
  on: TypeInfo[T] = (),
514
459
  mode: Mode | ModeStr = Mode.get_default(),
515
- ):
516
- def decorator(wp): # type: ignore[no-untyped-def]
460
+ ) -> Any:
461
+ def decorator(wp: type[T]) -> type[T]:
517
462
  lazy_instance = Lazy(wp)
518
463
  self.injectable(
519
464
  lambda: ~lazy_instance,
@@ -545,8 +490,8 @@ class Module(Broker, EventListener):
545
490
  )
546
491
  return self
547
492
 
548
- def inject[**P, T](self, wrapped: Callable[P, T] | None = None, /): # type: ignore[no-untyped-def]
549
- def decorator(wp): # type: ignore[no-untyped-def]
493
+ def inject[**P, T](self, wrapped: Callable[P, T] | None = None, /) -> Any:
494
+ def decorator(wp: Callable[P, T]) -> Callable[P, T]:
550
495
  if isclass(wp):
551
496
  wp.__init__ = self.inject(wp.__init__)
552
497
  return wp
@@ -555,47 +500,135 @@ class Module(Broker, EventListener):
555
500
 
556
501
  return decorator(wrapped) if wrapped else decorator
557
502
 
503
+ @overload
558
504
  def make_injected_function[**P, T](
559
505
  self,
560
506
  wrapped: Callable[P, T],
561
507
  /,
562
- ) -> InjectedFunction[P, T]:
563
- injected = Injected(wrapped)
508
+ ) -> SyncInjectedFunction[P, T]: ...
509
+
510
+ @overload
511
+ def make_injected_function[**P, T](
512
+ self,
513
+ wrapped: Callable[P, Awaitable[T]],
514
+ /,
515
+ ) -> AsyncInjectedFunction[P, T]: ...
516
+
517
+ def make_injected_function(self, wrapped, /): # type: ignore[no-untyped-def]
518
+ metadata = InjectMetadata(wrapped)
564
519
 
565
- @injected.on_setup
520
+ @metadata.on_setup
566
521
  def listen() -> None:
567
- injected.update(self)
568
- self.add_listener(injected)
522
+ metadata.update(self)
523
+ self.add_listener(metadata)
524
+
525
+ if iscoroutinefunction(wrapped):
526
+ return AsyncInjectedFunction(metadata)
569
527
 
570
- return InjectedFunction(injected)
528
+ return SyncInjectedFunction(metadata)
529
+
530
+ async def afind_instance[T](self, cls: InputType[T]) -> T:
531
+ injectable = self[cls]
532
+ return await injectable.aget_instance()
571
533
 
572
534
  def find_instance[T](self, cls: InputType[T]) -> T:
573
535
  injectable = self[cls]
574
536
  return injectable.get_instance()
575
537
 
538
+ @overload
539
+ async def aget_instance[T, Default](
540
+ self,
541
+ cls: InputType[T],
542
+ default: Default,
543
+ ) -> T | Default: ...
544
+
545
+ @overload
546
+ async def aget_instance[T](
547
+ self,
548
+ cls: InputType[T],
549
+ default: None = ...,
550
+ ) -> T | None: ...
551
+
552
+ async def aget_instance(self, cls, default=None): # type: ignore[no-untyped-def]
553
+ try:
554
+ return await self.afind_instance(cls)
555
+ except KeyError:
556
+ return default
557
+
558
+ @overload
576
559
  def get_instance[T, Default](
577
560
  self,
578
561
  cls: InputType[T],
579
- default: Default | None = None,
580
- ) -> T | Default | None:
562
+ default: Default,
563
+ ) -> T | Default: ...
564
+
565
+ @overload
566
+ def get_instance[T](
567
+ self,
568
+ cls: InputType[T],
569
+ default: None = ...,
570
+ ) -> T | None: ...
571
+
572
+ def get_instance(self, cls, default=None): # type: ignore[no-untyped-def]
581
573
  try:
582
574
  return self.find_instance(cls)
583
575
  except KeyError:
584
576
  return default
585
577
 
578
+ @overload
579
+ def aget_lazy_instance[T, Default](
580
+ self,
581
+ cls: InputType[T],
582
+ default: Default,
583
+ *,
584
+ cache: bool = ...,
585
+ ) -> Awaitable[T | Default]: ...
586
+
587
+ @overload
588
+ def aget_lazy_instance[T](
589
+ self,
590
+ cls: InputType[T],
591
+ default: None = ...,
592
+ *,
593
+ cache: bool = ...,
594
+ ) -> Awaitable[T | None]: ...
595
+
596
+ def aget_lazy_instance(self, cls, default=None, *, cache=False): # type: ignore[no-untyped-def]
597
+ if cache:
598
+ coroutine = self.aget_instance(cls, default)
599
+ return asyncio.ensure_future(coroutine)
600
+
601
+ function = self.make_injected_function(lambda instance=default: instance)
602
+ metadata = function.__inject_metadata__
603
+ metadata.set_owner(cls)
604
+ return SimpleAwaitable(metadata.acall)
605
+
606
+ @overload
586
607
  def get_lazy_instance[T, Default](
587
608
  self,
588
609
  cls: InputType[T],
589
- default: Default | None = None,
610
+ default: Default,
590
611
  *,
591
- cache: bool = False,
592
- ) -> Invertible[T | Default | None]:
612
+ cache: bool = ...,
613
+ ) -> Invertible[T | Default]: ...
614
+
615
+ @overload
616
+ def get_lazy_instance[T](
617
+ self,
618
+ cls: InputType[T],
619
+ default: None = ...,
620
+ *,
621
+ cache: bool = ...,
622
+ ) -> Invertible[T | None]: ...
623
+
624
+ def get_lazy_instance(self, cls, default=None, *, cache=False): # type: ignore[no-untyped-def]
593
625
  if cache:
594
626
  return Lazy(lambda: self.get_instance(cls, default))
595
627
 
596
- function = self.inject(lambda instance=default: instance)
597
- function.__injected__.set_owner(cls)
598
- return SimpleInvertible(function)
628
+ function = self.make_injected_function(lambda instance=default: instance)
629
+ metadata = function.__inject_metadata__
630
+ metadata.set_owner(cls)
631
+ return SimpleInvertible(metadata.call)
599
632
 
600
633
  def update[T](self, updater: Updater[T]) -> Self:
601
634
  self.__locator.update(updater)
@@ -670,6 +703,11 @@ class Module(Broker, EventListener):
670
703
 
671
704
  return self
672
705
 
706
+ @override
707
+ async def all_ready(self) -> None:
708
+ for broker in self.__brokers:
709
+ await broker.all_ready()
710
+
673
711
  def add_logger(self, logger: Logger) -> Self:
674
712
  self.__loggers.append(logger)
675
713
  return self
@@ -730,6 +768,13 @@ class Module(Broker, EventListener):
730
768
  return cls.from_name("__default__")
731
769
 
732
770
 
771
+ def mod(name: str | None = None, /) -> Module:
772
+ if name is None:
773
+ return Module.default()
774
+
775
+ return Module.from_name(name)
776
+
777
+
733
778
  """
734
779
  InjectedFunction
735
780
  """
@@ -744,7 +789,13 @@ class Dependencies:
744
789
 
745
790
  def __iter__(self) -> Iterator[tuple[str, Any]]:
746
791
  for name, injectable in self.mapping.items():
747
- yield name, injectable.get_instance()
792
+ instance = injectable.get_instance()
793
+ yield name, instance
794
+
795
+ async def __aiter__(self) -> AsyncIterator[tuple[str, Any]]:
796
+ for name, injectable in self.mapping.items():
797
+ instance = await injectable.aget_instance()
798
+ yield name, instance
748
799
 
749
800
  @property
750
801
  def are_resolved(self) -> bool:
@@ -753,9 +804,11 @@ class Dependencies:
753
804
 
754
805
  return bool(self)
755
806
 
756
- @property
757
- def arguments(self) -> OrderedDict[str, Any]:
758
- return OrderedDict(self)
807
+ async def aget_arguments(self) -> dict[str, Any]:
808
+ return {key: value async for key, value in self}
809
+
810
+ def get_arguments(self) -> dict[str, Any]:
811
+ return dict(self)
759
812
 
760
813
  @classmethod
761
814
  def from_mapping(cls, mapping: Mapping[str, Injectable[Any]]) -> Self:
@@ -810,7 +863,7 @@ class Arguments(NamedTuple):
810
863
  kwargs: Mapping[str, Any]
811
864
 
812
865
 
813
- class Injected[**P, T](EventListener):
866
+ class InjectMetadata[**P, T](Caller[P, T], EventListener):
814
867
  __slots__ = (
815
868
  "__dependencies",
816
869
  "__owner",
@@ -831,11 +884,6 @@ class Injected[**P, T](EventListener):
831
884
  self.__setup_queue = Queue()
832
885
  self.__wrapped = wrapped
833
886
 
834
- def __call__(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
835
- self.__setup()
836
- arguments = self.bind(args, kwargs)
837
- return self.wrapped(*arguments.args, **arguments.kwargs)
838
-
839
887
  @property
840
888
  def signature(self) -> Signature:
841
889
  with suppress(AttributeError):
@@ -851,22 +899,33 @@ class Injected[**P, T](EventListener):
851
899
  def wrapped(self) -> Callable[P, T]:
852
900
  return self.__wrapped
853
901
 
902
+ async def abind(
903
+ self,
904
+ args: Iterable[Any] = (),
905
+ kwargs: Mapping[str, Any] | None = None,
906
+ ) -> Arguments:
907
+ additional_arguments = await self.__dependencies.aget_arguments()
908
+ return self.__bind(args, kwargs, additional_arguments)
909
+
854
910
  def bind(
855
911
  self,
856
912
  args: Iterable[Any] = (),
857
913
  kwargs: Mapping[str, Any] | None = None,
858
914
  ) -> Arguments:
859
- if kwargs is None:
860
- kwargs = {}
915
+ additional_arguments = self.__dependencies.get_arguments()
916
+ return self.__bind(args, kwargs, additional_arguments)
861
917
 
862
- if not self.__dependencies:
863
- return Arguments(args, kwargs)
918
+ @override
919
+ async def acall(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
920
+ self.__setup()
921
+ arguments = await self.abind(args, kwargs)
922
+ return self.wrapped(*arguments.args, **arguments.kwargs)
864
923
 
865
- bound = self.signature.bind_partial(*args, **kwargs)
866
- bound.arguments = (
867
- bound.arguments | self.__dependencies.arguments | bound.arguments
868
- )
869
- return Arguments(bound.args, bound.kwargs)
924
+ @override
925
+ def call(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
926
+ self.__setup()
927
+ arguments = self.bind(args, kwargs)
928
+ return self.wrapped(*arguments.args, **arguments.kwargs)
870
929
 
871
930
  def set_owner(self, owner: type) -> Self:
872
931
  if self.__dependencies.are_resolved:
@@ -885,8 +944,8 @@ class Injected[**P, T](EventListener):
885
944
  self.__dependencies = Dependencies.resolve(self.signature, module, self.__owner)
886
945
  return self
887
946
 
888
- def on_setup[**_P, _T](self, wrapped: Callable[_P, _T] | None = None, /): # type: ignore[no-untyped-def]
889
- def decorator(wp): # type: ignore[no-untyped-def]
947
+ def on_setup[**_P, _T](self, wrapped: Callable[_P, _T] | None = None, /) -> Any:
948
+ def decorator(wp: Callable[_P, _T]) -> Callable[_P, _T]:
890
949
  queue = self.__setup_queue
891
950
 
892
951
  if queue is None:
@@ -908,6 +967,22 @@ class Injected[**P, T](EventListener):
908
967
  yield
909
968
  self.update(event.module)
910
969
 
970
+ def __bind(
971
+ self,
972
+ args: Iterable[Any],
973
+ kwargs: Mapping[str, Any] | None,
974
+ additional_arguments: dict[str, Any] | None,
975
+ ) -> Arguments:
976
+ if kwargs is None:
977
+ kwargs = {}
978
+
979
+ if not additional_arguments:
980
+ return Arguments(args, kwargs)
981
+
982
+ bound = self.signature.bind_partial(*args, **kwargs)
983
+ bound.arguments = bound.arguments | additional_arguments | bound.arguments
984
+ return Arguments(bound.args, bound.kwargs)
985
+
911
986
  def __close_setup_queue(self) -> None:
912
987
  self.__setup_queue = None
913
988
 
@@ -930,25 +1005,26 @@ class Injected[**P, T](EventListener):
930
1005
  self.__close_setup_queue()
931
1006
 
932
1007
 
933
- class InjectedFunction[**P, T]:
934
- __slots__ = ("__dict__", "__injected__")
1008
+ class InjectedFunction[**P, T](ABC):
1009
+ __slots__ = ("__dict__", "__inject_metadata__")
935
1010
 
936
- __injected__: Injected[P, T]
1011
+ __inject_metadata__: InjectMetadata[P, T]
937
1012
 
938
- def __init__(self, injected: Injected[P, T]) -> None:
939
- update_wrapper(self, injected.wrapped)
940
- self.__injected__ = injected
1013
+ def __init__(self, metadata: InjectMetadata[P, T]) -> None:
1014
+ update_wrapper(self, metadata.wrapped)
1015
+ self.__inject_metadata__ = metadata
941
1016
 
942
1017
  @override
943
1018
  def __repr__(self) -> str: # pragma: no cover
944
- return repr(self.__injected__.wrapped)
1019
+ return repr(self.__inject_metadata__.wrapped)
945
1020
 
946
1021
  @override
947
1022
  def __str__(self) -> str: # pragma: no cover
948
- return str(self.__injected__.wrapped)
1023
+ return str(self.__inject_metadata__.wrapped)
949
1024
 
1025
+ @abstractmethod
950
1026
  def __call__(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
951
- return self.__injected__(*args, **kwargs)
1027
+ raise NotImplementedError
952
1028
 
953
1029
  def __get__(
954
1030
  self,
@@ -961,4 +1037,43 @@ class InjectedFunction[**P, T]:
961
1037
  return MethodType(self, instance)
962
1038
 
963
1039
  def __set_name__(self, owner: type, name: str) -> None:
964
- self.__injected__.set_owner(owner)
1040
+ self.__inject_metadata__.set_owner(owner)
1041
+
1042
+
1043
+ class AsyncInjectedFunction[**P, T](InjectedFunction[P, Awaitable[T]]):
1044
+ __slots__ = ()
1045
+
1046
+ @override
1047
+ async def __call__(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
1048
+ return await (await self.__inject_metadata__.acall(*args, **kwargs))
1049
+
1050
+
1051
+ class SyncInjectedFunction[**P, T](InjectedFunction[P, T]):
1052
+ __slots__ = ()
1053
+
1054
+ @override
1055
+ def __call__(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
1056
+ return self.__inject_metadata__.call(*args, **kwargs)
1057
+
1058
+
1059
+ def _is_coroutine_function[**P, T](
1060
+ function: Callable[P, T] | Callable[P, Awaitable[T]],
1061
+ ) -> TypeGuard[Callable[P, Awaitable[T]]]:
1062
+ if iscoroutinefunction(function):
1063
+ return True
1064
+
1065
+ elif isclass(function):
1066
+ return False
1067
+
1068
+ call = getattr(function, "__call__", None)
1069
+ return iscoroutinefunction(call)
1070
+
1071
+
1072
+ def _get_caller[**P, T](function: Callable[P, T]) -> Caller[P, T]:
1073
+ if _is_coroutine_function(function):
1074
+ return AsyncCaller(function)
1075
+
1076
+ elif isinstance(function, InjectedFunction):
1077
+ return function.__inject_metadata__
1078
+
1079
+ return SyncCaller(function)
@@ -1,6 +1,6 @@
1
- from collections.abc import Callable
1
+ from collections.abc import Awaitable
2
2
  from types import GenericAlias
3
- from typing import Any, ClassVar, TypeAliasType
3
+ from typing import Any, TypeAliasType
4
4
 
5
5
  from injection import Module, mod
6
6
  from injection.exceptions import InjectionError
@@ -28,20 +28,29 @@ def Inject[T]( # noqa: N802
28
28
 
29
29
 
30
30
  class InjectionDependency[T]:
31
- __slots__ = ("__call__", "__class", "__module")
31
+ __slots__ = ("__class", "__lazy_instance", "__module")
32
32
 
33
- __call__: Callable[[], T]
34
33
  __class: type[T] | TypeAliasType | GenericAlias
34
+ __lazy_instance: Awaitable[T]
35
35
  __module: Module
36
36
 
37
- __sentinel: ClassVar[object] = object()
38
-
39
- def __init__(self, cls: type[T] | TypeAliasType | GenericAlias, module: Module):
40
- lazy_instance = module.get_lazy_instance(cls, default=self.__sentinel)
41
- self.__call__ = lambda: self.__ensure(~lazy_instance)
37
+ def __init__(
38
+ self,
39
+ cls: type[T] | TypeAliasType | GenericAlias,
40
+ module: Module,
41
+ ) -> None:
42
42
  self.__class = cls
43
+ self.__lazy_instance = module.aget_lazy_instance(cls, default=NotImplemented)
43
44
  self.__module = module
44
45
 
46
+ async def __call__(self) -> T:
47
+ instance = await self.__lazy_instance
48
+
49
+ if instance is NotImplemented:
50
+ raise InjectionError(f"`{self.__class}` is an unknown dependency.")
51
+
52
+ return instance
53
+
45
54
  def __eq__(self, other: Any) -> bool:
46
55
  if isinstance(other, type(self)):
47
56
  return self.__key == other.__key
@@ -54,9 +63,3 @@ class InjectionDependency[T]:
54
63
  @property
55
64
  def __key(self) -> tuple[type[T] | TypeAliasType | GenericAlias, Module]:
56
65
  return self.__class, self.__module
57
-
58
- def __ensure(self, instance: T | Any) -> T:
59
- if instance is self.__sentinel:
60
- raise InjectionError(f"`{self.__class}` is an unknown dependency.")
61
-
62
- return instance
@@ -21,5 +21,5 @@ test_injectable = mod(_TEST_PROFILE_NAME).injectable
21
21
  test_singleton = mod(_TEST_PROFILE_NAME).singleton
22
22
 
23
23
 
24
- def load_test_profile(*other_profile_names: str) -> ContextManager[None]:
25
- return load_profile(_TEST_PROFILE_NAME, *other_profile_names)
24
+ def load_test_profile(*names: str) -> ContextManager[None]:
25
+ return load_profile(_TEST_PROFILE_NAME, *names)