ttasks 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ttasks/_graph.py ADDED
@@ -0,0 +1,409 @@
1
+ """TaskGraph: a DAG of Tasks executed on a ThreadPoolExecutor.
2
+
3
+ The graph owns its own task references and dependency edges; it does not
4
+ delegate task storage to a separate ledger. Persistence is the job of
5
+ ``ttasks.store.Store`` (consulted via ``TaskExecutor`` for auto-save).
6
+ """
7
+
8
+ import uuid
9
+ from collections.abc import Iterable, Iterator
10
+ from concurrent.futures import Future, ThreadPoolExecutor
11
+ from datetime import datetime
12
+ from threading import Event, RLock
13
+
14
+ from ._executor import TaskExecutor
15
+ from ._task import Task, TaskStatus
16
+
17
+
18
+ class TaskGraph:
19
+ """A directed acyclic graph of :class:`Task` objects.
20
+
21
+ Tasks and edges are stored on the graph itself. ``graph.run(executor)``
22
+ submits ready tasks to a thread pool; if the executor was constructed
23
+ with a ``store``, every lifecycle transition is auto-persisted there.
24
+ """
25
+
26
+ def __init__(self, *, title: str = "") -> None:
27
+ """Create a graph with display ``title``."""
28
+ if not isinstance(title, str):
29
+ raise TypeError("title must be a str")
30
+ self._id = str(uuid.uuid4())
31
+ self.title = title
32
+ self.created_at = datetime.now()
33
+ # task_id -> Task; insertion order preserved for stable iteration.
34
+ self._tasks: dict[str, Task] = {}
35
+ # task_id -> list of upstream task_ids.
36
+ self._deps: dict[str, list[str]] = {}
37
+ # Exceptions raised by submitted task futures during the most recent
38
+ # run, keyed by task id. Cleared at the start of each run().
39
+ self._errors: dict[str, BaseException] = {}
40
+ # Finally tasks run after their dependencies finish, fail, cancel, or
41
+ # become blocked. Optional tasks report failures without making ok false.
42
+ self._finally: set[str] = set()
43
+ self._optional: set[str] = set()
44
+
45
+ # ---- mapping protocol ---------------------------------------------------
46
+
47
+ def add(
48
+ self,
49
+ task: Task,
50
+ *,
51
+ after: Iterable[Task] = (),
52
+ finally_: bool = False,
53
+ required: bool = True,
54
+ ) -> None:
55
+ """Register ``task`` in the graph.
56
+
57
+ ``after`` lists upstream tasks that must complete before ``task`` runs.
58
+ ``finally_=True`` registers a finally task: it becomes ready once every
59
+ listed upstream task is no longer active, regardless of success.
60
+ ``required=False`` is only meaningful with ``finally_=True`` and marks
61
+ the task as optional so its failure does not make :attr:`ok` false.
62
+ """
63
+ if not isinstance(task, Task):
64
+ raise TypeError(f"Expected Task, got {type(task).__name__}")
65
+ if not isinstance(finally_, bool):
66
+ raise TypeError("finally_ must be a bool")
67
+ if not isinstance(required, bool):
68
+ raise TypeError("required must be a bool")
69
+ if not required and not finally_:
70
+ raise ValueError("required=False is only valid with finally_=True")
71
+
72
+ self._tasks[task.id] = task
73
+ self._deps[task.id] = [d.id for d in after]
74
+ if finally_:
75
+ self._finally.add(task.id)
76
+ if required:
77
+ self._optional.discard(task.id)
78
+ else:
79
+ self._optional.add(task.id)
80
+ else:
81
+ self._finally.discard(task.id)
82
+ self._optional.discard(task.id)
83
+
84
+ def __setitem__(self, task: Task, deps: Iterable[Task]) -> None:
85
+ """Register ``task`` in the graph and record its upstream dependencies.
86
+
87
+ Mapping-syntax sugar for :meth:`add` (without ``finally_``).
88
+ """
89
+ self.add(task, after=deps)
90
+
91
+ def __getitem__(self, task: Task) -> list[Task]:
92
+ """Return the upstream :class:`Task` objects ``task`` depends on."""
93
+ return self.dependencies(task)
94
+
95
+ def __contains__(self, task: object) -> bool:
96
+ """Return whether ``task`` is a Task registered in this graph."""
97
+ return isinstance(task, Task) and task.id in self._deps
98
+
99
+ def __iter__(self) -> Iterator[Task]:
100
+ """Iterate over graph tasks in insertion order."""
101
+ return (self._tasks[tid] for tid in self._deps)
102
+
103
+ def __len__(self) -> int:
104
+ """Return the number of tasks registered in this graph."""
105
+ return len(self._deps)
106
+
107
+ def __repr__(self) -> str:
108
+ """Return a concise representation including dependency edges."""
109
+ edges = ", ".join(
110
+ f"{self._tasks[d].title}->{self._tasks[t].title}"
111
+ for t, ds in self._deps.items()
112
+ for d in ds
113
+ )
114
+ return f"TaskGraph({len(self)} tasks, edges=[{edges}])"
115
+
116
+ # ---- public introspection (used by persistence backends) ---------------
117
+
118
+ @property
119
+ def id(self) -> str:
120
+ """Return the immutable graph identity."""
121
+ return self._id
122
+
123
+ def dependencies(self, task: Task) -> list[Task]:
124
+ """Return the direct upstream tasks of ``task``."""
125
+ return [self._tasks[d] for d in self._deps[task.id]]
126
+
127
+ def is_finally(self, task: Task) -> bool:
128
+ """Return whether ``task`` was registered with ``finally_=True``."""
129
+ return task.id in self._finally
130
+
131
+ def is_optional(self, task: Task) -> bool:
132
+ """Return whether ``task`` is a finally task with ``required=False``."""
133
+ return task.id in self._optional
134
+
135
+ def items(self) -> Iterator[tuple[Task, list[Task]]]:
136
+ """Yield ``(task, deps)`` pairs in insertion order."""
137
+ for tid in self._deps:
138
+ task = self._tasks[tid]
139
+ yield task, [self._tasks[d] for d in self._deps[tid]]
140
+
141
+ # ---- status views (post-run) --------------------------------------------
142
+
143
+ @property
144
+ def succeeded(self) -> list[Task]:
145
+ """Tasks in this graph whose status is SUCCEEDED."""
146
+ return [t for t in self if t.status == TaskStatus.SUCCEEDED]
147
+
148
+ @property
149
+ def failed(self) -> list[Task]:
150
+ """Tasks in this graph whose status is FAILED."""
151
+ return [t for t in self if t.status == TaskStatus.FAILED]
152
+
153
+ @property
154
+ def cancelled(self) -> list[Task]:
155
+ """Tasks in this graph whose status is CANCELLED."""
156
+ return [t for t in self if t.status == TaskStatus.CANCELLED]
157
+
158
+ @property
159
+ def blocked(self) -> list[Task]:
160
+ """Tasks in this graph whose status is BLOCKED."""
161
+ return [t for t in self if t.status == TaskStatus.BLOCKED]
162
+
163
+ @property
164
+ def errors(self) -> dict[str, BaseException]:
165
+ """Exceptions raised by task futures during the most recent run."""
166
+ return dict(self._errors)
167
+
168
+ @property
169
+ def ok(self) -> bool:
170
+ """True iff every required task succeeded without run errors."""
171
+ return all(
172
+ self._tasks[tid].status == TaskStatus.SUCCEEDED
173
+ for tid in self._deps
174
+ if tid not in self._optional
175
+ )
176
+
177
+ # ---- topology views -----------------------------------------------------
178
+
179
+ def roots(self) -> list[Task]:
180
+ """Tasks with no upstream dependencies."""
181
+ return [self._tasks[tid] for tid, ds in self._deps.items() if not ds]
182
+
183
+ def leaves(self) -> list[Task]:
184
+ """Tasks that no other task depends on."""
185
+ depended_on: set[str] = set()
186
+ for ds in self._deps.values():
187
+ depended_on.update(ds)
188
+ return [
189
+ self._tasks[tid] for tid in self._deps if tid not in depended_on
190
+ ]
191
+
192
+ # ---- validation ---------------------------------------------------------
193
+
194
+ def _validate(self) -> None:
195
+ """Raise ValueError on missing deps, cycles, or stale RUNNING state.
196
+
197
+ Called from :meth:`run`. A task already in RUNNING cannot transition
198
+ again and would deadlock the scheduler, so we surface it eagerly
199
+ rather than time out.
200
+ """
201
+ for task in self._tasks.values():
202
+ if task.status == TaskStatus.RUNNING:
203
+ raise ValueError(
204
+ f"task {task.title!r} is RUNNING; reset before run()"
205
+ )
206
+ for tid, ds in self._deps.items():
207
+ for d in ds:
208
+ if d not in self._deps:
209
+ raise ValueError(
210
+ f"task {self._tasks[tid].title!r} depends on "
211
+ f"unregistered task id {d!r}"
212
+ )
213
+ # Kahn's algorithm: count visited nodes vs total.
214
+ indeg = {tid: len(ds) for tid, ds in self._deps.items()}
215
+ queue = [tid for tid, n in indeg.items() if n == 0]
216
+ visited = 0
217
+ while queue:
218
+ cur = queue.pop()
219
+ visited += 1
220
+ for tid, ds in self._deps.items():
221
+ if cur in ds:
222
+ indeg[tid] -= 1
223
+ if indeg[tid] == 0:
224
+ queue.append(tid)
225
+ if visited != len(self._deps):
226
+ raise ValueError("TaskGraph contains a cycle")
227
+
228
+ # ---- execution ----------------------------------------------------------
229
+
230
+ def run(
231
+ self,
232
+ executor: TaskExecutor,
233
+ max_workers: int = 4,
234
+ ) -> "TaskGraph":
235
+ """Execute the DAG. Blocks until done. Returns ``self`` for chaining.
236
+
237
+ Failure policy: if a task fails or is cancelled, every descendant is
238
+ marked blocked and never submitted; the run terminates instead of
239
+ hanging. Already-SUCCEEDED tasks count as satisfied dependencies so a graph
240
+ can be run again or extended after partial completion. Use
241
+ :attr:`failed` and :attr:`blocked` to inspect the outcome.
242
+ """
243
+ if max_workers <= 0:
244
+ raise ValueError("max_workers must be greater than 0")
245
+
246
+ self._validate()
247
+ # Auto-save after validation so an invalid graph leaves no trace.
248
+ executor._persist_graph(self)
249
+
250
+ try:
251
+ return self._run_inner(executor, max_workers)
252
+ finally:
253
+ executor._persist_graph(self)
254
+
255
+ def _run_inner(
256
+ self,
257
+ executor: TaskExecutor,
258
+ max_workers: int,
259
+ ) -> "TaskGraph":
260
+ """Inner scheduler loop, split out so :meth:`run` can wrap save logic."""
261
+ # Reset run-scoped state from any previous run.
262
+ self._errors = {}
263
+
264
+ # Empty graph: nothing to wait for. Return early to avoid deadlock.
265
+ if not self._deps:
266
+ return self
267
+
268
+ # Snapshot tasks that entered this run already BLOCKED. Only these
269
+ # are eligible for in-run retry; tasks that get BLOCKED during this
270
+ # run stay terminal so finally readiness and inactive() remain
271
+ # consistent within a single invocation.
272
+ entering_blocked = {
273
+ tid for tid, t in self._tasks.items() if t.status == TaskStatus.BLOCKED
274
+ }
275
+
276
+ futures: dict[str, Future] = {}
277
+ lock = RLock()
278
+ done = Event()
279
+ # ``scheduler_error`` is the single-slot escape hatch for surfacing
280
+ # exceptions raised inside ThreadPoolExecutor callbacks (e.g., the
281
+ # no-progress guard) so they propagate out of run() instead of being
282
+ # silently swallowed in a callback thread.
283
+ scheduler_error: list[BaseException] = []
284
+
285
+ with ThreadPoolExecutor(max_workers=max_workers) as pool:
286
+
287
+ def succeeded(tid: str) -> bool:
288
+ """Return whether tid is already done or succeeded in this run."""
289
+ return self._tasks[tid].status == TaskStatus.SUCCEEDED
290
+
291
+ def inactive(tid: str) -> bool:
292
+ """Return whether tid can no longer change in this run."""
293
+ task = self._tasks[tid]
294
+ return (
295
+ task.is_terminal
296
+ or tid in self._errors
297
+ or (tid in futures and futures[tid].done())
298
+ )
299
+
300
+ def ready(tid: str) -> bool:
301
+ """Return whether all upstream dependencies are satisfied."""
302
+ if tid in self._finally:
303
+ return all(inactive(d) for d in self._deps[tid])
304
+ return all(succeeded(d) for d in self._deps[tid])
305
+
306
+ def first_bad_parent(tid: str) -> str | None:
307
+ """Return the first dep (in declaration order) blocking ``tid``.
308
+
309
+ A parent "blocks" when its status is FAILED, CANCELLED, or
310
+ BLOCKED. Returns ``None`` if every parent is still
311
+ recoverable. Pre-start handler errors terminalize the parent
312
+ to FAILED before raising, so status alone is authoritative.
313
+ """
314
+ for d in self._deps[tid]:
315
+ if self._tasks[d].status.is_bad:
316
+ return d
317
+ return None
318
+
319
+ def finished(tid: str) -> bool:
320
+ """Return whether tid no longer needs scheduler attention."""
321
+ task = self._tasks[tid]
322
+ return task.is_terminal or (
323
+ tid in futures and futures[tid].done()
324
+ )
325
+
326
+ def upstream_tasks(tid: str) -> dict[str, Task]:
327
+ """Return direct upstream task refs for tid from the graph."""
328
+ return {dep_id: self._tasks[dep_id] for dep_id in self._deps[tid]}
329
+
330
+ def submit(tid: str) -> None:
331
+ """Submit tid to the thread pool and register its callback."""
332
+ fut = pool.submit(
333
+ executor.execute,
334
+ self._tasks[tid],
335
+ upstream_tasks(tid),
336
+ )
337
+ futures[tid] = fut
338
+ fut.add_done_callback(lambda f, t=tid: on_finish(t, f))
339
+
340
+ def schedule() -> None:
341
+ """Advance scheduling until no more tasks can change state."""
342
+ changed = True
343
+ while changed:
344
+ changed = False
345
+ for tid in self._deps:
346
+ task = self._tasks[tid]
347
+ if tid in futures:
348
+ continue
349
+ # SUCCEEDED/CANCELLED are absolute terminal states
350
+ # SUCCEEDED and CANCELLED are absolute SM sinks:
351
+ # never retry-eligible. BLOCKED tasks that entered
352
+ # this run blocked are eligible for retry (carryover);
353
+ # BLOCKED tasks that became blocked during this run
354
+ # stay blocked so finally readiness is unambiguous.
355
+ if task.status.is_sink:
356
+ continue
357
+ if (
358
+ task.status == TaskStatus.BLOCKED
359
+ and tid not in entering_blocked
360
+ ):
361
+ continue
362
+ if tid not in self._finally:
363
+ bad = first_bad_parent(tid)
364
+ if bad is not None:
365
+ if task.status == TaskStatus.PENDING:
366
+ executor.mark_blocked(task, bad)
367
+ changed = True
368
+ # Carryover BLOCKED whose parents are still
369
+ # bad: stays BLOCKED until parents recover.
370
+ continue
371
+ if ready(tid) and task.can_transition_to(TaskStatus.RUNNING):
372
+ submit(tid)
373
+ changed = True
374
+
375
+ if all(finished(tid) for tid in self._deps):
376
+ done.set()
377
+ return
378
+ # No live work to wait on and not finished: deadlocked.
379
+ live = any(not f.done() for f in futures.values())
380
+ if not live:
381
+ stuck = [
382
+ self._tasks[tid].title
383
+ for tid in self._deps
384
+ if not finished(tid)
385
+ ]
386
+ scheduler_error.append(
387
+ RuntimeError(
388
+ f"scheduler made no progress; stuck={stuck!r}"
389
+ )
390
+ )
391
+ done.set()
392
+
393
+ def on_finish(tid: str, fut: Future) -> None:
394
+ """Resume scheduling after a submitted task future completes."""
395
+ with lock:
396
+ exception = fut.exception()
397
+ if exception is not None:
398
+ self._errors[tid] = exception
399
+ schedule()
400
+
401
+ # Kick off every task whose dependencies are already satisfied.
402
+ with lock:
403
+ schedule()
404
+
405
+ done.wait()
406
+
407
+ if scheduler_error:
408
+ raise scheduler_error[0]
409
+ return self