anydi 0.32.2__tar.gz → 0.33.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. {anydi-0.32.2 → anydi-0.33.0}/PKG-INFO +1 -1
  2. {anydi-0.32.2 → anydi-0.33.0}/anydi/_container.py +132 -17
  3. {anydi-0.32.2 → anydi-0.33.0}/anydi/_context.py +43 -23
  4. {anydi-0.32.2 → anydi-0.33.0}/anydi/_provider.py +1 -1
  5. {anydi-0.32.2 → anydi-0.33.0}/anydi/_types.py +6 -0
  6. {anydi-0.32.2 → anydi-0.33.0}/anydi/ext/_utils.py +2 -2
  7. {anydi-0.32.2 → anydi-0.33.0}/anydi/ext/fastapi.py +1 -1
  8. {anydi-0.32.2 → anydi-0.33.0}/anydi/ext/faststream.py +1 -1
  9. {anydi-0.32.2 → anydi-0.33.0}/pyproject.toml +3 -3
  10. anydi-0.32.2/anydi/_injector.py +0 -94
  11. {anydi-0.32.2 → anydi-0.33.0}/LICENSE +0 -0
  12. {anydi-0.32.2 → anydi-0.33.0}/README.md +0 -0
  13. {anydi-0.32.2 → anydi-0.33.0}/anydi/__init__.py +0 -0
  14. {anydi-0.32.2 → anydi-0.33.0}/anydi/_logger.py +0 -0
  15. {anydi-0.32.2 → anydi-0.33.0}/anydi/_module.py +0 -0
  16. {anydi-0.32.2 → anydi-0.33.0}/anydi/_scanner.py +0 -0
  17. {anydi-0.32.2 → anydi-0.33.0}/anydi/_utils.py +0 -0
  18. {anydi-0.32.2 → anydi-0.33.0}/anydi/ext/__init__.py +0 -0
  19. {anydi-0.32.2 → anydi-0.33.0}/anydi/ext/django/__init__.py +0 -0
  20. {anydi-0.32.2 → anydi-0.33.0}/anydi/ext/django/_container.py +0 -0
  21. {anydi-0.32.2 → anydi-0.33.0}/anydi/ext/django/_settings.py +0 -0
  22. {anydi-0.32.2 → anydi-0.33.0}/anydi/ext/django/_utils.py +0 -0
  23. {anydi-0.32.2 → anydi-0.33.0}/anydi/ext/django/apps.py +0 -0
  24. {anydi-0.32.2 → anydi-0.33.0}/anydi/ext/django/middleware.py +0 -0
  25. {anydi-0.32.2 → anydi-0.33.0}/anydi/ext/django/ninja/__init__.py +0 -0
  26. {anydi-0.32.2 → anydi-0.33.0}/anydi/ext/django/ninja/_operation.py +0 -0
  27. {anydi-0.32.2 → anydi-0.33.0}/anydi/ext/django/ninja/_signature.py +0 -0
  28. {anydi-0.32.2 → anydi-0.33.0}/anydi/ext/pydantic_settings.py +0 -0
  29. {anydi-0.32.2 → anydi-0.33.0}/anydi/ext/pytest_plugin.py +0 -0
  30. {anydi-0.32.2 → anydi-0.33.0}/anydi/ext/starlette/__init__.py +0 -0
  31. {anydi-0.32.2 → anydi-0.33.0}/anydi/ext/starlette/middleware.py +0 -0
  32. {anydi-0.32.2 → anydi-0.33.0}/anydi/py.typed +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: anydi
3
- Version: 0.32.2
3
+ Version: 0.33.0
4
4
  Summary: Dependency Injection library
5
5
  Home-page: https://github.com/antonrh/anydi
6
6
  License: MIT
@@ -3,12 +3,14 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import contextlib
6
+ import functools
6
7
  import inspect
7
8
  import types
8
9
  from collections import defaultdict
9
- from collections.abc import AsyncIterator, Awaitable, Iterable, Iterator, Sequence
10
+ from collections.abc import AsyncIterator, Iterable, Iterator, Sequence
10
11
  from contextvars import ContextVar
11
12
  from typing import Any, Callable, TypeVar, cast, overload
13
+ from weakref import WeakKeyDictionary
12
14
 
13
15
  from typing_extensions import ParamSpec, Self, final
14
16
 
@@ -19,11 +21,11 @@ from ._context import (
19
21
  SingletonContext,
20
22
  TransientContext,
21
23
  )
22
- from ._injector import Injector
24
+ from ._logger import logger
23
25
  from ._module import Module, ModuleRegistry
24
26
  from ._provider import Provider
25
27
  from ._scanner import Scanner
26
- from ._types import AnyInterface, Interface, Scope
28
+ from ._types import AnyInterface, Interface, Scope, TestInterface, is_marker
27
29
  from ._utils import get_full_qualname, get_typed_parameters, is_builtin_type
28
30
 
29
31
  T = TypeVar("T", bound=Any)
@@ -47,6 +49,7 @@ class Container:
47
49
  modules: Sequence[Module | type[Module] | Callable[[Container], None] | str]
48
50
  | None = None,
49
51
  strict: bool = False,
52
+ testing: bool = False,
50
53
  ) -> None:
51
54
  self._providers: dict[type[Any], Provider] = {}
52
55
  self._resource_cache: dict[Scope, list[type[Any]]] = defaultdict(list)
@@ -57,10 +60,13 @@ class Container:
57
60
  )
58
61
  self._override_instances: dict[type[Any], Any] = {}
59
62
  self._strict = strict
63
+ self._testing = testing
60
64
  self._unresolved_interfaces: set[type[Any]] = set()
65
+ self._inject_cache: WeakKeyDictionary[
66
+ Callable[..., Any], Callable[..., Any]
67
+ ] = WeakKeyDictionary()
61
68
 
62
69
  # Components
63
- self._injector = Injector(self)
64
70
  self._modules = ModuleRegistry(self)
65
71
  self._scanner = Scanner(self)
66
72
 
@@ -79,6 +85,11 @@ class Container:
79
85
  """Check if strict mode is enabled."""
80
86
  return self._strict
81
87
 
88
+ @property
89
+ def testing(self) -> bool:
90
+ """Check if testing mode is enabled."""
91
+ return self._testing
92
+
82
93
  @property
83
94
  def providers(self) -> dict[type[Any], Provider]:
84
95
  """Get the registered providers."""
@@ -202,6 +213,10 @@ class Container:
202
213
  annotation, parent_scope=provider.scope
203
214
  )
204
215
  except LookupError:
216
+ # Skip unresolved interfaces in non-strict mode
217
+ if not self.strict and parameter.default is not inspect.Parameter.empty:
218
+ continue
219
+
205
220
  if provider.scope not in {"singleton", "transient"}:
206
221
  self._unresolved_interfaces.add(provider.interface)
207
222
  continue
@@ -225,7 +240,12 @@ class Container:
225
240
  scopes = set()
226
241
 
227
242
  for parameter in get_typed_parameters(call):
228
- sub_provider = self._get_or_register_provider(parameter.annotation)
243
+ try:
244
+ sub_provider = self._get_or_register_provider(parameter.annotation)
245
+ except LookupError:
246
+ if not self.strict and parameter.default is not inspect.Parameter.empty:
247
+ continue
248
+ raise
229
249
  scope = sub_provider.scope
230
250
 
231
251
  if scope == "transient":
@@ -234,7 +254,7 @@ class Container:
234
254
 
235
255
  # If all scopes are found, we can return based on priority order
236
256
  if {"transient", "request", "singleton"}.issubset(scopes):
237
- break
257
+ break # pragma: no cover
238
258
 
239
259
  # Determine scope based on priority
240
260
  if "request" in scopes:
@@ -346,7 +366,32 @@ class Container:
346
366
 
347
367
  provider = self._get_or_register_provider(interface)
348
368
  scoped_context = self._get_scoped_context(provider.scope)
349
- return cast(T, scoped_context.get(provider))
369
+ instance, created = scoped_context.get_or_create(provider)
370
+ if self.testing and created:
371
+ self._patch_test_resolver(instance)
372
+ return cast(T, instance)
373
+
374
+ def _patch_test_resolver(self, instance: Any) -> None:
375
+ """Patch the test resolver for the instance."""
376
+
377
+ def _resolver(_self: Any, _name: str) -> Any:
378
+ try:
379
+ test_interfaces = object.__getattribute__(_self, "__test_interfaces__")
380
+ except AttributeError:
381
+ test_interfaces = {
382
+ name: value.interface
383
+ for name, value in object.__getattribute__(
384
+ _self, "__dict__"
385
+ ).items()
386
+ if isinstance(value, TestInterface)
387
+ }
388
+ object.__setattr__(_self, "__test_interfaces__", test_interfaces)
389
+ if _name in test_interfaces:
390
+ return self.resolve(test_interfaces[_name])
391
+ return object.__getattribute__(_self, _name)
392
+
393
+ if hasattr(instance, "__class__") and not is_builtin_type(instance.__class__):
394
+ instance.__class__.__getattribute__ = _resolver
350
395
 
351
396
  @overload
352
397
  async def aresolve(self, interface: Interface[T]) -> T: ...
@@ -361,7 +406,10 @@ class Container:
361
406
 
362
407
  provider = self._get_or_register_provider(interface)
363
408
  scoped_context = self._get_scoped_context(provider.scope)
364
- return cast(T, await scoped_context.aget(provider))
409
+ instance, created = await scoped_context.aget_or_create(provider)
410
+ if self.testing and created:
411
+ self._patch_test_resolver(instance)
412
+ return cast(T, instance)
365
413
 
366
414
  def is_resolved(self, interface: AnyInterface) -> bool:
367
415
  """Check if an instance by interface exists."""
@@ -424,25 +472,92 @@ class Container:
424
472
  def inject(self) -> Callable[[Callable[P, T]], Callable[P, T]]: ...
425
473
 
426
474
  def inject(
427
- self, func: Callable[P, T | Awaitable[T]] | None = None
428
- ) -> (
429
- Callable[[Callable[P, T | Awaitable[T]]], Callable[P, T | Awaitable[T]]]
430
- | Callable[P, T | Awaitable[T]]
431
- ):
475
+ self, func: Callable[P, T] | None = None
476
+ ) -> Callable[[Callable[P, T]], Callable[P, T]] | Callable[P, T]:
432
477
  """Decorator to inject dependencies into a callable."""
433
478
 
434
479
  def decorator(
435
- inner: Callable[P, T | Awaitable[T]],
436
- ) -> Callable[P, T | Awaitable[T]]:
437
- return self._injector.inject(inner)
480
+ call: Callable[P, T],
481
+ ) -> Callable[P, T]:
482
+ return self._inject(call)
438
483
 
439
484
  if func is None:
440
485
  return decorator
441
486
  return decorator(func)
442
487
 
488
+ def _inject(
489
+ self,
490
+ call: Callable[P, T],
491
+ ) -> Callable[P, T]:
492
+ """Inject dependencies into a callable."""
493
+ if call in self._inject_cache:
494
+ return cast(Callable[P, T], self._inject_cache[call])
495
+
496
+ injected_params = self._get_injected_params(call)
497
+
498
+ if inspect.iscoroutinefunction(call):
499
+
500
+ @functools.wraps(call)
501
+ async def awrapper(*args: P.args, **kwargs: P.kwargs) -> T:
502
+ for name, annotation in injected_params.items():
503
+ kwargs[name] = await self.aresolve(annotation)
504
+ return cast(T, await call(*args, **kwargs))
505
+
506
+ self._inject_cache[call] = awrapper
507
+
508
+ return awrapper # type: ignore[return-value]
509
+
510
+ @functools.wraps(call)
511
+ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
512
+ for name, annotation in injected_params.items():
513
+ kwargs[name] = self.resolve(annotation)
514
+ return call(*args, **kwargs)
515
+
516
+ self._inject_cache[call] = wrapper
517
+
518
+ return wrapper
519
+
520
+ def _get_injected_params(self, call: Callable[..., Any]) -> dict[str, Any]:
521
+ """Get the injected parameters of a callable object."""
522
+ injected_params = {}
523
+ for parameter in get_typed_parameters(call):
524
+ if not is_marker(parameter.default):
525
+ continue
526
+ try:
527
+ self._validate_injected_parameter(call, parameter)
528
+ except LookupError as exc:
529
+ if not self.strict:
530
+ logger.debug(
531
+ f"Cannot validate the `{get_full_qualname(call)}` parameter "
532
+ f"`{parameter.name}` with an annotation of "
533
+ f"`{get_full_qualname(parameter.annotation)} due to being "
534
+ "in non-strict mode. It will be validated at the first call."
535
+ )
536
+ else:
537
+ raise exc
538
+ injected_params[parameter.name] = parameter.annotation
539
+ return injected_params
540
+
541
+ def _validate_injected_parameter(
542
+ self, call: Callable[..., Any], parameter: inspect.Parameter
543
+ ) -> None:
544
+ """Validate an injected parameter."""
545
+ if parameter.annotation is inspect.Parameter.empty:
546
+ raise TypeError(
547
+ f"Missing `{get_full_qualname(call)}` parameter "
548
+ f"`{parameter.name}` annotation."
549
+ )
550
+
551
+ if not self.is_registered(parameter.annotation):
552
+ raise LookupError(
553
+ f"`{get_full_qualname(call)}` has an unknown dependency parameter "
554
+ f"`{parameter.name}` with an annotation of "
555
+ f"`{get_full_qualname(parameter.annotation)}`."
556
+ )
557
+
443
558
  def run(self, func: Callable[P, T], /, *args: P.args, **kwargs: P.kwargs) -> T:
444
559
  """Run the given function with injected dependencies."""
445
- return cast(T, self._injector.inject(func)(*args, **kwargs))
560
+ return self._inject(func)(*args, **kwargs)
446
561
 
447
562
  def scan(
448
563
  self,
@@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Callable, ClassVar
9
9
  from typing_extensions import Self, final
10
10
 
11
11
  from ._provider import CallableKind, Provider
12
- from ._types import AnyInterface, Scope, is_event_type
12
+ from ._types import AnyInterface, Scope, TestInterface, is_event_type
13
13
  from ._utils import get_full_qualname, run_async
14
14
 
15
15
  if TYPE_CHECKING:
@@ -30,12 +30,12 @@ class ScopedContext(abc.ABC):
30
30
  self._instances[interface] = instance
31
31
 
32
32
  @abc.abstractmethod
33
- def get(self, provider: Provider) -> Any:
34
- """Get an instance of a dependency from the scoped context."""
33
+ def get_or_create(self, provider: Provider) -> tuple[Any, bool]:
34
+ """Get or create an instance of a dependency from the scoped context."""
35
35
 
36
36
  @abc.abstractmethod
37
- async def aget(self, provider: Provider) -> Any:
38
- """Get an async instance of a dependency from the scoped context."""
37
+ async def aget_or_create(self, provider: Provider) -> tuple[Any, bool]:
38
+ """Get or create an async instance of a dependency from the scoped context."""
39
39
 
40
40
  def _create_instance(self, provider: Provider) -> Any:
41
41
  """Create an instance using the provider."""
@@ -44,12 +44,12 @@ class ScopedContext(abc.ABC):
44
44
  f"The instance for the coroutine provider `{provider}` cannot be "
45
45
  "created in synchronous mode."
46
46
  )
47
- args, kwargs = self._get_provider_params(provider)
47
+ args, kwargs = self._get_provided_args(provider)
48
48
  return provider.call(*args, **kwargs)
49
49
 
50
50
  async def _acreate_instance(self, provider: Provider) -> Any:
51
51
  """Create an instance asynchronously using the provider."""
52
- args, kwargs = await self._aget_provider_params(provider)
52
+ args, kwargs = await self._aget_provided_args(provider)
53
53
  if provider.kind == CallableKind.COROUTINE:
54
54
  return await provider.call(*args, **kwargs)
55
55
  return await run_async(provider.call, *args, **kwargs)
@@ -78,7 +78,7 @@ class ScopedContext(abc.ABC):
78
78
  "or set in the scoped context."
79
79
  )
80
80
 
81
- def _get_provider_params(
81
+ def _get_provided_args(
82
82
  self, provider: Provider
83
83
  ) -> tuple[list[Any], dict[str, Any]]:
84
84
  """Retrieve the arguments for a provider."""
@@ -91,14 +91,22 @@ class ScopedContext(abc.ABC):
91
91
  elif parameter.annotation in self._instances:
92
92
  instance = self._instances[parameter.annotation]
93
93
  else:
94
- instance = self._resolve_parameter(provider, parameter)
94
+ try:
95
+ instance = self._resolve_parameter(provider, parameter)
96
+ except LookupError:
97
+ if parameter.default is inspect.Parameter.empty:
98
+ raise
99
+ instance = parameter.default
100
+ else:
101
+ if self.container.testing:
102
+ instance = TestInterface(interface=parameter.annotation)
95
103
  if parameter.kind == parameter.POSITIONAL_ONLY:
96
104
  args.append(instance)
97
105
  else:
98
106
  kwargs[parameter.name] = instance
99
107
  return args, kwargs
100
108
 
101
- async def _aget_provider_params(
109
+ async def _aget_provided_args(
102
110
  self, provider: Provider
103
111
  ) -> tuple[list[Any], dict[str, Any]]:
104
112
  """Asynchronously retrieve the arguments for a provider."""
@@ -111,7 +119,15 @@ class ScopedContext(abc.ABC):
111
119
  elif parameter.annotation in self._instances:
112
120
  instance = self._instances[parameter.annotation]
113
121
  else:
114
- instance = await self._aresolve_parameter(provider, parameter)
122
+ try:
123
+ instance = await self._aresolve_parameter(provider, parameter)
124
+ except LookupError:
125
+ if parameter.default is inspect.Parameter.empty:
126
+ raise
127
+ instance = parameter.default
128
+ else:
129
+ if self.container.testing:
130
+ instance = TestInterface(interface=parameter.annotation)
115
131
  if parameter.kind == parameter.POSITIONAL_ONLY:
116
132
  args.append(instance)
117
133
  else:
@@ -128,7 +144,7 @@ class ResourceScopedContext(ScopedContext):
128
144
  self._stack = contextlib.ExitStack()
129
145
  self._async_stack = contextlib.AsyncExitStack()
130
146
 
131
- def get(self, provider: Provider) -> Any:
147
+ def get_or_create(self, provider: Provider) -> tuple[Any, bool]:
132
148
  """Get an instance of a dependency from the scoped context."""
133
149
  instance = self._instances.get(provider.interface)
134
150
  if instance is None:
@@ -143,9 +159,10 @@ class ResourceScopedContext(ScopedContext):
143
159
  else:
144
160
  instance = self._create_instance(provider)
145
161
  self._instances[provider.interface] = instance
146
- return instance
162
+ return instance, True
163
+ return instance, False
147
164
 
148
- async def aget(self, provider: Provider) -> Any:
165
+ async def aget_or_create(self, provider: Provider) -> tuple[Any, bool]:
149
166
  """Get an async instance of a dependency from the scoped context."""
150
167
  instance = self._instances.get(provider.interface)
151
168
  if instance is None:
@@ -156,7 +173,8 @@ class ResourceScopedContext(ScopedContext):
156
173
  else:
157
174
  instance = await self._acreate_instance(provider)
158
175
  self._instances[provider.interface] = instance
159
- return instance
176
+ return instance, True
177
+ return instance, False
160
178
 
161
179
  def has(self, interface: AnyInterface) -> bool:
162
180
  """Check if the scoped context has an instance of the dependency."""
@@ -172,7 +190,7 @@ class ResourceScopedContext(ScopedContext):
172
190
 
173
191
  def _create_resource(self, provider: Provider) -> Any:
174
192
  """Create a resource using the provider."""
175
- args, kwargs = self._get_provider_params(provider)
193
+ args, kwargs = self._get_provided_args(provider)
176
194
  cm = contextlib.contextmanager(provider.call)(*args, **kwargs)
177
195
  return self._stack.enter_context(cm)
178
196
 
@@ -186,7 +204,7 @@ class ResourceScopedContext(ScopedContext):
186
204
 
187
205
  async def _acreate_resource(self, provider: Provider) -> Any:
188
206
  """Create a resource asynchronously using the provider."""
189
- args, kwargs = await self._aget_provider_params(provider)
207
+ args, kwargs = await self._aget_provided_args(provider)
190
208
  cm = contextlib.asynccontextmanager(provider.call)(*args, **kwargs)
191
209
  return await self._async_stack.enter_async_context(cm)
192
210
 
@@ -287,10 +305,12 @@ class TransientContext(ScopedContext):
287
305
 
288
306
  scope = "transient"
289
307
 
290
- def get(self, provider: Provider) -> Any:
291
- """Get an instance of a dependency from the transient context."""
292
- return self._create_instance(provider)
308
+ def get_or_create(self, provider: Provider) -> tuple[Any, bool]:
309
+ """Get or create an instance of a dependency from the transient context."""
310
+ return self._create_instance(provider), True
293
311
 
294
- async def aget(self, provider: Provider) -> Any:
295
- """Get an async instance of a dependency from the transient context."""
296
- return await self._acreate_instance(provider)
312
+ async def aget_or_create(self, provider: Provider) -> tuple[Any, bool]:
313
+ """
314
+ Get or create an async instance of a dependency from the transient context.
315
+ """
316
+ return await self._acreate_instance(provider), True
@@ -69,7 +69,7 @@ class Provider:
69
69
 
70
70
  def __eq__(self, other: object) -> bool:
71
71
  if not isinstance(other, Provider):
72
- return NotImplemented
72
+ return NotImplemented # pragma: no cover
73
73
  return (
74
74
  self._call == other._call
75
75
  and self._scope == other._scope
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import inspect
4
+ from dataclasses import dataclass
4
5
  from typing import Annotated, Any, TypeVar, Union
5
6
 
6
7
  from typing_extensions import Literal, Self, TypeAlias
@@ -35,3 +36,8 @@ class Event:
35
36
  def is_event_type(obj: Any) -> bool:
36
37
  """Checks if an object is an event type."""
37
38
  return inspect.isclass(obj) and issubclass(obj, Event)
39
+
40
+
41
+ @dataclass(frozen=True)
42
+ class TestInterface:
43
+ interface: type[Any]
@@ -61,7 +61,7 @@ def patch_annotated_parameter(parameter: inspect.Parameter) -> inspect.Parameter
61
61
 
62
62
 
63
63
  def patch_call_parameter(
64
- call: Callable[..., Any], parameter: inspect.Parameter, container: Container
64
+ container: Container, call: Callable[..., Any], parameter: inspect.Parameter
65
65
  ) -> None:
66
66
  """Patch a parameter to inject dependencies using AnyDI."""
67
67
  parameter = patch_annotated_parameter(parameter)
@@ -78,6 +78,6 @@ def patch_call_parameter(
78
78
  "first call because it is running in non-strict mode."
79
79
  )
80
80
  else:
81
- container._injector._validate_injected_parameter(call, parameter) # noqa
81
+ container._validate_injected_parameter(call, parameter) # noqa
82
82
 
83
83
  parameter.default.interface = parameter.annotation
@@ -41,7 +41,7 @@ def install(app: FastAPI, container: Container) -> None:
41
41
  if not call:
42
42
  continue # pragma: no cover
43
43
  for parameter in get_typed_parameters(call):
44
- patch_call_parameter(call, parameter, container)
44
+ patch_call_parameter(container, call, parameter)
45
45
 
46
46
 
47
47
  def get_container(request: Request) -> Container:
@@ -26,7 +26,7 @@ def install(broker: BrokerUsecase[Any, Any], container: Container) -> None:
26
26
  for handler in _get_broken_handlers(broker):
27
27
  call = handler._original_call # noqa
28
28
  for parameter in get_typed_parameters(call):
29
- patch_call_parameter(call, parameter, container)
29
+ patch_call_parameter(container, call, parameter)
30
30
 
31
31
 
32
32
  def _get_broken_handlers(broker: BrokerUsecase[Any, Any]) -> list[Any]:
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "anydi"
3
- version = "0.32.2"
3
+ version = "0.33.0"
4
4
  description = "Dependency Injection library"
5
5
  authors = ["Anton Ruhlov <antonruhlov@gmail.com>"]
6
6
  license = "MIT"
@@ -45,8 +45,8 @@ docs = ["mkdocs", "mkdocs-material"]
45
45
  async = ["anyio"]
46
46
 
47
47
  [tool.poetry.group.dev.dependencies]
48
- mypy = "^1.13.0"
49
- ruff = "^0.7.1"
48
+ mypy = { version = "^1.14.0", extras = ["faster-cache"] }
49
+ ruff = "^0.8.4"
50
50
  pytest = "^8.3.1"
51
51
  pytest-cov = "^5.0.0"
52
52
  fastapi = "^0.100.0"
@@ -1,94 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import inspect
4
- from collections.abc import Awaitable
5
- from functools import wraps
6
- from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, cast
7
-
8
- from typing_extensions import ParamSpec
9
-
10
- from ._logger import logger
11
- from ._types import is_marker
12
- from ._utils import get_full_qualname, get_typed_parameters
13
-
14
- if TYPE_CHECKING:
15
- from ._container import Container
16
-
17
-
18
- T = TypeVar("T", bound=Any)
19
- P = ParamSpec("P")
20
-
21
-
22
- class Injector:
23
- def __init__(self, container: Container) -> None:
24
- self.container = container
25
-
26
- def inject(
27
- self,
28
- call: Callable[P, T | Awaitable[T]],
29
- ) -> Callable[P, T | Awaitable[T]]:
30
- # Check if the inner callable has already been wrapped
31
- if hasattr(call, "__inject_wrapper__"):
32
- return cast(Callable[P, Union[T, Awaitable[T]]], call.__inject_wrapper__)
33
-
34
- injected_params = self._get_injected_params(call)
35
-
36
- if inspect.iscoroutinefunction(call):
37
-
38
- @wraps(call)
39
- async def awrapper(*args: P.args, **kwargs: P.kwargs) -> T:
40
- for name, annotation in injected_params.items():
41
- kwargs[name] = await self.container.aresolve(annotation)
42
- return cast(T, await call(*args, **kwargs))
43
-
44
- call.__inject_wrapper__ = awrapper # type: ignore[attr-defined]
45
-
46
- return awrapper
47
-
48
- @wraps(call)
49
- def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
50
- for name, annotation in injected_params.items():
51
- kwargs[name] = self.container.resolve(annotation)
52
- return cast(T, call(*args, **kwargs))
53
-
54
- call.__inject_wrapper__ = wrapper # type: ignore[attr-defined]
55
-
56
- return wrapper
57
-
58
- def _get_injected_params(self, call: Callable[..., Any]) -> dict[str, Any]:
59
- """Get the injected parameters of a callable object."""
60
- injected_params = {}
61
- for parameter in get_typed_parameters(call):
62
- if not is_marker(parameter.default):
63
- continue
64
- try:
65
- self._validate_injected_parameter(call, parameter)
66
- except LookupError as exc:
67
- if not self.container.strict:
68
- logger.debug(
69
- f"Cannot validate the `{get_full_qualname(call)}` parameter "
70
- f"`{parameter.name}` with an annotation of "
71
- f"`{get_full_qualname(parameter.annotation)} due to being "
72
- "in non-strict mode. It will be validated at the first call."
73
- )
74
- else:
75
- raise exc
76
- injected_params[parameter.name] = parameter.annotation
77
- return injected_params
78
-
79
- def _validate_injected_parameter(
80
- self, call: Callable[..., Any], parameter: inspect.Parameter
81
- ) -> None:
82
- """Validate an injected parameter."""
83
- if parameter.annotation is inspect.Parameter.empty:
84
- raise TypeError(
85
- f"Missing `{get_full_qualname(call)}` parameter "
86
- f"`{parameter.name}` annotation."
87
- )
88
-
89
- if not self.container.is_registered(parameter.annotation):
90
- raise LookupError(
91
- f"`{get_full_qualname(call)}` has an unknown dependency parameter "
92
- f"`{parameter.name}` with an annotation of "
93
- f"`{get_full_qualname(parameter.annotation)}`."
94
- )
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes