python-injection 0.11.0__py3-none-any.whl → 0.12.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
injection/_core/module.py CHANGED
@@ -1,21 +1,23 @@
1
1
  from __future__ import annotations
2
2
 
3
- import inspect
3
+ import asyncio
4
4
  from abc import ABC, abstractmethod
5
5
  from collections import OrderedDict
6
6
  from collections.abc import (
7
+ AsyncIterator,
8
+ Awaitable,
7
9
  Callable,
8
10
  Collection,
9
11
  Iterable,
10
12
  Iterator,
11
13
  Mapping,
12
- MutableMapping,
13
14
  )
14
15
  from contextlib import contextmanager, suppress
15
16
  from dataclasses import dataclass, field
16
17
  from enum import StrEnum
17
18
  from functools import partialmethod, singledispatchmethod, update_wrapper
18
- from inspect import Signature, isclass
19
+ from inspect import Signature, isclass, iscoroutinefunction, markcoroutinefunction
20
+ from inspect import signature as inspect_signature
19
21
  from logging import Logger, getLogger
20
22
  from queue import Empty, Queue
21
23
  from types import MethodType
@@ -25,22 +27,32 @@ from typing import (
25
27
  ContextManager,
26
28
  Literal,
27
29
  NamedTuple,
28
- NoReturn,
29
30
  Protocol,
30
31
  Self,
31
- override,
32
+ overload,
32
33
  runtime_checkable,
33
34
  )
34
35
  from uuid import uuid4
35
36
 
37
+ from injection._core.common.asynchronous import (
38
+ AsyncCaller,
39
+ Caller,
40
+ SimpleAwaitable,
41
+ SyncCaller,
42
+ )
36
43
  from injection._core.common.event import Event, EventChannel, EventListener
37
44
  from injection._core.common.invertible import Invertible, SimpleInvertible
38
45
  from injection._core.common.lazy import Lazy, LazyMapping
39
46
  from injection._core.common.threading import synchronized
40
47
  from injection._core.common.type import InputType, TypeInfo, get_return_types
41
48
  from injection._core.hook import Hook, apply_hooks
49
+ from injection._core.injectables import (
50
+ Injectable,
51
+ ShouldBeInjectable,
52
+ SimpleInjectable,
53
+ SingletonInjectable,
54
+ )
42
55
  from injection.exceptions import (
43
- InjectionError,
44
56
  ModuleError,
45
57
  ModuleLockError,
46
58
  ModuleNotUsedError,
@@ -62,7 +74,6 @@ class LocatorDependenciesUpdated[T](LocatorEvent):
62
74
  classes: Collection[InputType[T]]
63
75
  mode: Mode
64
76
 
65
- @override
66
77
  def __str__(self) -> str:
67
78
  length = len(self.classes)
68
79
  formatted_types = ", ".join(f"`{cls}`" for cls in self.classes)
@@ -81,7 +92,6 @@ class ModuleEvent(Event, ABC):
81
92
  class ModuleEventProxy(ModuleEvent):
82
93
  event: Event
83
94
 
84
- @override
85
95
  def __str__(self) -> str:
86
96
  return f"`{self.module}` has propagated an event: {self.origin}"
87
97
 
@@ -102,7 +112,6 @@ class ModuleAdded(ModuleEvent):
102
112
  module_added: Module
103
113
  priority: Priority
104
114
 
105
- @override
106
115
  def __str__(self) -> str:
107
116
  return f"`{self.module}` now uses `{self.module_added}`."
108
117
 
@@ -111,7 +120,6 @@ class ModuleAdded(ModuleEvent):
111
120
  class ModuleRemoved(ModuleEvent):
112
121
  module_removed: Module
113
122
 
114
- @override
115
123
  def __str__(self) -> str:
116
124
  return f"`{self.module}` no longer uses `{self.module_removed}`."
117
125
 
@@ -121,7 +129,6 @@ class ModulePriorityUpdated(ModuleEvent):
121
129
  module_updated: Module
122
130
  priority: Priority
123
131
 
124
- @override
125
132
  def __str__(self) -> str:
126
133
  return (
127
134
  f"In `{self.module}`, the priority `{self.priority}` "
@@ -129,86 +136,6 @@ class ModulePriorityUpdated(ModuleEvent):
129
136
  )
130
137
 
131
138
 
132
- """
133
- Injectables
134
- """
135
-
136
-
137
- @runtime_checkable
138
- class Injectable[T](Protocol):
139
- __slots__ = ()
140
-
141
- @property
142
- def is_locked(self) -> bool:
143
- return False
144
-
145
- def unlock(self) -> None:
146
- return
147
-
148
- @abstractmethod
149
- def get_instance(self) -> T:
150
- raise NotImplementedError
151
-
152
-
153
- @dataclass(repr=False, frozen=True, slots=True)
154
- class BaseInjectable[T](Injectable[T], ABC):
155
- factory: Callable[..., T]
156
-
157
-
158
- class SimpleInjectable[T](BaseInjectable[T]):
159
- __slots__ = ()
160
-
161
- @override
162
- def get_instance(self) -> T:
163
- return self.factory()
164
-
165
-
166
- class SingletonInjectable[T](BaseInjectable[T]):
167
- __slots__ = ("__dict__",)
168
-
169
- __key: ClassVar[str] = "$instance"
170
-
171
- @property
172
- def cache(self) -> MutableMapping[str, Any]:
173
- return self.__dict__
174
-
175
- @property
176
- @override
177
- def is_locked(self) -> bool:
178
- return self.__key in self.cache
179
-
180
- @override
181
- def unlock(self) -> None:
182
- self.cache.clear()
183
-
184
- @override
185
- def get_instance(self) -> T:
186
- with suppress(KeyError):
187
- return self.cache[self.__key]
188
-
189
- with synchronized():
190
- instance = self.factory()
191
- self.cache[self.__key] = instance
192
-
193
- return instance
194
-
195
-
196
- @dataclass(repr=False, frozen=True, slots=True)
197
- class ShouldBeInjectable[T](Injectable[T]):
198
- cls: type[T]
199
-
200
- @override
201
- def get_instance(self) -> NoReturn:
202
- raise InjectionError(f"`{self.cls}` should be an injectable.")
203
-
204
- @classmethod
205
- def from_callable(cls, callable: Callable[..., T]) -> Self:
206
- if not isclass(callable):
207
- raise TypeError(f"`{callable}` should be a class.")
208
-
209
- return cls(callable)
210
-
211
-
212
139
  """
213
140
  Broker
214
141
  """
@@ -235,6 +162,10 @@ class Broker(Protocol):
235
162
  def unlock(self) -> Self:
236
163
  raise NotImplementedError
237
164
 
165
+ @abstractmethod
166
+ async def all_ready(self) -> None:
167
+ raise NotImplementedError
168
+
238
169
 
239
170
  """
240
171
  Locator
@@ -257,7 +188,7 @@ class Mode(StrEnum):
257
188
 
258
189
  type ModeStr = Literal["fallback", "normal", "override"]
259
190
 
260
- type InjectableFactory[T] = Callable[[Callable[..., T]], Injectable[T]]
191
+ type InjectableFactory[T] = Callable[[Caller[..., T]], Injectable[T]]
261
192
 
262
193
 
263
194
  class Record[T](NamedTuple):
@@ -267,7 +198,7 @@ class Record[T](NamedTuple):
267
198
 
268
199
  @dataclass(repr=False, eq=False, kw_only=True, slots=True)
269
200
  class Updater[T]:
270
- factory: Callable[..., T]
201
+ factory: Caller[..., T]
271
202
  classes: Iterable[InputType[T]]
272
203
  injectable_factory: InjectableFactory[T]
273
204
  mode: Mode
@@ -304,7 +235,6 @@ class Locator(Broker):
304
235
 
305
236
  static_hooks: ClassVar[LocatorHooks[Any]] = LocatorHooks.default()
306
237
 
307
- @override
308
238
  def __getitem__[T](self, cls: InputType[T], /) -> Injectable[T]:
309
239
  for input_class in self.__standardize_inputs((cls,)):
310
240
  try:
@@ -316,7 +246,6 @@ class Locator(Broker):
316
246
 
317
247
  raise NoInjectable(cls)
318
248
 
319
- @override
320
249
  def __contains__(self, cls: InputType[Any], /) -> bool:
321
250
  return any(
322
251
  input_class in self.__records
@@ -324,7 +253,6 @@ class Locator(Broker):
324
253
  )
325
254
 
326
255
  @property
327
- @override
328
256
  def is_locked(self) -> bool:
329
257
  return any(injectable.is_locked for injectable in self.__injectables)
330
258
 
@@ -346,7 +274,6 @@ class Locator(Broker):
346
274
 
347
275
  return self
348
276
 
349
- @override
350
277
  @synchronized()
351
278
  def unlock(self) -> Self:
352
279
  for injectable in self.__injectables:
@@ -354,6 +281,10 @@ class Locator(Broker):
354
281
 
355
282
  return self
356
283
 
284
+ async def all_ready(self) -> None:
285
+ for injectable in self.__injectables:
286
+ await injectable.aget_instance()
287
+
357
288
  def add_listener(self, listener: EventListener) -> Self:
358
289
  self.__channel.add_listener(listener)
359
290
  return self
@@ -444,7 +375,6 @@ class Module(Broker, EventListener):
444
375
  def __post_init__(self) -> None:
445
376
  self.__locator.add_listener(self)
446
377
 
447
- @override
448
378
  def __getitem__[T](self, cls: InputType[T], /) -> Injectable[T]:
449
379
  for broker in self.__brokers:
450
380
  with suppress(KeyError):
@@ -452,12 +382,10 @@ class Module(Broker, EventListener):
452
382
 
453
383
  raise NoInjectable(cls)
454
384
 
455
- @override
456
385
  def __contains__(self, cls: InputType[Any], /) -> bool:
457
386
  return any(cls in broker for broker in self.__brokers)
458
387
 
459
388
  @property
460
- @override
461
389
  def is_locked(self) -> bool:
462
390
  return any(broker.is_locked for broker in self.__brokers)
463
391
 
@@ -466,18 +394,20 @@ class Module(Broker, EventListener):
466
394
  yield from tuple(self.__modules)
467
395
  yield self.__locator
468
396
 
469
- def injectable[**P, T]( # type: ignore[no-untyped-def]
397
+ def injectable[**P, T](
470
398
  self,
471
- wrapped: Callable[P, T] | None = None,
399
+ wrapped: Callable[P, T] | Callable[P, Awaitable[T]] | None = None,
472
400
  /,
473
401
  *,
474
402
  cls: InjectableFactory[T] = SimpleInjectable,
475
403
  inject: bool = True,
476
404
  on: TypeInfo[T] = (),
477
405
  mode: Mode | ModeStr = Mode.get_default(),
478
- ):
479
- def decorator(wp): # type: ignore[no-untyped-def]
480
- factory = self.make_injected_function(wp) if inject else wp
406
+ ) -> Any:
407
+ def decorator(
408
+ wp: Callable[P, T] | Callable[P, Awaitable[T]],
409
+ ) -> Callable[P, T] | Callable[P, Awaitable[T]]:
410
+ factory = _get_caller(self.make_injected_function(wp) if inject else wp)
481
411
  classes = get_return_types(wp, on)
482
412
  updater = Updater(
483
413
  factory=factory,
@@ -492,12 +422,12 @@ class Module(Broker, EventListener):
492
422
 
493
423
  singleton = partialmethod(injectable, cls=SingletonInjectable)
494
424
 
495
- def should_be_injectable[T](self, wrapped: type[T] | None = None, /): # type: ignore[no-untyped-def]
496
- def decorator(wp): # type: ignore[no-untyped-def]
425
+ def should_be_injectable[T](self, wrapped: type[T] | None = None, /) -> Any:
426
+ def decorator(wp: type[T]) -> type[T]:
497
427
  updater = Updater(
498
- factory=wp,
428
+ factory=SyncCaller(wp),
499
429
  classes=(wp,),
500
- injectable_factory=ShouldBeInjectable.from_callable,
430
+ injectable_factory=lambda _: ShouldBeInjectable(wp),
501
431
  mode=Mode.FALLBACK,
502
432
  )
503
433
  self.update(updater)
@@ -505,15 +435,15 @@ class Module(Broker, EventListener):
505
435
 
506
436
  return decorator(wrapped) if wrapped else decorator
507
437
 
508
- def constant[T]( # type: ignore[no-untyped-def]
438
+ def constant[T](
509
439
  self,
510
440
  wrapped: type[T] | None = None,
511
441
  /,
512
442
  *,
513
443
  on: TypeInfo[T] = (),
514
444
  mode: Mode | ModeStr = Mode.get_default(),
515
- ):
516
- def decorator(wp): # type: ignore[no-untyped-def]
445
+ ) -> Any:
446
+ def decorator(wp: type[T]) -> type[T]:
517
447
  lazy_instance = Lazy(wp)
518
448
  self.injectable(
519
449
  lambda: ~lazy_instance,
@@ -545,8 +475,8 @@ class Module(Broker, EventListener):
545
475
  )
546
476
  return self
547
477
 
548
- def inject[**P, T](self, wrapped: Callable[P, T] | None = None, /): # type: ignore[no-untyped-def]
549
- def decorator(wp): # type: ignore[no-untyped-def]
478
+ def inject[**P, T](self, wrapped: Callable[P, T] | None = None, /) -> Any:
479
+ def decorator(wp: Callable[P, T]) -> Callable[P, T]:
550
480
  if isclass(wp):
551
481
  wp.__init__ = self.inject(wp.__init__)
552
482
  return wp
@@ -555,47 +485,135 @@ class Module(Broker, EventListener):
555
485
 
556
486
  return decorator(wrapped) if wrapped else decorator
557
487
 
488
+ @overload
558
489
  def make_injected_function[**P, T](
559
490
  self,
560
491
  wrapped: Callable[P, T],
561
492
  /,
562
- ) -> InjectedFunction[P, T]:
563
- injected = Injected(wrapped)
493
+ ) -> SyncInjectedFunction[P, T]: ...
494
+
495
+ @overload
496
+ def make_injected_function[**P, T](
497
+ self,
498
+ wrapped: Callable[P, Awaitable[T]],
499
+ /,
500
+ ) -> AsyncInjectedFunction[P, T]: ...
564
501
 
565
- @injected.on_setup
502
+ def make_injected_function(self, wrapped, /): # type: ignore[no-untyped-def]
503
+ metadata = InjectMetadata(wrapped)
504
+
505
+ @metadata.on_setup
566
506
  def listen() -> None:
567
- injected.update(self)
568
- self.add_listener(injected)
507
+ metadata.update(self)
508
+ self.add_listener(metadata)
509
+
510
+ if iscoroutinefunction(wrapped):
511
+ return AsyncInjectedFunction(metadata)
569
512
 
570
- return InjectedFunction(injected)
513
+ return SyncInjectedFunction(metadata)
514
+
515
+ async def afind_instance[T](self, cls: InputType[T]) -> T:
516
+ injectable = self[cls]
517
+ return await injectable.aget_instance()
571
518
 
572
519
  def find_instance[T](self, cls: InputType[T]) -> T:
573
520
  injectable = self[cls]
574
521
  return injectable.get_instance()
575
522
 
523
+ @overload
524
+ async def aget_instance[T, Default](
525
+ self,
526
+ cls: InputType[T],
527
+ default: Default,
528
+ ) -> T | Default: ...
529
+
530
+ @overload
531
+ async def aget_instance[T](
532
+ self,
533
+ cls: InputType[T],
534
+ default: None = ...,
535
+ ) -> T | None: ...
536
+
537
+ async def aget_instance(self, cls, default=None): # type: ignore[no-untyped-def]
538
+ try:
539
+ return await self.afind_instance(cls)
540
+ except KeyError:
541
+ return default
542
+
543
+ @overload
576
544
  def get_instance[T, Default](
577
545
  self,
578
546
  cls: InputType[T],
579
- default: Default | None = None,
580
- ) -> T | Default | None:
547
+ default: Default,
548
+ ) -> T | Default: ...
549
+
550
+ @overload
551
+ def get_instance[T](
552
+ self,
553
+ cls: InputType[T],
554
+ default: None = ...,
555
+ ) -> T | None: ...
556
+
557
+ def get_instance(self, cls, default=None): # type: ignore[no-untyped-def]
581
558
  try:
582
559
  return self.find_instance(cls)
583
560
  except KeyError:
584
561
  return default
585
562
 
563
+ @overload
564
+ def aget_lazy_instance[T, Default](
565
+ self,
566
+ cls: InputType[T],
567
+ default: Default,
568
+ *,
569
+ cache: bool = ...,
570
+ ) -> Awaitable[T | Default]: ...
571
+
572
+ @overload
573
+ def aget_lazy_instance[T](
574
+ self,
575
+ cls: InputType[T],
576
+ default: None = ...,
577
+ *,
578
+ cache: bool = ...,
579
+ ) -> Awaitable[T | None]: ...
580
+
581
+ def aget_lazy_instance(self, cls, default=None, *, cache=False): # type: ignore[no-untyped-def]
582
+ if cache:
583
+ coroutine = self.aget_instance(cls, default)
584
+ return asyncio.ensure_future(coroutine)
585
+
586
+ function = self.make_injected_function(lambda instance=default: instance)
587
+ metadata = function.__inject_metadata__
588
+ metadata.set_owner(cls)
589
+ return SimpleAwaitable(metadata.acall)
590
+
591
+ @overload
586
592
  def get_lazy_instance[T, Default](
587
593
  self,
588
594
  cls: InputType[T],
589
- default: Default | None = None,
595
+ default: Default,
590
596
  *,
591
- cache: bool = False,
592
- ) -> Invertible[T | Default | None]:
597
+ cache: bool = ...,
598
+ ) -> Invertible[T | Default]: ...
599
+
600
+ @overload
601
+ def get_lazy_instance[T](
602
+ self,
603
+ cls: InputType[T],
604
+ default: None = ...,
605
+ *,
606
+ cache: bool = ...,
607
+ ) -> Invertible[T | None]: ...
608
+
609
+ def get_lazy_instance(self, cls, default=None, *, cache=False): # type: ignore[no-untyped-def]
593
610
  if cache:
594
611
  return Lazy(lambda: self.get_instance(cls, default))
595
612
 
596
- function = self.inject(lambda instance=default: instance)
597
- function.__injected__.set_owner(cls)
598
- return SimpleInvertible(function)
613
+ function = self.make_injected_function(lambda instance=default: instance)
614
+ metadata = function.__inject_metadata__
615
+ metadata.set_owner(cls)
616
+ return SimpleInvertible(metadata.call)
599
617
 
600
618
  def update[T](self, updater: Updater[T]) -> Self:
601
619
  self.__locator.update(updater)
@@ -662,7 +680,6 @@ class Module(Broker, EventListener):
662
680
 
663
681
  return self
664
682
 
665
- @override
666
683
  @synchronized()
667
684
  def unlock(self) -> Self:
668
685
  for broker in self.__brokers:
@@ -670,6 +687,10 @@ class Module(Broker, EventListener):
670
687
 
671
688
  return self
672
689
 
690
+ async def all_ready(self) -> None:
691
+ for broker in self.__brokers:
692
+ await broker.all_ready()
693
+
673
694
  def add_logger(self, logger: Logger) -> Self:
674
695
  self.__loggers.append(logger)
675
696
  return self
@@ -682,7 +703,6 @@ class Module(Broker, EventListener):
682
703
  self.__channel.remove_listener(listener)
683
704
  return self
684
705
 
685
- @override
686
706
  def on_event(self, event: Event, /) -> ContextManager[None] | None:
687
707
  self_event = ModuleEventProxy(self, event)
688
708
  return self.dispatch(self_event)
@@ -730,6 +750,13 @@ class Module(Broker, EventListener):
730
750
  return cls.from_name("__default__")
731
751
 
732
752
 
753
+ def mod(name: str | None = None, /) -> Module:
754
+ if name is None:
755
+ return Module.default()
756
+
757
+ return Module.from_name(name)
758
+
759
+
733
760
  """
734
761
  InjectedFunction
735
762
  """
@@ -744,7 +771,13 @@ class Dependencies:
744
771
 
745
772
  def __iter__(self) -> Iterator[tuple[str, Any]]:
746
773
  for name, injectable in self.mapping.items():
747
- yield name, injectable.get_instance()
774
+ instance = injectable.get_instance()
775
+ yield name, instance
776
+
777
+ async def __aiter__(self) -> AsyncIterator[tuple[str, Any]]:
778
+ for name, injectable in self.mapping.items():
779
+ instance = await injectable.aget_instance()
780
+ yield name, instance
748
781
 
749
782
  @property
750
783
  def are_resolved(self) -> bool:
@@ -753,9 +786,11 @@ class Dependencies:
753
786
 
754
787
  return bool(self)
755
788
 
756
- @property
757
- def arguments(self) -> OrderedDict[str, Any]:
758
- return OrderedDict(self)
789
+ async def aget_arguments(self) -> dict[str, Any]:
790
+ return {key: value async for key, value in self}
791
+
792
+ def get_arguments(self) -> dict[str, Any]:
793
+ return dict(self)
759
794
 
760
795
  @classmethod
761
796
  def from_mapping(cls, mapping: Mapping[str, Injectable[Any]]) -> Self:
@@ -810,7 +845,7 @@ class Arguments(NamedTuple):
810
845
  kwargs: Mapping[str, Any]
811
846
 
812
847
 
813
- class Injected[**P, T](EventListener):
848
+ class InjectMetadata[**P, T](Caller[P, T], EventListener):
814
849
  __slots__ = (
815
850
  "__dependencies",
816
851
  "__owner",
@@ -831,18 +866,13 @@ class Injected[**P, T](EventListener):
831
866
  self.__setup_queue = Queue()
832
867
  self.__wrapped = wrapped
833
868
 
834
- def __call__(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
835
- self.__setup()
836
- arguments = self.bind(args, kwargs)
837
- return self.wrapped(*arguments.args, **arguments.kwargs)
838
-
839
869
  @property
840
870
  def signature(self) -> Signature:
841
871
  with suppress(AttributeError):
842
872
  return self.__signature
843
873
 
844
874
  with synchronized():
845
- signature = inspect.signature(self.wrapped, eval_str=True)
875
+ signature = inspect_signature(self.wrapped, eval_str=True)
846
876
  self.__signature = signature
847
877
 
848
878
  return signature
@@ -851,22 +881,31 @@ class Injected[**P, T](EventListener):
851
881
  def wrapped(self) -> Callable[P, T]:
852
882
  return self.__wrapped
853
883
 
884
+ async def abind(
885
+ self,
886
+ args: Iterable[Any] = (),
887
+ kwargs: Mapping[str, Any] | None = None,
888
+ ) -> Arguments:
889
+ additional_arguments = await self.__dependencies.aget_arguments()
890
+ return self.__bind(args, kwargs, additional_arguments)
891
+
854
892
  def bind(
855
893
  self,
856
894
  args: Iterable[Any] = (),
857
895
  kwargs: Mapping[str, Any] | None = None,
858
896
  ) -> Arguments:
859
- if kwargs is None:
860
- kwargs = {}
897
+ additional_arguments = self.__dependencies.get_arguments()
898
+ return self.__bind(args, kwargs, additional_arguments)
861
899
 
862
- if not self.__dependencies:
863
- return Arguments(args, kwargs)
900
+ async def acall(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
901
+ self.__setup()
902
+ arguments = await self.abind(args, kwargs)
903
+ return self.wrapped(*arguments.args, **arguments.kwargs)
864
904
 
865
- bound = self.signature.bind_partial(*args, **kwargs)
866
- bound.arguments = (
867
- bound.arguments | self.__dependencies.arguments | bound.arguments
868
- )
869
- return Arguments(bound.args, bound.kwargs)
905
+ def call(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
906
+ self.__setup()
907
+ arguments = self.bind(args, kwargs)
908
+ return self.wrapped(*arguments.args, **arguments.kwargs)
870
909
 
871
910
  def set_owner(self, owner: type) -> Self:
872
911
  if self.__dependencies.are_resolved:
@@ -885,8 +924,8 @@ class Injected[**P, T](EventListener):
885
924
  self.__dependencies = Dependencies.resolve(self.signature, module, self.__owner)
886
925
  return self
887
926
 
888
- def on_setup[**_P, _T](self, wrapped: Callable[_P, _T] | None = None, /): # type: ignore[no-untyped-def]
889
- def decorator(wp): # type: ignore[no-untyped-def]
927
+ def on_setup[**_P, _T](self, wrapped: Callable[_P, _T] | None = None, /) -> Any:
928
+ def decorator(wp: Callable[_P, _T]) -> Callable[_P, _T]:
890
929
  queue = self.__setup_queue
891
930
 
892
931
  if queue is None:
@@ -898,7 +937,6 @@ class Injected[**P, T](EventListener):
898
937
  return decorator(wrapped) if wrapped else decorator
899
938
 
900
939
  @singledispatchmethod
901
- @override
902
940
  def on_event(self, event: Event, /) -> ContextManager[None] | None: # type: ignore[override]
903
941
  return None
904
942
 
@@ -908,6 +946,22 @@ class Injected[**P, T](EventListener):
908
946
  yield
909
947
  self.update(event.module)
910
948
 
949
+ def __bind(
950
+ self,
951
+ args: Iterable[Any],
952
+ kwargs: Mapping[str, Any] | None,
953
+ additional_arguments: dict[str, Any] | None,
954
+ ) -> Arguments:
955
+ if kwargs is None:
956
+ kwargs = {}
957
+
958
+ if not additional_arguments:
959
+ return Arguments(args, kwargs)
960
+
961
+ bound = self.signature.bind_partial(*args, **kwargs)
962
+ bound.arguments = bound.arguments | additional_arguments | bound.arguments
963
+ return Arguments(bound.args, bound.kwargs)
964
+
911
965
  def __close_setup_queue(self) -> None:
912
966
  self.__setup_queue = None
913
967
 
@@ -930,25 +984,24 @@ class Injected[**P, T](EventListener):
930
984
  self.__close_setup_queue()
931
985
 
932
986
 
933
- class InjectedFunction[**P, T]:
934
- __slots__ = ("__dict__", "__injected__")
987
+ class InjectedFunction[**P, T](ABC):
988
+ __slots__ = ("__dict__", "__inject_metadata__")
935
989
 
936
- __injected__: Injected[P, T]
990
+ __inject_metadata__: InjectMetadata[P, T]
937
991
 
938
- def __init__(self, injected: Injected[P, T]) -> None:
939
- update_wrapper(self, injected.wrapped)
940
- self.__injected__ = injected
992
+ def __init__(self, metadata: InjectMetadata[P, T]) -> None:
993
+ update_wrapper(self, metadata.wrapped)
994
+ self.__inject_metadata__ = metadata
941
995
 
942
- @override
943
996
  def __repr__(self) -> str: # pragma: no cover
944
- return repr(self.__injected__.wrapped)
997
+ return repr(self.__inject_metadata__.wrapped)
945
998
 
946
- @override
947
999
  def __str__(self) -> str: # pragma: no cover
948
- return str(self.__injected__.wrapped)
1000
+ return str(self.__inject_metadata__.wrapped)
949
1001
 
1002
+ @abstractmethod
950
1003
  def __call__(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
951
- return self.__injected__(*args, **kwargs)
1004
+ raise NotImplementedError
952
1005
 
953
1006
  def __get__(
954
1007
  self,
@@ -961,4 +1014,32 @@ class InjectedFunction[**P, T]:
961
1014
  return MethodType(self, instance)
962
1015
 
963
1016
  def __set_name__(self, owner: type, name: str) -> None:
964
- self.__injected__.set_owner(owner)
1017
+ self.__inject_metadata__.set_owner(owner)
1018
+
1019
+
1020
+ class AsyncInjectedFunction[**P, T](InjectedFunction[P, Awaitable[T]]):
1021
+ __slots__ = ()
1022
+
1023
+ def __init__(self, metadata: InjectMetadata[P, Awaitable[T]]) -> None:
1024
+ super().__init__(metadata)
1025
+ markcoroutinefunction(self)
1026
+
1027
+ async def __call__(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
1028
+ return await (await self.__inject_metadata__.acall(*args, **kwargs))
1029
+
1030
+
1031
+ class SyncInjectedFunction[**P, T](InjectedFunction[P, T]):
1032
+ __slots__ = ()
1033
+
1034
+ def __call__(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
1035
+ return self.__inject_metadata__.call(*args, **kwargs)
1036
+
1037
+
1038
+ def _get_caller[**P, T](function: Callable[P, T]) -> Caller[P, T]:
1039
+ if iscoroutinefunction(function):
1040
+ return AsyncCaller(function)
1041
+
1042
+ elif isinstance(function, InjectedFunction):
1043
+ return function.__inject_metadata__
1044
+
1045
+ return SyncCaller(function)