aionode 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aionode-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,158 @@
1
+ Metadata-Version: 2.4
2
+ Name: aionode
3
+ Version: 0.1.0
4
+ Summary: Lightweight asyncio task tracking as call tree and DAG
5
+ Author: Matteo De Pellegrin
6
+ Author-email: Matteo De Pellegrin <matteo.dep97@gmail.com>
7
+ License-Expression: MIT
8
+ Requires-Python: >=3.12
9
+ Project-URL: Repository, https://github.com/MatteoDep/aionode
10
+ Project-URL: Issues, https://github.com/MatteoDep/aionode/issues
11
+ Description-Content-Type: text/markdown
12
+
13
+ # aionode
14
+
15
+ Lightweight asyncio task tracking with dependency graphs and progress rendering.
16
+
17
+ ## Installation
18
+
19
+ ```bash
20
+ pip install aionode
21
+ ```
22
+
23
+ For Rich table rendering:
24
+
25
+ ```bash
26
+ pip install aionode[viz]
27
+ ```
28
+
29
+ ## Quick Start
30
+
31
+ ```python
32
+ import asyncio
33
+ import aionode
34
+
35
+ async def fetch_data() -> list[int]:
36
+ await asyncio.sleep(1)
37
+ return list(range(100))
38
+
39
+ async def process(data: list[int]) -> int:
40
+ await asyncio.sleep(0.5)
41
+ return sum(data)
42
+
43
+ async def pipeline() -> None:
44
+ async with asyncio.TaskGroup() as tg:
45
+ fetch = tg.create_task(aionode.node(fetch_data)(), name="fetch")
46
+ tg.create_task(aionode.node(process)(aionode.resolve(fetch)), name="process")
47
+
48
+ async def main() -> None:
49
+ root = asyncio.create_task(aionode.track(pipeline)(), name="pipeline")
50
+ root_id = await aionode.get_task_id(root)
51
+ graph = aionode.TaskGraph(root_id=root_id)
52
+
53
+ await aionode.watch(graph, interval=0.3)
54
+ await root
55
+
56
+ asyncio.run(main())
57
+ ```
58
+
59
+ ## Core Concepts
60
+
61
+ ### `node(func, wait_for=[], track=True, auto_progress=True)`
62
+
63
+ Wraps an async function as a DAG node. Use `resolve()` to pass awaitables as arguments — they are gathered concurrently before the function is called. Sync functions must be wrapped with `make_async` first.
64
+
65
+ ```python
66
+ # Async function — pass upstream tasks with resolve()
67
+ fetch = tg.create_task(aionode.node(fetch_data)(), name="fetch")
68
+ process_task = tg.create_task(aionode.node(process)(aionode.resolve(fetch)), name="process")
69
+
70
+ # Sync function — wrap with make_async first
71
+ summarize = tg.create_task(
72
+ aionode.node(aionode.make_async(my_sync_fn))(aionode.resolve(process_task)),
73
+ name="summarize",
74
+ )
75
+
76
+ # Side-only dependency (no value passed): use wait_for
77
+ task_b = tg.create_task(aionode.node(cleanup, wait_for=[fetch])(), name="cleanup")
78
+ ```
79
+
80
+ ### `resolve(awaitable)`
81
+
82
+ Marks an awaitable to be resolved before being passed as an argument to `node()`. This preserves type information — the type checker sees `resolve(task: Task[T])` as returning `T`.
83
+
84
+ ```python
85
+ result = await aionode.node(process)(aionode.resolve(upstream_task))
86
+ ```
87
+
88
+ ### `track(func, start=True)`
89
+
90
+ Tracks a coroutine by registering it in the task graph. Use this for the root task or tasks that don't need `node()`'s dependency resolution.
91
+
92
+ ```python
93
+ root = asyncio.create_task(aionode.track(my_coro)(), name="root")
94
+ ```
95
+
96
+ ### `TaskGraph`
97
+
98
+ Query the dependency graph:
99
+
100
+ ```python
101
+ graph = aionode.TaskGraph(root_id=root_id)
102
+
103
+ graph.nodes() # All tasks in topological order
104
+ graph.roots() # Tasks with no upstream deps
105
+ graph.leaves() # Tasks with no downstream dependents
106
+ graph.upstream(task_id) # Transitive upstream deps
107
+ graph.downstream(task_id) # Transitive downstream dependents
108
+ graph.summary() # {TaskStatus: count}
109
+ graph.critical_path() # Longest-duration path
110
+ ```
111
+
112
+ ### Rendering
113
+
114
+ ```python
115
+ # Auto-detect Rich or fall back to ANSI text
116
+ renderer = aionode.get_render()
117
+
118
+ # Force plain text
119
+ renderer = aionode.get_render(rich=False)
120
+
121
+ # Live-updating display
122
+ await aionode.watch(graph, interval=0.5, renderer=renderer)
123
+ ```
124
+
125
+ ### Task Inspection
126
+
127
+ ```python
128
+ task_id = await aionode.get_task_id(asyncio_task)
129
+ info = aionode.get_task(task_id)
130
+
131
+ info.status # TaskStatus: WAITING, RUNNING, DONE, FAILED, CANCELLED
132
+ info.duration() # Elapsed seconds
133
+ info.description # Task name
134
+ info.deps # Upstream dependency IDs
135
+ info.dependents # Downstream dependent IDs
136
+ info.logs # Accumulated log output
137
+
138
+ await aionode.log("processing record 42") # Append to current task's logs
139
+ ```
140
+
141
+ ## API Reference
142
+
143
+ | Function | Description |
144
+ |---|---|
145
+ | `node(func, wait_for, track, auto_progress)` | Wrap an async function as a DAG node |
146
+ | `resolve(awaitable)` | Mark an awaitable to be resolved as a node argument |
147
+ | `track(func, start)` | Track a coroutine in the task graph |
148
+ | `get_task_id(task, timeout)` | Get the task ID for an asyncio.Task |
149
+ | `get_task(task_id)` | Get TaskInfo by ID |
150
+ | `remove_task(task_id)` | Remove a task and its descendants |
151
+ | `log(value, end)` | Append to the current task's logs |
152
+ | `make_async(func)` | Run a sync function in a thread |
153
+ | `make_async_generator(gen)` | Async iterate a sync iterator via threads |
154
+ | `TaskGraph(root_id)` | Create a graph view rooted at a task |
155
+ | `TaskGraph.from_task(task)` | Create a graph from an asyncio.Task |
156
+ | `TaskGraph.current()` | Graph over all tasks in the current loop |
157
+ | `get_render(rich, bar_width, ...)` | Get a configured render function |
158
+ | `watch(graph, interval, renderer)` | Live-render until all tasks finish |
@@ -0,0 +1,146 @@
1
+ # aionode
2
+
3
+ Lightweight asyncio task tracking with dependency graphs and progress rendering.
4
+
5
+ ## Installation
6
+
7
+ ```bash
8
+ pip install aionode
9
+ ```
10
+
11
+ For Rich table rendering:
12
+
13
+ ```bash
14
+ pip install aionode[viz]
15
+ ```
16
+
17
+ ## Quick Start
18
+
19
+ ```python
20
+ import asyncio
21
+ import aionode
22
+
23
+ async def fetch_data() -> list[int]:
24
+ await asyncio.sleep(1)
25
+ return list(range(100))
26
+
27
+ async def process(data: list[int]) -> int:
28
+ await asyncio.sleep(0.5)
29
+ return sum(data)
30
+
31
+ async def pipeline() -> None:
32
+ async with asyncio.TaskGroup() as tg:
33
+ fetch = tg.create_task(aionode.node(fetch_data)(), name="fetch")
34
+ tg.create_task(aionode.node(process)(aionode.resolve(fetch)), name="process")
35
+
36
+ async def main() -> None:
37
+ root = asyncio.create_task(aionode.track(pipeline)(), name="pipeline")
38
+ root_id = await aionode.get_task_id(root)
39
+ graph = aionode.TaskGraph(root_id=root_id)
40
+
41
+ await aionode.watch(graph, interval=0.3)
42
+ await root
43
+
44
+ asyncio.run(main())
45
+ ```
46
+
47
+ ## Core Concepts
48
+
49
+ ### `node(func, wait_for=[], track=True, auto_progress=True)`
50
+
51
+ Wraps an async function as a DAG node. Use `resolve()` to pass awaitables as arguments — they are gathered concurrently before the function is called. Sync functions must be wrapped with `make_async` first.
52
+
53
+ ```python
54
+ # Async function — pass upstream tasks with resolve()
55
+ fetch = tg.create_task(aionode.node(fetch_data)(), name="fetch")
56
+ process_task = tg.create_task(aionode.node(process)(aionode.resolve(fetch)), name="process")
57
+
58
+ # Sync function — wrap with make_async first
59
+ summarize = tg.create_task(
60
+ aionode.node(aionode.make_async(my_sync_fn))(aionode.resolve(process_task)),
61
+ name="summarize",
62
+ )
63
+
64
+ # Side-only dependency (no value passed): use wait_for
65
+ task_b = tg.create_task(aionode.node(cleanup, wait_for=[fetch])(), name="cleanup")
66
+ ```
67
+
68
+ ### `resolve(awaitable)`
69
+
70
+ Marks an awaitable to be resolved before being passed as an argument to `node()`. This preserves type information — the type checker sees `resolve(task: Task[T])` as returning `T`.
71
+
72
+ ```python
73
+ result = await aionode.node(process)(aionode.resolve(upstream_task))
74
+ ```
75
+
76
+ ### `track(func, start=True)`
77
+
78
+ Tracks a coroutine by registering it in the task graph. Use this for the root task or tasks that don't need `node()`'s dependency resolution.
79
+
80
+ ```python
81
+ root = asyncio.create_task(aionode.track(my_coro)(), name="root")
82
+ ```
83
+
84
+ ### `TaskGraph`
85
+
86
+ Query the dependency graph:
87
+
88
+ ```python
89
+ graph = aionode.TaskGraph(root_id=root_id)
90
+
91
+ graph.nodes() # All tasks in topological order
92
+ graph.roots() # Tasks with no upstream deps
93
+ graph.leaves() # Tasks with no downstream dependents
94
+ graph.upstream(task_id) # Transitive upstream deps
95
+ graph.downstream(task_id) # Transitive downstream dependents
96
+ graph.summary() # {TaskStatus: count}
97
+ graph.critical_path() # Longest-duration path
98
+ ```
99
+
100
+ ### Rendering
101
+
102
+ ```python
103
+ # Auto-detect Rich or fall back to ANSI text
104
+ renderer = aionode.get_render()
105
+
106
+ # Force plain text
107
+ renderer = aionode.get_render(rich=False)
108
+
109
+ # Live-updating display
110
+ await aionode.watch(graph, interval=0.5, renderer=renderer)
111
+ ```
112
+
113
+ ### Task Inspection
114
+
115
+ ```python
116
+ task_id = await aionode.get_task_id(asyncio_task)
117
+ info = aionode.get_task(task_id)
118
+
119
+ info.status # TaskStatus: WAITING, RUNNING, DONE, FAILED, CANCELLED
120
+ info.duration() # Elapsed seconds
121
+ info.description # Task name
122
+ info.deps # Upstream dependency IDs
123
+ info.dependents # Downstream dependent IDs
124
+ info.logs # Accumulated log output
125
+
126
+ await aionode.log("processing record 42") # Append to current task's logs
127
+ ```
128
+
129
+ ## API Reference
130
+
131
+ | Function | Description |
132
+ |---|---|
133
+ | `node(func, wait_for, track, auto_progress)` | Wrap an async function as a DAG node |
134
+ | `resolve(awaitable)` | Mark an awaitable to be resolved as a node argument |
135
+ | `track(func, start)` | Track a coroutine in the task graph |
136
+ | `get_task_id(task, timeout)` | Get the task ID for an asyncio.Task |
137
+ | `get_task(task_id)` | Get TaskInfo by ID |
138
+ | `remove_task(task_id)` | Remove a task and its descendants |
139
+ | `log(value, end)` | Append to the current task's logs |
140
+ | `make_async(func)` | Run a sync function in a thread |
141
+ | `make_async_generator(gen)` | Async iterate a sync iterator via threads |
142
+ | `TaskGraph(root_id)` | Create a graph view rooted at a task |
143
+ | `TaskGraph.from_task(task)` | Create a graph from an asyncio.Task |
144
+ | `TaskGraph.current()` | Graph over all tasks in the current loop |
145
+ | `get_render(rich, bar_width, ...)` | Get a configured render function |
146
+ | `watch(graph, interval, renderer)` | Live-render until all tasks finish |
@@ -0,0 +1,48 @@
1
+ [project]
2
+ name = "aionode"
3
+ version = "0.1.0"
4
+ description = "Lightweight asyncio task tracking as call tree and DAG"
5
+ readme = "README.md"
6
+ authors = [
7
+ { name = "Matteo De Pellegrin", email = "matteo.dep97@gmail.com" }
8
+ ]
9
+ license = "MIT"
10
+ requires-python = ">=3.12"
11
+ dependencies = []
12
+
13
+ [project.urls]
14
+ Repository = "https://github.com/MatteoDep/aionode"
15
+ Issues = "https://github.com/MatteoDep/aionode/issues"
16
+
17
+ [build-system]
18
+ requires = ["uv_build>=0.10.3,<0.11.0"]
19
+ build-backend = "uv_build"
20
+
21
+ [dependency-groups]
22
+ dev = [
23
+ "pytest>=9.0.2",
24
+ "pytest-asyncio>=1.3.0",
25
+ "ruff>=0.15.4",
26
+ "ty>=0.0.19",
27
+ ]
28
+
29
+ [tool.pytest.ini_options]
30
+ asyncio_mode = "auto"
31
+
32
+ [tool.ruff]
33
+ line-length = 120
34
+
35
+ [tool.ruff.lint]
36
+ select = [
37
+ "E", # pycodestyle errors
38
+ "W", # pycodestyle warnings
39
+ "F", # pyflakes
40
+ "I", # isort
41
+ "UP", # pyupgrade
42
+ "B", # flake8-bugbear
43
+ "C4", # flake8-comprehensions
44
+ "RUF", # ruff-specific rules
45
+ ]
46
+
47
+ [tool.ruff.lint.per-file-ignores]
48
+ "tests/**" = ["S101"] # allow assert in tests
@@ -0,0 +1,623 @@
1
+ from __future__ import annotations
2
+
3
+ __version__ = "0.1.0"
4
+
5
+ import asyncio
6
+ import functools
7
+ import inspect
8
+ import threading
9
+ import time
10
+ import weakref
11
+ from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine, Iterator, Sequence
12
+ from contextlib import asynccontextmanager
13
+ from contextvars import ContextVar
14
+ from dataclasses import dataclass, field
15
+ from datetime import datetime
16
+ from enum import StrEnum
17
+ from typing import Any, Protocol, cast, runtime_checkable
18
+
19
+
20
+ @dataclass(slots=True)
21
+ class _Resolved[T]:
22
+ awaitable: Awaitable[T]
23
+
24
+
25
+ def resolve[T](awaitable: Awaitable[T]) -> T:
26
+ return _Resolved(awaitable) # type: ignore[return-value]
27
+
28
+
29
+ def node[**P, R](
30
+ func: Callable[P, Coroutine[Any, Any, R]],
31
+ /,
32
+ wait_for: Sequence[Awaitable[Any]] | None = None,
33
+ track: bool = True,
34
+ auto_progress: bool = True,
35
+ ) -> Callable[P, Coroutine[Any, Any, R]]:
36
+ @functools.wraps(func)
37
+ async def wrapper(*args: Any, **kwargs: Any) -> Any:
38
+ if track:
39
+ await _init_task_info(start=False, auto_progress=auto_progress)
40
+ _start = _start_task
41
+ else:
42
+
43
+ async def _start() -> None:
44
+ pass
45
+
46
+ # Identify _Resolved positions
47
+ resolved_arg_idxs = [(i, a) for i, a in enumerate(args) if isinstance(a, _Resolved)]
48
+ resolved_kwarg_keys = [(k, v) for k, v in kwargs.items() if isinstance(v, _Resolved)]
49
+ all_awaitables = (
50
+ [a.awaitable for _, a in resolved_arg_idxs]
51
+ + [v.awaitable for _, v in resolved_kwarg_keys]
52
+ + list(wait_for or [])
53
+ )
54
+
55
+ if track:
56
+ state = _get_state()
57
+ our_id = _task_id.get()
58
+ dep_tasks: list[asyncio.Task] = (
59
+ [a.awaitable for _, a in resolved_arg_idxs if isinstance(a.awaitable, asyncio.Task)]
60
+ + [v.awaitable for _, v in resolved_kwarg_keys if isinstance(v.awaitable, asyncio.Task)]
61
+ + [d for d in (wait_for or []) if isinstance(d, asyncio.Task)]
62
+ )
63
+ for dep_task in dep_tasks:
64
+ if dep_task in state.task_ids:
65
+ await _register_dep(our_id, state.task_ids[dep_task])
66
+
67
+ try:
68
+ if all_awaitables:
69
+ results = await asyncio.gather(*all_awaitables)
70
+ n_args = len(resolved_arg_idxs)
71
+ n_kw = len(resolved_kwarg_keys)
72
+ arg_results = list(results[:n_args])
73
+ kwarg_results = list(results[n_args : n_args + n_kw])
74
+ # results[n_args + n_kw:] are wait_for results — discarded
75
+ else:
76
+ arg_results, kwarg_results = [], []
77
+ except Exception as e:
78
+ msg = "Failed while waiting to start."
79
+ raise RuntimeError(msg) from e
80
+
81
+ # Rebuild args/kwargs with resolved values
82
+ resolved_args = list(args)
83
+ for (i, _), val in zip(resolved_arg_idxs, arg_results, strict=True):
84
+ resolved_args[i] = val
85
+ resolved_kwargs = dict(kwargs)
86
+ for (k, _), val in zip(resolved_kwarg_keys, kwarg_results, strict=True):
87
+ resolved_kwargs[k] = val
88
+
89
+ await _start()
90
+
91
+ result = func(*resolved_args, **resolved_kwargs)
92
+ return await result if inspect.isawaitable(result) else result
93
+
94
+ return wrapper
95
+
96
+
97
+ class _Unset:
98
+ """Sentinel for unset keyword arguments."""
99
+
100
+
101
+ _UNSET = _Unset()
102
+
103
+
104
+ class TaskStatus(StrEnum):
105
+ """Status of a tracked asyncio task."""
106
+
107
+ WAITING = "waiting to start"
108
+ RUNNING = "running"
109
+ DONE = "done"
110
+ FAILED = "failed"
111
+ CANCELLED = "cancelled"
112
+
113
+
114
+ @dataclass(slots=True)
115
+ class TaskInfo:
116
+ """Metadata and state for a single tracked asyncio task."""
117
+
118
+ id: int
119
+ task: asyncio.Task
120
+ name: str
121
+ parent: int | None
122
+ subtasks: list[int]
123
+ running_subtasks: list[int]
124
+ status: TaskStatus
125
+ started_at: datetime | None = None
126
+ finished_at: datetime | None = None
127
+ exception: BaseException | None = None
128
+ logs: str = ""
129
+ completed: float = 0
130
+ total: float | None = None
131
+ auto_progress: bool = True
132
+ deps: list[int] = field(default_factory=list)
133
+ dependents: list[int] = field(default_factory=list)
134
+ tree_depth: int = 0
135
+ dag_depth: int = 0
136
+ _start_mono: float | None = field(default=None, repr=False, compare=False)
137
+ _finish_mono: float | None = field(default=None, repr=False, compare=False)
138
+ _lock: asyncio.Lock = field(default_factory=asyncio.Lock, repr=False, compare=False)
139
+ _edit_allowed: bool = field(default=False, repr=False, compare=False)
140
+
141
+ def __setattr__(self, name: str, value: Any, /) -> None:
142
+ if name in ("_edit_allowed", "_lock", "_start_mono", "_finish_mono"):
143
+ object.__setattr__(self, name, value)
144
+ return
145
+
146
+ if hasattr(self, "_edit_allowed") and not object.__getattribute__(self, "_edit_allowed"):
147
+ msg = "Edit not allowed. Use the `allow_edit` context manager."
148
+ raise RuntimeError(msg)
149
+
150
+ object.__setattr__(self, name, value)
151
+
152
+ @asynccontextmanager
153
+ async def allow_edit(self) -> AsyncGenerator[None]:
154
+ async with self._lock:
155
+ self._edit_allowed = True
156
+ try:
157
+ yield
158
+ finally:
159
+ self._edit_allowed = False
160
+
161
+ async def update(
162
+ self,
163
+ *,
164
+ completed: float | _Unset = _UNSET,
165
+ total: float | None | _Unset = _UNSET,
166
+ ) -> None:
167
+ """Update user-facing fields atomically."""
168
+ async with self.allow_edit():
169
+ if not isinstance(completed, _Unset):
170
+ self.completed = completed
171
+ if not isinstance(total, _Unset):
172
+ self.total = total
173
+
174
+ def subtasks_info(
175
+ self,
176
+ fmt: Callable[[TaskInfo], str] = "- {0.name}: {0.status.value}".format,
177
+ sep: str = "\n",
178
+ all_subtasks: bool = False,
179
+ ) -> str:
180
+ items: list[str] = []
181
+ for child_id in self.subtasks if all_subtasks else self.running_subtasks:
182
+ try:
183
+ items.append(fmt(get_task_info(child_id)))
184
+ except ValueError:
185
+ continue
186
+ return sep.join(items)
187
+
188
+ def started(self) -> bool:
189
+ return self.started_at is not None
190
+
191
+ def done(self) -> bool:
192
+ return self.finished_at is not None
193
+
194
+ def duration(self) -> float:
195
+ """Get task duration in seconds."""
196
+ if self._start_mono is None:
197
+ return 0.0
198
+ end = self._finish_mono if self._finish_mono is not None else time.monotonic()
199
+ return end - self._start_mono
200
+
201
+
202
+ _task_id: ContextVar[int] = ContextVar("task_id")
203
+
204
+
205
+ @dataclass
206
+ class _LoopState:
207
+ task_infos: dict[int, TaskInfo] = field(default_factory=dict)
208
+ task_ids: dict[asyncio.Task, int] = field(default_factory=dict)
209
+ background_tasks: set[asyncio.Task] = field(default_factory=set)
210
+ _next_id: int = field(default=0)
211
+ _id_lock: threading.Lock = field(default_factory=threading.Lock)
212
+
213
+ def allocate_id(self) -> int:
214
+ with self._id_lock:
215
+ task_id = self._next_id
216
+ self._next_id += 1
217
+ return task_id
218
+
219
+
220
+ _loop_states: weakref.WeakKeyDictionary[asyncio.AbstractEventLoop, _LoopState] = weakref.WeakKeyDictionary()
221
+
222
+
223
+ def _get_state() -> _LoopState:
224
+ loop = asyncio.get_running_loop()
225
+ try:
226
+ return _loop_states[loop]
227
+ except KeyError:
228
+ state = _LoopState()
229
+ _loop_states[loop] = state
230
+ return state
231
+
232
+
233
+ async def _set_done(task_id: int) -> None:
234
+ state = _get_state()
235
+ task_info = state.task_infos[task_id]
236
+ async with task_info.allow_edit():
237
+ task_info._finish_mono = time.monotonic()
238
+ task_info.finished_at = datetime.now()
239
+ if task_info.task.cancelled():
240
+ task_info.status = TaskStatus.CANCELLED
241
+ elif exc := task_info.task.exception():
242
+ task_info.status = TaskStatus.FAILED
243
+ task_info.exception = exc
244
+ else:
245
+ task_info.status = TaskStatus.DONE
246
+ if task_info.total is None:
247
+ task_info.total = 1
248
+ task_info.completed = 1
249
+
250
+
251
+ async def _update_parent(task_id: int, parent_id: int, auto_progress: bool) -> None:
252
+ state = _get_state()
253
+ parent_task_info = state.task_infos[parent_id]
254
+ async with parent_task_info.allow_edit():
255
+ if auto_progress:
256
+ parent_task_info.completed = (parent_task_info.completed or 0) + 1
257
+ if task_id in parent_task_info.running_subtasks:
258
+ parent_task_info.running_subtasks.remove(task_id)
259
+
260
+
261
+ def _add_done_callback(task: asyncio.Task, task_id: int, state: _LoopState) -> None:
262
+ def callback(_: asyncio.Task) -> None:
263
+ callback_task = asyncio.create_task(_set_done(task_id=task_id))
264
+ state.background_tasks.add(callback_task)
265
+ callback_task.add_done_callback(state.background_tasks.discard)
266
+
267
+ task.add_done_callback(callback)
268
+
269
+
270
+ def _add_update_parent_callback(
271
+ task: asyncio.Task,
272
+ task_id: int,
273
+ parent_id: int,
274
+ auto_progress: bool,
275
+ state: _LoopState,
276
+ ) -> None:
277
+ def callback(_: asyncio.Task) -> None:
278
+ callback_task = asyncio.create_task(
279
+ _update_parent(task_id=task_id, parent_id=parent_id, auto_progress=auto_progress)
280
+ )
281
+ state.background_tasks.add(callback_task)
282
+ callback_task.add_done_callback(state.background_tasks.discard)
283
+
284
+ task.add_done_callback(callback)
285
+
286
+
287
+ def _get_task() -> asyncio.Task:
288
+ try:
289
+ task = asyncio.current_task()
290
+ except RuntimeError as e:
291
+ msg = "This function can only be called from a coroutine."
292
+ raise RuntimeError(msg) from e
293
+ if task is None:
294
+ msg = "No current task. This function must be called from within an asyncio task, not a callback."
295
+ raise RuntimeError(msg)
296
+ return task
297
+
298
+
299
+ async def _init_task_info(start: bool = True, auto_progress: bool = True) -> None:
300
+ state = _get_state()
301
+ task = _get_task()
302
+ task_name = task.get_name()
303
+ if task in state.task_ids:
304
+ msg = f"Task {task_name} is already initialized"
305
+ raise RuntimeError(msg)
306
+
307
+ task_id = state.allocate_id()
308
+
309
+ # get parent
310
+ try:
311
+ parent_id = _task_id.get()
312
+ except LookupError:
313
+ parent_id = None
314
+
315
+ parent_tree_depth = state.task_infos[parent_id].tree_depth if parent_id is not None else -1
316
+
317
+ task_info = TaskInfo(
318
+ id=task_id,
319
+ name=task_name,
320
+ parent=parent_id,
321
+ subtasks=[],
322
+ started_at=datetime.now() if start else None,
323
+ _start_mono=time.monotonic() if start else None,
324
+ status=TaskStatus.RUNNING if start else TaskStatus.WAITING,
325
+ task=task,
326
+ running_subtasks=[],
327
+ auto_progress=auto_progress,
328
+ tree_depth=parent_tree_depth + 1,
329
+ dag_depth=0,
330
+ )
331
+
332
+ async with task_info.allow_edit():
333
+ state.task_infos[task_id] = task_info
334
+ _add_done_callback(task, task_id=task_id, state=state)
335
+ if parent_id is not None:
336
+ parent_task_info = state.task_infos[parent_id]
337
+ async with parent_task_info.allow_edit():
338
+ parent_task_info.subtasks.append(task_id)
339
+ if start:
340
+ parent_task_info.running_subtasks.append(task_id)
341
+ _add_update_parent_callback(
342
+ task,
343
+ task_id=task_id,
344
+ parent_id=parent_id,
345
+ auto_progress=parent_task_info.auto_progress,
346
+ state=state,
347
+ )
348
+ if parent_task_info.auto_progress:
349
+ total = parent_task_info.total or 0
350
+ parent_task_info.total = total + 1
351
+ state.task_ids[task] = task_id
352
+ _task_id.set(task_id)
353
+
354
+
355
+
356
+ async def _start_task() -> None:
357
+ state = _get_state()
358
+ task = _get_task()
359
+ if task not in state.task_ids:
360
+ msg = f"Cannot start uninitialized task {task.get_name()}"
361
+ raise RuntimeError(msg)
362
+ task_id = _task_id.get()
363
+ task_info = state.task_infos[task_id]
364
+ async with task_info.allow_edit():
365
+ task_info._start_mono = time.monotonic()
366
+ task_info.started_at = datetime.now()
367
+ task_info.status = TaskStatus.RUNNING
368
+ if task_info.parent is not None:
369
+ parent_info = state.task_infos[task_info.parent]
370
+ async with parent_info.allow_edit():
371
+ parent_info.running_subtasks.append(task_id)
372
+
373
+
374
+ def _would_cycle(state: _LoopState, from_id: int, to_id: int) -> bool:
375
+ """Return True if adding edge from_id -> to_id would create a cycle."""
376
+ visited: set[int] = set()
377
+ stack = [from_id]
378
+ while stack:
379
+ tid = stack.pop()
380
+ if tid == to_id:
381
+ continue
382
+ if tid in visited:
383
+ continue
384
+ visited.add(tid)
385
+ info = state.task_infos.get(tid)
386
+ if info is not None:
387
+ for dep_id in info.dependents:
388
+ if dep_id == to_id:
389
+ return True
390
+ stack.append(dep_id)
391
+ return False
392
+
393
+
394
+ async def _register_dep(from_id: int, to_id: int) -> None:
395
+ """Register a dependency edge: from_id depends on to_id."""
396
+ state = _get_state()
397
+ from_info = state.task_infos.get(from_id)
398
+ to_info = state.task_infos.get(to_id)
399
+ if from_info is None or to_info is None:
400
+ return
401
+ if _would_cycle(state, from_id, to_id):
402
+ from_desc = from_info.name
403
+ to_desc = to_info.name
404
+ msg = f"Circular dependency detected: {from_desc!r} -> {to_desc!r} would create a cycle."
405
+ raise RuntimeError(msg)
406
+ async with from_info.allow_edit():
407
+ if to_id not in from_info.deps:
408
+ from_info.deps.append(to_id)
409
+ new_dag_depth = to_info.dag_depth + 1
410
+ if new_dag_depth > from_info.dag_depth:
411
+ from_info.dag_depth = new_dag_depth
412
+ async with to_info.allow_edit():
413
+ if from_id not in to_info.dependents:
414
+ to_info.dependents.append(from_id)
415
+
416
+
417
+ async def log(value: str = "", end: str = "\n") -> None:
418
+ """Add log to task info."""
419
+ try:
420
+ task_id = _task_id.get()
421
+ except LookupError:
422
+ return
423
+ state = _get_state()
424
+ task_info = state.task_infos[task_id]
425
+ async with task_info.allow_edit():
426
+ task_info.logs += value + end
427
+
428
+
429
+ async def get_task_id(task: asyncio.Task, timeout: float = 1) -> int:
430
+ """Get the task_id associated with an asyncio task."""
431
+ state = _get_state()
432
+ async with asyncio.timeout(timeout):
433
+ while task not in state.task_ids:
434
+ await asyncio.sleep(0)
435
+ return state.task_ids[task]
436
+
437
+
438
+ def current_task_info() -> TaskInfo:
439
+ """Return the TaskInfo for the currently executing tracked task.
440
+
441
+ Raises RuntimeError if called outside a node()-wrapped coroutine.
442
+ """
443
+ try:
444
+ task_id = _task_id.get()
445
+ except LookupError:
446
+ msg = "Not inside a tracked task. Call this from within a node()-wrapped coroutine."
447
+ raise RuntimeError(msg) from None
448
+ return get_task_info(task_id)
449
+
450
+
451
+ def get_task_info(task_id: int) -> TaskInfo:
452
+ """Get the task info from a task_id."""
453
+ loop = asyncio.get_running_loop()
454
+ try:
455
+ return _loop_states[loop].task_infos[task_id]
456
+ except KeyError:
457
+ msg = f"No task with id {task_id!r} found in the current event loop."
458
+ raise ValueError(msg) from None
459
+
460
+
461
+ def remove_task(task_id: int) -> None:
462
+ """Remove a task and all its descendants from tracking to free memory."""
463
+ loop = asyncio.get_running_loop()
464
+ state = _loop_states[loop]
465
+ if task_id not in state.task_infos:
466
+ msg = f"No task with id {task_id!r} found in the current event loop."
467
+ raise ValueError(msg)
468
+ stack = [task_id]
469
+ while stack:
470
+ tid = stack.pop()
471
+ task_info = state.task_infos.pop(tid, None)
472
+ if task_info is None:
473
+ continue
474
+ state.task_ids.pop(task_info.task, None)
475
+ stack.extend(task_info.subtasks)
476
+
477
+
478
+ def track[**P, R](
479
+ func: Callable[P, Coroutine[Any, Any, R]],
480
+ start: bool = True,
481
+ ) -> Callable[P, Coroutine[Any, Any, R]]:
482
+ """Track a coroutine by recording task info."""
483
+
484
+ @functools.wraps(func)
485
+ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
486
+ await _init_task_info(start=start)
487
+ return await func(*args, **kwargs)
488
+
489
+ return wrapper
490
+
491
+
492
+
493
+ def make_async[**P, T](
494
+ func: Callable[P, T],
495
+ ) -> Callable[P, Coroutine[Any, Any, T]]:
496
+ """Run function in a separate thread."""
497
+
498
+ @functools.wraps(func)
499
+ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
500
+ return await asyncio.to_thread(func, *args, **kwargs)
501
+
502
+ return wrapper
503
+
504
+
505
+ @runtime_checkable
506
+ class SupportsNext[T](Protocol):
507
+ def __next__(self) -> T: ...
508
+
509
+
510
+ async def make_async_generator[T](gen: SupportsNext[T]) -> AsyncGenerator[T]:
511
+ """Run each `next` call in a separate thread."""
512
+ sentinel = object()
513
+
514
+ def step() -> T | object:
515
+ return next(gen, sentinel)
516
+
517
+ while True:
518
+ obj = await asyncio.to_thread(step)
519
+ if obj is sentinel:
520
+ break
521
+ yield cast("T", obj)
522
+
523
+
524
+ def walk_tree(root: asyncio.Task | int | None = None) -> Iterator[TaskInfo]:
525
+ """DFS pre-order through call tree (parent -> subtasks).
526
+
527
+ If *root* is an ``asyncio.Task``, its tracked id is looked up.
528
+ If *root* is ``None``, the root task (parent=None) is used.
529
+ """
530
+ state = _get_state()
531
+ if root is None:
532
+ root_id = next((tid for tid, info in state.task_infos.items() if info.parent is None), None)
533
+ if root_id is None:
534
+ return
535
+ elif isinstance(root, int):
536
+ root_id = root
537
+ else:
538
+ root_id = state.task_ids.get(root)
539
+ if root_id is None:
540
+ return
541
+
542
+ stack = [root_id]
543
+ while stack:
544
+ tid = stack.pop()
545
+ info = state.task_infos.get(tid)
546
+ if info is None:
547
+ continue
548
+ yield info
549
+ # Push children in reverse so leftmost child is visited first
550
+ stack.extend(reversed(info.subtasks))
551
+
552
+
553
+ def walk_dag(root: asyncio.Task | int | None = None) -> Iterator[TaskInfo]:
554
+ """Topological order (Kahn's algorithm) over tasks.
555
+
556
+ Uses both tree edges (parent->child) and DAG edges (dep->dependent).
557
+ If *root* is ``None``, includes all tasks in the event loop.
558
+ """
559
+ state = _get_state()
560
+
561
+ # Collect the set of task IDs to include
562
+ if root is None:
563
+ ids = list(state.task_infos.keys())
564
+ else:
565
+ if isinstance(root, int):
566
+ start_id = root
567
+ else:
568
+ start_id = state.task_ids.get(root)
569
+ if start_id is None:
570
+ return
571
+ ids = []
572
+ visited: set[int] = set()
573
+ bfs = [start_id]
574
+ while bfs:
575
+ tid = bfs.pop()
576
+ if tid in visited or tid not in state.task_infos:
577
+ continue
578
+ visited.add(tid)
579
+ ids.append(tid)
580
+ bfs.extend(state.task_infos[tid].subtasks)
581
+
582
+ infos = {tid: state.task_infos[tid] for tid in ids if tid in state.task_infos}
583
+ if not infos:
584
+ return
585
+
586
+ # Kahn's algorithm — edges: parent->child and dep->dependent
587
+ in_degree: dict[int, int] = dict.fromkeys(infos, 0)
588
+ successors: dict[int, list[int]] = {tid: [] for tid in infos}
589
+
590
+ for tid, info in infos.items():
591
+ if info.parent is not None and info.parent in infos:
592
+ in_degree[tid] += 1
593
+ successors[info.parent].append(tid)
594
+ for dep_id in info.deps:
595
+ if dep_id in infos and dep_id != info.parent:
596
+ in_degree[tid] += 1
597
+ successors[dep_id].append(tid)
598
+
599
+ queue = sorted(tid for tid, d in in_degree.items() if d == 0)
600
+ while queue:
601
+ tid = queue.pop(0)
602
+ yield infos[tid]
603
+ newly_free = sorted(s for s in successors[tid] if in_degree[s] - 1 == 0)
604
+ for s in successors[tid]:
605
+ in_degree[s] -= 1
606
+ queue = sorted(set(queue) | set(newly_free))
607
+
608
+
609
+ __all__ = [
610
+ "TaskInfo",
611
+ "TaskStatus",
612
+ "current_task_info",
613
+ "get_task_id",
614
+ "get_task_info",
615
+ "log",
616
+ "make_async",
617
+ "make_async_generator",
618
+ "node",
619
+ "remove_task",
620
+ "resolve",
621
+ "walk_dag",
622
+ "walk_tree",
623
+ ]
File without changes