ttasks 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ttasks/_sqlite.py ADDED
@@ -0,0 +1,587 @@
1
+ """SQLite-backed durable :class:`Store`."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import sqlite3
6
+ import warnings
7
+ from collections.abc import Iterator, MutableMapping
8
+ from datetime import datetime
9
+ from pathlib import Path
10
+ from threading import RLock
11
+ from typing import Any, cast
12
+
13
+ from ._graph import TaskGraph
14
+ from ._task import Task, TaskResult, TaskStatus, TaskType, TerminationReason
15
+
16
+ _SCHEMA_VERSION = "4"
17
+ _CONNECT_TIMEOUT_SECONDS = 30.0
18
+ # Tables known to the current schema. Used to detect "populated by something
19
+ # we recognize" and to drive the destructive rebuild path.
20
+ _KNOWN_TABLES = (
21
+ "graph_dependencies",
22
+ "graph_tasks",
23
+ "graphs",
24
+ "task_results",
25
+ "tasks",
26
+ "metadata",
27
+ )
28
+
29
+
30
+ class _Connection:
31
+ """Per-store connection helper: shared schema init, per-call connections.
32
+
33
+ Each save uses its own connection so the SQLite GIL behavior is fine for
34
+ concurrent writes from :meth:`TaskGraph.run` thread pools. WAL mode plus
35
+ a generous busy timeout absorbs short write contention.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ path: str | Path,
41
+ *,
42
+ allow_destructive_migration: bool = False,
43
+ ) -> None:
44
+ """Open or create the SQLite database at ``path`` and init the schema."""
45
+ self.path = Path(path)
46
+ if self.path.parent != Path(""):
47
+ self.path.parent.mkdir(parents=True, exist_ok=True)
48
+ self._schema_lock = RLock()
49
+ self._allow_destructive = allow_destructive_migration
50
+ self._init_schema()
51
+
52
+ def connect(self) -> sqlite3.Connection:
53
+ """Return a fresh SQLite connection configured for the store."""
54
+ connection = sqlite3.connect(self.path, timeout=_CONNECT_TIMEOUT_SECONDS)
55
+ connection.row_factory = sqlite3.Row
56
+ connection.execute("PRAGMA foreign_keys = ON")
57
+ return connection
58
+
59
+ def _init_schema(self) -> None:
60
+ """Create tables and tune SQLite for concurrent writes."""
61
+ with self._schema_lock, self.connect() as connection:
62
+ connection.execute("PRAGMA journal_mode = WAL")
63
+ connection.execute("PRAGMA synchronous = NORMAL")
64
+ self._enforce_schema_version(connection, self._allow_destructive)
65
+ connection.execute(
66
+ """
67
+ CREATE TABLE IF NOT EXISTS metadata (
68
+ key TEXT PRIMARY KEY,
69
+ value TEXT NOT NULL
70
+ )
71
+ """
72
+ )
73
+ connection.execute(
74
+ """
75
+ CREATE TABLE IF NOT EXISTS tasks (
76
+ id TEXT PRIMARY KEY,
77
+ title TEXT NOT NULL,
78
+ description TEXT NOT NULL,
79
+ payload TEXT NOT NULL,
80
+ type TEXT NOT NULL,
81
+ status TEXT NOT NULL,
82
+ error TEXT,
83
+ timeout REAL,
84
+ blocked_by TEXT,
85
+ created_at TEXT NOT NULL
86
+ )
87
+ """
88
+ )
89
+ connection.execute(
90
+ """
91
+ CREATE TABLE IF NOT EXISTS task_results (
92
+ task_id TEXT PRIMARY KEY,
93
+ status TEXT NOT NULL,
94
+ started_at TEXT NOT NULL,
95
+ finished_at TEXT NOT NULL,
96
+ duration REAL NOT NULL,
97
+ output TEXT NOT NULL,
98
+ error TEXT,
99
+ returncode INTEGER,
100
+ termination_reason TEXT,
101
+ FOREIGN KEY(task_id) REFERENCES tasks(id) ON DELETE CASCADE
102
+ )
103
+ """
104
+ )
105
+ connection.execute(
106
+ """
107
+ CREATE TABLE IF NOT EXISTS graphs (
108
+ id TEXT PRIMARY KEY,
109
+ title TEXT NOT NULL,
110
+ created_at TEXT NOT NULL
111
+ )
112
+ """
113
+ )
114
+ connection.execute(
115
+ """
116
+ CREATE TABLE IF NOT EXISTS graph_tasks (
117
+ graph_id TEXT NOT NULL,
118
+ task_id TEXT NOT NULL,
119
+ is_finally INTEGER NOT NULL DEFAULT 0,
120
+ is_optional INTEGER NOT NULL DEFAULT 0,
121
+ position INTEGER NOT NULL,
122
+ PRIMARY KEY(graph_id, task_id),
123
+ FOREIGN KEY(graph_id) REFERENCES graphs(id) ON DELETE CASCADE,
124
+ FOREIGN KEY(task_id) REFERENCES tasks(id) ON DELETE CASCADE
125
+ )
126
+ """
127
+ )
128
+ connection.execute(
129
+ """
130
+ CREATE TABLE IF NOT EXISTS graph_dependencies (
131
+ graph_id TEXT NOT NULL,
132
+ task_id TEXT NOT NULL,
133
+ dependency_id TEXT NOT NULL,
134
+ position INTEGER NOT NULL,
135
+ PRIMARY KEY(graph_id, task_id, dependency_id),
136
+ FOREIGN KEY(graph_id) REFERENCES graphs(id) ON DELETE CASCADE,
137
+ FOREIGN KEY(task_id) REFERENCES tasks(id) ON DELETE CASCADE,
138
+ FOREIGN KEY(dependency_id) REFERENCES tasks(id) ON DELETE CASCADE
139
+ )
140
+ """
141
+ )
142
+ connection.execute(
143
+ """
144
+ INSERT OR IGNORE INTO metadata(key, value)
145
+ VALUES ('schema_version', ?)
146
+ """,
147
+ (_SCHEMA_VERSION,),
148
+ )
149
+
150
+ @staticmethod
151
+ def _enforce_schema_version(
152
+ connection: sqlite3.Connection, allow_destructive: bool
153
+ ) -> None:
154
+ """Validate the persisted schema version; never silently drop data.
155
+
156
+ Behaviour:
157
+ - Truly empty database (no recognised tables): accept and let
158
+ ``_init_schema`` stamp the current version.
159
+ - Recognised tables present with a matching ``schema_version`` row:
160
+ accept and reuse.
161
+ - Recognised tables present but ``schema_version`` row missing, or
162
+ version mismatched, *without* ``allow_destructive=True``: raise
163
+ ``RuntimeError`` so callers must opt in to data loss.
164
+ - Mismatch *with* ``allow_destructive=True``: emit a ``UserWarning``
165
+ and drop every known table so ``_init_schema`` can rebuild.
166
+ """
167
+ existing_tables = {
168
+ row[0]
169
+ for row in connection.execute(
170
+ "SELECT name FROM sqlite_master WHERE type = 'table'"
171
+ ).fetchall()
172
+ }
173
+ known_present = existing_tables & set(_KNOWN_TABLES)
174
+ if not known_present:
175
+ return
176
+
177
+ version: str | None = None
178
+ if "metadata" in known_present:
179
+ row = connection.execute(
180
+ "SELECT value FROM metadata WHERE key = 'schema_version'"
181
+ ).fetchone()
182
+ if row is not None:
183
+ version = str(row[0])
184
+
185
+ if version == _SCHEMA_VERSION:
186
+ return
187
+
188
+ if version is None:
189
+ message = (
190
+ "SQLite store at this path has known tables but no "
191
+ "schema_version row. Refusing to touch it. Pass "
192
+ "allow_destructive_migration=True to drop and rebuild."
193
+ )
194
+ else:
195
+ message = (
196
+ f"SQLite store schema_version {version!r} does not match "
197
+ f"current {_SCHEMA_VERSION!r}. Pass "
198
+ f"allow_destructive_migration=True to drop and rebuild."
199
+ )
200
+
201
+ if not allow_destructive:
202
+ raise RuntimeError(message)
203
+
204
+ warnings.warn(message, UserWarning, stacklevel=2)
205
+ for table in _KNOWN_TABLES:
206
+ connection.execute(f"DROP TABLE IF EXISTS {table}")
207
+
208
+
209
+ class SQLiteTaskCollection(MutableMapping[str, Task]):
210
+ """SQLite-backed task collection. Returns detached snapshots on read."""
211
+
212
+ def __init__(self, connection: _Connection) -> None:
213
+ """Wrap ``connection`` as a task collection."""
214
+ self._connection = connection
215
+
216
+ def save(self, task: Task) -> None:
217
+ """Persist ``task`` under its own ID."""
218
+ self[task.id] = task
219
+
220
+ def __setitem__(self, task_id: str, task: Task) -> None:
221
+ """Durably store ``task`` and its current result under its own ID."""
222
+ if not isinstance(task, Task):
223
+ raise TypeError(f"Expected Task, got {type(task).__name__}")
224
+ if task_id != task.id:
225
+ raise ValueError("task_id must match task.id")
226
+ with self._connection.connect() as connection:
227
+ _upsert_task(connection, task)
228
+
229
+ def __getitem__(self, task_id: str) -> Task:
230
+ """Return a detached task snapshot for ``task_id`` or raise ``KeyError``."""
231
+ with self._connection.connect() as connection:
232
+ task_row = connection.execute(
233
+ "SELECT * FROM tasks WHERE id = ?", (task_id,)
234
+ ).fetchone()
235
+ if task_row is None:
236
+ raise KeyError(task_id)
237
+ result_row = connection.execute(
238
+ "SELECT * FROM task_results WHERE task_id = ?", (task_id,)
239
+ ).fetchone()
240
+ result = _result_from_row(result_row) if result_row is not None else None
241
+ return _task_from_row(task_row, result)
242
+
243
+ def __delitem__(self, task_id: str) -> None:
244
+ """Remove the task and its result row."""
245
+ with self._connection.connect() as connection:
246
+ cursor = connection.execute("DELETE FROM tasks WHERE id = ?", (task_id,))
247
+ if cursor.rowcount == 0:
248
+ raise KeyError(task_id)
249
+
250
+ def __iter__(self) -> Iterator[str]:
251
+ """Iterate over task IDs in stable (created_at, id) order."""
252
+ with self._connection.connect() as connection:
253
+ rows = connection.execute(
254
+ "SELECT id FROM tasks ORDER BY created_at, id"
255
+ ).fetchall()
256
+ return iter(str(row["id"]) for row in rows)
257
+
258
+ def __len__(self) -> int:
259
+ """Return the number of stored tasks."""
260
+ with self._connection.connect() as connection:
261
+ row = connection.execute("SELECT COUNT(*) AS count FROM tasks").fetchone()
262
+ return int(row["count"])
263
+
264
+ def __contains__(self, key: object) -> bool:
265
+ """Return whether ``key`` (task or task id) is present."""
266
+ if isinstance(key, Task):
267
+ key = key.id
268
+ if not isinstance(key, str):
269
+ return False
270
+ with self._connection.connect() as connection:
271
+ row = connection.execute(
272
+ "SELECT 1 FROM tasks WHERE id = ?", (key,)
273
+ ).fetchone()
274
+ return row is not None
275
+
276
+ def cancel(self, task_id: str) -> None:
277
+ """Cancel a task and save the updated snapshot."""
278
+ task = self[task_id]
279
+ task.cancel()
280
+ self.save(task)
281
+
282
+
283
+ class SQLiteGraphCollection(MutableMapping[str, TaskGraph]):
284
+ """SQLite-backed graph collection. Saves graph + member tasks atomically."""
285
+
286
+ def __init__(
287
+ self,
288
+ connection: _Connection,
289
+ tasks: SQLiteTaskCollection,
290
+ ) -> None:
291
+ """Wrap ``connection`` as a graph collection sharing ``tasks``."""
292
+ self._connection = connection
293
+ self._tasks = tasks
294
+
295
+ def save(self, graph: TaskGraph) -> None:
296
+ """Persist ``graph`` under its own ID."""
297
+ self[graph.id] = graph
298
+
299
+ def __setitem__(self, graph_id: str, graph: TaskGraph) -> None:
300
+ """Atomically store ``graph`` metadata, membership, edges, and tasks."""
301
+ if not isinstance(graph, TaskGraph):
302
+ raise TypeError(f"Expected TaskGraph, got {type(graph).__name__}")
303
+ if graph_id != graph.id:
304
+ raise ValueError("graph_id must match graph.id")
305
+
306
+ members = list(graph)
307
+ with self._connection.connect() as connection:
308
+ for member in members:
309
+ _upsert_task(connection, member)
310
+ connection.execute(
311
+ """
312
+ INSERT INTO graphs (id, title, created_at)
313
+ VALUES (?, ?, ?)
314
+ ON CONFLICT(id) DO UPDATE SET
315
+ title = excluded.title,
316
+ created_at = excluded.created_at
317
+ """,
318
+ (graph.id, graph.title, graph.created_at.isoformat()),
319
+ )
320
+ connection.execute(
321
+ "DELETE FROM graph_dependencies WHERE graph_id = ?", (graph.id,)
322
+ )
323
+ connection.execute(
324
+ "DELETE FROM graph_tasks WHERE graph_id = ?", (graph.id,)
325
+ )
326
+ for position, task in enumerate(members):
327
+ connection.execute(
328
+ """
329
+ INSERT INTO graph_tasks (
330
+ graph_id, task_id, is_finally, is_optional, position
331
+ )
332
+ VALUES (?, ?, ?, ?, ?)
333
+ """,
334
+ (
335
+ graph.id,
336
+ task.id,
337
+ int(graph.is_finally(task)),
338
+ int(graph.is_optional(task)),
339
+ position,
340
+ ),
341
+ )
342
+ for task in members:
343
+ for position, dep in enumerate(graph.dependencies(task)):
344
+ connection.execute(
345
+ """
346
+ INSERT INTO graph_dependencies (
347
+ graph_id, task_id, dependency_id, position
348
+ )
349
+ VALUES (?, ?, ?, ?)
350
+ """,
351
+ (graph.id, task.id, dep.id, position),
352
+ )
353
+
354
+ def __getitem__(self, graph_id: str) -> TaskGraph:
355
+ """Return a detached graph snapshot for ``graph_id``.
356
+
357
+ Member tasks are loaded as snapshots from the task collection; the
358
+ returned graph is independent of any in-memory references.
359
+ """
360
+ with self._connection.connect() as connection:
361
+ graph_row = connection.execute(
362
+ "SELECT * FROM graphs WHERE id = ?", (graph_id,)
363
+ ).fetchone()
364
+ if graph_row is None:
365
+ raise KeyError(graph_id)
366
+ task_rows = connection.execute(
367
+ """
368
+ SELECT * FROM graph_tasks
369
+ WHERE graph_id = ?
370
+ ORDER BY position, task_id
371
+ """,
372
+ (graph_id,),
373
+ ).fetchall()
374
+ dependency_rows = connection.execute(
375
+ """
376
+ SELECT * FROM graph_dependencies
377
+ WHERE graph_id = ?
378
+ ORDER BY task_id, position, dependency_id
379
+ """,
380
+ (graph_id,),
381
+ ).fetchall()
382
+
383
+ graph = TaskGraph(title=str(graph_row["title"]))
384
+ object.__setattr__(graph, "_id", str(graph_row["id"]))
385
+ graph.created_at = datetime.fromisoformat(str(graph_row["created_at"]))
386
+
387
+ deps_by_task: dict[str, list[str]] = {
388
+ str(row["task_id"]): [] for row in task_rows
389
+ }
390
+ for row in dependency_rows:
391
+ deps_by_task[str(row["task_id"])].append(str(row["dependency_id"]))
392
+
393
+ for row in task_rows:
394
+ task_id = str(row["task_id"])
395
+ task = self._tasks[task_id]
396
+ deps = [self._tasks[d] for d in deps_by_task[task_id]]
397
+ if bool(row["is_finally"]):
398
+ graph.add(
399
+ task,
400
+ after=deps,
401
+ finally_=True,
402
+ required=not bool(row["is_optional"]),
403
+ )
404
+ else:
405
+ graph[task] = deps
406
+
407
+ return graph
408
+
409
+ def __delitem__(self, graph_id: str) -> None:
410
+ """Remove the graph metadata; member tasks remain in the task collection."""
411
+ with self._connection.connect() as connection:
412
+ cursor = connection.execute("DELETE FROM graphs WHERE id = ?", (graph_id,))
413
+ if cursor.rowcount == 0:
414
+ raise KeyError(graph_id)
415
+
416
+ def __iter__(self) -> Iterator[str]:
417
+ """Iterate over graph IDs in stable (created_at, id) order."""
418
+ with self._connection.connect() as connection:
419
+ rows = connection.execute(
420
+ "SELECT id FROM graphs ORDER BY created_at, id"
421
+ ).fetchall()
422
+ return iter(str(row["id"]) for row in rows)
423
+
424
+ def __len__(self) -> int:
425
+ """Return the number of stored graphs."""
426
+ with self._connection.connect() as connection:
427
+ row = connection.execute("SELECT COUNT(*) AS count FROM graphs").fetchone()
428
+ return int(row["count"])
429
+
430
+ def __contains__(self, key: object) -> bool:
431
+ """Return whether ``key`` (graph or graph id) is present."""
432
+ if isinstance(key, TaskGraph):
433
+ key = key.id
434
+ if not isinstance(key, str):
435
+ return False
436
+ with self._connection.connect() as connection:
437
+ row = connection.execute(
438
+ "SELECT 1 FROM graphs WHERE id = ?", (key,)
439
+ ).fetchone()
440
+ return row is not None
441
+
442
+
443
+ class SQLiteStore:
444
+ """SQLite-backed durable :class:`Store` exposing tasks and graphs."""
445
+
446
+ def __init__(
447
+ self,
448
+ path: str | Path,
449
+ *,
450
+ allow_destructive_migration: bool = False,
451
+ ) -> None:
452
+ """Open or create a SQLite store at ``path``.
453
+
454
+ ``allow_destructive_migration=True`` permits dropping and rebuilding
455
+ the database when the on-disk schema version does not match. Without
456
+ the flag, a mismatch (or a populated database that lacks a
457
+ ``schema_version`` row) raises ``RuntimeError`` so callers must
458
+ explicitly accept data loss.
459
+ """
460
+ self.path = Path(path)
461
+ self._connection = _Connection(
462
+ path, allow_destructive_migration=allow_destructive_migration
463
+ )
464
+ self.tasks = SQLiteTaskCollection(self._connection)
465
+ self.graphs = SQLiteGraphCollection(self._connection, self.tasks)
466
+
467
+ def __repr__(self) -> str:
468
+ """Return a concise representation including the database path."""
469
+ return f"SQLiteStore({self.path!s})"
470
+
471
+
472
+ def _upsert_task(connection: sqlite3.Connection, task: Task) -> None:
473
+ """Upsert ``task`` and its result row using ``connection``."""
474
+ connection.execute(
475
+ """
476
+ INSERT INTO tasks (
477
+ id, title, description, payload, type, status,
478
+ error, timeout, blocked_by, created_at
479
+ )
480
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
481
+ ON CONFLICT(id) DO UPDATE SET
482
+ title = excluded.title,
483
+ description = excluded.description,
484
+ payload = excluded.payload,
485
+ type = excluded.type,
486
+ status = excluded.status,
487
+ error = excluded.error,
488
+ timeout = excluded.timeout,
489
+ blocked_by = excluded.blocked_by,
490
+ created_at = excluded.created_at
491
+ """,
492
+ _task_values(task),
493
+ )
494
+ if task.result is None:
495
+ connection.execute(
496
+ "DELETE FROM task_results WHERE task_id = ?", (task.id,)
497
+ )
498
+ else:
499
+ connection.execute(
500
+ """
501
+ INSERT INTO task_results (
502
+ task_id, status, started_at, finished_at, duration,
503
+ output, error, returncode, termination_reason
504
+ )
505
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
506
+ ON CONFLICT(task_id) DO UPDATE SET
507
+ status = excluded.status,
508
+ started_at = excluded.started_at,
509
+ finished_at = excluded.finished_at,
510
+ duration = excluded.duration,
511
+ output = excluded.output,
512
+ error = excluded.error,
513
+ returncode = excluded.returncode,
514
+ termination_reason = excluded.termination_reason
515
+ """,
516
+ _result_values(task.result),
517
+ )
518
+
519
+
520
+ def _task_values(task: Task) -> tuple[Any, ...]:
521
+ """Return task values in tasks-table column order."""
522
+ return (
523
+ task.id,
524
+ task.title,
525
+ task.description,
526
+ task.payload,
527
+ task.type.value,
528
+ task.status.value,
529
+ task.error,
530
+ task.timeout,
531
+ task.blocked_by,
532
+ task.created_at.isoformat(),
533
+ )
534
+
535
+
536
+ def _result_values(result: TaskResult) -> tuple[Any, ...]:
537
+ """Return result values in task_results-table column order."""
538
+ return (
539
+ result.task_id,
540
+ result.status.value,
541
+ result.started_at.isoformat(),
542
+ result.finished_at.isoformat(),
543
+ result.duration,
544
+ result.output,
545
+ result.error,
546
+ result.returncode,
547
+ result.termination_reason,
548
+ )
549
+
550
+
551
+ def _task_from_row(row: sqlite3.Row, result: TaskResult | None) -> Task:
552
+ """Reconstruct a task snapshot from a SQLite row and optional result."""
553
+ task = Task(
554
+ title=str(row["title"]),
555
+ description=str(row["description"]),
556
+ payload=str(row["payload"]),
557
+ type=TaskType(str(row["type"])),
558
+ error=row["error"],
559
+ timeout=row["timeout"],
560
+ _id=str(row["id"]),
561
+ created_at=datetime.fromisoformat(str(row["created_at"])),
562
+ )
563
+ object.__setattr__(task, "_status", TaskStatus(str(row["status"])))
564
+ object.__setattr__(task, "_result", result)
565
+ blocked_by = row["blocked_by"]
566
+ object.__setattr__(
567
+ task, "_blocked_by", str(blocked_by) if blocked_by is not None else None
568
+ )
569
+ return task
570
+
571
+
572
+ def _result_from_row(row: sqlite3.Row) -> TaskResult:
573
+ """Reconstruct a TaskResult from a SQLite row, omitting raw data."""
574
+ return TaskResult(
575
+ task_id=str(row["task_id"]),
576
+ status=TaskStatus(str(row["status"])),
577
+ started_at=datetime.fromisoformat(str(row["started_at"])),
578
+ finished_at=datetime.fromisoformat(str(row["finished_at"])),
579
+ duration=float(row["duration"]),
580
+ output=str(row["output"]),
581
+ error=row["error"],
582
+ returncode=row["returncode"],
583
+ raw=None,
584
+ termination_reason=cast(
585
+ "TerminationReason | None", row["termination_reason"]
586
+ ),
587
+ )