anydi 0.30.0__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
anydi/_context.py CHANGED
@@ -2,19 +2,19 @@ from __future__ import annotations
2
2
 
3
3
  import abc
4
4
  import contextlib
5
+ import inspect
5
6
  from types import TracebackType
6
- from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, cast
7
+ from typing import TYPE_CHECKING, Any, Callable, ClassVar
7
8
 
8
9
  from typing_extensions import Self, final
9
10
 
10
- from ._types import AnyInterface, Interface, Provider, Scope, is_event_type
11
- from ._utils import run_async
11
+ from ._provider import CallableKind, Provider
12
+ from ._types import AnyInterface, Scope, is_event_type
13
+ from ._utils import get_full_qualname, run_async
12
14
 
13
15
  if TYPE_CHECKING:
14
16
  from ._container import Container
15
17
 
16
- T = TypeVar("T")
17
-
18
18
 
19
19
  class ScopedContext(abc.ABC):
20
20
  """ScopedContext base class."""
@@ -23,104 +23,91 @@ class ScopedContext(abc.ABC):
23
23
 
24
24
  def __init__(self, container: Container) -> None:
25
25
  self.container = container
26
+ self._instances: dict[Any, Any] = {}
26
27
 
27
- @abc.abstractmethod
28
- def get(self, interface: Interface[T], provider: Provider) -> T:
29
- """Get an instance of a dependency from the scoped context.
30
-
31
- Args:
32
- interface: The interface of the dependency.
33
- provider: The provider for the instance.
34
-
35
- Returns:
36
- An instance of the dependency.
37
- """
28
+ def set(self, interface: AnyInterface, instance: Any) -> None:
29
+ """Set an instance of a dependency in the scoped context."""
30
+ self._instances[interface] = instance
38
31
 
39
32
  @abc.abstractmethod
40
- async def aget(self, interface: Interface[T], provider: Provider) -> T:
41
- """Get an async instance of a dependency from the scoped context.
42
-
43
- Args:
44
- interface: The interface of the dependency.
45
- provider: The provider for the instance.
33
+ def get(self, provider: Provider) -> Any:
34
+ """Get an instance of a dependency from the scoped context."""
46
35
 
47
- Returns:
48
- An async instance of the dependency.
49
- """
36
+ @abc.abstractmethod
37
+ async def aget(self, provider: Provider) -> Any:
38
+ """Get an async instance of a dependency from the scoped context."""
50
39
 
51
40
  def _create_instance(self, provider: Provider) -> Any:
52
- """Create an instance using the provider.
53
-
54
- Args:
55
- provider: The provider for the instance.
56
-
57
- Returns:
58
- The created instance.
59
-
60
- Raises:
61
- TypeError: If the provider's instance is a coroutine provider
62
- and synchronous mode is used.
63
- """
64
- if provider.is_coroutine:
41
+ """Create an instance using the provider."""
42
+ if provider.kind == CallableKind.COROUTINE:
65
43
  raise TypeError(
66
44
  f"The instance for the coroutine provider `{provider}` cannot be "
67
45
  "created in synchronous mode."
68
46
  )
69
- args, kwargs = self._get_provider_arguments(provider)
70
- return provider.obj(*args, **kwargs)
47
+ args, kwargs = self._get_provider_params(provider)
48
+ return provider.call(*args, **kwargs)
71
49
 
72
50
  async def _acreate_instance(self, provider: Provider) -> Any:
73
- """Create an instance asynchronously using the provider.
74
-
75
- Args:
76
- provider: The provider for the instance.
77
-
78
- Returns:
79
- The created instance.
80
-
81
- Raises:
82
- TypeError: If the provider's instance is a coroutine provider
83
- and asynchronous mode is used.
84
- """
85
- args, kwargs = await self._aget_provider_arguments(provider)
86
- if provider.is_coroutine:
87
- return await provider.obj(*args, **kwargs)
88
- return await run_async(provider.obj, *args, **kwargs)
51
+ """Create an instance asynchronously using the provider."""
52
+ args, kwargs = await self._aget_provider_params(provider)
53
+ if provider.kind == CallableKind.COROUTINE:
54
+ return await provider.call(*args, **kwargs)
55
+ return await run_async(provider.call, *args, **kwargs)
56
+
57
+ def _resolve_parameter(
58
+ self, provider: Provider, parameter: inspect.Parameter
59
+ ) -> Any:
60
+ self._validate_resolvable_parameter(parameter, call=provider.call)
61
+ return self.container.resolve(parameter.annotation)
62
+
63
+ async def _aresolve_parameter(
64
+ self, provider: Provider, parameter: inspect.Parameter
65
+ ) -> Any:
66
+ self._validate_resolvable_parameter(parameter, call=provider.call)
67
+ return await self.container.aresolve(parameter.annotation)
68
+
69
+ def _validate_resolvable_parameter(
70
+ self, parameter: inspect.Parameter, call: Callable[..., Any]
71
+ ) -> None:
72
+ """Ensure that the specified interface is resolved."""
73
+ if parameter.annotation in self.container._unresolved_interfaces: # noqa
74
+ raise LookupError(
75
+ f"You are attempting to get the parameter `{parameter.name}` with the "
76
+ f"annotation `{get_full_qualname(parameter.annotation)}` as a "
77
+ f"dependency into `{get_full_qualname(call)}` which is not registered "
78
+ "or set in the scoped context."
79
+ )
89
80
 
90
- def _get_provider_arguments(
81
+ def _get_provider_params(
91
82
  self, provider: Provider
92
83
  ) -> tuple[list[Any], dict[str, Any]]:
93
- """Retrieve the arguments for a provider.
84
+ """Retrieve the arguments for a provider."""
85
+ args: list[Any] = []
86
+ kwargs: dict[str, Any] = {}
94
87
 
95
- Args:
96
- provider: The provider object.
97
-
98
- Returns:
99
- The arguments for the provider.
100
- """
101
- args, kwargs = [], {}
102
88
  for parameter in provider.parameters:
103
- instance = self.container.resolve(parameter.annotation)
89
+ if parameter.annotation in self._instances:
90
+ instance = self._instances[parameter.annotation]
91
+ else:
92
+ instance = self._resolve_parameter(provider, parameter)
104
93
  if parameter.kind == parameter.POSITIONAL_ONLY:
105
94
  args.append(instance)
106
95
  else:
107
96
  kwargs[parameter.name] = instance
108
97
  return args, kwargs
109
98
 
110
- async def _aget_provider_arguments(
99
+ async def _aget_provider_params(
111
100
  self, provider: Provider
112
101
  ) -> tuple[list[Any], dict[str, Any]]:
113
- """Asynchronously retrieve the arguments for a provider.
114
-
115
- Args:
116
- provider: The provider object.
102
+ """Asynchronously retrieve the arguments for a provider."""
103
+ args: list[Any] = []
104
+ kwargs: dict[str, Any] = {}
117
105
 
118
- Returns:
119
- The arguments for the provider.
120
- """
121
- args, kwargs = [], {}
122
106
  for parameter in provider.parameters:
123
- instance = await self.container.aresolve(parameter.annotation)
107
+ if parameter.annotation in self._instances:
108
+ instance = self._instances[parameter.annotation]
109
+ else:
110
+ instance = await self._aresolve_parameter(provider, parameter)
124
111
  if parameter.kind == parameter.POSITIONAL_ONLY:
125
112
  args.append(instance)
126
113
  else:
@@ -134,25 +121,16 @@ class ResourceScopedContext(ScopedContext):
134
121
  def __init__(self, container: Container) -> None:
135
122
  """Initialize the ScopedContext."""
136
123
  super().__init__(container)
137
- self._instances: dict[type[Any], Any] = {}
138
124
  self._stack = contextlib.ExitStack()
139
125
  self._async_stack = contextlib.AsyncExitStack()
140
126
 
141
- def get(self, interface: Interface[T], provider: Provider) -> T:
142
- """Get an instance of a dependency from the scoped context.
143
-
144
- Args:
145
- interface: The interface of the dependency.
146
- provider: The provider for the instance.
147
-
148
- Returns:
149
- An instance of the dependency.
150
- """
151
- instance = self._instances.get(interface)
127
+ def get(self, provider: Provider) -> Any:
128
+ """Get an instance of a dependency from the scoped context."""
129
+ instance = self._instances.get(provider.interface)
152
130
  if instance is None:
153
- if provider.is_generator:
131
+ if provider.kind == CallableKind.GENERATOR:
154
132
  instance = self._create_resource(provider)
155
- elif provider.is_async_generator:
133
+ elif provider.kind == CallableKind.ASYNC_GENERATOR:
156
134
  raise TypeError(
157
135
  f"The provider `{provider}` cannot be started in synchronous mode "
158
136
  "because it is an asynchronous provider. Please start the provider "
@@ -160,39 +138,24 @@ class ResourceScopedContext(ScopedContext):
160
138
  )
161
139
  else:
162
140
  instance = self._create_instance(provider)
163
- self._instances[interface] = instance
164
- return cast(T, instance)
165
-
166
- async def aget(self, interface: Interface[T], provider: Provider) -> T:
167
- """Get an async instance of a dependency from the scoped context.
168
-
169
- Args:
170
- interface: The interface of the dependency.
171
- provider: The provider for the instance.
141
+ self._instances[provider.interface] = instance
142
+ return instance
172
143
 
173
- Returns:
174
- An async instance of the dependency.
175
- """
176
- instance = self._instances.get(interface)
144
+ async def aget(self, provider: Provider) -> Any:
145
+ """Get an async instance of a dependency from the scoped context."""
146
+ instance = self._instances.get(provider.interface)
177
147
  if instance is None:
178
- if provider.is_generator:
148
+ if provider.kind == CallableKind.GENERATOR:
179
149
  instance = await run_async(self._create_resource, provider)
180
- elif provider.is_async_generator:
150
+ elif provider.kind == CallableKind.ASYNC_GENERATOR:
181
151
  instance = await self._acreate_resource(provider)
182
152
  else:
183
153
  instance = await self._acreate_instance(provider)
184
- self._instances[interface] = instance
185
- return cast(T, instance)
154
+ self._instances[provider.interface] = instance
155
+ return instance
186
156
 
187
157
  def has(self, interface: AnyInterface) -> bool:
188
- """Check if the scoped context has an instance of the dependency.
189
-
190
- Args:
191
- interface: The interface of the dependency.
192
-
193
- Returns:
194
- Whether the scoped context has an instance of the dependency.
195
- """
158
+ """Check if the scoped context has an instance of the dependency."""
196
159
  return interface in self._instances
197
160
 
198
161
  def _create_instance(self, provider: Provider) -> Any:
@@ -204,16 +167,9 @@ class ResourceScopedContext(ScopedContext):
204
167
  return instance
205
168
 
206
169
  def _create_resource(self, provider: Provider) -> Any:
207
- """Create a resource using the provider.
208
-
209
- Args:
210
- provider: The provider for the resource.
211
-
212
- Returns:
213
- The created resource.
214
- """
215
- args, kwargs = self._get_provider_arguments(provider)
216
- cm = contextlib.contextmanager(provider.obj)(*args, **kwargs)
170
+ """Create a resource using the provider."""
171
+ args, kwargs = self._get_provider_params(provider)
172
+ cm = contextlib.contextmanager(provider.call)(*args, **kwargs)
217
173
  return self._stack.enter_context(cm)
218
174
 
219
175
  async def _acreate_instance(self, provider: Provider) -> Any:
@@ -225,32 +181,17 @@ class ResourceScopedContext(ScopedContext):
225
181
  return instance
226
182
 
227
183
  async def _acreate_resource(self, provider: Provider) -> Any:
228
- """Create a resource asynchronously using the provider.
229
-
230
- Args:
231
- provider: The provider for the resource.
232
-
233
- Returns:
234
- The created resource.
235
- """
236
- args, kwargs = await self._aget_provider_arguments(provider)
237
- cm = contextlib.asynccontextmanager(provider.obj)(*args, **kwargs)
184
+ """Create a resource asynchronously using the provider."""
185
+ args, kwargs = await self._aget_provider_params(provider)
186
+ cm = contextlib.asynccontextmanager(provider.call)(*args, **kwargs)
238
187
  return await self._async_stack.enter_async_context(cm)
239
188
 
240
189
  def delete(self, interface: AnyInterface) -> None:
241
- """Delete a dependency instance from the scoped context.
242
-
243
- Args:
244
- interface: The interface of the dependency.
245
- """
190
+ """Delete a dependency instance from the scoped context."""
246
191
  self._instances.pop(interface, None)
247
192
 
248
193
  def __enter__(self) -> Self:
249
- """Enter the context.
250
-
251
- Returns:
252
- The scoped context.
253
- """
194
+ """Enter the context."""
254
195
  self.start()
255
196
  return self
256
197
 
@@ -260,13 +201,7 @@ class ResourceScopedContext(ScopedContext):
260
201
  exc_val: BaseException | None,
261
202
  exc_tb: TracebackType | None,
262
203
  ) -> bool:
263
- """Exit the context.
264
-
265
- Args:
266
- exc_type: The type of the exception, if any.
267
- exc_val: The exception instance, if any.
268
- exc_tb: The traceback, if any.
269
- """
204
+ """Exit the context."""
270
205
  return self._stack.__exit__(exc_type, exc_val, exc_tb) # type: ignore[return-value]
271
206
 
272
207
  @abc.abstractmethod
@@ -280,11 +215,7 @@ class ResourceScopedContext(ScopedContext):
280
215
  self._stack.__exit__(None, None, None)
281
216
 
282
217
  async def __aenter__(self) -> Self:
283
- """Enter the context asynchronously.
284
-
285
- Returns:
286
- The scoped context.
287
- """
218
+ """Enter the context asynchronously."""
288
219
  await self.astart()
289
220
  return self
290
221
 
@@ -294,13 +225,7 @@ class ResourceScopedContext(ScopedContext):
294
225
  exc_val: BaseException | None,
295
226
  exc_tb: TracebackType | None,
296
227
  ) -> bool:
297
- """Exit the context asynchronously.
298
-
299
- Args:
300
- exc_type: The type of the exception, if any.
301
- exc_val: The exception instance, if any.
302
- exc_tb: The traceback, if any.
303
- """
228
+ """Exit the context asynchronously."""
304
229
  return await run_async(
305
230
  self.__exit__, exc_type, exc_val, exc_tb
306
231
  ) or await self._async_stack.__aexit__(exc_type, exc_val, exc_tb)
@@ -358,28 +283,10 @@ class TransientContext(ScopedContext):
358
283
 
359
284
  scope = "transient"
360
285
 
361
- def get(self, interface: Interface[T], provider: Provider) -> T:
362
- """Get an instance of a dependency from the transient context.
363
-
364
- Args:
365
- interface: The interface of the dependency.
366
- provider: The provider for the instance.
367
-
368
- Returns:
369
- An instance of the dependency.
370
- """
371
- instance = self._create_instance(provider)
372
- return cast(T, instance)
373
-
374
- async def aget(self, interface: Interface[T], provider: Provider) -> T:
375
- """Get an async instance of a dependency from the transient context.
376
-
377
- Args:
378
- interface: The interface of the dependency.
379
- provider: The provider for the instance.
286
+ def get(self, provider: Provider) -> Any:
287
+ """Get an instance of a dependency from the transient context."""
288
+ return self._create_instance(provider)
380
289
 
381
- Returns:
382
- An instance of the dependency.
383
- """
384
- instance = await self._acreate_instance(provider)
385
- return cast(T, instance)
290
+ async def aget(self, provider: Provider) -> Any:
291
+ """Get an async instance of a dependency from the transient context."""
292
+ return await self._acreate_instance(provider)
anydi/_injector.py ADDED
@@ -0,0 +1,94 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ from collections.abc import Awaitable
5
+ from functools import wraps
6
+ from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, cast
7
+
8
+ from typing_extensions import ParamSpec
9
+
10
+ from ._logger import logger
11
+ from ._types import is_marker
12
+ from ._utils import get_full_qualname, get_typed_parameters
13
+
14
+ if TYPE_CHECKING:
15
+ from ._container import Container
16
+
17
+
18
+ T = TypeVar("T", bound=Any)
19
+ P = ParamSpec("P")
20
+
21
+
22
+ class Injector:
23
+ def __init__(self, container: Container) -> None:
24
+ self.container = container
25
+
26
+ def inject(
27
+ self,
28
+ call: Callable[P, T | Awaitable[T]],
29
+ ) -> Callable[P, T | Awaitable[T]]:
30
+ # Check if the inner callable has already been wrapped
31
+ if hasattr(call, "__inject_wrapper__"):
32
+ return cast(Callable[P, Union[T, Awaitable[T]]], call.__inject_wrapper__)
33
+
34
+ injected_params = self._get_injected_params(call)
35
+
36
+ if inspect.iscoroutinefunction(call):
37
+
38
+ @wraps(call)
39
+ async def awrapper(*args: P.args, **kwargs: P.kwargs) -> T:
40
+ for name, annotation in injected_params.items():
41
+ kwargs[name] = await self.container.aresolve(annotation)
42
+ return cast(T, await call(*args, **kwargs))
43
+
44
+ call.__inject_wrapper__ = awrapper # type: ignore[attr-defined]
45
+
46
+ return awrapper
47
+
48
+ @wraps(call)
49
+ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
50
+ for name, annotation in injected_params.items():
51
+ kwargs[name] = self.container.resolve(annotation)
52
+ return cast(T, call(*args, **kwargs))
53
+
54
+ call.__inject_wrapper__ = wrapper # type: ignore[attr-defined]
55
+
56
+ return wrapper
57
+
58
+ def _get_injected_params(self, call: Callable[..., Any]) -> dict[str, Any]:
59
+ """Get the injected parameters of a callable object."""
60
+ injected_params = {}
61
+ for parameter in get_typed_parameters(call):
62
+ if not is_marker(parameter.default):
63
+ continue
64
+ try:
65
+ self._validate_injected_parameter(call, parameter)
66
+ except LookupError as exc:
67
+ if not self.container.strict:
68
+ logger.debug(
69
+ f"Cannot validate the `{get_full_qualname(call)}` parameter "
70
+ f"`{parameter.name}` with an annotation of "
71
+ f"`{get_full_qualname(parameter.annotation)} due to being "
72
+ "in non-strict mode. It will be validated at the first call."
73
+ )
74
+ else:
75
+ raise exc
76
+ injected_params[parameter.name] = parameter.annotation
77
+ return injected_params
78
+
79
+ def _validate_injected_parameter(
80
+ self, call: Callable[..., Any], parameter: inspect.Parameter
81
+ ) -> None:
82
+ """Validate an injected parameter."""
83
+ if parameter.annotation is inspect.Parameter.empty:
84
+ raise TypeError(
85
+ f"Missing `{get_full_qualname(call)}` parameter "
86
+ f"`{parameter.name}` annotation."
87
+ )
88
+
89
+ if not self.container.is_registered(parameter.annotation):
90
+ raise LookupError(
91
+ f"`{get_full_qualname(call)}` has an unknown dependency parameter "
92
+ f"`{parameter.name}` with an annotation of "
93
+ f"`{get_full_qualname(parameter.annotation)}`."
94
+ )
anydi/_module.py CHANGED
@@ -19,26 +19,9 @@ P = ParamSpec("P")
19
19
 
20
20
 
21
21
  class ModuleMeta(type):
22
- """A metaclass used for the Module base class.
23
-
24
- This metaclass extracts provider information from the class attributes
25
- and stores it in the `providers` attribute.
26
- """
22
+ """A metaclass used for the Module base class."""
27
23
 
28
24
  def __new__(cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any]) -> Any:
29
- """Create a new instance of the ModuleMeta class.
30
-
31
- This method extracts provider information from the class attributes and
32
- stores it in the `providers` attribute.
33
-
34
- Args:
35
- name: The name of the class.
36
- bases: The base classes of the class.
37
- attrs: The attributes of the class.
38
-
39
- Returns:
40
- The new instance of the class.
41
- """
42
25
  attrs["providers"] = [
43
26
  (name, getattr(value, "__provider__"))
44
27
  for name, value in attrs.items()
@@ -53,14 +36,7 @@ class Module(metaclass=ModuleMeta):
53
36
  providers: list[tuple[str, ProviderDecoratorArgs]]
54
37
 
55
38
  def configure(self, container: Container) -> None:
56
- """Configure the AnyDI container with providers and their dependencies.
57
-
58
- This method can be overridden in derived classes to provide the
59
- configuration logic.
60
-
61
- Args:
62
- container: The AnyDI container to be configured.
63
- """
39
+ """Configure the AnyDI container with providers and their dependencies."""
64
40
 
65
41
 
66
42
  class ModuleRegistry:
@@ -70,11 +46,7 @@ class ModuleRegistry:
70
46
  def register(
71
47
  self, module: Module | type[Module] | Callable[[Container], None] | str
72
48
  ) -> None:
73
- """Register a module as a callable, module type, or module instance.
74
-
75
- Args:
76
- module: The module to register.
77
- """
49
+ """Register a module as a callable, module type, or module instance."""
78
50
 
79
51
  # Callable Module
80
52
  if inspect.isfunction(module):
@@ -107,16 +79,7 @@ class ProviderDecoratorArgs(NamedTuple):
107
79
  def provider(
108
80
  *, scope: Scope, override: bool = False
109
81
  ) -> Callable[[Callable[Concatenate[M, P], T]], Callable[Concatenate[M, P], T]]:
110
- """Decorator for marking a function or method as a provider in a AnyDI module.
111
-
112
- Args:
113
- scope: The scope in which the provided instance should be managed.
114
- override: Whether the provider should override existing providers
115
- with the same interface.
116
-
117
- Returns:
118
- A decorator that marks the target function or method as a provider.
119
- """
82
+ """Decorator for marking a function or method as a provider in a AnyDI module."""
120
83
 
121
84
  def decorator(
122
85
  target: Callable[Concatenate[M, P], T],