haiway 0.3.1__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. {haiway-0.3.1/src/haiway.egg-info → haiway-0.4.0}/PKG-INFO +1 -1
  2. {haiway-0.3.1 → haiway-0.4.0}/pyproject.toml +1 -1
  3. {haiway-0.3.1 → haiway-0.4.0}/src/haiway/__init__.py +13 -1
  4. {haiway-0.3.1 → haiway-0.4.0}/src/haiway/context/access.py +81 -33
  5. {haiway-0.3.1 → haiway-0.4.0}/src/haiway/context/metrics.py +113 -49
  6. {haiway-0.3.1 → haiway-0.4.0}/src/haiway/helpers/__init__.py +4 -0
  7. haiway-0.4.0/src/haiway/helpers/tracing.py +136 -0
  8. {haiway-0.3.1 → haiway-0.4.0}/src/haiway/state/validation.py +25 -1
  9. {haiway-0.3.1 → haiway-0.4.0}/src/haiway/types/missing.py +4 -1
  10. {haiway-0.3.1 → haiway-0.4.0}/src/haiway/utils/queue.py +6 -9
  11. {haiway-0.3.1 → haiway-0.4.0/src/haiway.egg-info}/PKG-INFO +1 -1
  12. {haiway-0.3.1 → haiway-0.4.0}/src/haiway.egg-info/SOURCES.txt +2 -0
  13. {haiway-0.3.1 → haiway-0.4.0}/tests/test_context.py +63 -7
  14. {haiway-0.3.1 → haiway-0.4.0}/tests/test_state.py +7 -1
  15. haiway-0.4.0/tests/test_streaming.py +160 -0
  16. {haiway-0.3.1 → haiway-0.4.0}/LICENSE +0 -0
  17. {haiway-0.3.1 → haiway-0.4.0}/README.md +0 -0
  18. {haiway-0.3.1 → haiway-0.4.0}/setup.cfg +0 -0
  19. {haiway-0.3.1 → haiway-0.4.0}/src/haiway/context/__init__.py +0 -0
  20. {haiway-0.3.1 → haiway-0.4.0}/src/haiway/context/disposables.py +0 -0
  21. {haiway-0.3.1 → haiway-0.4.0}/src/haiway/context/state.py +0 -0
  22. {haiway-0.3.1 → haiway-0.4.0}/src/haiway/context/tasks.py +0 -0
  23. {haiway-0.3.1 → haiway-0.4.0}/src/haiway/context/types.py +0 -0
  24. {haiway-0.3.1 → haiway-0.4.0}/src/haiway/helpers/asynchrony.py +0 -0
  25. {haiway-0.3.1 → haiway-0.4.0}/src/haiway/helpers/caching.py +0 -0
  26. {haiway-0.3.1 → haiway-0.4.0}/src/haiway/helpers/retries.py +0 -0
  27. {haiway-0.3.1 → haiway-0.4.0}/src/haiway/helpers/throttling.py +0 -0
  28. {haiway-0.3.1 → haiway-0.4.0}/src/haiway/helpers/timeouted.py +0 -0
  29. {haiway-0.3.1 → haiway-0.4.0}/src/haiway/py.typed +0 -0
  30. {haiway-0.3.1 → haiway-0.4.0}/src/haiway/state/__init__.py +0 -0
  31. {haiway-0.3.1 → haiway-0.4.0}/src/haiway/state/attributes.py +0 -0
  32. {haiway-0.3.1 → haiway-0.4.0}/src/haiway/state/structure.py +0 -0
  33. {haiway-0.3.1 → haiway-0.4.0}/src/haiway/types/__init__.py +0 -0
  34. {haiway-0.3.1 → haiway-0.4.0}/src/haiway/types/frozen.py +0 -0
  35. {haiway-0.3.1 → haiway-0.4.0}/src/haiway/utils/__init__.py +0 -0
  36. {haiway-0.3.1 → haiway-0.4.0}/src/haiway/utils/always.py +0 -0
  37. {haiway-0.3.1 → haiway-0.4.0}/src/haiway/utils/env.py +0 -0
  38. {haiway-0.3.1 → haiway-0.4.0}/src/haiway/utils/immutable.py +0 -0
  39. {haiway-0.3.1 → haiway-0.4.0}/src/haiway/utils/logs.py +0 -0
  40. {haiway-0.3.1 → haiway-0.4.0}/src/haiway/utils/mimic.py +0 -0
  41. {haiway-0.3.1 → haiway-0.4.0}/src/haiway/utils/noop.py +0 -0
  42. {haiway-0.3.1 → haiway-0.4.0}/src/haiway.egg-info/dependency_links.txt +0 -0
  43. {haiway-0.3.1 → haiway-0.4.0}/src/haiway.egg-info/requires.txt +0 -0
  44. {haiway-0.3.1 → haiway-0.4.0}/src/haiway.egg-info/top_level.txt +0 -0
  45. {haiway-0.3.1 → haiway-0.4.0}/tests/test_async_queue.py +0 -0
  46. {haiway-0.3.1 → haiway-0.4.0}/tests/test_auto_retry.py +0 -0
  47. {haiway-0.3.1 → haiway-0.4.0}/tests/test_cache.py +0 -0
  48. {haiway-0.3.1 → haiway-0.4.0}/tests/test_timeout.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: haiway
3
- Version: 0.3.1
3
+ Version: 0.4.0
4
4
  Summary: Framework for dependency injection and state management within structured concurrency model.
5
5
  Maintainer-email: Kacper Kaliński <kacper.kalinski@miquido.com>
6
6
  License: MIT License
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
5
5
  [project]
6
6
  name = "haiway"
7
7
  description = "Framework for dependency injection and state management within structured concurrency model."
8
- version = "0.3.1"
8
+ version = "0.4.0"
9
9
  readme = "README.md"
10
10
  maintainers = [
11
11
  { name = "Kacper Kaliński", email = "kacper.kalinski@miquido.com" },
@@ -6,7 +6,16 @@ from haiway.context import (
6
6
  ScopeMetrics,
7
7
  ctx,
8
8
  )
9
- from haiway.helpers import asynchronous, cache, retry, throttle, timeout
9
+ from haiway.helpers import (
10
+ ArgumentsTrace,
11
+ ResultTrace,
12
+ asynchronous,
13
+ cache,
14
+ retry,
15
+ throttle,
16
+ timeout,
17
+ traced,
18
+ )
10
19
  from haiway.state import State
11
20
  from haiway.types import (
12
21
  MISSING,
@@ -34,6 +43,7 @@ from haiway.utils import (
34
43
 
35
44
  __all__ = [
36
45
  "always",
46
+ "ArgumentsTrace",
37
47
  "async_always",
38
48
  "async_noop",
39
49
  "asynchronous",
@@ -57,11 +67,13 @@ __all__ = [
57
67
  "MissingState",
58
68
  "noop",
59
69
  "not_missing",
70
+ "ResultTrace",
60
71
  "retry",
61
72
  "ScopeMetrics",
62
73
  "setup_logging",
63
74
  "State",
64
75
  "throttle",
65
76
  "timeout",
77
+ "traced",
66
78
  "when_missing",
67
79
  ]
@@ -1,12 +1,16 @@
1
1
  from asyncio import (
2
+ CancelledError,
2
3
  Task,
3
4
  current_task,
4
5
  )
5
6
  from collections.abc import (
7
+ AsyncGenerator,
8
+ AsyncIterator,
6
9
  Callable,
7
10
  Coroutine,
8
11
  Iterable,
9
12
  )
13
+ from contextvars import Context, copy_context
10
14
  from logging import Logger
11
15
  from types import TracebackType
12
16
  from typing import Any, final
@@ -32,32 +36,28 @@ class ScopeContext:
32
36
  logger: Logger | None,
33
37
  state: tuple[State, ...],
34
38
  disposables: Disposables | None,
35
- task_group: TaskGroupContext,
36
- completion: Callable[[ScopeMetrics], Coroutine[None, None, None]] | None,
39
+ completion: Callable[[ScopeMetrics], Coroutine[None, None, None]]
40
+ | Callable[[ScopeMetrics], None]
41
+ | None,
37
42
  ) -> None:
38
- self._task_group: TaskGroupContext = task_group
39
- self._logger: Logger | None = logger
40
- self._trace_id: str | None = trace_id
41
- self._name: str = name
43
+ self._task_group_context: TaskGroupContext = TaskGroupContext()
44
+ # postponing state creation to include disposables if needed
42
45
  self._state_context: StateContext
43
46
  self._state: tuple[State, ...] = state
44
47
  self._disposables: Disposables | None = disposables
45
- self._metrics_context: MetricsContext
46
- self._completion: Callable[[ScopeMetrics], Coroutine[None, None, None]] | None = completion
48
+ # pre-building metrics context to ensure nested context registering
49
+ self._metrics_context: MetricsContext = MetricsContext.scope(
50
+ name,
51
+ logger=logger,
52
+ trace_id=trace_id,
53
+ completion=completion,
54
+ )
47
55
 
48
56
  freeze(self)
49
57
 
50
58
  def __enter__(self) -> None:
51
- assert self._completion is None, "Can't enter synchronous context with completion" # nosec: B101
52
59
  assert self._disposables is None, "Can't enter synchronous context with disposables" # nosec: B101
53
-
54
60
  self._state_context = StateContext.updated(self._state)
55
- self._metrics_context = MetricsContext.scope(
56
- self._name,
57
- logger=self._logger,
58
- trace_id=self._trace_id,
59
- )
60
-
61
61
  self._state_context.__enter__()
62
62
  self._metrics_context.__enter__()
63
63
 
@@ -80,9 +80,9 @@ class ScopeContext:
80
80
  )
81
81
 
82
82
  async def __aenter__(self) -> None:
83
- await self._task_group.__aenter__()
83
+ await self._task_group_context.__aenter__()
84
84
 
85
- if self._disposables:
85
+ if self._disposables is not None:
86
86
  self._state_context = StateContext.updated(
87
87
  (*self._state, *await self._disposables.__aenter__())
88
88
  )
@@ -90,12 +90,6 @@ class ScopeContext:
90
90
  else:
91
91
  self._state_context = StateContext.updated(self._state)
92
92
 
93
- self._metrics_context = MetricsContext.scope(
94
- self._name,
95
- logger=self._logger,
96
- trace_id=self._trace_id,
97
- )
98
-
99
93
  self._state_context.__enter__()
100
94
  self._metrics_context.__enter__()
101
95
 
@@ -105,14 +99,14 @@ class ScopeContext:
105
99
  exc_val: BaseException | None,
106
100
  exc_tb: TracebackType | None,
107
101
  ) -> None:
108
- if self._disposables:
102
+ if self._disposables is not None:
109
103
  await self._disposables.__aexit__(
110
104
  exc_type=exc_type,
111
105
  exc_val=exc_val,
112
106
  exc_tb=exc_tb,
113
107
  )
114
108
 
115
- await self._task_group.__aexit__(
109
+ await self._task_group_context.__aexit__(
116
110
  exc_type=exc_type,
117
111
  exc_val=exc_val,
118
112
  exc_tb=exc_tb,
@@ -130,9 +124,6 @@ class ScopeContext:
130
124
  exc_tb=exc_tb,
131
125
  )
132
126
 
133
- if completion := self._completion:
134
- await completion(self._metrics_context._metrics) # pyright: ignore[reportPrivateUsage]
135
-
136
127
 
137
128
  @final
138
129
  class ctx:
@@ -144,7 +135,9 @@ class ctx:
144
135
  disposables: Disposables | Iterable[Disposable] | None = None,
145
136
  logger: Logger | None = None,
146
137
  trace_id: str | None = None,
147
- completion: Callable[[ScopeMetrics], Coroutine[None, None, None]] | None = None,
138
+ completion: Callable[[ScopeMetrics], Coroutine[None, None, None]]
139
+ | Callable[[ScopeMetrics], None]
140
+ | None = None,
148
141
  ) -> ScopeContext:
149
142
  """
150
143
  Access scope context with given parameters. When called within an existing context\
@@ -173,7 +166,7 @@ class ctx:
173
166
  provided current identifier will be used if any, otherwise it random id will\
174
167
  be generated
175
168
 
176
- completion: Callable[[ScopeMetrics], Coroutine[None, None, None]] | None = None
169
+ completion: Callable[[ScopeMetrics], Coroutine[None, None, None]] | Callable[[ScopeMetrics], None] | None = None
177
170
  completion callback called on exit from the scope granting access to finished\
178
171
  scope metrics. Completion is called outside of the context when its metrics is\
179
172
  already finished. Make sure to avoid any long operations within the completion.
@@ -182,7 +175,7 @@ class ctx:
182
175
  -------
183
176
  ScopeContext
184
177
  context object intended to enter context manager with it
185
- """
178
+ """ # noqa: E501
186
179
 
187
180
  resolved_disposables: Disposables | None
188
181
  match disposables:
@@ -201,7 +194,6 @@ class ctx:
201
194
  logger=logger,
202
195
  state=state,
203
196
  disposables=resolved_disposables,
204
- task_group=TaskGroupContext(),
205
197
  completion=completion,
206
198
  )
207
199
 
@@ -257,6 +249,62 @@ class ctx:
257
249
 
258
250
  return TaskGroupContext.run(function, *args, **kwargs)
259
251
 
252
+ @staticmethod
253
+ def stream[Result, **Arguments](
254
+ source: Callable[Arguments, AsyncGenerator[Result, None]],
255
+ /,
256
+ *args: Arguments.args,
257
+ **kwargs: Arguments.kwargs,
258
+ ) -> AsyncIterator[Result]:
259
+ """
260
+ Stream results produced by a generator within the proper context state.
261
+
262
+ Parameters
263
+ ----------
264
+ source: Callable[Arguments, AsyncGenerator[Result, None]]
265
+ generator streamed as the result
266
+
267
+ *args: Arguments.args
268
+ positional arguments passed to generator call
269
+
270
+ **kwargs: Arguments.kwargs
271
+ keyword arguments passed to generator call
272
+
273
+ Returns
274
+ -------
275
+ AsyncIterator[Result]
276
+ iterator for accessing generated results
277
+ """
278
+
279
+ # prepare context snapshot
280
+ context_snapshot: Context = copy_context()
281
+
282
+ # prepare nested context
283
+ streaming_context: ScopeContext = ctx.scope(
284
+ getattr(
285
+ source,
286
+ "__name__",
287
+ "streaming",
288
+ )
289
+ )
290
+
291
+ async def generator() -> AsyncGenerator[Result, None]:
292
+ async with streaming_context:
293
+ async for result in source(*args, **kwargs):
294
+ yield result
295
+
296
+ # finally return it as an iterator
297
+ return context_snapshot.run(generator)
298
+
299
+ @staticmethod
300
+ def check_cancellation() -> None:
301
+ """
302
+ Check if current asyncio task is cancelled, raises CancelledError if so.
303
+ """
304
+
305
+ if (task := current_task()) and task.cancelled():
306
+ raise CancelledError()
307
+
260
308
  @staticmethod
261
309
  def cancel() -> None:
262
310
  """
@@ -1,5 +1,12 @@
1
- from asyncio import Future, gather, get_event_loop
2
- from collections.abc import Callable
1
+ from asyncio import (
2
+ AbstractEventLoop,
3
+ Future,
4
+ gather,
5
+ get_event_loop,
6
+ iscoroutinefunction,
7
+ run_coroutine_threadsafe,
8
+ )
9
+ from collections.abc import Callable, Coroutine
3
10
  from contextvars import ContextVar, Token
4
11
  from copy import copy
5
12
  from itertools import chain
@@ -10,6 +17,7 @@ from typing import Any, Self, cast, final, overload
10
17
  from uuid import uuid4
11
18
 
12
19
  from haiway.state import State
20
+ from haiway.types import MISSING, Missing, not_missing
13
21
  from haiway.utils import freeze
14
22
 
15
23
  __all__ = [
@@ -26,36 +34,75 @@ class ScopeMetrics:
26
34
  trace_id: str | None,
27
35
  scope: str,
28
36
  logger: Logger | None,
37
+ parent: Self | None,
38
+ completion: Callable[[Self], Coroutine[None, None, None]] | Callable[[Self], None] | None,
29
39
  ) -> None:
30
40
  self.trace_id: str = trace_id or uuid4().hex
31
- self._label: str = f"{self.trace_id}|{scope}" if scope else self.trace_id
41
+ self.identifier: str = uuid4().hex
42
+ self.label: str = scope
43
+ self._logger_prefix: str = (
44
+ f"[{self.trace_id}] [{scope}] [{self.identifier}]"
45
+ if scope
46
+ else f"[{self.trace_id}] [{self.identifier}]"
47
+ )
32
48
  self._logger: Logger = logger or getLogger(name=scope)
49
+ self._parent: Self | None = parent if parent else None
33
50
  self._metrics: dict[type[State], State] = {}
34
- self._nested: list[ScopeMetrics] = []
51
+ self._nested: set[ScopeMetrics] = set()
35
52
  self._timestamp: float = monotonic()
36
- self._completed: Future[float] = get_event_loop().create_future()
53
+ self._finished: bool = False
54
+ self._loop: AbstractEventLoop = get_event_loop()
55
+ self._completed: Future[float] = self._loop.create_future()
56
+
57
+ if parent := parent:
58
+ parent._nested.add(self)
37
59
 
38
60
  freeze(self)
39
61
 
62
+ if completion := completion:
63
+ metrics: Self = self
64
+ if iscoroutinefunction(completion):
65
+
66
+ def callback(_: Future[float]) -> None:
67
+ run_coroutine_threadsafe(
68
+ completion(metrics),
69
+ metrics._loop,
70
+ )
71
+
72
+ else:
73
+
74
+ def callback(_: Future[float]) -> None:
75
+ completion(metrics)
76
+
77
+ self._completed.add_done_callback(callback)
78
+
40
79
  def __del__(self) -> None:
41
- self._complete() # ensure completion on deinit
80
+ assert self.is_completed, "Deinitializing not completed scope metrics" # nosec: B101
42
81
 
43
82
  def __str__(self) -> str:
44
- return self._label
83
+ return f"{self.label}[{self.identifier}]@[{self.trace_id}]"
45
84
 
46
85
  def metrics(
47
86
  self,
48
87
  *,
49
- merge: Callable[[State, State], State] = lambda lhs, rhs: lhs,
88
+ merge: Callable[[State | Missing, State], State | Missing] | None = None,
50
89
  ) -> list[State]:
90
+ if not merge:
91
+ return list(self._metrics.values())
92
+
51
93
  metrics: dict[type[State], State] = copy(self._metrics)
52
94
  for metric in chain.from_iterable(nested.metrics(merge=merge) for nested in self._nested):
53
95
  metric_type: type[State] = type(metric)
54
- if current := metrics.get(metric_type):
55
- metrics[metric_type] = merge(current, metric)
96
+ merged: State | Missing = merge(
97
+ metrics.get( # current
98
+ metric_type,
99
+ MISSING,
100
+ ),
101
+ metric, # received
102
+ )
56
103
 
57
- else:
58
- metrics[metric_type] = metric
104
+ if not_missing(merged):
105
+ metrics[metric_type] = merged
59
106
 
60
107
  return list(metrics.values())
61
108
 
@@ -98,8 +145,8 @@ class ScopeMetrics:
98
145
  self._metrics[metric_type] = metric
99
146
 
100
147
  @property
101
- def completed(self) -> bool:
102
- return self._completed.done() and all(nested.completed for nested in self._nested)
148
+ def is_completed(self) -> bool:
149
+ return self._completed.done() and all(nested.is_completed for nested in self._nested)
103
150
 
104
151
  @property
105
152
  def time(self) -> float:
@@ -116,24 +163,36 @@ class ScopeMetrics:
116
163
  return_exceptions=False,
117
164
  )
118
165
 
119
- def _complete(self) -> None:
120
- if self._completed.done():
121
- return # already completed
166
+ def _finish(self) -> None:
167
+ assert ( # nosec: B101
168
+ not self._completed.done()
169
+ ), "Invalid state - called finish on already completed scope"
170
+
171
+ assert ( # nosec: B101
172
+ not self._finished
173
+ ), "Invalid state - called completion on already finished scope"
174
+
175
+ self._finished = True # self is now finished
176
+
177
+ self._complete_if_able()
122
178
 
179
+ def _complete_if_able(self) -> None:
180
+ assert ( # nosec: B101
181
+ not self._completed.done()
182
+ ), "Invalid state - called complete on already completed scope"
183
+
184
+ if not self._finished:
185
+ return # wait for finishing self
186
+
187
+ if any(not nested.is_completed for nested in self._nested):
188
+ return # wait for completing all nested scopes
189
+
190
+ # set completion time
123
191
  self._completed.set_result(monotonic() - self._timestamp)
124
192
 
125
- def scope(
126
- self,
127
- name: str,
128
- /,
129
- ) -> Self:
130
- nested: Self = self.__class__(
131
- scope=name,
132
- logger=self._logger,
133
- trace_id=self.trace_id,
134
- )
135
- self._nested.append(nested)
136
- return nested
193
+ # notify parent about completion
194
+ if parent := self._parent:
195
+ parent._complete_if_able()
137
196
 
138
197
  def log(
139
198
  self,
@@ -145,7 +204,7 @@ class ScopeMetrics:
145
204
  ) -> None:
146
205
  self._logger.log(
147
206
  level,
148
- f"[{self}] {message}",
207
+ f"{self._logger_prefix} {message}",
149
208
  *args,
150
209
  exc_info=exception,
151
210
  )
@@ -163,29 +222,37 @@ class MetricsContext:
163
222
  *,
164
223
  trace_id: str | None = None,
165
224
  logger: Logger | None = None,
225
+ completion: Callable[[ScopeMetrics], Coroutine[None, None, None]]
226
+ | Callable[[ScopeMetrics], None]
227
+ | None,
166
228
  ) -> Self:
167
- try:
168
- context: ScopeMetrics = cls._context.get()
169
- if trace_id is None or context.trace_id == trace_id:
170
- return cls(context.scope(name))
229
+ current: ScopeMetrics
230
+ try: # check for current scope context
231
+ current = cls._context.get()
171
232
 
172
- else:
173
- return cls(
174
- ScopeMetrics(
175
- trace_id=trace_id,
176
- scope=name,
177
- logger=logger or context._logger, # pyright: ignore[reportPrivateUsage]
178
- )
179
- )
180
- except LookupError: # create metrics scope when missing yet
233
+ except LookupError:
234
+ # create metrics scope when missing yet
181
235
  return cls(
182
236
  ScopeMetrics(
183
237
  trace_id=trace_id,
184
238
  scope=name,
185
239
  logger=logger,
240
+ parent=None,
241
+ completion=completion,
186
242
  )
187
243
  )
188
244
 
245
+ # or create nested metrics otherwise
246
+ return cls(
247
+ ScopeMetrics(
248
+ trace_id=trace_id,
249
+ scope=name,
250
+ logger=logger or current._logger, # pyright: ignore[reportPrivateUsage]
251
+ parent=current,
252
+ completion=completion,
253
+ )
254
+ )
255
+
189
256
  @classmethod
190
257
  def record[Metric: State](
191
258
  cls,
@@ -305,15 +372,12 @@ class MetricsContext:
305
372
  ) -> None:
306
373
  self._metrics: ScopeMetrics = metrics
307
374
  self._token: Token[ScopeMetrics] | None = None
308
- self._started: float | None = None
309
- self._finished: float | None = None
310
375
 
311
376
  def __enter__(self) -> None:
312
377
  assert ( # nosec: B101
313
- self._token is None and self._started is None
378
+ self._token is None and not self._metrics._finished # pyright: ignore[reportPrivateUsage]
314
379
  ), "MetricsContext reentrance is not allowed"
315
380
  self._token = MetricsContext._context.set(self._metrics)
316
- self._started = monotonic()
317
381
 
318
382
  def __exit__(
319
383
  self,
@@ -322,8 +386,8 @@ class MetricsContext:
322
386
  exc_tb: TracebackType | None,
323
387
  ) -> None:
324
388
  assert ( # nosec: B101
325
- self._token is not None and self._started is not None and self._finished is None
389
+ self._token is not None
326
390
  ), "Unbalanced MetricsContext context enter/exit"
327
- self._finished = monotonic()
328
391
  MetricsContext._context.reset(self._token)
392
+ self._metrics._finish() # pyright: ignore[reportPrivateUsage]
329
393
  self._token = None
@@ -3,11 +3,15 @@ from haiway.helpers.caching import cache
3
3
  from haiway.helpers.retries import retry
4
4
  from haiway.helpers.throttling import throttle
5
5
  from haiway.helpers.timeouted import timeout
6
+ from haiway.helpers.tracing import ArgumentsTrace, ResultTrace, traced
6
7
 
7
8
  __all__ = [
9
+ "ArgumentsTrace",
8
10
  "asynchronous",
9
11
  "cache",
12
+ "ResultTrace",
10
13
  "retry",
11
14
  "throttle",
12
15
  "timeout",
16
+ "traced",
13
17
  ]
@@ -0,0 +1,136 @@
1
+ from asyncio import iscoroutinefunction
2
+ from collections.abc import Callable, Coroutine
3
+ from typing import Any, Self, cast
4
+
5
+ from haiway.context import ctx
6
+ from haiway.state import State
7
+ from haiway.types import MISSING, Missing
8
+ from haiway.utils import mimic_function
9
+
10
+ __all__ = [
11
+ "traced",
12
+ "ArgumentsTrace",
13
+ "ResultTrace",
14
+ ]
15
+
16
+
17
+ class ArgumentsTrace(State):
18
+ if __debug__:
19
+
20
+ @classmethod
21
+ def of(cls, *args: Any, **kwargs: Any) -> Self:
22
+ return cls(
23
+ args=args if args else MISSING,
24
+ kwargs=kwargs if kwargs else MISSING,
25
+ )
26
+
27
+ else: # remove tracing for non debug runs to prevent accidental secret leaks
28
+
29
+ @classmethod
30
+ def of(cls, *args: Any, **kwargs: Any) -> Self:
31
+ return cls(
32
+ args=MISSING,
33
+ kwargs=MISSING,
34
+ )
35
+
36
+ args: tuple[Any, ...] | Missing
37
+ kwargs: dict[str, Any] | Missing
38
+
39
+
40
+ class ResultTrace(State):
41
+ if __debug__:
42
+
43
+ @classmethod
44
+ def of(
45
+ cls,
46
+ value: Any,
47
+ /,
48
+ ) -> Self:
49
+ return cls(result=value)
50
+
51
+ else: # remove tracing for non debug runs to prevent accidental secret leaks
52
+
53
+ @classmethod
54
+ def of(
55
+ cls,
56
+ value: Any,
57
+ /,
58
+ ) -> Self:
59
+ return cls(result=MISSING)
60
+
61
+ result: Any | Missing
62
+
63
+
64
+ def traced[**Args, Result](
65
+ function: Callable[Args, Result],
66
+ /,
67
+ ) -> Callable[Args, Result]:
68
+ if __debug__:
69
+ if iscoroutinefunction(function):
70
+ return cast(
71
+ Callable[Args, Result],
72
+ _traced_async(
73
+ function,
74
+ label=function.__name__,
75
+ ),
76
+ )
77
+ else:
78
+ return _traced_sync(
79
+ function,
80
+ label=function.__name__,
81
+ )
82
+
83
+ else: # do not trace on non debug runs
84
+ return function
85
+
86
+
87
+ def _traced_sync[**Args, Result](
88
+ function: Callable[Args, Result],
89
+ /,
90
+ label: str,
91
+ ) -> Callable[Args, Result]:
92
+ def traced(
93
+ *args: Args.args,
94
+ **kwargs: Args.kwargs,
95
+ ) -> Result:
96
+ with ctx.scope(label):
97
+ ctx.record(ArgumentsTrace.of(*args, **kwargs))
98
+ try:
99
+ result: Result = function(*args, **kwargs)
100
+ ctx.record(ResultTrace.of(result))
101
+ return result
102
+
103
+ except BaseException as exc:
104
+ ctx.record(ResultTrace.of(exc))
105
+ raise exc
106
+
107
+ return mimic_function(
108
+ function,
109
+ within=traced,
110
+ )
111
+
112
+
113
+ def _traced_async[**Args, Result](
114
+ function: Callable[Args, Coroutine[Any, Any, Result]],
115
+ /,
116
+ label: str,
117
+ ) -> Callable[Args, Coroutine[Any, Any, Result]]:
118
+ async def traced(
119
+ *args: Args.args,
120
+ **kwargs: Args.kwargs,
121
+ ) -> Result:
122
+ with ctx.scope(label):
123
+ ctx.record(ArgumentsTrace.of(*args, **kwargs))
124
+ try:
125
+ result: Result = await function(*args, **kwargs)
126
+ ctx.record(ResultTrace.of(result))
127
+ return result
128
+
129
+ except BaseException as exc:
130
+ ctx.record(ResultTrace.of(exc))
131
+ raise exc
132
+
133
+ return mimic_function(
134
+ function,
135
+ within=traced,
136
+ )
@@ -11,7 +11,7 @@ __all__ = [
11
11
  ]
12
12
 
13
13
 
14
- def attribute_type_validator(
14
+ def attribute_type_validator( # noqa: PLR0911
15
15
  annotation: AttributeAnnotation,
16
16
  /,
17
17
  ) -> Callable[[Any], Any]:
@@ -31,6 +31,10 @@ def attribute_type_validator(
31
31
  case typing.Any:
32
32
  return _any_validator
33
33
 
34
+ # typed dicts fail on type checks
35
+ case typed_dict if typing.is_typeddict(typed_dict):
36
+ return _prepare_typed_dict_validator(typed_dict)
37
+
34
38
  case type() as other_type:
35
39
  return _prepare_type_validator(other_type)
36
40
 
@@ -123,3 +127,23 @@ def _prepare_type_validator(
123
127
  )
124
128
 
125
129
  return type_validator
130
+
131
+
132
+ def _prepare_typed_dict_validator(
133
+ validated_type: type[Any],
134
+ /,
135
+ ) -> Callable[[Any], Any]:
136
+ def typed_dict_validator(
137
+ value: Any,
138
+ ) -> Any:
139
+ match value:
140
+ case value if isinstance(value, dict):
141
+ # for typed dicts check only if that is a dict
142
+ return value # pyright: ignore[reportUnknownVariableType]
143
+
144
+ case _:
145
+ raise TypeError(
146
+ f"Type '{type(value)}' is not matching expected type '{validated_type}'"
147
+ )
148
+
149
+ return typed_dict_validator
@@ -27,6 +27,9 @@ class Missing(metaclass=MissingType):
27
27
  Type representing absence of a value. Use MISSING constant for its value.
28
28
  """
29
29
 
30
+ __slots__ = ()
31
+ __match_args__ = ()
32
+
30
33
  def __bool__(self) -> bool:
31
34
  return False
32
35
 
@@ -42,7 +45,7 @@ class Missing(metaclass=MissingType):
42
45
  def __repr__(self) -> str:
43
46
  return "MISSING"
44
47
 
45
- def __getattribute__(
48
+ def __getattr__(
46
49
  self,
47
50
  name: str,
48
51
  ) -> Any:
@@ -3,8 +3,6 @@ from collections import deque
3
3
  from collections.abc import AsyncIterator
4
4
  from typing import Self
5
5
 
6
- from haiway.utils.immutable import freeze
7
-
8
6
  __all__ = [
9
7
  "AsyncQueue",
10
8
  ]
@@ -18,20 +16,19 @@ class AsyncQueue[Element](AsyncIterator[Element]):
18
16
 
19
17
  def __init__(
20
18
  self,
19
+ *elements: Element,
21
20
  loop: AbstractEventLoop | None = None,
22
21
  ) -> None:
23
22
  self._loop: AbstractEventLoop = loop or get_running_loop()
24
- self._queue: deque[Element] = deque()
23
+ self._queue: deque[Element] = deque(elements)
25
24
  self._waiting: Future[Element] | None = None
26
25
  self._finish_reason: BaseException | None = None
27
26
 
28
- freeze(self)
29
-
30
27
  def __del__(self) -> None:
31
28
  self.finish()
32
29
 
33
30
  @property
34
- def finished(self) -> bool:
31
+ def is_finished(self) -> bool:
35
32
  return self._finish_reason is not None
36
33
 
37
34
  def enqueue(
@@ -40,7 +37,7 @@ class AsyncQueue[Element](AsyncIterator[Element]):
40
37
  /,
41
38
  *elements: Element,
42
39
  ) -> None:
43
- if self.finished:
40
+ if self.is_finished:
44
41
  raise RuntimeError("AsyncQueue is already finished")
45
42
 
46
43
  if self._waiting is not None and not self._waiting.done():
@@ -55,7 +52,7 @@ class AsyncQueue[Element](AsyncIterator[Element]):
55
52
  self,
56
53
  exception: BaseException | None = None,
57
54
  ) -> None:
58
- if self.finished:
55
+ if self.is_finished:
59
56
  return # already finished, ignore
60
57
 
61
58
  self._finish_reason = exception or StopAsyncIteration()
@@ -70,7 +67,7 @@ class AsyncQueue[Element](AsyncIterator[Element]):
70
67
  return self
71
68
 
72
69
  async def __anext__(self) -> Element:
73
- assert self._waiting is None, "Only a single queue iterator is supported!" # nosec: B101
70
+ assert self._waiting is None, "Only a single queue consumer is supported!" # nosec: B101
74
71
 
75
72
  if self._queue: # check the queue, let it finish
76
73
  return self._queue.popleft()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: haiway
3
- Version: 0.3.1
3
+ Version: 0.4.0
4
4
  Summary: Framework for dependency injection and state management within structured concurrency model.
5
5
  Maintainer-email: Kacper Kaliński <kacper.kalinski@miquido.com>
6
6
  License: MIT License
@@ -21,6 +21,7 @@ src/haiway/helpers/caching.py
21
21
  src/haiway/helpers/retries.py
22
22
  src/haiway/helpers/throttling.py
23
23
  src/haiway/helpers/timeouted.py
24
+ src/haiway/helpers/tracing.py
24
25
  src/haiway/state/__init__.py
25
26
  src/haiway/state/attributes.py
26
27
  src/haiway/state/structure.py
@@ -41,4 +42,5 @@ tests/test_auto_retry.py
41
42
  tests/test_cache.py
42
43
  tests/test_context.py
43
44
  tests/test_state.py
45
+ tests/test_streaming.py
44
46
  tests/test_timeout.py
@@ -1,3 +1,5 @@
1
+ from asyncio import get_running_loop
2
+
1
3
  from haiway import MissingContext, ScopeMetrics, State, ctx
2
4
  from pytest import mark, raises
3
5
 
@@ -64,38 +66,92 @@ async def test_exceptions_are_propagated():
64
66
 
65
67
  @mark.asyncio
66
68
  async def test_completions_are_called_according_to_context_exits():
69
+ completion_future = get_running_loop().create_future()
70
+ nested_completion_future = get_running_loop().create_future()
71
+ executions: int = 0
72
+
73
+ def completion(metrics: ScopeMetrics):
74
+ nonlocal executions
75
+ executions += 1
76
+ completion_future.set_result(())
77
+
78
+ def nested_completion(metrics: ScopeMetrics):
79
+ nonlocal executions
80
+ executions += 1
81
+ nested_completion_future.set_result(())
82
+
83
+ async with ctx.scope("outer", completion=completion):
84
+ assert executions == 0
85
+
86
+ async with ctx.scope("inner", completion=nested_completion):
87
+ assert executions == 0
88
+
89
+ await nested_completion_future
90
+ assert executions == 1
91
+
92
+ await completion_future
93
+ assert executions == 2
94
+
95
+
96
+ @mark.asyncio
97
+ async def test_async_completions_are_called_according_to_context_exits():
98
+ completion_future = get_running_loop().create_future()
99
+ nested_completion_future = get_running_loop().create_future()
67
100
  executions: int = 0
68
101
 
69
102
  async def completion(metrics: ScopeMetrics):
70
103
  nonlocal executions
71
104
  executions += 1
105
+ completion_future.set_result(())
106
+
107
+ async def nested_completion(metrics: ScopeMetrics):
108
+ nonlocal executions
109
+ executions += 1
110
+ nested_completion_future.set_result(())
72
111
 
73
112
  async with ctx.scope("outer", completion=completion):
74
113
  assert executions == 0
75
114
 
76
- async with ctx.scope("inner", completion=completion):
115
+ async with ctx.scope("inner", completion=nested_completion):
77
116
  assert executions == 0
78
117
 
118
+ await nested_completion_future
79
119
  assert executions == 1
80
120
 
121
+ await completion_future
81
122
  assert executions == 2
82
123
 
83
124
 
84
125
  @mark.asyncio
85
126
  async def test_metrics_are_recorded_within_context():
86
- def verify_example_metrics(state: str):
87
- async def completion(metrics: ScopeMetrics):
88
- assert metrics.read(ExampleState, default=ExampleState()).state == state
127
+ completion_future = get_running_loop().create_future()
128
+ nested_completion_future = get_running_loop().create_future()
129
+ metric: ExampleState = ExampleState()
130
+ nested_metric: ExampleState = ExampleState()
131
+
132
+ async def completion(metrics: ScopeMetrics):
133
+ nonlocal metric
134
+ metric = metrics.read(ExampleState, default=ExampleState())
135
+ completion_future.set_result(())
89
136
 
90
- return completion
137
+ async def nested_completion(metrics: ScopeMetrics):
138
+ nonlocal nested_metric
139
+ nested_metric = metrics.read(ExampleState, default=ExampleState())
140
+ nested_completion_future.set_result(())
91
141
 
92
- async with ctx.scope("outer", completion=verify_example_metrics("outer-in-out")):
142
+ async with ctx.scope("outer", completion=completion):
93
143
  ctx.record(ExampleState(state="outer-in"))
94
144
 
95
- async with ctx.scope("inner", completion=verify_example_metrics("inner")):
145
+ async with ctx.scope("inner", completion=nested_completion):
96
146
  ctx.record(ExampleState(state="inner"))
97
147
 
148
+ await nested_completion_future
149
+ assert nested_metric.state == "inner"
150
+
98
151
  ctx.record(
99
152
  ExampleState(state="-out"),
100
153
  merge=lambda lhs, rhs: ExampleState(state=lhs.state + rhs.state),
101
154
  )
155
+
156
+ await completion_future
157
+ assert metric.state == "outer-in-out"
@@ -1,10 +1,13 @@
1
1
  from collections.abc import Callable
2
- from typing import Literal, Protocol, Self, runtime_checkable
2
+ from typing import Literal, Protocol, Self, TypedDict, runtime_checkable
3
3
 
4
4
  from haiway import State, frozenlist
5
5
 
6
6
 
7
7
  def test_basic_initializes_with_arguments() -> None:
8
+ class DictTyped(TypedDict):
9
+ value: str
10
+
8
11
  @runtime_checkable
9
12
  class Proto(Protocol):
10
13
  def __call__(self) -> None: ...
@@ -20,6 +23,7 @@ def test_basic_initializes_with_arguments() -> None:
20
23
  none: None
21
24
  function: Callable[[], None]
22
25
  proto: Proto
26
+ dict_typed: DictTyped
23
27
 
24
28
  basic = Basics(
25
29
  string="string",
@@ -32,6 +36,7 @@ def test_basic_initializes_with_arguments() -> None:
32
36
  none=None,
33
37
  function=lambda: None,
34
38
  proto=lambda: None,
39
+ dict_typed={"value": "42"},
35
40
  )
36
41
  assert basic.string == "string"
37
42
  assert basic.literal == "A"
@@ -41,6 +46,7 @@ def test_basic_initializes_with_arguments() -> None:
41
46
  assert basic.union == "union"
42
47
  assert basic.optional == "optional"
43
48
  assert basic.none is None
49
+ assert basic.dict_typed == {"value": "42"}
44
50
  assert callable(basic.function)
45
51
  assert isinstance(basic.proto, Proto)
46
52
 
@@ -0,0 +1,160 @@
1
+ from asyncio import CancelledError, get_running_loop, sleep
2
+ from collections.abc import AsyncGenerator, AsyncIterator
3
+
4
+ from haiway import ctx
5
+ from haiway.context.metrics import ScopeMetrics
6
+ from haiway.state.structure import State
7
+ from pytest import mark, raises
8
+
9
+
10
+ class FakeException(Exception):
11
+ pass
12
+
13
+
14
+ @mark.asyncio
15
+ async def test_fails_when_generator_fails():
16
+ async def generator(value: int) -> AsyncGenerator[int, None]:
17
+ yield value
18
+ raise FakeException()
19
+
20
+ elements: int = 0
21
+ with raises(FakeException):
22
+ async for _ in ctx.stream(generator, 42):
23
+ elements += 1
24
+
25
+ assert elements == 1
26
+
27
+
28
+ @mark.asyncio
29
+ async def test_cancels_when_iteration_cancels():
30
+ async def generator(value: int) -> AsyncGenerator[int, None]:
31
+ await sleep(0)
32
+ yield value
33
+
34
+ elements: int = 0
35
+ with raises(CancelledError):
36
+ ctx.cancel()
37
+ async for _ in ctx.stream(generator, 42):
38
+ elements += 1
39
+
40
+ assert elements == 0
41
+
42
+
43
+ @mark.asyncio
44
+ async def test_ends_when_generator_ends():
45
+ async def generator(value: int) -> AsyncGenerator[int, None]:
46
+ yield value
47
+
48
+ elements: int = 0
49
+ async for _ in ctx.stream(generator, 42):
50
+ elements += 1
51
+
52
+ assert elements == 1
53
+
54
+
55
+ @mark.asyncio
56
+ async def test_delivers_updates_when_generating():
57
+ async def generator(value: int) -> AsyncGenerator[int, None]:
58
+ for i in range(0, value):
59
+ yield i
60
+
61
+ elements: list[int] = []
62
+
63
+ async for element in ctx.stream(generator, 10):
64
+ elements.append(element)
65
+
66
+ assert elements == list(range(0, 10))
67
+
68
+
69
+ @mark.asyncio
70
+ async def test_streaming_context_variables_access_is_preserved():
71
+ class TestState(State):
72
+ value: int = 42
73
+ other: str = "other"
74
+
75
+ async def generator(value: int) -> AsyncGenerator[TestState, None]:
76
+ yield ctx.state(TestState)
77
+ with ctx.scope("nested", ctx.state(TestState).updated(value=value)):
78
+ yield ctx.state(TestState)
79
+
80
+ stream: AsyncIterator[TestState]
81
+ async with ctx.scope("test", TestState(value=42)):
82
+ elements: list[TestState] = []
83
+
84
+ stream = ctx.stream(generator, 10)
85
+
86
+ async for element in stream:
87
+ elements.append(element)
88
+
89
+ assert elements == [
90
+ TestState(value=42),
91
+ TestState(value=10),
92
+ ]
93
+
94
+
95
+ @mark.asyncio
96
+ async def test_nested_streaming_streams_correctly():
97
+ class TestState(State):
98
+ value: int = 42
99
+ other: str = "other"
100
+
101
+ async def inner(value: int) -> AsyncGenerator[TestState, None]:
102
+ yield ctx.state(TestState)
103
+ with ctx.scope("inner", ctx.state(TestState).updated(value=value, other="inner")):
104
+ yield ctx.state(TestState)
105
+
106
+ async def outer(value: int) -> AsyncGenerator[TestState, None]:
107
+ yield ctx.state(TestState)
108
+ with ctx.scope("outer", ctx.state(TestState).updated(other="outer")):
109
+ async for item in ctx.stream(inner, value):
110
+ yield item
111
+
112
+ stream: AsyncIterator[TestState]
113
+ async with ctx.scope("test", TestState(value=42)):
114
+ elements: list[TestState] = []
115
+
116
+ stream = ctx.stream(outer, 10)
117
+
118
+ async for element in stream:
119
+ elements.append(element)
120
+
121
+ assert elements == [
122
+ TestState(value=42),
123
+ TestState(value=42, other="outer"),
124
+ TestState(value=10, other="inner"),
125
+ ]
126
+
127
+
128
+ @mark.asyncio
129
+ async def test_streaming_context_completion_is_called_at_the_end_of_stream():
130
+ completion_future = get_running_loop().create_future()
131
+
132
+ class IterationMetric(State):
133
+ value: int = 0
134
+
135
+ metric: IterationMetric = IterationMetric()
136
+
137
+ def completion(metrics: ScopeMetrics):
138
+ nonlocal metric
139
+ metric = metrics.metrics(merge=lambda current, nested: nested)[0] # pyright: ignore[reportAssignmentType]
140
+ completion_future.set_result(())
141
+
142
+ async def generator(value: int) -> AsyncGenerator[int, None]:
143
+ for i in range(0, value):
144
+ ctx.record(
145
+ IterationMetric(value=i),
146
+ merge=lambda lhs, rhs: IterationMetric(value=lhs.value + rhs.value),
147
+ )
148
+ yield i
149
+
150
+ stream: AsyncIterator[int]
151
+ async with ctx.scope("test", completion=completion):
152
+ elements: list[int] = []
153
+
154
+ stream = ctx.stream(generator, 10)
155
+
156
+ async for element in stream:
157
+ elements.append(element)
158
+
159
+ await completion_future
160
+ assert metric.value == 45
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes