haldi 0.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
haldi-0.2/PKG-INFO ADDED
@@ -0,0 +1,3 @@
1
+ Metadata-Version: 2.1
2
+ Name: haldi
3
+ Version: 0.2
@@ -0,0 +1,294 @@
1
+ import abc
2
+ from asyncio import Lock
3
+ from dataclasses import dataclass
4
+ import dataclasses
5
+ from enum import IntFlag, auto
6
+ import functools
7
+ import types
8
+ import typing as t
9
+ from inspect import Signature, Parameter, isclass, iscoroutinefunction
10
+ from contextlib import asynccontextmanager
11
+
12
+
13
+ def is_lambda_function(obj):
14
+ return isinstance(obj, types.LambdaType) and obj.__name__ == "<lambda>"
15
+
16
+
17
+ def is_context_manager(obj):
18
+ return hasattr(obj, "__aenter__") and hasattr(obj, "__aexit__")
19
+
20
+
21
+ class ServiceLifetime(IntFlag):
22
+ TRANSIENT = auto()
23
+ SCOPED = auto()
24
+ SINGLETON = auto()
25
+ ONCE = SCOPED | SINGLETON
26
+
27
+
28
+ T = t.TypeVar("T")
29
+ U = t.TypeVar("U")
30
+
31
+
32
+ class DIException(Exception):
33
+ pass
34
+
35
+
36
+ class CircularDependencyException(DIException):
37
+ pass
38
+
39
+
40
+ class InstantiationException(DIException):
41
+ pass
42
+
43
+
44
+ class DependencyResolutionContext:
45
+ def __init__(self, container: "Container"):
46
+ self._container = container
47
+ self._chain: t.Set[t.Type] = set()
48
+ self._stack = []
49
+
50
+ def push(self, typ: t.Type[T]):
51
+ if typ in self._chain:
52
+ raise CircularDependencyException()
53
+
54
+ self._chain.add(typ)
55
+ self._stack.append(typ)
56
+
57
+ def remove(self, typ: t.Type[T]):
58
+ self._chain.remove(typ)
59
+
60
+ def __call__(self, typ: t.Type):
61
+ self.push(typ)
62
+ return self
63
+
64
+ def __enter__(self):
65
+ return self
66
+
67
+ def __exit__(self, exc_type, exc, tb):
68
+ elem = self._stack.pop()
69
+ self._chain.remove(elem)
70
+
71
+
72
+ @dataclass
73
+ class ServiceProvider:
74
+ base_type: t.Type
75
+ concrete_type_or_factory: t.Type | t.Callable[..., t.Any]
76
+ lifetime: ServiceLifetime
77
+ instance: t.Any
78
+ lock: Lock
79
+
80
+
81
+ class Container:
82
+ PRIMITIVE_TYPES = (
83
+ str,
84
+ int,
85
+ )
86
+
87
+ def __init__(self, container: "Container" = None):
88
+ self._registered_services: t.Dict[t.Any, list[ServiceProvider]] = {}
89
+ self._pending_ctx_managers = []
90
+
91
+ if container:
92
+ for base_type, providers in container.registered_services.items():
93
+ for provider in providers:
94
+ if provider.lifetime is ServiceLifetime.SINGLETON:
95
+ self._add_provider(base_type, provider)
96
+ else:
97
+ self._add_provider(
98
+ base_type, dataclasses.replace(provider, instance=None)
99
+ )
100
+
101
+ def _add_provider(self, key: str | t.Type, provider: ServiceProvider):
102
+ if providers := self._registered_services.get(key, None):
103
+ providers.append(provider)
104
+ else:
105
+ self._registered_services[key] = [provider]
106
+
107
+ def register(
108
+ self,
109
+ base_type: t.Type[T],
110
+ concrete_type_or_factory: t.Type[U] | t.Callable[..., U] | t.Any,
111
+ lifetime: ServiceLifetime,
112
+ ):
113
+
114
+ if lifetime not in ServiceLifetime.ONCE:
115
+ if not callable(concrete_type_or_factory) and not isclass(
116
+ concrete_type_or_factory
117
+ ):
118
+ raise DIException(
119
+ "Transient lifetime requires a callable or class type, not a instance."
120
+ )
121
+
122
+ service_provider = ServiceProvider(
123
+ base_type, concrete_type_or_factory, lifetime, None, Lock()
124
+ )
125
+
126
+ self._add_provider(base_type, service_provider)
127
+ self._add_provider(base_type.__name__, service_provider)
128
+
129
+ @property
130
+ def registered_services(self):
131
+ return self._registered_services
132
+
133
+ def _get_service_definition(
134
+ self, name_or_type: str | t.Type
135
+ ) -> ServiceProvider | None:
136
+ return self._registered_services.get(name_or_type, None)
137
+
138
+ def add_transient(self, typ, concrete_typ=None):
139
+ return self.register(typ, concrete_typ or typ, ServiceLifetime.TRANSIENT)
140
+
141
+ def add_scoped(self, typ, concrete_typ=None):
142
+ return self.register(typ, concrete_typ or typ, ServiceLifetime.SCOPED)
143
+
144
+ def add_singleton(self, typ, concrete_typ=None):
145
+ return self.register(typ, concrete_typ or typ, ServiceLifetime.SINGLETON)
146
+
147
+ def create_scope(self):
148
+ scoped_container = Container(self)
149
+ return scoped_container
150
+
151
+ @asynccontextmanager
152
+ async def scoped(self):
153
+ scope = self.create_scope()
154
+ yield scope
155
+ await scope.close()
156
+
157
+ def extract_depedencies(self, callable_: t.Callable[..., t.Any]):
158
+ signature = Signature.from_callable(callable_)
159
+
160
+ dependencies = {}
161
+ for param_name, param in signature.parameters.items():
162
+ if param.annotation is Parameter.empty:
163
+ continue
164
+ dependencies[param_name] = param.annotation
165
+ return dependencies
166
+
167
+ async def _resolve(
168
+ self,
169
+ desired_type_or_callable: t.Type[T] | t.Callable[..., T],
170
+ context: DependencyResolutionContext,
171
+ strict: bool = True,
172
+ ) -> T | list[T] | None:
173
+
174
+ if desired_type_or_callable is self.__class__:
175
+ return self # type: ignore
176
+
177
+ desired_type = desired_type_or_callable
178
+
179
+ # Discover which type is gonna be instantiated.
180
+ providers = self._get_service_definition(desired_type)
181
+
182
+ if providers is None or len(providers) == 0:
183
+ if strict:
184
+ type_name = (
185
+ desired_type
186
+ if isinstance(desired_type, str)
187
+ else desired_type.__name__
188
+ )
189
+ raise DIException(f"Type {type_name} could not be resolved.")
190
+ return None
191
+
192
+ async def _resolve_provider(provider):
193
+
194
+ concrete_type_or_callable = provider.concrete_type_or_factory
195
+ lifetime = provider.lifetime
196
+
197
+ factory = concrete_type = concrete_type_or_callable
198
+
199
+ if not isclass(concrete_type_or_callable) and callable(
200
+ concrete_type_or_callable
201
+ ):
202
+ if is_lambda_function(concrete_type_or_callable):
203
+ concrete_type = desired_type
204
+ else:
205
+ sig = Signature.from_callable(concrete_type_or_callable)
206
+ if (
207
+ not sig.return_annotation
208
+ or sig.return_annotation is Signature.empty
209
+ ):
210
+ raise DIException(
211
+ "Callable is not a lambda function AND has not return type."
212
+ )
213
+ concrete_type = sig.return_annotation
214
+
215
+ factory = concrete_type_or_callable
216
+
217
+ # Push the concrete type to the context chain, prevent circular dependency loop.
218
+ with context(concrete_type):
219
+ if lifetime in ServiceLifetime.ONCE:
220
+ await provider.lock.acquire()
221
+ if provider.instance is not None:
222
+ provider.lock.release()
223
+ return provider.instance
224
+
225
+ kwargs = {}
226
+
227
+ if isclass(factory):
228
+ dependencies = self.extract_depedencies(factory.__init__)
229
+ elif callable(factory):
230
+ dependencies = self.extract_depedencies(factory)
231
+ else:
232
+ dependencies = {}
233
+
234
+ for key, typ in dependencies.items():
235
+ if typ in self.PRIMITIVE_TYPES:
236
+ continue
237
+ resolved = await self._resolve(typ, context)
238
+ if resolved is None:
239
+ continue
240
+ kwargs[key] = resolved
241
+
242
+ if iscoroutinefunction(factory):
243
+ instance = await factory(**kwargs)
244
+ elif callable(factory):
245
+ instance = factory(**kwargs)
246
+ if is_context_manager(instance):
247
+ self._pending_ctx_managers.append(instance)
248
+ instance = await instance.__aenter__()
249
+ else:
250
+ instance = factory
251
+
252
+ if lifetime in ServiceLifetime.ONCE:
253
+ provider.instance = instance
254
+ provider.lock.release()
255
+
256
+ return instance
257
+
258
+ if len(providers) > 1:
259
+ return [await _resolve_provider(provider) for provider in providers]
260
+
261
+ return await _resolve_provider(providers[0])
262
+
263
+ async def get(self, desired_type: t.Type[T]) -> T:
264
+ return await self.resolve(desired_type, True) # type: ignore
265
+
266
+ async def try_get(self, desired_type: t.Type[T]) -> T | None:
267
+ return await self.resolve(desired_type, False)
268
+
269
+ async def resolve(self, desired_type: t.Type[T], strict: bool = True) -> T | None:
270
+ context = DependencyResolutionContext(self)
271
+ return await self._resolve(desired_type, context, strict)
272
+
273
+ async def get_executor(
274
+ self, callable_: t.Callable[..., T], strict: bool = True
275
+ ) -> t.Callable[..., T]:
276
+ dependencies = self.extract_depedencies(callable_)
277
+ context = DependencyResolutionContext(self)
278
+ kwargs = {}
279
+ for key, typ in dependencies.items():
280
+ if typ in self.PRIMITIVE_TYPES:
281
+ continue
282
+
283
+ resolved = await self._resolve(typ, context, strict)
284
+ if resolved is None:
285
+ continue
286
+
287
+ kwargs[key] = resolved
288
+
289
+ return functools.partial(callable_, **kwargs)
290
+
291
+ async def close(self):
292
+ for ctx_manager in self._pending_ctx_managers:
293
+ await ctx_manager.__aexit__(None, None, None)
294
+ self._pending_ctx_managers.clear()
@@ -0,0 +1,3 @@
1
+ Metadata-Version: 2.1
2
+ Name: haldi
3
+ Version: 0.2
@@ -0,0 +1,8 @@
1
+ setup.py
2
+ haldi/__init__.py
3
+ haldi.egg-info/PKG-INFO
4
+ haldi.egg-info/SOURCES.txt
5
+ haldi.egg-info/dependency_links.txt
6
+ haldi.egg-info/top_level.txt
7
+ tests/__init__.py
8
+ tests/test_test.py
@@ -0,0 +1,2 @@
1
+ haldi
2
+ tests
haldi-0.2/setup.cfg ADDED
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
haldi-0.2/setup.py ADDED
@@ -0,0 +1,10 @@
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name='haldi',
5
+ version='0.2',
6
+ packages=find_packages(),
7
+ install_requires=[
8
+
9
+ ],
10
+ )
File without changes
@@ -0,0 +1,76 @@
1
+ from haldi import Container
2
+ import pytest
3
+
4
+
5
+ class Interface:
6
+ ...
7
+
8
+
9
+ class ImplA(Interface):
10
+ ...
11
+
12
+
13
+ class ImplB(Interface):
14
+ ...
15
+
16
+
17
+ @pytest.mark.asyncio
18
+ async def test_multiple_resolves():
19
+ container = Container()
20
+
21
+ container.add_transient(Interface, ImplA)
22
+ container.add_transient(Interface, ImplB)
23
+
24
+ instances = await container.resolve(Interface)
25
+ assert isinstance(instances, list)
26
+
27
+ impla, implb = instances
28
+ assert isinstance(impla, ImplA)
29
+ assert isinstance(implb, ImplB)
30
+
31
+
32
+ @pytest.mark.asyncio
33
+ async def test_single_resolve():
34
+ container = Container()
35
+
36
+ container.add_transient(Interface, ImplA)
37
+
38
+ instance = await container.resolve(Interface)
39
+ assert isinstance(instance, ImplA)
40
+
41
+ @pytest.mark.asyncio
42
+ async def test_singleton_simple():
43
+ container = Container()
44
+
45
+ container.add_singleton(Interface, ImplA)
46
+
47
+ assert await container.resolve(Interface) is await container.resolve(Interface)
48
+
49
+
50
+ @pytest.mark.asyncio
51
+ async def test_transient_is_not_same_instance_simple():
52
+ container = Container()
53
+
54
+ container.add_transient(Interface, ImplA)
55
+
56
+ assert await container.resolve(Interface) is not await container.resolve(Interface)
57
+
58
+ @pytest.mark.asyncio
59
+ async def test_scoped_context():
60
+ container = Container()
61
+
62
+ container.add_scoped(Interface, ImplA)
63
+
64
+ instance_a = await container.resolve(Interface)
65
+
66
+ scope_b = container.create_scope()
67
+ instance_b = await scope_b.resolve(Interface)
68
+ assert instance_b is await scope_b.resolve(Interface)
69
+
70
+ scope_c = container.create_scope()
71
+ instance_c = await scope_c.resolve(Interface)
72
+ assert instance_c is await scope_c.resolve(Interface)
73
+
74
+
75
+ assert instance_a is not instance_b
76
+ assert instance_b is not instance_c