aionode 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aionode-0.1.0/PKG-INFO +158 -0
- aionode-0.1.0/README.md +146 -0
- aionode-0.1.0/pyproject.toml +48 -0
- aionode-0.1.0/src/aionode/__init__.py +623 -0
- aionode-0.1.0/src/aionode/py.typed +0 -0
aionode-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: aionode
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Lightweight asyncio task tracking as call tree and DAG
|
|
5
|
+
Author: Matteo De Pellegrin
|
|
6
|
+
Author-email: Matteo De Pellegrin <matteo.dep97@gmail.com>
|
|
7
|
+
License-Expression: MIT
|
|
8
|
+
Requires-Python: >=3.12
|
|
9
|
+
Project-URL: Repository, https://github.com/MatteoDep/aionode
|
|
10
|
+
Project-URL: Issues, https://github.com/MatteoDep/aionode/issues
|
|
11
|
+
Description-Content-Type: text/markdown
|
|
12
|
+
|
|
13
|
+
# aionode
|
|
14
|
+
|
|
15
|
+
Lightweight asyncio task tracking with dependency graphs and progress rendering.
|
|
16
|
+
|
|
17
|
+
## Installation
|
|
18
|
+
|
|
19
|
+
```bash
|
|
20
|
+
pip install aionode
|
|
21
|
+
```
|
|
22
|
+
|
|
23
|
+
For Rich table rendering:
|
|
24
|
+
|
|
25
|
+
```bash
|
|
26
|
+
pip install aionode[viz]
|
|
27
|
+
```
|
|
28
|
+
|
|
29
|
+
## Quick Start
|
|
30
|
+
|
|
31
|
+
```python
|
|
32
|
+
import asyncio
|
|
33
|
+
import aionode
|
|
34
|
+
|
|
35
|
+
async def fetch_data() -> list[int]:
|
|
36
|
+
await asyncio.sleep(1)
|
|
37
|
+
return list(range(100))
|
|
38
|
+
|
|
39
|
+
async def process(data: list[int]) -> int:
|
|
40
|
+
await asyncio.sleep(0.5)
|
|
41
|
+
return sum(data)
|
|
42
|
+
|
|
43
|
+
async def pipeline() -> None:
|
|
44
|
+
async with asyncio.TaskGroup() as tg:
|
|
45
|
+
fetch = tg.create_task(aionode.node(fetch_data)(), name="fetch")
|
|
46
|
+
tg.create_task(aionode.node(process)(aionode.resolve(fetch)), name="process")
|
|
47
|
+
|
|
48
|
+
async def main() -> None:
|
|
49
|
+
root = asyncio.create_task(aionode.track(pipeline)(), name="pipeline")
|
|
50
|
+
root_id = await aionode.get_task_id(root)
|
|
51
|
+
graph = aionode.TaskGraph(root_id=root_id)
|
|
52
|
+
|
|
53
|
+
await aionode.watch(graph, interval=0.3)
|
|
54
|
+
await root
|
|
55
|
+
|
|
56
|
+
asyncio.run(main())
|
|
57
|
+
```
|
|
58
|
+
|
|
59
|
+
## Core Concepts
|
|
60
|
+
|
|
61
|
+
### `node(func, wait_for=[], track=True, auto_progress=True)`
|
|
62
|
+
|
|
63
|
+
Wraps an async function as a DAG node. Use `resolve()` to pass awaitables as arguments — they are gathered concurrently before the function is called. Sync functions must be wrapped with `make_async` first.
|
|
64
|
+
|
|
65
|
+
```python
|
|
66
|
+
# Async function — pass upstream tasks with resolve()
|
|
67
|
+
fetch = tg.create_task(aionode.node(fetch_data)(), name="fetch")
|
|
68
|
+
process_task = tg.create_task(aionode.node(process)(aionode.resolve(fetch)), name="process")
|
|
69
|
+
|
|
70
|
+
# Sync function — wrap with make_async first
|
|
71
|
+
summarize = tg.create_task(
|
|
72
|
+
aionode.node(aionode.make_async(my_sync_fn))(aionode.resolve(process_task)),
|
|
73
|
+
name="summarize",
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
# Side-only dependency (no value passed): use wait_for
|
|
77
|
+
task_b = tg.create_task(aionode.node(cleanup, wait_for=[fetch])(), name="cleanup")
|
|
78
|
+
```
|
|
79
|
+
|
|
80
|
+
### `resolve(awaitable)`
|
|
81
|
+
|
|
82
|
+
Marks an awaitable to be resolved before being passed as an argument to `node()`. This preserves type information — the type checker sees `resolve(task: Task[T])` as returning `T`.
|
|
83
|
+
|
|
84
|
+
```python
|
|
85
|
+
result = await aionode.node(process)(aionode.resolve(upstream_task))
|
|
86
|
+
```
|
|
87
|
+
|
|
88
|
+
### `track(func, start=True)`
|
|
89
|
+
|
|
90
|
+
Tracks a coroutine by registering it in the task graph. Use this for the root task or tasks that don't need `node()`'s dependency resolution.
|
|
91
|
+
|
|
92
|
+
```python
|
|
93
|
+
root = asyncio.create_task(aionode.track(my_coro)(), name="root")
|
|
94
|
+
```
|
|
95
|
+
|
|
96
|
+
### `TaskGraph`
|
|
97
|
+
|
|
98
|
+
Query the dependency graph:
|
|
99
|
+
|
|
100
|
+
```python
|
|
101
|
+
graph = aionode.TaskGraph(root_id=root_id)
|
|
102
|
+
|
|
103
|
+
graph.nodes() # All tasks in topological order
|
|
104
|
+
graph.roots() # Tasks with no upstream deps
|
|
105
|
+
graph.leaves() # Tasks with no downstream dependents
|
|
106
|
+
graph.upstream(task_id) # Transitive upstream deps
|
|
107
|
+
graph.downstream(task_id) # Transitive downstream dependents
|
|
108
|
+
graph.summary() # {TaskStatus: count}
|
|
109
|
+
graph.critical_path() # Longest-duration path
|
|
110
|
+
```
|
|
111
|
+
|
|
112
|
+
### Rendering
|
|
113
|
+
|
|
114
|
+
```python
|
|
115
|
+
# Auto-detect Rich or fall back to ANSI text
|
|
116
|
+
renderer = aionode.get_render()
|
|
117
|
+
|
|
118
|
+
# Force plain text
|
|
119
|
+
renderer = aionode.get_render(rich=False)
|
|
120
|
+
|
|
121
|
+
# Live-updating display
|
|
122
|
+
await aionode.watch(graph, interval=0.5, renderer=renderer)
|
|
123
|
+
```
|
|
124
|
+
|
|
125
|
+
### Task Inspection
|
|
126
|
+
|
|
127
|
+
```python
|
|
128
|
+
task_id = await aionode.get_task_id(asyncio_task)
|
|
129
|
+
info = aionode.get_task(task_id)
|
|
130
|
+
|
|
131
|
+
info.status # TaskStatus: WAITING, RUNNING, DONE, FAILED, CANCELLED
|
|
132
|
+
info.duration() # Elapsed seconds
|
|
133
|
+
info.description # Task name
|
|
134
|
+
info.deps # Upstream dependency IDs
|
|
135
|
+
info.dependents # Downstream dependent IDs
|
|
136
|
+
info.logs # Accumulated log output
|
|
137
|
+
|
|
138
|
+
await aionode.log("processing record 42") # Append to current task's logs
|
|
139
|
+
```
|
|
140
|
+
|
|
141
|
+
## API Reference
|
|
142
|
+
|
|
143
|
+
| Function | Description |
|
|
144
|
+
|---|---|
|
|
145
|
+
| `node(func, wait_for, track, auto_progress)` | Wrap an async function as a DAG node |
|
|
146
|
+
| `resolve(awaitable)` | Mark an awaitable to be resolved as a node argument |
|
|
147
|
+
| `track(func, start)` | Track a coroutine in the task graph |
|
|
148
|
+
| `get_task_id(task, timeout)` | Get the task ID for an asyncio.Task |
|
|
149
|
+
| `get_task(task_id)` | Get TaskInfo by ID |
|
|
150
|
+
| `remove_task(task_id)` | Remove a task and its descendants |
|
|
151
|
+
| `log(value, end)` | Append to the current task's logs |
|
|
152
|
+
| `make_async(func)` | Run a sync function in a thread |
|
|
153
|
+
| `make_async_generator(gen)` | Async iterate a sync iterator via threads |
|
|
154
|
+
| `TaskGraph(root_id)` | Create a graph view rooted at a task |
|
|
155
|
+
| `TaskGraph.from_task(task)` | Create a graph from an asyncio.Task |
|
|
156
|
+
| `TaskGraph.current()` | Graph over all tasks in the current loop |
|
|
157
|
+
| `get_render(rich, bar_width, ...)` | Get a configured render function |
|
|
158
|
+
| `watch(graph, interval, renderer)` | Live-render until all tasks finish |
|
aionode-0.1.0/README.md
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
# aionode
|
|
2
|
+
|
|
3
|
+
Lightweight asyncio task tracking with dependency graphs and progress rendering.
|
|
4
|
+
|
|
5
|
+
## Installation
|
|
6
|
+
|
|
7
|
+
```bash
|
|
8
|
+
pip install aionode
|
|
9
|
+
```
|
|
10
|
+
|
|
11
|
+
For Rich table rendering:
|
|
12
|
+
|
|
13
|
+
```bash
|
|
14
|
+
pip install aionode[viz]
|
|
15
|
+
```
|
|
16
|
+
|
|
17
|
+
## Quick Start
|
|
18
|
+
|
|
19
|
+
```python
|
|
20
|
+
import asyncio
|
|
21
|
+
import aionode
|
|
22
|
+
|
|
23
|
+
async def fetch_data() -> list[int]:
|
|
24
|
+
await asyncio.sleep(1)
|
|
25
|
+
return list(range(100))
|
|
26
|
+
|
|
27
|
+
async def process(data: list[int]) -> int:
|
|
28
|
+
await asyncio.sleep(0.5)
|
|
29
|
+
return sum(data)
|
|
30
|
+
|
|
31
|
+
async def pipeline() -> None:
|
|
32
|
+
async with asyncio.TaskGroup() as tg:
|
|
33
|
+
fetch = tg.create_task(aionode.node(fetch_data)(), name="fetch")
|
|
34
|
+
tg.create_task(aionode.node(process)(aionode.resolve(fetch)), name="process")
|
|
35
|
+
|
|
36
|
+
async def main() -> None:
|
|
37
|
+
root = asyncio.create_task(aionode.track(pipeline)(), name="pipeline")
|
|
38
|
+
root_id = await aionode.get_task_id(root)
|
|
39
|
+
graph = aionode.TaskGraph(root_id=root_id)
|
|
40
|
+
|
|
41
|
+
await aionode.watch(graph, interval=0.3)
|
|
42
|
+
await root
|
|
43
|
+
|
|
44
|
+
asyncio.run(main())
|
|
45
|
+
```
|
|
46
|
+
|
|
47
|
+
## Core Concepts
|
|
48
|
+
|
|
49
|
+
### `node(func, wait_for=[], track=True, auto_progress=True)`
|
|
50
|
+
|
|
51
|
+
Wraps an async function as a DAG node. Use `resolve()` to pass awaitables as arguments — they are gathered concurrently before the function is called. Sync functions must be wrapped with `make_async` first.
|
|
52
|
+
|
|
53
|
+
```python
|
|
54
|
+
# Async function — pass upstream tasks with resolve()
|
|
55
|
+
fetch = tg.create_task(aionode.node(fetch_data)(), name="fetch")
|
|
56
|
+
process_task = tg.create_task(aionode.node(process)(aionode.resolve(fetch)), name="process")
|
|
57
|
+
|
|
58
|
+
# Sync function — wrap with make_async first
|
|
59
|
+
summarize = tg.create_task(
|
|
60
|
+
aionode.node(aionode.make_async(my_sync_fn))(aionode.resolve(process_task)),
|
|
61
|
+
name="summarize",
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# Side-only dependency (no value passed): use wait_for
|
|
65
|
+
task_b = tg.create_task(aionode.node(cleanup, wait_for=[fetch])(), name="cleanup")
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
### `resolve(awaitable)`
|
|
69
|
+
|
|
70
|
+
Marks an awaitable to be resolved before being passed as an argument to `node()`. This preserves type information — the type checker sees `resolve(task: Task[T])` as returning `T`.
|
|
71
|
+
|
|
72
|
+
```python
|
|
73
|
+
result = await aionode.node(process)(aionode.resolve(upstream_task))
|
|
74
|
+
```
|
|
75
|
+
|
|
76
|
+
### `track(func, start=True)`
|
|
77
|
+
|
|
78
|
+
Tracks a coroutine by registering it in the task graph. Use this for the root task or tasks that don't need `node()`'s dependency resolution.
|
|
79
|
+
|
|
80
|
+
```python
|
|
81
|
+
root = asyncio.create_task(aionode.track(my_coro)(), name="root")
|
|
82
|
+
```
|
|
83
|
+
|
|
84
|
+
### `TaskGraph`
|
|
85
|
+
|
|
86
|
+
Query the dependency graph:
|
|
87
|
+
|
|
88
|
+
```python
|
|
89
|
+
graph = aionode.TaskGraph(root_id=root_id)
|
|
90
|
+
|
|
91
|
+
graph.nodes() # All tasks in topological order
|
|
92
|
+
graph.roots() # Tasks with no upstream deps
|
|
93
|
+
graph.leaves() # Tasks with no downstream dependents
|
|
94
|
+
graph.upstream(task_id) # Transitive upstream deps
|
|
95
|
+
graph.downstream(task_id) # Transitive downstream dependents
|
|
96
|
+
graph.summary() # {TaskStatus: count}
|
|
97
|
+
graph.critical_path() # Longest-duration path
|
|
98
|
+
```
|
|
99
|
+
|
|
100
|
+
### Rendering
|
|
101
|
+
|
|
102
|
+
```python
|
|
103
|
+
# Auto-detect Rich or fall back to ANSI text
|
|
104
|
+
renderer = aionode.get_render()
|
|
105
|
+
|
|
106
|
+
# Force plain text
|
|
107
|
+
renderer = aionode.get_render(rich=False)
|
|
108
|
+
|
|
109
|
+
# Live-updating display
|
|
110
|
+
await aionode.watch(graph, interval=0.5, renderer=renderer)
|
|
111
|
+
```
|
|
112
|
+
|
|
113
|
+
### Task Inspection
|
|
114
|
+
|
|
115
|
+
```python
|
|
116
|
+
task_id = await aionode.get_task_id(asyncio_task)
|
|
117
|
+
info = aionode.get_task(task_id)
|
|
118
|
+
|
|
119
|
+
info.status # TaskStatus: WAITING, RUNNING, DONE, FAILED, CANCELLED
|
|
120
|
+
info.duration() # Elapsed seconds
|
|
121
|
+
info.description # Task name
|
|
122
|
+
info.deps # Upstream dependency IDs
|
|
123
|
+
info.dependents # Downstream dependent IDs
|
|
124
|
+
info.logs # Accumulated log output
|
|
125
|
+
|
|
126
|
+
await aionode.log("processing record 42") # Append to current task's logs
|
|
127
|
+
```
|
|
128
|
+
|
|
129
|
+
## API Reference
|
|
130
|
+
|
|
131
|
+
| Function | Description |
|
|
132
|
+
|---|---|
|
|
133
|
+
| `node(func, wait_for, track, auto_progress)` | Wrap an async function as a DAG node |
|
|
134
|
+
| `resolve(awaitable)` | Mark an awaitable to be resolved as a node argument |
|
|
135
|
+
| `track(func, start)` | Track a coroutine in the task graph |
|
|
136
|
+
| `get_task_id(task, timeout)` | Get the task ID for an asyncio.Task |
|
|
137
|
+
| `get_task(task_id)` | Get TaskInfo by ID |
|
|
138
|
+
| `remove_task(task_id)` | Remove a task and its descendants |
|
|
139
|
+
| `log(value, end)` | Append to the current task's logs |
|
|
140
|
+
| `make_async(func)` | Run a sync function in a thread |
|
|
141
|
+
| `make_async_generator(gen)` | Async iterate a sync iterator via threads |
|
|
142
|
+
| `TaskGraph(root_id)` | Create a graph view rooted at a task |
|
|
143
|
+
| `TaskGraph.from_task(task)` | Create a graph from an asyncio.Task |
|
|
144
|
+
| `TaskGraph.current()` | Graph over all tasks in the current loop |
|
|
145
|
+
| `get_render(rich, bar_width, ...)` | Get a configured render function |
|
|
146
|
+
| `watch(graph, interval, renderer)` | Live-render until all tasks finish |
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "aionode"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "Lightweight asyncio task tracking as call tree and DAG"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
authors = [
|
|
7
|
+
{ name = "Matteo De Pellegrin", email = "matteo.dep97@gmail.com" }
|
|
8
|
+
]
|
|
9
|
+
license = "MIT"
|
|
10
|
+
requires-python = ">=3.12"
|
|
11
|
+
dependencies = []
|
|
12
|
+
|
|
13
|
+
[project.urls]
|
|
14
|
+
Repository = "https://github.com/MatteoDep/aionode"
|
|
15
|
+
Issues = "https://github.com/MatteoDep/aionode/issues"
|
|
16
|
+
|
|
17
|
+
[build-system]
|
|
18
|
+
requires = ["uv_build>=0.10.3,<0.11.0"]
|
|
19
|
+
build-backend = "uv_build"
|
|
20
|
+
|
|
21
|
+
[dependency-groups]
|
|
22
|
+
dev = [
|
|
23
|
+
"pytest>=9.0.2",
|
|
24
|
+
"pytest-asyncio>=1.3.0",
|
|
25
|
+
"ruff>=0.15.4",
|
|
26
|
+
"ty>=0.0.19",
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
[tool.pytest.ini_options]
|
|
30
|
+
asyncio_mode = "auto"
|
|
31
|
+
|
|
32
|
+
[tool.ruff]
|
|
33
|
+
line-length = 120
|
|
34
|
+
|
|
35
|
+
[tool.ruff.lint]
|
|
36
|
+
select = [
|
|
37
|
+
"E", # pycodestyle errors
|
|
38
|
+
"W", # pycodestyle warnings
|
|
39
|
+
"F", # pyflakes
|
|
40
|
+
"I", # isort
|
|
41
|
+
"UP", # pyupgrade
|
|
42
|
+
"B", # flake8-bugbear
|
|
43
|
+
"C4", # flake8-comprehensions
|
|
44
|
+
"RUF", # ruff-specific rules
|
|
45
|
+
]
|
|
46
|
+
|
|
47
|
+
[tool.ruff.lint.per-file-ignores]
|
|
48
|
+
"tests/**" = ["S101"] # allow assert in tests
|
|
@@ -0,0 +1,623 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
__version__ = "0.1.0"
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import functools
|
|
7
|
+
import inspect
|
|
8
|
+
import threading
|
|
9
|
+
import time
|
|
10
|
+
import weakref
|
|
11
|
+
from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine, Iterator, Sequence
|
|
12
|
+
from contextlib import asynccontextmanager
|
|
13
|
+
from contextvars import ContextVar
|
|
14
|
+
from dataclasses import dataclass, field
|
|
15
|
+
from datetime import datetime
|
|
16
|
+
from enum import StrEnum
|
|
17
|
+
from typing import Any, Protocol, cast, runtime_checkable
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass(slots=True)
|
|
21
|
+
class _Resolved[T]:
|
|
22
|
+
awaitable: Awaitable[T]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def resolve[T](awaitable: Awaitable[T]) -> T:
|
|
26
|
+
return _Resolved(awaitable) # type: ignore[return-value]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def node[**P, R](
|
|
30
|
+
func: Callable[P, Coroutine[Any, Any, R]],
|
|
31
|
+
/,
|
|
32
|
+
wait_for: Sequence[Awaitable[Any]] | None = None,
|
|
33
|
+
track: bool = True,
|
|
34
|
+
auto_progress: bool = True,
|
|
35
|
+
) -> Callable[P, Coroutine[Any, Any, R]]:
|
|
36
|
+
@functools.wraps(func)
|
|
37
|
+
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
38
|
+
if track:
|
|
39
|
+
await _init_task_info(start=False, auto_progress=auto_progress)
|
|
40
|
+
_start = _start_task
|
|
41
|
+
else:
|
|
42
|
+
|
|
43
|
+
async def _start() -> None:
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
# Identify _Resolved positions
|
|
47
|
+
resolved_arg_idxs = [(i, a) for i, a in enumerate(args) if isinstance(a, _Resolved)]
|
|
48
|
+
resolved_kwarg_keys = [(k, v) for k, v in kwargs.items() if isinstance(v, _Resolved)]
|
|
49
|
+
all_awaitables = (
|
|
50
|
+
[a.awaitable for _, a in resolved_arg_idxs]
|
|
51
|
+
+ [v.awaitable for _, v in resolved_kwarg_keys]
|
|
52
|
+
+ list(wait_for or [])
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
if track:
|
|
56
|
+
state = _get_state()
|
|
57
|
+
our_id = _task_id.get()
|
|
58
|
+
dep_tasks: list[asyncio.Task] = (
|
|
59
|
+
[a.awaitable for _, a in resolved_arg_idxs if isinstance(a.awaitable, asyncio.Task)]
|
|
60
|
+
+ [v.awaitable for _, v in resolved_kwarg_keys if isinstance(v.awaitable, asyncio.Task)]
|
|
61
|
+
+ [d for d in (wait_for or []) if isinstance(d, asyncio.Task)]
|
|
62
|
+
)
|
|
63
|
+
for dep_task in dep_tasks:
|
|
64
|
+
if dep_task in state.task_ids:
|
|
65
|
+
await _register_dep(our_id, state.task_ids[dep_task])
|
|
66
|
+
|
|
67
|
+
try:
|
|
68
|
+
if all_awaitables:
|
|
69
|
+
results = await asyncio.gather(*all_awaitables)
|
|
70
|
+
n_args = len(resolved_arg_idxs)
|
|
71
|
+
n_kw = len(resolved_kwarg_keys)
|
|
72
|
+
arg_results = list(results[:n_args])
|
|
73
|
+
kwarg_results = list(results[n_args : n_args + n_kw])
|
|
74
|
+
# results[n_args + n_kw:] are wait_for results — discarded
|
|
75
|
+
else:
|
|
76
|
+
arg_results, kwarg_results = [], []
|
|
77
|
+
except Exception as e:
|
|
78
|
+
msg = "Failed while waiting to start."
|
|
79
|
+
raise RuntimeError(msg) from e
|
|
80
|
+
|
|
81
|
+
# Rebuild args/kwargs with resolved values
|
|
82
|
+
resolved_args = list(args)
|
|
83
|
+
for (i, _), val in zip(resolved_arg_idxs, arg_results, strict=True):
|
|
84
|
+
resolved_args[i] = val
|
|
85
|
+
resolved_kwargs = dict(kwargs)
|
|
86
|
+
for (k, _), val in zip(resolved_kwarg_keys, kwarg_results, strict=True):
|
|
87
|
+
resolved_kwargs[k] = val
|
|
88
|
+
|
|
89
|
+
await _start()
|
|
90
|
+
|
|
91
|
+
result = func(*resolved_args, **resolved_kwargs)
|
|
92
|
+
return await result if inspect.isawaitable(result) else result
|
|
93
|
+
|
|
94
|
+
return wrapper
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class _Unset:
|
|
98
|
+
"""Sentinel for unset keyword arguments."""
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
_UNSET = _Unset()
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class TaskStatus(StrEnum):
|
|
105
|
+
"""Status of a tracked asyncio task."""
|
|
106
|
+
|
|
107
|
+
WAITING = "waiting to start"
|
|
108
|
+
RUNNING = "running"
|
|
109
|
+
DONE = "done"
|
|
110
|
+
FAILED = "failed"
|
|
111
|
+
CANCELLED = "cancelled"
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@dataclass(slots=True)
|
|
115
|
+
class TaskInfo:
|
|
116
|
+
"""Metadata and state for a single tracked asyncio task."""
|
|
117
|
+
|
|
118
|
+
id: int
|
|
119
|
+
task: asyncio.Task
|
|
120
|
+
name: str
|
|
121
|
+
parent: int | None
|
|
122
|
+
subtasks: list[int]
|
|
123
|
+
running_subtasks: list[int]
|
|
124
|
+
status: TaskStatus
|
|
125
|
+
started_at: datetime | None = None
|
|
126
|
+
finished_at: datetime | None = None
|
|
127
|
+
exception: BaseException | None = None
|
|
128
|
+
logs: str = ""
|
|
129
|
+
completed: float = 0
|
|
130
|
+
total: float | None = None
|
|
131
|
+
auto_progress: bool = True
|
|
132
|
+
deps: list[int] = field(default_factory=list)
|
|
133
|
+
dependents: list[int] = field(default_factory=list)
|
|
134
|
+
tree_depth: int = 0
|
|
135
|
+
dag_depth: int = 0
|
|
136
|
+
_start_mono: float | None = field(default=None, repr=False, compare=False)
|
|
137
|
+
_finish_mono: float | None = field(default=None, repr=False, compare=False)
|
|
138
|
+
_lock: asyncio.Lock = field(default_factory=asyncio.Lock, repr=False, compare=False)
|
|
139
|
+
_edit_allowed: bool = field(default=False, repr=False, compare=False)
|
|
140
|
+
|
|
141
|
+
def __setattr__(self, name: str, value: Any, /) -> None:
|
|
142
|
+
if name in ("_edit_allowed", "_lock", "_start_mono", "_finish_mono"):
|
|
143
|
+
object.__setattr__(self, name, value)
|
|
144
|
+
return
|
|
145
|
+
|
|
146
|
+
if hasattr(self, "_edit_allowed") and not object.__getattribute__(self, "_edit_allowed"):
|
|
147
|
+
msg = "Edit not allowed. Use the `allow_edit` context manager."
|
|
148
|
+
raise RuntimeError(msg)
|
|
149
|
+
|
|
150
|
+
object.__setattr__(self, name, value)
|
|
151
|
+
|
|
152
|
+
@asynccontextmanager
|
|
153
|
+
async def allow_edit(self) -> AsyncGenerator[None]:
|
|
154
|
+
async with self._lock:
|
|
155
|
+
self._edit_allowed = True
|
|
156
|
+
try:
|
|
157
|
+
yield
|
|
158
|
+
finally:
|
|
159
|
+
self._edit_allowed = False
|
|
160
|
+
|
|
161
|
+
async def update(
|
|
162
|
+
self,
|
|
163
|
+
*,
|
|
164
|
+
completed: float | _Unset = _UNSET,
|
|
165
|
+
total: float | None | _Unset = _UNSET,
|
|
166
|
+
) -> None:
|
|
167
|
+
"""Update user-facing fields atomically."""
|
|
168
|
+
async with self.allow_edit():
|
|
169
|
+
if not isinstance(completed, _Unset):
|
|
170
|
+
self.completed = completed
|
|
171
|
+
if not isinstance(total, _Unset):
|
|
172
|
+
self.total = total
|
|
173
|
+
|
|
174
|
+
def subtasks_info(
|
|
175
|
+
self,
|
|
176
|
+
fmt: Callable[[TaskInfo], str] = "- {0.name}: {0.status.value}".format,
|
|
177
|
+
sep: str = "\n",
|
|
178
|
+
all_subtasks: bool = False,
|
|
179
|
+
) -> str:
|
|
180
|
+
items: list[str] = []
|
|
181
|
+
for child_id in self.subtasks if all_subtasks else self.running_subtasks:
|
|
182
|
+
try:
|
|
183
|
+
items.append(fmt(get_task_info(child_id)))
|
|
184
|
+
except ValueError:
|
|
185
|
+
continue
|
|
186
|
+
return sep.join(items)
|
|
187
|
+
|
|
188
|
+
def started(self) -> bool:
|
|
189
|
+
return self.started_at is not None
|
|
190
|
+
|
|
191
|
+
def done(self) -> bool:
|
|
192
|
+
return self.finished_at is not None
|
|
193
|
+
|
|
194
|
+
def duration(self) -> float:
|
|
195
|
+
"""Get task duration in seconds."""
|
|
196
|
+
if self._start_mono is None:
|
|
197
|
+
return 0.0
|
|
198
|
+
end = self._finish_mono if self._finish_mono is not None else time.monotonic()
|
|
199
|
+
return end - self._start_mono
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
_task_id: ContextVar[int] = ContextVar("task_id")
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
@dataclass
|
|
206
|
+
class _LoopState:
|
|
207
|
+
task_infos: dict[int, TaskInfo] = field(default_factory=dict)
|
|
208
|
+
task_ids: dict[asyncio.Task, int] = field(default_factory=dict)
|
|
209
|
+
background_tasks: set[asyncio.Task] = field(default_factory=set)
|
|
210
|
+
_next_id: int = field(default=0)
|
|
211
|
+
_id_lock: threading.Lock = field(default_factory=threading.Lock)
|
|
212
|
+
|
|
213
|
+
def allocate_id(self) -> int:
|
|
214
|
+
with self._id_lock:
|
|
215
|
+
task_id = self._next_id
|
|
216
|
+
self._next_id += 1
|
|
217
|
+
return task_id
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
_loop_states: weakref.WeakKeyDictionary[asyncio.AbstractEventLoop, _LoopState] = weakref.WeakKeyDictionary()
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def _get_state() -> _LoopState:
|
|
224
|
+
loop = asyncio.get_running_loop()
|
|
225
|
+
try:
|
|
226
|
+
return _loop_states[loop]
|
|
227
|
+
except KeyError:
|
|
228
|
+
state = _LoopState()
|
|
229
|
+
_loop_states[loop] = state
|
|
230
|
+
return state
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
async def _set_done(task_id: int) -> None:
|
|
234
|
+
state = _get_state()
|
|
235
|
+
task_info = state.task_infos[task_id]
|
|
236
|
+
async with task_info.allow_edit():
|
|
237
|
+
task_info._finish_mono = time.monotonic()
|
|
238
|
+
task_info.finished_at = datetime.now()
|
|
239
|
+
if task_info.task.cancelled():
|
|
240
|
+
task_info.status = TaskStatus.CANCELLED
|
|
241
|
+
elif exc := task_info.task.exception():
|
|
242
|
+
task_info.status = TaskStatus.FAILED
|
|
243
|
+
task_info.exception = exc
|
|
244
|
+
else:
|
|
245
|
+
task_info.status = TaskStatus.DONE
|
|
246
|
+
if task_info.total is None:
|
|
247
|
+
task_info.total = 1
|
|
248
|
+
task_info.completed = 1
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
async def _update_parent(task_id: int, parent_id: int, auto_progress: bool) -> None:
|
|
252
|
+
state = _get_state()
|
|
253
|
+
parent_task_info = state.task_infos[parent_id]
|
|
254
|
+
async with parent_task_info.allow_edit():
|
|
255
|
+
if auto_progress:
|
|
256
|
+
parent_task_info.completed = (parent_task_info.completed or 0) + 1
|
|
257
|
+
if task_id in parent_task_info.running_subtasks:
|
|
258
|
+
parent_task_info.running_subtasks.remove(task_id)
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def _add_done_callback(task: asyncio.Task, task_id: int, state: _LoopState) -> None:
|
|
262
|
+
def callback(_: asyncio.Task) -> None:
|
|
263
|
+
callback_task = asyncio.create_task(_set_done(task_id=task_id))
|
|
264
|
+
state.background_tasks.add(callback_task)
|
|
265
|
+
callback_task.add_done_callback(state.background_tasks.discard)
|
|
266
|
+
|
|
267
|
+
task.add_done_callback(callback)
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def _add_update_parent_callback(
|
|
271
|
+
task: asyncio.Task,
|
|
272
|
+
task_id: int,
|
|
273
|
+
parent_id: int,
|
|
274
|
+
auto_progress: bool,
|
|
275
|
+
state: _LoopState,
|
|
276
|
+
) -> None:
|
|
277
|
+
def callback(_: asyncio.Task) -> None:
|
|
278
|
+
callback_task = asyncio.create_task(
|
|
279
|
+
_update_parent(task_id=task_id, parent_id=parent_id, auto_progress=auto_progress)
|
|
280
|
+
)
|
|
281
|
+
state.background_tasks.add(callback_task)
|
|
282
|
+
callback_task.add_done_callback(state.background_tasks.discard)
|
|
283
|
+
|
|
284
|
+
task.add_done_callback(callback)
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def _get_task() -> asyncio.Task:
|
|
288
|
+
try:
|
|
289
|
+
task = asyncio.current_task()
|
|
290
|
+
except RuntimeError as e:
|
|
291
|
+
msg = "This function can only be called from a coroutine."
|
|
292
|
+
raise RuntimeError(msg) from e
|
|
293
|
+
if task is None:
|
|
294
|
+
msg = "No current task. This function must be called from within an asyncio task, not a callback."
|
|
295
|
+
raise RuntimeError(msg)
|
|
296
|
+
return task
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
async def _init_task_info(start: bool = True, auto_progress: bool = True) -> None:
|
|
300
|
+
state = _get_state()
|
|
301
|
+
task = _get_task()
|
|
302
|
+
task_name = task.get_name()
|
|
303
|
+
if task in state.task_ids:
|
|
304
|
+
msg = f"Task {task_name} is already initialized"
|
|
305
|
+
raise RuntimeError(msg)
|
|
306
|
+
|
|
307
|
+
task_id = state.allocate_id()
|
|
308
|
+
|
|
309
|
+
# get parent
|
|
310
|
+
try:
|
|
311
|
+
parent_id = _task_id.get()
|
|
312
|
+
except LookupError:
|
|
313
|
+
parent_id = None
|
|
314
|
+
|
|
315
|
+
parent_tree_depth = state.task_infos[parent_id].tree_depth if parent_id is not None else -1
|
|
316
|
+
|
|
317
|
+
task_info = TaskInfo(
|
|
318
|
+
id=task_id,
|
|
319
|
+
name=task_name,
|
|
320
|
+
parent=parent_id,
|
|
321
|
+
subtasks=[],
|
|
322
|
+
started_at=datetime.now() if start else None,
|
|
323
|
+
_start_mono=time.monotonic() if start else None,
|
|
324
|
+
status=TaskStatus.RUNNING if start else TaskStatus.WAITING,
|
|
325
|
+
task=task,
|
|
326
|
+
running_subtasks=[],
|
|
327
|
+
auto_progress=auto_progress,
|
|
328
|
+
tree_depth=parent_tree_depth + 1,
|
|
329
|
+
dag_depth=0,
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
async with task_info.allow_edit():
|
|
333
|
+
state.task_infos[task_id] = task_info
|
|
334
|
+
_add_done_callback(task, task_id=task_id, state=state)
|
|
335
|
+
if parent_id is not None:
|
|
336
|
+
parent_task_info = state.task_infos[parent_id]
|
|
337
|
+
async with parent_task_info.allow_edit():
|
|
338
|
+
parent_task_info.subtasks.append(task_id)
|
|
339
|
+
if start:
|
|
340
|
+
parent_task_info.running_subtasks.append(task_id)
|
|
341
|
+
_add_update_parent_callback(
|
|
342
|
+
task,
|
|
343
|
+
task_id=task_id,
|
|
344
|
+
parent_id=parent_id,
|
|
345
|
+
auto_progress=parent_task_info.auto_progress,
|
|
346
|
+
state=state,
|
|
347
|
+
)
|
|
348
|
+
if parent_task_info.auto_progress:
|
|
349
|
+
total = parent_task_info.total or 0
|
|
350
|
+
parent_task_info.total = total + 1
|
|
351
|
+
state.task_ids[task] = task_id
|
|
352
|
+
_task_id.set(task_id)
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
async def _start_task() -> None:
|
|
357
|
+
state = _get_state()
|
|
358
|
+
task = _get_task()
|
|
359
|
+
if task not in state.task_ids:
|
|
360
|
+
msg = f"Cannot start uninitialized task {task.get_name()}"
|
|
361
|
+
raise RuntimeError(msg)
|
|
362
|
+
task_id = _task_id.get()
|
|
363
|
+
task_info = state.task_infos[task_id]
|
|
364
|
+
async with task_info.allow_edit():
|
|
365
|
+
task_info._start_mono = time.monotonic()
|
|
366
|
+
task_info.started_at = datetime.now()
|
|
367
|
+
task_info.status = TaskStatus.RUNNING
|
|
368
|
+
if task_info.parent is not None:
|
|
369
|
+
parent_info = state.task_infos[task_info.parent]
|
|
370
|
+
async with parent_info.allow_edit():
|
|
371
|
+
parent_info.running_subtasks.append(task_id)
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
def _would_cycle(state: _LoopState, from_id: int, to_id: int) -> bool:
|
|
375
|
+
"""Return True if adding edge from_id -> to_id would create a cycle."""
|
|
376
|
+
visited: set[int] = set()
|
|
377
|
+
stack = [from_id]
|
|
378
|
+
while stack:
|
|
379
|
+
tid = stack.pop()
|
|
380
|
+
if tid == to_id:
|
|
381
|
+
continue
|
|
382
|
+
if tid in visited:
|
|
383
|
+
continue
|
|
384
|
+
visited.add(tid)
|
|
385
|
+
info = state.task_infos.get(tid)
|
|
386
|
+
if info is not None:
|
|
387
|
+
for dep_id in info.dependents:
|
|
388
|
+
if dep_id == to_id:
|
|
389
|
+
return True
|
|
390
|
+
stack.append(dep_id)
|
|
391
|
+
return False
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
async def _register_dep(from_id: int, to_id: int) -> None:
|
|
395
|
+
"""Register a dependency edge: from_id depends on to_id."""
|
|
396
|
+
state = _get_state()
|
|
397
|
+
from_info = state.task_infos.get(from_id)
|
|
398
|
+
to_info = state.task_infos.get(to_id)
|
|
399
|
+
if from_info is None or to_info is None:
|
|
400
|
+
return
|
|
401
|
+
if _would_cycle(state, from_id, to_id):
|
|
402
|
+
from_desc = from_info.name
|
|
403
|
+
to_desc = to_info.name
|
|
404
|
+
msg = f"Circular dependency detected: {from_desc!r} -> {to_desc!r} would create a cycle."
|
|
405
|
+
raise RuntimeError(msg)
|
|
406
|
+
async with from_info.allow_edit():
|
|
407
|
+
if to_id not in from_info.deps:
|
|
408
|
+
from_info.deps.append(to_id)
|
|
409
|
+
new_dag_depth = to_info.dag_depth + 1
|
|
410
|
+
if new_dag_depth > from_info.dag_depth:
|
|
411
|
+
from_info.dag_depth = new_dag_depth
|
|
412
|
+
async with to_info.allow_edit():
|
|
413
|
+
if from_id not in to_info.dependents:
|
|
414
|
+
to_info.dependents.append(from_id)
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
async def log(value: str = "", end: str = "\n") -> None:
|
|
418
|
+
"""Add log to task info."""
|
|
419
|
+
try:
|
|
420
|
+
task_id = _task_id.get()
|
|
421
|
+
except LookupError:
|
|
422
|
+
return
|
|
423
|
+
state = _get_state()
|
|
424
|
+
task_info = state.task_infos[task_id]
|
|
425
|
+
async with task_info.allow_edit():
|
|
426
|
+
task_info.logs += value + end
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
async def get_task_id(task: asyncio.Task, timeout: float = 1) -> int:
|
|
430
|
+
"""Get the task_id associated with an asyncio task."""
|
|
431
|
+
state = _get_state()
|
|
432
|
+
async with asyncio.timeout(timeout):
|
|
433
|
+
while task not in state.task_ids:
|
|
434
|
+
await asyncio.sleep(0)
|
|
435
|
+
return state.task_ids[task]
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
def current_task_info() -> TaskInfo:
|
|
439
|
+
"""Return the TaskInfo for the currently executing tracked task.
|
|
440
|
+
|
|
441
|
+
Raises RuntimeError if called outside a node()-wrapped coroutine.
|
|
442
|
+
"""
|
|
443
|
+
try:
|
|
444
|
+
task_id = _task_id.get()
|
|
445
|
+
except LookupError:
|
|
446
|
+
msg = "Not inside a tracked task. Call this from within a node()-wrapped coroutine."
|
|
447
|
+
raise RuntimeError(msg) from None
|
|
448
|
+
return get_task_info(task_id)
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
def get_task_info(task_id: int) -> TaskInfo:
|
|
452
|
+
"""Get the task info from a task_id."""
|
|
453
|
+
loop = asyncio.get_running_loop()
|
|
454
|
+
try:
|
|
455
|
+
return _loop_states[loop].task_infos[task_id]
|
|
456
|
+
except KeyError:
|
|
457
|
+
msg = f"No task with id {task_id!r} found in the current event loop."
|
|
458
|
+
raise ValueError(msg) from None
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
def remove_task(task_id: int) -> None:
|
|
462
|
+
"""Remove a task and all its descendants from tracking to free memory."""
|
|
463
|
+
loop = asyncio.get_running_loop()
|
|
464
|
+
state = _loop_states[loop]
|
|
465
|
+
if task_id not in state.task_infos:
|
|
466
|
+
msg = f"No task with id {task_id!r} found in the current event loop."
|
|
467
|
+
raise ValueError(msg)
|
|
468
|
+
stack = [task_id]
|
|
469
|
+
while stack:
|
|
470
|
+
tid = stack.pop()
|
|
471
|
+
task_info = state.task_infos.pop(tid, None)
|
|
472
|
+
if task_info is None:
|
|
473
|
+
continue
|
|
474
|
+
state.task_ids.pop(task_info.task, None)
|
|
475
|
+
stack.extend(task_info.subtasks)
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
def track[**P, R](
|
|
479
|
+
func: Callable[P, Coroutine[Any, Any, R]],
|
|
480
|
+
start: bool = True,
|
|
481
|
+
) -> Callable[P, Coroutine[Any, Any, R]]:
|
|
482
|
+
"""Track a coroutine by recording task info."""
|
|
483
|
+
|
|
484
|
+
@functools.wraps(func)
|
|
485
|
+
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
|
486
|
+
await _init_task_info(start=start)
|
|
487
|
+
return await func(*args, **kwargs)
|
|
488
|
+
|
|
489
|
+
return wrapper
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
|
|
493
|
+
def make_async[**P, T](
|
|
494
|
+
func: Callable[P, T],
|
|
495
|
+
) -> Callable[P, Coroutine[Any, Any, T]]:
|
|
496
|
+
"""Run function in a separate thread."""
|
|
497
|
+
|
|
498
|
+
@functools.wraps(func)
|
|
499
|
+
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
500
|
+
return await asyncio.to_thread(func, *args, **kwargs)
|
|
501
|
+
|
|
502
|
+
return wrapper
|
|
503
|
+
|
|
504
|
+
|
|
505
|
+
@runtime_checkable
|
|
506
|
+
class SupportsNext[T](Protocol):
|
|
507
|
+
def __next__(self) -> T: ...
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
async def make_async_generator[T](gen: SupportsNext[T]) -> AsyncGenerator[T]:
|
|
511
|
+
"""Run each `next` call in a separate thread."""
|
|
512
|
+
sentinel = object()
|
|
513
|
+
|
|
514
|
+
def step() -> T | object:
|
|
515
|
+
return next(gen, sentinel)
|
|
516
|
+
|
|
517
|
+
while True:
|
|
518
|
+
obj = await asyncio.to_thread(step)
|
|
519
|
+
if obj is sentinel:
|
|
520
|
+
break
|
|
521
|
+
yield cast("T", obj)
|
|
522
|
+
|
|
523
|
+
|
|
524
|
+
def walk_tree(root: asyncio.Task | int | None = None) -> Iterator[TaskInfo]:
|
|
525
|
+
"""DFS pre-order through call tree (parent -> subtasks).
|
|
526
|
+
|
|
527
|
+
If *root* is an ``asyncio.Task``, its tracked id is looked up.
|
|
528
|
+
If *root* is ``None``, the root task (parent=None) is used.
|
|
529
|
+
"""
|
|
530
|
+
state = _get_state()
|
|
531
|
+
if root is None:
|
|
532
|
+
root_id = next((tid for tid, info in state.task_infos.items() if info.parent is None), None)
|
|
533
|
+
if root_id is None:
|
|
534
|
+
return
|
|
535
|
+
elif isinstance(root, int):
|
|
536
|
+
root_id = root
|
|
537
|
+
else:
|
|
538
|
+
root_id = state.task_ids.get(root)
|
|
539
|
+
if root_id is None:
|
|
540
|
+
return
|
|
541
|
+
|
|
542
|
+
stack = [root_id]
|
|
543
|
+
while stack:
|
|
544
|
+
tid = stack.pop()
|
|
545
|
+
info = state.task_infos.get(tid)
|
|
546
|
+
if info is None:
|
|
547
|
+
continue
|
|
548
|
+
yield info
|
|
549
|
+
# Push children in reverse so leftmost child is visited first
|
|
550
|
+
stack.extend(reversed(info.subtasks))
|
|
551
|
+
|
|
552
|
+
|
|
553
|
+
def walk_dag(root: asyncio.Task | int | None = None) -> Iterator[TaskInfo]:
|
|
554
|
+
"""Topological order (Kahn's algorithm) over tasks.
|
|
555
|
+
|
|
556
|
+
Uses both tree edges (parent->child) and DAG edges (dep->dependent).
|
|
557
|
+
If *root* is ``None``, includes all tasks in the event loop.
|
|
558
|
+
"""
|
|
559
|
+
state = _get_state()
|
|
560
|
+
|
|
561
|
+
# Collect the set of task IDs to include
|
|
562
|
+
if root is None:
|
|
563
|
+
ids = list(state.task_infos.keys())
|
|
564
|
+
else:
|
|
565
|
+
if isinstance(root, int):
|
|
566
|
+
start_id = root
|
|
567
|
+
else:
|
|
568
|
+
start_id = state.task_ids.get(root)
|
|
569
|
+
if start_id is None:
|
|
570
|
+
return
|
|
571
|
+
ids = []
|
|
572
|
+
visited: set[int] = set()
|
|
573
|
+
bfs = [start_id]
|
|
574
|
+
while bfs:
|
|
575
|
+
tid = bfs.pop()
|
|
576
|
+
if tid in visited or tid not in state.task_infos:
|
|
577
|
+
continue
|
|
578
|
+
visited.add(tid)
|
|
579
|
+
ids.append(tid)
|
|
580
|
+
bfs.extend(state.task_infos[tid].subtasks)
|
|
581
|
+
|
|
582
|
+
infos = {tid: state.task_infos[tid] for tid in ids if tid in state.task_infos}
|
|
583
|
+
if not infos:
|
|
584
|
+
return
|
|
585
|
+
|
|
586
|
+
# Kahn's algorithm — edges: parent->child and dep->dependent
|
|
587
|
+
in_degree: dict[int, int] = dict.fromkeys(infos, 0)
|
|
588
|
+
successors: dict[int, list[int]] = {tid: [] for tid in infos}
|
|
589
|
+
|
|
590
|
+
for tid, info in infos.items():
|
|
591
|
+
if info.parent is not None and info.parent in infos:
|
|
592
|
+
in_degree[tid] += 1
|
|
593
|
+
successors[info.parent].append(tid)
|
|
594
|
+
for dep_id in info.deps:
|
|
595
|
+
if dep_id in infos and dep_id != info.parent:
|
|
596
|
+
in_degree[tid] += 1
|
|
597
|
+
successors[dep_id].append(tid)
|
|
598
|
+
|
|
599
|
+
queue = sorted(tid for tid, d in in_degree.items() if d == 0)
|
|
600
|
+
while queue:
|
|
601
|
+
tid = queue.pop(0)
|
|
602
|
+
yield infos[tid]
|
|
603
|
+
newly_free = sorted(s for s in successors[tid] if in_degree[s] - 1 == 0)
|
|
604
|
+
for s in successors[tid]:
|
|
605
|
+
in_degree[s] -= 1
|
|
606
|
+
queue = sorted(set(queue) | set(newly_free))
|
|
607
|
+
|
|
608
|
+
|
|
609
|
+
__all__ = [
|
|
610
|
+
"TaskInfo",
|
|
611
|
+
"TaskStatus",
|
|
612
|
+
"current_task_info",
|
|
613
|
+
"get_task_id",
|
|
614
|
+
"get_task_info",
|
|
615
|
+
"log",
|
|
616
|
+
"make_async",
|
|
617
|
+
"make_async_generator",
|
|
618
|
+
"node",
|
|
619
|
+
"remove_task",
|
|
620
|
+
"resolve",
|
|
621
|
+
"walk_dag",
|
|
622
|
+
"walk_tree",
|
|
623
|
+
]
|
|
File without changes
|