haiway 0.3.2__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. {haiway-0.3.2/src/haiway.egg-info → haiway-0.4.0}/PKG-INFO +1 -1
  2. {haiway-0.3.2 → haiway-0.4.0}/pyproject.toml +1 -1
  3. {haiway-0.3.2 → haiway-0.4.0}/src/haiway/context/access.py +81 -33
  4. {haiway-0.3.2 → haiway-0.4.0}/src/haiway/context/metrics.py +90 -41
  5. {haiway-0.3.2 → haiway-0.4.0}/src/haiway/utils/queue.py +6 -9
  6. {haiway-0.3.2 → haiway-0.4.0/src/haiway.egg-info}/PKG-INFO +1 -1
  7. {haiway-0.3.2 → haiway-0.4.0}/src/haiway.egg-info/SOURCES.txt +1 -0
  8. {haiway-0.3.2 → haiway-0.4.0}/tests/test_context.py +63 -7
  9. haiway-0.4.0/tests/test_streaming.py +160 -0
  10. {haiway-0.3.2 → haiway-0.4.0}/LICENSE +0 -0
  11. {haiway-0.3.2 → haiway-0.4.0}/README.md +0 -0
  12. {haiway-0.3.2 → haiway-0.4.0}/setup.cfg +0 -0
  13. {haiway-0.3.2 → haiway-0.4.0}/src/haiway/__init__.py +0 -0
  14. {haiway-0.3.2 → haiway-0.4.0}/src/haiway/context/__init__.py +0 -0
  15. {haiway-0.3.2 → haiway-0.4.0}/src/haiway/context/disposables.py +0 -0
  16. {haiway-0.3.2 → haiway-0.4.0}/src/haiway/context/state.py +0 -0
  17. {haiway-0.3.2 → haiway-0.4.0}/src/haiway/context/tasks.py +0 -0
  18. {haiway-0.3.2 → haiway-0.4.0}/src/haiway/context/types.py +0 -0
  19. {haiway-0.3.2 → haiway-0.4.0}/src/haiway/helpers/__init__.py +0 -0
  20. {haiway-0.3.2 → haiway-0.4.0}/src/haiway/helpers/asynchrony.py +0 -0
  21. {haiway-0.3.2 → haiway-0.4.0}/src/haiway/helpers/caching.py +0 -0
  22. {haiway-0.3.2 → haiway-0.4.0}/src/haiway/helpers/retries.py +0 -0
  23. {haiway-0.3.2 → haiway-0.4.0}/src/haiway/helpers/throttling.py +0 -0
  24. {haiway-0.3.2 → haiway-0.4.0}/src/haiway/helpers/timeouted.py +0 -0
  25. {haiway-0.3.2 → haiway-0.4.0}/src/haiway/helpers/tracing.py +0 -0
  26. {haiway-0.3.2 → haiway-0.4.0}/src/haiway/py.typed +0 -0
  27. {haiway-0.3.2 → haiway-0.4.0}/src/haiway/state/__init__.py +0 -0
  28. {haiway-0.3.2 → haiway-0.4.0}/src/haiway/state/attributes.py +0 -0
  29. {haiway-0.3.2 → haiway-0.4.0}/src/haiway/state/structure.py +0 -0
  30. {haiway-0.3.2 → haiway-0.4.0}/src/haiway/state/validation.py +0 -0
  31. {haiway-0.3.2 → haiway-0.4.0}/src/haiway/types/__init__.py +0 -0
  32. {haiway-0.3.2 → haiway-0.4.0}/src/haiway/types/frozen.py +0 -0
  33. {haiway-0.3.2 → haiway-0.4.0}/src/haiway/types/missing.py +0 -0
  34. {haiway-0.3.2 → haiway-0.4.0}/src/haiway/utils/__init__.py +0 -0
  35. {haiway-0.3.2 → haiway-0.4.0}/src/haiway/utils/always.py +0 -0
  36. {haiway-0.3.2 → haiway-0.4.0}/src/haiway/utils/env.py +0 -0
  37. {haiway-0.3.2 → haiway-0.4.0}/src/haiway/utils/immutable.py +0 -0
  38. {haiway-0.3.2 → haiway-0.4.0}/src/haiway/utils/logs.py +0 -0
  39. {haiway-0.3.2 → haiway-0.4.0}/src/haiway/utils/mimic.py +0 -0
  40. {haiway-0.3.2 → haiway-0.4.0}/src/haiway/utils/noop.py +0 -0
  41. {haiway-0.3.2 → haiway-0.4.0}/src/haiway.egg-info/dependency_links.txt +0 -0
  42. {haiway-0.3.2 → haiway-0.4.0}/src/haiway.egg-info/requires.txt +0 -0
  43. {haiway-0.3.2 → haiway-0.4.0}/src/haiway.egg-info/top_level.txt +0 -0
  44. {haiway-0.3.2 → haiway-0.4.0}/tests/test_async_queue.py +0 -0
  45. {haiway-0.3.2 → haiway-0.4.0}/tests/test_auto_retry.py +0 -0
  46. {haiway-0.3.2 → haiway-0.4.0}/tests/test_cache.py +0 -0
  47. {haiway-0.3.2 → haiway-0.4.0}/tests/test_state.py +0 -0
  48. {haiway-0.3.2 → haiway-0.4.0}/tests/test_timeout.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: haiway
3
- Version: 0.3.2
3
+ Version: 0.4.0
4
4
  Summary: Framework for dependency injection and state management within structured concurrency model.
5
5
  Maintainer-email: Kacper Kaliński <kacper.kalinski@miquido.com>
6
6
  License: MIT License
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
5
5
  [project]
6
6
  name = "haiway"
7
7
  description = "Framework for dependency injection and state management within structured concurrency model."
8
- version = "0.3.2"
8
+ version = "0.4.0"
9
9
  readme = "README.md"
10
10
  maintainers = [
11
11
  { name = "Kacper Kaliński", email = "kacper.kalinski@miquido.com" },
@@ -1,12 +1,16 @@
1
1
  from asyncio import (
2
+ CancelledError,
2
3
  Task,
3
4
  current_task,
4
5
  )
5
6
  from collections.abc import (
7
+ AsyncGenerator,
8
+ AsyncIterator,
6
9
  Callable,
7
10
  Coroutine,
8
11
  Iterable,
9
12
  )
13
+ from contextvars import Context, copy_context
10
14
  from logging import Logger
11
15
  from types import TracebackType
12
16
  from typing import Any, final
@@ -32,32 +36,28 @@ class ScopeContext:
32
36
  logger: Logger | None,
33
37
  state: tuple[State, ...],
34
38
  disposables: Disposables | None,
35
- task_group: TaskGroupContext,
36
- completion: Callable[[ScopeMetrics], Coroutine[None, None, None]] | None,
39
+ completion: Callable[[ScopeMetrics], Coroutine[None, None, None]]
40
+ | Callable[[ScopeMetrics], None]
41
+ | None,
37
42
  ) -> None:
38
- self._task_group: TaskGroupContext = task_group
39
- self._logger: Logger | None = logger
40
- self._trace_id: str | None = trace_id
41
- self._name: str = name
43
+ self._task_group_context: TaskGroupContext = TaskGroupContext()
44
+ # postponing state creation to include disposables if needed
42
45
  self._state_context: StateContext
43
46
  self._state: tuple[State, ...] = state
44
47
  self._disposables: Disposables | None = disposables
45
- self._metrics_context: MetricsContext
46
- self._completion: Callable[[ScopeMetrics], Coroutine[None, None, None]] | None = completion
48
+ # pre-building metrics context to ensure nested context registering
49
+ self._metrics_context: MetricsContext = MetricsContext.scope(
50
+ name,
51
+ logger=logger,
52
+ trace_id=trace_id,
53
+ completion=completion,
54
+ )
47
55
 
48
56
  freeze(self)
49
57
 
50
58
  def __enter__(self) -> None:
51
- assert self._completion is None, "Can't enter synchronous context with completion" # nosec: B101
52
59
  assert self._disposables is None, "Can't enter synchronous context with disposables" # nosec: B101
53
-
54
60
  self._state_context = StateContext.updated(self._state)
55
- self._metrics_context = MetricsContext.scope(
56
- self._name,
57
- logger=self._logger,
58
- trace_id=self._trace_id,
59
- )
60
-
61
61
  self._state_context.__enter__()
62
62
  self._metrics_context.__enter__()
63
63
 
@@ -80,9 +80,9 @@ class ScopeContext:
80
80
  )
81
81
 
82
82
  async def __aenter__(self) -> None:
83
- await self._task_group.__aenter__()
83
+ await self._task_group_context.__aenter__()
84
84
 
85
- if self._disposables:
85
+ if self._disposables is not None:
86
86
  self._state_context = StateContext.updated(
87
87
  (*self._state, *await self._disposables.__aenter__())
88
88
  )
@@ -90,12 +90,6 @@ class ScopeContext:
90
90
  else:
91
91
  self._state_context = StateContext.updated(self._state)
92
92
 
93
- self._metrics_context = MetricsContext.scope(
94
- self._name,
95
- logger=self._logger,
96
- trace_id=self._trace_id,
97
- )
98
-
99
93
  self._state_context.__enter__()
100
94
  self._metrics_context.__enter__()
101
95
 
@@ -105,14 +99,14 @@ class ScopeContext:
105
99
  exc_val: BaseException | None,
106
100
  exc_tb: TracebackType | None,
107
101
  ) -> None:
108
- if self._disposables:
102
+ if self._disposables is not None:
109
103
  await self._disposables.__aexit__(
110
104
  exc_type=exc_type,
111
105
  exc_val=exc_val,
112
106
  exc_tb=exc_tb,
113
107
  )
114
108
 
115
- await self._task_group.__aexit__(
109
+ await self._task_group_context.__aexit__(
116
110
  exc_type=exc_type,
117
111
  exc_val=exc_val,
118
112
  exc_tb=exc_tb,
@@ -130,9 +124,6 @@ class ScopeContext:
130
124
  exc_tb=exc_tb,
131
125
  )
132
126
 
133
- if completion := self._completion:
134
- await completion(self._metrics_context._metrics) # pyright: ignore[reportPrivateUsage]
135
-
136
127
 
137
128
  @final
138
129
  class ctx:
@@ -144,7 +135,9 @@ class ctx:
144
135
  disposables: Disposables | Iterable[Disposable] | None = None,
145
136
  logger: Logger | None = None,
146
137
  trace_id: str | None = None,
147
- completion: Callable[[ScopeMetrics], Coroutine[None, None, None]] | None = None,
138
+ completion: Callable[[ScopeMetrics], Coroutine[None, None, None]]
139
+ | Callable[[ScopeMetrics], None]
140
+ | None = None,
148
141
  ) -> ScopeContext:
149
142
  """
150
143
  Access scope context with given parameters. When called within an existing context\
@@ -173,7 +166,7 @@ class ctx:
173
166
  provided current identifier will be used if any, otherwise it random id will\
174
167
  be generated
175
168
 
176
- completion: Callable[[ScopeMetrics], Coroutine[None, None, None]] | None = None
169
+ completion: Callable[[ScopeMetrics], Coroutine[None, None, None]] | Callable[[ScopeMetrics], None] | None = None
177
170
  completion callback called on exit from the scope granting access to finished\
178
171
  scope metrics. Completion is called outside of the context when its metrics is\
179
172
  already finished. Make sure to avoid any long operations within the completion.
@@ -182,7 +175,7 @@ class ctx:
182
175
  -------
183
176
  ScopeContext
184
177
  context object intended to enter context manager with it
185
- """
178
+ """ # noqa: E501
186
179
 
187
180
  resolved_disposables: Disposables | None
188
181
  match disposables:
@@ -201,7 +194,6 @@ class ctx:
201
194
  logger=logger,
202
195
  state=state,
203
196
  disposables=resolved_disposables,
204
- task_group=TaskGroupContext(),
205
197
  completion=completion,
206
198
  )
207
199
 
@@ -257,6 +249,62 @@ class ctx:
257
249
 
258
250
  return TaskGroupContext.run(function, *args, **kwargs)
259
251
 
252
+ @staticmethod
253
+ def stream[Result, **Arguments](
254
+ source: Callable[Arguments, AsyncGenerator[Result, None]],
255
+ /,
256
+ *args: Arguments.args,
257
+ **kwargs: Arguments.kwargs,
258
+ ) -> AsyncIterator[Result]:
259
+ """
260
+ Stream results produced by a generator within the proper context state.
261
+
262
+ Parameters
263
+ ----------
264
+ source: Callable[Arguments, AsyncGenerator[Result, None]]
265
+ generator streamed as the result
266
+
267
+ *args: Arguments.args
268
+ positional arguments passed to generator call
269
+
270
+ **kwargs: Arguments.kwargs
271
+ keyword arguments passed to generator call
272
+
273
+ Returns
274
+ -------
275
+ AsyncIterator[Result]
276
+ iterator for accessing generated results
277
+ """
278
+
279
+ # prepare context snapshot
280
+ context_snapshot: Context = copy_context()
281
+
282
+ # prepare nested context
283
+ streaming_context: ScopeContext = ctx.scope(
284
+ getattr(
285
+ source,
286
+ "__name__",
287
+ "streaming",
288
+ )
289
+ )
290
+
291
+ async def generator() -> AsyncGenerator[Result, None]:
292
+ async with streaming_context:
293
+ async for result in source(*args, **kwargs):
294
+ yield result
295
+
296
+ # finally return it as an iterator
297
+ return context_snapshot.run(generator)
298
+
299
+ @staticmethod
300
+ def check_cancellation() -> None:
301
+ """
302
+ Check if current asyncio task is cancelled, raises CancelledError if so.
303
+ """
304
+
305
+ if (task := current_task()) and task.cancelled():
306
+ raise CancelledError()
307
+
260
308
  @staticmethod
261
309
  def cancel() -> None:
262
310
  """
@@ -1,5 +1,12 @@
1
- from asyncio import Future, gather, get_event_loop
2
- from collections.abc import Callable
1
+ from asyncio import (
2
+ AbstractEventLoop,
3
+ Future,
4
+ gather,
5
+ get_event_loop,
6
+ iscoroutinefunction,
7
+ run_coroutine_threadsafe,
8
+ )
9
+ from collections.abc import Callable, Coroutine
3
10
  from contextvars import ContextVar, Token
4
11
  from copy import copy
5
12
  from itertools import chain
@@ -27,6 +34,8 @@ class ScopeMetrics:
27
34
  trace_id: str | None,
28
35
  scope: str,
29
36
  logger: Logger | None,
37
+ parent: Self | None,
38
+ completion: Callable[[Self], Coroutine[None, None, None]] | Callable[[Self], None] | None,
30
39
  ) -> None:
31
40
  self.trace_id: str = trace_id or uuid4().hex
32
41
  self.identifier: str = uuid4().hex
@@ -37,15 +46,38 @@ class ScopeMetrics:
37
46
  else f"[{self.trace_id}] [{self.identifier}]"
38
47
  )
39
48
  self._logger: Logger = logger or getLogger(name=scope)
49
+ self._parent: Self | None = parent if parent else None
40
50
  self._metrics: dict[type[State], State] = {}
41
- self._nested: list[ScopeMetrics] = []
51
+ self._nested: set[ScopeMetrics] = set()
42
52
  self._timestamp: float = monotonic()
43
- self._completed: Future[float] = get_event_loop().create_future()
53
+ self._finished: bool = False
54
+ self._loop: AbstractEventLoop = get_event_loop()
55
+ self._completed: Future[float] = self._loop.create_future()
56
+
57
+ if parent := parent:
58
+ parent._nested.add(self)
44
59
 
45
60
  freeze(self)
46
61
 
62
+ if completion := completion:
63
+ metrics: Self = self
64
+ if iscoroutinefunction(completion):
65
+
66
+ def callback(_: Future[float]) -> None:
67
+ run_coroutine_threadsafe(
68
+ completion(metrics),
69
+ metrics._loop,
70
+ )
71
+
72
+ else:
73
+
74
+ def callback(_: Future[float]) -> None:
75
+ completion(metrics)
76
+
77
+ self._completed.add_done_callback(callback)
78
+
47
79
  def __del__(self) -> None:
48
- self._complete() # ensure completion on deinit
80
+ assert self.is_completed, "Deinitializing not completed scope metrics" # nosec: B101
49
81
 
50
82
  def __str__(self) -> str:
51
83
  return f"{self.label}[{self.identifier}]@[{self.trace_id}]"
@@ -113,8 +145,8 @@ class ScopeMetrics:
113
145
  self._metrics[metric_type] = metric
114
146
 
115
147
  @property
116
- def completed(self) -> bool:
117
- return self._completed.done() and all(nested.completed for nested in self._nested)
148
+ def is_completed(self) -> bool:
149
+ return self._completed.done() and all(nested.is_completed for nested in self._nested)
118
150
 
119
151
  @property
120
152
  def time(self) -> float:
@@ -131,24 +163,36 @@ class ScopeMetrics:
131
163
  return_exceptions=False,
132
164
  )
133
165
 
134
- def _complete(self) -> None:
135
- if self._completed.done():
136
- return # already completed
166
+ def _finish(self) -> None:
167
+ assert ( # nosec: B101
168
+ not self._completed.done()
169
+ ), "Invalid state - called finish on already completed scope"
170
+
171
+ assert ( # nosec: B101
172
+ not self._finished
173
+ ), "Invalid state - called completion on already finished scope"
174
+
175
+ self._finished = True # self is now finished
176
+
177
+ self._complete_if_able()
178
+
179
+ def _complete_if_able(self) -> None:
180
+ assert ( # nosec: B101
181
+ not self._completed.done()
182
+ ), "Invalid state - called complete on already completed scope"
137
183
 
184
+ if not self._finished:
185
+ return # wait for finishing self
186
+
187
+ if any(not nested.is_completed for nested in self._nested):
188
+ return # wait for completing all nested scopes
189
+
190
+ # set completion time
138
191
  self._completed.set_result(monotonic() - self._timestamp)
139
192
 
140
- def scope(
141
- self,
142
- name: str,
143
- /,
144
- ) -> Self:
145
- nested: Self = self.__class__(
146
- scope=name,
147
- logger=self._logger,
148
- trace_id=self.trace_id,
149
- )
150
- self._nested.append(nested)
151
- return nested
193
+ # notify parent about completion
194
+ if parent := self._parent:
195
+ parent._complete_if_able()
152
196
 
153
197
  def log(
154
198
  self,
@@ -178,29 +222,37 @@ class MetricsContext:
178
222
  *,
179
223
  trace_id: str | None = None,
180
224
  logger: Logger | None = None,
225
+ completion: Callable[[ScopeMetrics], Coroutine[None, None, None]]
226
+ | Callable[[ScopeMetrics], None]
227
+ | None,
181
228
  ) -> Self:
182
- try:
183
- context: ScopeMetrics = cls._context.get()
184
- if trace_id is None or context.trace_id == trace_id:
185
- return cls(context.scope(name))
229
+ current: ScopeMetrics
230
+ try: # check for current scope context
231
+ current = cls._context.get()
186
232
 
187
- else:
188
- return cls(
189
- ScopeMetrics(
190
- trace_id=trace_id,
191
- scope=name,
192
- logger=logger or context._logger, # pyright: ignore[reportPrivateUsage]
193
- )
194
- )
195
- except LookupError: # create metrics scope when missing yet
233
+ except LookupError:
234
+ # create metrics scope when missing yet
196
235
  return cls(
197
236
  ScopeMetrics(
198
237
  trace_id=trace_id,
199
238
  scope=name,
200
239
  logger=logger,
240
+ parent=None,
241
+ completion=completion,
201
242
  )
202
243
  )
203
244
 
245
+ # or create nested metrics otherwise
246
+ return cls(
247
+ ScopeMetrics(
248
+ trace_id=trace_id,
249
+ scope=name,
250
+ logger=logger or current._logger, # pyright: ignore[reportPrivateUsage]
251
+ parent=current,
252
+ completion=completion,
253
+ )
254
+ )
255
+
204
256
  @classmethod
205
257
  def record[Metric: State](
206
258
  cls,
@@ -320,15 +372,12 @@ class MetricsContext:
320
372
  ) -> None:
321
373
  self._metrics: ScopeMetrics = metrics
322
374
  self._token: Token[ScopeMetrics] | None = None
323
- self._started: float | None = None
324
- self._finished: float | None = None
325
375
 
326
376
  def __enter__(self) -> None:
327
377
  assert ( # nosec: B101
328
- self._token is None and self._started is None
378
+ self._token is None and not self._metrics._finished # pyright: ignore[reportPrivateUsage]
329
379
  ), "MetricsContext reentrance is not allowed"
330
380
  self._token = MetricsContext._context.set(self._metrics)
331
- self._started = monotonic()
332
381
 
333
382
  def __exit__(
334
383
  self,
@@ -337,8 +386,8 @@ class MetricsContext:
337
386
  exc_tb: TracebackType | None,
338
387
  ) -> None:
339
388
  assert ( # nosec: B101
340
- self._token is not None and self._started is not None and self._finished is None
389
+ self._token is not None
341
390
  ), "Unbalanced MetricsContext context enter/exit"
342
- self._finished = monotonic()
343
391
  MetricsContext._context.reset(self._token)
392
+ self._metrics._finish() # pyright: ignore[reportPrivateUsage]
344
393
  self._token = None
@@ -3,8 +3,6 @@ from collections import deque
3
3
  from collections.abc import AsyncIterator
4
4
  from typing import Self
5
5
 
6
- from haiway.utils.immutable import freeze
7
-
8
6
  __all__ = [
9
7
  "AsyncQueue",
10
8
  ]
@@ -18,20 +16,19 @@ class AsyncQueue[Element](AsyncIterator[Element]):
18
16
 
19
17
  def __init__(
20
18
  self,
19
+ *elements: Element,
21
20
  loop: AbstractEventLoop | None = None,
22
21
  ) -> None:
23
22
  self._loop: AbstractEventLoop = loop or get_running_loop()
24
- self._queue: deque[Element] = deque()
23
+ self._queue: deque[Element] = deque(elements)
25
24
  self._waiting: Future[Element] | None = None
26
25
  self._finish_reason: BaseException | None = None
27
26
 
28
- freeze(self)
29
-
30
27
  def __del__(self) -> None:
31
28
  self.finish()
32
29
 
33
30
  @property
34
- def finished(self) -> bool:
31
+ def is_finished(self) -> bool:
35
32
  return self._finish_reason is not None
36
33
 
37
34
  def enqueue(
@@ -40,7 +37,7 @@ class AsyncQueue[Element](AsyncIterator[Element]):
40
37
  /,
41
38
  *elements: Element,
42
39
  ) -> None:
43
- if self.finished:
40
+ if self.is_finished:
44
41
  raise RuntimeError("AsyncQueue is already finished")
45
42
 
46
43
  if self._waiting is not None and not self._waiting.done():
@@ -55,7 +52,7 @@ class AsyncQueue[Element](AsyncIterator[Element]):
55
52
  self,
56
53
  exception: BaseException | None = None,
57
54
  ) -> None:
58
- if self.finished:
55
+ if self.is_finished:
59
56
  return # already finished, ignore
60
57
 
61
58
  self._finish_reason = exception or StopAsyncIteration()
@@ -70,7 +67,7 @@ class AsyncQueue[Element](AsyncIterator[Element]):
70
67
  return self
71
68
 
72
69
  async def __anext__(self) -> Element:
73
- assert self._waiting is None, "Only a single queue iterator is supported!" # nosec: B101
70
+ assert self._waiting is None, "Only a single queue consumer is supported!" # nosec: B101
74
71
 
75
72
  if self._queue: # check the queue, let it finish
76
73
  return self._queue.popleft()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: haiway
3
- Version: 0.3.2
3
+ Version: 0.4.0
4
4
  Summary: Framework for dependency injection and state management within structured concurrency model.
5
5
  Maintainer-email: Kacper Kaliński <kacper.kalinski@miquido.com>
6
6
  License: MIT License
@@ -42,4 +42,5 @@ tests/test_auto_retry.py
42
42
  tests/test_cache.py
43
43
  tests/test_context.py
44
44
  tests/test_state.py
45
+ tests/test_streaming.py
45
46
  tests/test_timeout.py
@@ -1,3 +1,5 @@
1
+ from asyncio import get_running_loop
2
+
1
3
  from haiway import MissingContext, ScopeMetrics, State, ctx
2
4
  from pytest import mark, raises
3
5
 
@@ -64,38 +66,92 @@ async def test_exceptions_are_propagated():
64
66
 
65
67
  @mark.asyncio
66
68
  async def test_completions_are_called_according_to_context_exits():
69
+ completion_future = get_running_loop().create_future()
70
+ nested_completion_future = get_running_loop().create_future()
71
+ executions: int = 0
72
+
73
+ def completion(metrics: ScopeMetrics):
74
+ nonlocal executions
75
+ executions += 1
76
+ completion_future.set_result(())
77
+
78
+ def nested_completion(metrics: ScopeMetrics):
79
+ nonlocal executions
80
+ executions += 1
81
+ nested_completion_future.set_result(())
82
+
83
+ async with ctx.scope("outer", completion=completion):
84
+ assert executions == 0
85
+
86
+ async with ctx.scope("inner", completion=nested_completion):
87
+ assert executions == 0
88
+
89
+ await nested_completion_future
90
+ assert executions == 1
91
+
92
+ await completion_future
93
+ assert executions == 2
94
+
95
+
96
+ @mark.asyncio
97
+ async def test_async_completions_are_called_according_to_context_exits():
98
+ completion_future = get_running_loop().create_future()
99
+ nested_completion_future = get_running_loop().create_future()
67
100
  executions: int = 0
68
101
 
69
102
  async def completion(metrics: ScopeMetrics):
70
103
  nonlocal executions
71
104
  executions += 1
105
+ completion_future.set_result(())
106
+
107
+ async def nested_completion(metrics: ScopeMetrics):
108
+ nonlocal executions
109
+ executions += 1
110
+ nested_completion_future.set_result(())
72
111
 
73
112
  async with ctx.scope("outer", completion=completion):
74
113
  assert executions == 0
75
114
 
76
- async with ctx.scope("inner", completion=completion):
115
+ async with ctx.scope("inner", completion=nested_completion):
77
116
  assert executions == 0
78
117
 
118
+ await nested_completion_future
79
119
  assert executions == 1
80
120
 
121
+ await completion_future
81
122
  assert executions == 2
82
123
 
83
124
 
84
125
  @mark.asyncio
85
126
  async def test_metrics_are_recorded_within_context():
86
- def verify_example_metrics(state: str):
87
- async def completion(metrics: ScopeMetrics):
88
- assert metrics.read(ExampleState, default=ExampleState()).state == state
127
+ completion_future = get_running_loop().create_future()
128
+ nested_completion_future = get_running_loop().create_future()
129
+ metric: ExampleState = ExampleState()
130
+ nested_metric: ExampleState = ExampleState()
131
+
132
+ async def completion(metrics: ScopeMetrics):
133
+ nonlocal metric
134
+ metric = metrics.read(ExampleState, default=ExampleState())
135
+ completion_future.set_result(())
89
136
 
90
- return completion
137
+ async def nested_completion(metrics: ScopeMetrics):
138
+ nonlocal nested_metric
139
+ nested_metric = metrics.read(ExampleState, default=ExampleState())
140
+ nested_completion_future.set_result(())
91
141
 
92
- async with ctx.scope("outer", completion=verify_example_metrics("outer-in-out")):
142
+ async with ctx.scope("outer", completion=completion):
93
143
  ctx.record(ExampleState(state="outer-in"))
94
144
 
95
- async with ctx.scope("inner", completion=verify_example_metrics("inner")):
145
+ async with ctx.scope("inner", completion=nested_completion):
96
146
  ctx.record(ExampleState(state="inner"))
97
147
 
148
+ await nested_completion_future
149
+ assert nested_metric.state == "inner"
150
+
98
151
  ctx.record(
99
152
  ExampleState(state="-out"),
100
153
  merge=lambda lhs, rhs: ExampleState(state=lhs.state + rhs.state),
101
154
  )
155
+
156
+ await completion_future
157
+ assert metric.state == "outer-in-out"
@@ -0,0 +1,160 @@
1
+ from asyncio import CancelledError, get_running_loop, sleep
2
+ from collections.abc import AsyncGenerator, AsyncIterator
3
+
4
+ from haiway import ctx
5
+ from haiway.context.metrics import ScopeMetrics
6
+ from haiway.state.structure import State
7
+ from pytest import mark, raises
8
+
9
+
10
+ class FakeException(Exception):
11
+ pass
12
+
13
+
14
+ @mark.asyncio
15
+ async def test_fails_when_generator_fails():
16
+ async def generator(value: int) -> AsyncGenerator[int, None]:
17
+ yield value
18
+ raise FakeException()
19
+
20
+ elements: int = 0
21
+ with raises(FakeException):
22
+ async for _ in ctx.stream(generator, 42):
23
+ elements += 1
24
+
25
+ assert elements == 1
26
+
27
+
28
+ @mark.asyncio
29
+ async def test_cancels_when_iteration_cancels():
30
+ async def generator(value: int) -> AsyncGenerator[int, None]:
31
+ await sleep(0)
32
+ yield value
33
+
34
+ elements: int = 0
35
+ with raises(CancelledError):
36
+ ctx.cancel()
37
+ async for _ in ctx.stream(generator, 42):
38
+ elements += 1
39
+
40
+ assert elements == 0
41
+
42
+
43
+ @mark.asyncio
44
+ async def test_ends_when_generator_ends():
45
+ async def generator(value: int) -> AsyncGenerator[int, None]:
46
+ yield value
47
+
48
+ elements: int = 0
49
+ async for _ in ctx.stream(generator, 42):
50
+ elements += 1
51
+
52
+ assert elements == 1
53
+
54
+
55
+ @mark.asyncio
56
+ async def test_delivers_updates_when_generating():
57
+ async def generator(value: int) -> AsyncGenerator[int, None]:
58
+ for i in range(0, value):
59
+ yield i
60
+
61
+ elements: list[int] = []
62
+
63
+ async for element in ctx.stream(generator, 10):
64
+ elements.append(element)
65
+
66
+ assert elements == list(range(0, 10))
67
+
68
+
69
+ @mark.asyncio
70
+ async def test_streaming_context_variables_access_is_preserved():
71
+ class TestState(State):
72
+ value: int = 42
73
+ other: str = "other"
74
+
75
+ async def generator(value: int) -> AsyncGenerator[TestState, None]:
76
+ yield ctx.state(TestState)
77
+ with ctx.scope("nested", ctx.state(TestState).updated(value=value)):
78
+ yield ctx.state(TestState)
79
+
80
+ stream: AsyncIterator[TestState]
81
+ async with ctx.scope("test", TestState(value=42)):
82
+ elements: list[TestState] = []
83
+
84
+ stream = ctx.stream(generator, 10)
85
+
86
+ async for element in stream:
87
+ elements.append(element)
88
+
89
+ assert elements == [
90
+ TestState(value=42),
91
+ TestState(value=10),
92
+ ]
93
+
94
+
95
+ @mark.asyncio
96
+ async def test_nested_streaming_streams_correctly():
97
+ class TestState(State):
98
+ value: int = 42
99
+ other: str = "other"
100
+
101
+ async def inner(value: int) -> AsyncGenerator[TestState, None]:
102
+ yield ctx.state(TestState)
103
+ with ctx.scope("inner", ctx.state(TestState).updated(value=value, other="inner")):
104
+ yield ctx.state(TestState)
105
+
106
+ async def outer(value: int) -> AsyncGenerator[TestState, None]:
107
+ yield ctx.state(TestState)
108
+ with ctx.scope("outer", ctx.state(TestState).updated(other="outer")):
109
+ async for item in ctx.stream(inner, value):
110
+ yield item
111
+
112
+ stream: AsyncIterator[TestState]
113
+ async with ctx.scope("test", TestState(value=42)):
114
+ elements: list[TestState] = []
115
+
116
+ stream = ctx.stream(outer, 10)
117
+
118
+ async for element in stream:
119
+ elements.append(element)
120
+
121
+ assert elements == [
122
+ TestState(value=42),
123
+ TestState(value=42, other="outer"),
124
+ TestState(value=10, other="inner"),
125
+ ]
126
+
127
+
128
+ @mark.asyncio
129
+ async def test_streaming_context_completion_is_called_at_the_end_of_stream():
130
+ completion_future = get_running_loop().create_future()
131
+
132
+ class IterationMetric(State):
133
+ value: int = 0
134
+
135
+ metric: IterationMetric = IterationMetric()
136
+
137
+ def completion(metrics: ScopeMetrics):
138
+ nonlocal metric
139
+ metric = metrics.metrics(merge=lambda current, nested: nested)[0] # pyright: ignore[reportAssignmentType]
140
+ completion_future.set_result(())
141
+
142
+ async def generator(value: int) -> AsyncGenerator[int, None]:
143
+ for i in range(0, value):
144
+ ctx.record(
145
+ IterationMetric(value=i),
146
+ merge=lambda lhs, rhs: IterationMetric(value=lhs.value + rhs.value),
147
+ )
148
+ yield i
149
+
150
+ stream: AsyncIterator[int]
151
+ async with ctx.scope("test", completion=completion):
152
+ elements: list[int] = []
153
+
154
+ stream = ctx.stream(generator, 10)
155
+
156
+ async for element in stream:
157
+ elements.append(element)
158
+
159
+ await completion_future
160
+ assert metric.value == 45
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes