loopllm 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
loopllm/store.py ADDED
@@ -0,0 +1,1126 @@
1
+ """SQLite-backed persistence for priors, observations, and sessions."""
2
+ from __future__ import annotations
3
+
4
+ import json
5
+ import sqlite3
6
+ import threading
7
+ from contextlib import contextmanager
8
+ from dataclasses import asdict
9
+ from datetime import datetime, timezone
10
+ from pathlib import Path
11
+ from typing import Any, cast, Iterator
12
+
13
+ import structlog
14
+
15
+ from loopllm.priors import (
16
+ AdaptivePriors,
17
+ BetaPrior,
18
+ CallObservation,
19
+ IterationProfile,
20
+ NormalPrior,
21
+ TaskModelPrior,
22
+ )
23
+
24
+ logger = structlog.get_logger(__name__)
25
+
26
+ SCHEMA_VERSION = 4
27
+
28
+ _SCHEMA_SQL = """\
29
+ CREATE TABLE IF NOT EXISTS schema_version (
30
+ version INTEGER NOT NULL
31
+ );
32
+
33
+ CREATE TABLE IF NOT EXISTS priors (
34
+ key TEXT PRIMARY KEY,
35
+ task_type TEXT NOT NULL,
36
+ model_id TEXT NOT NULL,
37
+ data TEXT NOT NULL,
38
+ created_at TEXT NOT NULL,
39
+ updated_at TEXT NOT NULL
40
+ );
41
+
42
+ CREATE TABLE IF NOT EXISTS observations (
43
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
44
+ task_type TEXT NOT NULL,
45
+ model_id TEXT NOT NULL,
46
+ data TEXT NOT NULL,
47
+ recorded_at TEXT NOT NULL
48
+ );
49
+
50
+ CREATE TABLE IF NOT EXISTS questions (
51
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
52
+ question_type TEXT NOT NULL,
53
+ task_type TEXT NOT NULL,
54
+ asked_count INTEGER NOT NULL DEFAULT 0,
55
+ positive_impact INTEGER NOT NULL DEFAULT 0,
56
+ negative_impact INTEGER NOT NULL DEFAULT 0,
57
+ avg_info_gain REAL NOT NULL DEFAULT 0.0,
58
+ updated_at TEXT NOT NULL
59
+ );
60
+
61
+ CREATE TABLE IF NOT EXISTS sessions (
62
+ session_id TEXT PRIMARY KEY,
63
+ original_prompt TEXT NOT NULL,
64
+ task_type TEXT,
65
+ model_id TEXT,
66
+ questions_json TEXT NOT NULL DEFAULT '[]',
67
+ answers_json TEXT NOT NULL DEFAULT '{}',
68
+ spec_json TEXT,
69
+ final_score REAL,
70
+ created_at TEXT NOT NULL,
71
+ completed_at TEXT
72
+ );
73
+
74
+ CREATE TABLE IF NOT EXISTS tasks (
75
+ id TEXT PRIMARY KEY,
76
+ parent_id TEXT,
77
+ session_id TEXT,
78
+ title TEXT NOT NULL,
79
+ description TEXT NOT NULL DEFAULT '',
80
+ state TEXT NOT NULL DEFAULT 'pending',
81
+ dependencies TEXT NOT NULL DEFAULT '[]',
82
+ spec_json TEXT,
83
+ result_json TEXT,
84
+ metadata_json TEXT NOT NULL DEFAULT '{}',
85
+ created_at TEXT NOT NULL,
86
+ updated_at TEXT NOT NULL,
87
+ FOREIGN KEY (parent_id) REFERENCES tasks(id),
88
+ FOREIGN KEY (session_id) REFERENCES sessions(session_id)
89
+ );
90
+
91
+ CREATE INDEX IF NOT EXISTS idx_observations_task_model
92
+ ON observations(task_type, model_id);
93
+ CREATE INDEX IF NOT EXISTS idx_questions_type
94
+ ON questions(question_type, task_type);
95
+ CREATE INDEX IF NOT EXISTS idx_tasks_session
96
+ ON tasks(session_id);
97
+ CREATE INDEX IF NOT EXISTS idx_tasks_state
98
+ ON tasks(state);
99
+ """
100
+
101
+
102
+ _SCHEMA_V2_SQL = """\
103
+ CREATE TABLE IF NOT EXISTS prompt_history (
104
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
105
+ timestamp TEXT NOT NULL,
106
+ prompt_text TEXT NOT NULL,
107
+ quality_score REAL NOT NULL,
108
+ specificity REAL NOT NULL DEFAULT 0.0,
109
+ constraint_clarity REAL NOT NULL DEFAULT 0.0,
110
+ context_completeness REAL NOT NULL DEFAULT 0.0,
111
+ ambiguity REAL NOT NULL DEFAULT 0.0,
112
+ format_spec REAL NOT NULL DEFAULT 0.0,
113
+ task_type TEXT NOT NULL DEFAULT 'general',
114
+ complexity REAL NOT NULL DEFAULT 0.0,
115
+ route_chosen TEXT NOT NULL DEFAULT 'refine',
116
+ word_count INTEGER NOT NULL DEFAULT 0,
117
+ grade TEXT NOT NULL DEFAULT 'C',
118
+ session_context TEXT NOT NULL DEFAULT 'default'
119
+ );
120
+
121
+ CREATE INDEX IF NOT EXISTS idx_prompt_history_ts
122
+ ON prompt_history(timestamp);
123
+ CREATE INDEX IF NOT EXISTS idx_prompt_history_grade
124
+ ON prompt_history(grade);
125
+ """
126
+
127
+ _SCHEMA_V3_SQL = """\
128
+ CREATE TABLE IF NOT EXISTS plans (
129
+ plan_id TEXT PRIMARY KEY,
130
+ goal TEXT NOT NULL,
131
+ data TEXT NOT NULL,
132
+ created_at TEXT NOT NULL,
133
+ updated_at TEXT NOT NULL
134
+ );
135
+ """
136
+
137
+ _SCHEMA_V4_SQL = """\
138
+ CREATE TABLE IF NOT EXISTS learned_weights (
139
+ id INTEGER PRIMARY KEY CHECK (id = 1),
140
+ weights TEXT NOT NULL,
141
+ n_updates INTEGER NOT NULL DEFAULT 0,
142
+ last_loss REAL NOT NULL DEFAULT 0.0,
143
+ updated_at TEXT NOT NULL
144
+ );
145
+ """
146
+
147
+
148
+ class LoopStore:
149
+ """SQLite-backed store for loop-llm state.
150
+
151
+ Thread-safe via a reentrant lock around all database operations.
152
+ Uses WAL journal mode for concurrent read performance.
153
+
154
+ Args:
155
+ db_path: Path to the SQLite database file. Use ``":memory:"``
156
+ for an ephemeral in-memory store (useful for testing).
157
+ """
158
+
159
+ def __init__(self, db_path: Path | str = ":memory:") -> None:
160
+ self.db_path = str(db_path)
161
+ self._lock = threading.RLock()
162
+ self._conn: sqlite3.Connection | None = None
163
+ self._ensure_schema()
164
+
165
+ # -- connection management -----------------------------------------------
166
+
167
+ @contextmanager
168
+ def _connection(self) -> Iterator[sqlite3.Connection]:
169
+ """Yield a thread-safe connection with WAL mode enabled."""
170
+ with self._lock:
171
+ if self._conn is None:
172
+ self._conn = sqlite3.connect(self.db_path)
173
+ self._conn.execute("PRAGMA journal_mode=WAL")
174
+ self._conn.execute("PRAGMA foreign_keys=ON")
175
+ self._conn.row_factory = sqlite3.Row
176
+ yield self._conn
177
+
178
+ def _ensure_schema(self) -> None:
179
+ """Create tables or run migrations as needed."""
180
+ with self._connection() as conn:
181
+ # Check if schema_version table exists
182
+ cursor = conn.execute(
183
+ "SELECT name FROM sqlite_master WHERE type='table' AND name='schema_version'"
184
+ )
185
+ if cursor.fetchone() is None:
186
+ conn.executescript(_SCHEMA_SQL)
187
+ conn.executescript(_SCHEMA_V2_SQL)
188
+ conn.executescript(_SCHEMA_V3_SQL)
189
+ conn.executescript(_SCHEMA_V4_SQL)
190
+ conn.execute(
191
+ "INSERT INTO schema_version (version) VALUES (?)",
192
+ (SCHEMA_VERSION,),
193
+ )
194
+ conn.commit()
195
+ logger.debug("store_schema_created", version=SCHEMA_VERSION)
196
+ else:
197
+ row = conn.execute("SELECT version FROM schema_version").fetchone()
198
+ current = row["version"] if row else 0
199
+ if current < SCHEMA_VERSION:
200
+ self._migrate(conn, current, SCHEMA_VERSION)
201
+
202
+ def _migrate(self, conn: sqlite3.Connection, from_v: int, to_v: int) -> None:
203
+ """Run schema migrations from *from_v* to *to_v*.
204
+
205
+ Args:
206
+ conn: Active database connection.
207
+ from_v: Current schema version.
208
+ to_v: Target schema version.
209
+ """
210
+ logger.info("store_migrating", from_version=from_v, to_version=to_v)
211
+ if from_v < 2:
212
+ conn.executescript(_SCHEMA_V2_SQL)
213
+ if from_v < 3:
214
+ conn.executescript(_SCHEMA_V3_SQL)
215
+ if from_v < 4:
216
+ conn.executescript(_SCHEMA_V4_SQL)
217
+ conn.execute("UPDATE schema_version SET version = ?", (to_v,))
218
+ conn.commit()
219
+
220
+ def close(self) -> None:
221
+ """Close the database connection."""
222
+ with self._lock:
223
+ if self._conn is not None:
224
+ self._conn.close()
225
+ self._conn = None
226
+
227
+ # -- priors CRUD ---------------------------------------------------------
228
+
229
+ def save_prior(self, key: str, prior: TaskModelPrior) -> None:
230
+ """Upsert a serialised :class:`TaskModelPrior`.
231
+
232
+ Args:
233
+ key: Storage key (typically ``task_type::model_id``).
234
+ prior: The prior to persist.
235
+ """
236
+ data = self._serialize_task_model_prior(prior)
237
+ now = datetime.now(timezone.utc).isoformat()
238
+ with self._connection() as conn:
239
+ conn.execute(
240
+ """INSERT INTO priors (key, task_type, model_id, data, created_at, updated_at)
241
+ VALUES (?, ?, ?, ?, ?, ?)
242
+ ON CONFLICT(key) DO UPDATE SET
243
+ data = excluded.data,
244
+ updated_at = excluded.updated_at""",
245
+ (key, prior.task_type, prior.model_id, json.dumps(data), now, now),
246
+ )
247
+ conn.commit()
248
+
249
+ def load_prior(self, key: str) -> TaskModelPrior | None:
250
+ """Load a :class:`TaskModelPrior` by key.
251
+
252
+ Args:
253
+ key: Storage key (typically ``task_type::model_id``).
254
+
255
+ Returns:
256
+ The deserialized prior, or ``None`` if not found.
257
+ """
258
+ with self._connection() as conn:
259
+ row = conn.execute(
260
+ "SELECT data FROM priors WHERE key = ?", (key,)
261
+ ).fetchone()
262
+ if row is None:
263
+ return None
264
+ return self._deserialize_task_model_prior(json.loads(row["data"]))
265
+
266
+ def load_all_priors(self) -> dict[str, TaskModelPrior]:
267
+ """Load every stored prior.
268
+
269
+ Returns:
270
+ Dict mapping keys to :class:`TaskModelPrior` instances.
271
+ """
272
+ with self._connection() as conn:
273
+ rows = conn.execute("SELECT key, data FROM priors").fetchall()
274
+ result: dict[str, TaskModelPrior] = {}
275
+ for row in rows:
276
+ result[row["key"]] = self._deserialize_task_model_prior(
277
+ json.loads(row["data"])
278
+ )
279
+ return result
280
+
281
+ def delete_prior(self, key: str) -> bool:
282
+ """Delete a prior by key.
283
+
284
+ Args:
285
+ key: Storage key to delete.
286
+
287
+ Returns:
288
+ True if a row was deleted.
289
+ """
290
+ with self._connection() as conn:
291
+ cursor = conn.execute("DELETE FROM priors WHERE key = ?", (key,))
292
+ conn.commit()
293
+ return cursor.rowcount > 0
294
+
295
+ # -- observations --------------------------------------------------------
296
+
297
+ def record_observation(self, obs: CallObservation) -> int:
298
+ """Append an observation to the log.
299
+
300
+ Args:
301
+ obs: The observation to record.
302
+
303
+ Returns:
304
+ The auto-generated row ID.
305
+ """
306
+ now = datetime.now(timezone.utc).isoformat()
307
+ data = asdict(obs)
308
+ with self._connection() as conn:
309
+ cursor = conn.execute(
310
+ """INSERT INTO observations (task_type, model_id, data, recorded_at)
311
+ VALUES (?, ?, ?, ?)""",
312
+ (obs.task_type, obs.model_id, json.dumps(data), now),
313
+ )
314
+ conn.commit()
315
+ return cursor.lastrowid or 0
316
+
317
+ def get_observations(
318
+ self,
319
+ task_type: str | None = None,
320
+ model_id: str | None = None,
321
+ limit: int = 100,
322
+ ) -> list[CallObservation]:
323
+ """Query observations with optional filters.
324
+
325
+ Args:
326
+ task_type: Filter by task type (optional).
327
+ model_id: Filter by model ID (optional).
328
+ limit: Maximum rows to return.
329
+
330
+ Returns:
331
+ List of :class:`CallObservation` instances, most recent first.
332
+ """
333
+ clauses: list[str] = []
334
+ params: list[Any] = []
335
+ if task_type is not None:
336
+ clauses.append("task_type = ?")
337
+ params.append(task_type)
338
+ if model_id is not None:
339
+ clauses.append("model_id = ?")
340
+ params.append(model_id)
341
+
342
+ where = f"WHERE {' AND '.join(clauses)}" if clauses else ""
343
+ query = f"SELECT data FROM observations {where} ORDER BY id DESC LIMIT ?"
344
+ params.append(limit)
345
+
346
+ with self._connection() as conn:
347
+ rows = conn.execute(query, params).fetchall()
348
+
349
+ results: list[CallObservation] = []
350
+ for row in rows:
351
+ d = json.loads(row["data"])
352
+ results.append(CallObservation(**d))
353
+ return results
354
+
355
+ def count_observations(
356
+ self, task_type: str | None = None, model_id: str | None = None
357
+ ) -> int:
358
+ """Count observations with optional filters.
359
+
360
+ Args:
361
+ task_type: Filter by task type (optional).
362
+ model_id: Filter by model ID (optional).
363
+
364
+ Returns:
365
+ Row count.
366
+ """
367
+ clauses: list[str] = []
368
+ params: list[Any] = []
369
+ if task_type is not None:
370
+ clauses.append("task_type = ?")
371
+ params.append(task_type)
372
+ if model_id is not None:
373
+ clauses.append("model_id = ?")
374
+ params.append(model_id)
375
+
376
+ where = f"WHERE {' AND '.join(clauses)}" if clauses else ""
377
+ query = f"SELECT COUNT(*) as cnt FROM observations {where}"
378
+
379
+ with self._connection() as conn:
380
+ row = conn.execute(query, params).fetchone()
381
+ return row["cnt"] if row else 0
382
+
383
+ # -- question effectiveness tracking -------------------------------------
384
+
385
+ def update_question_stats(
386
+ self,
387
+ question_type: str,
388
+ task_type: str,
389
+ *,
390
+ positive: bool,
391
+ info_gain: float = 0.0,
392
+ ) -> None:
393
+ """Update effectiveness statistics for a question type.
394
+
395
+ Args:
396
+ question_type: Category of the question (e.g. ``"scope"``).
397
+ task_type: Task type context in which it was asked.
398
+ positive: Whether the question had a positive impact on outcome.
399
+ info_gain: Measured information gain from the answer.
400
+ """
401
+ now = datetime.now(timezone.utc).isoformat()
402
+ with self._connection() as conn:
403
+ existing = conn.execute(
404
+ "SELECT id, asked_count, positive_impact, negative_impact, avg_info_gain "
405
+ "FROM questions WHERE question_type = ? AND task_type = ?",
406
+ (question_type, task_type),
407
+ ).fetchone()
408
+
409
+ if existing is None:
410
+ conn.execute(
411
+ """INSERT INTO questions
412
+ (question_type, task_type, asked_count, positive_impact,
413
+ negative_impact, avg_info_gain, updated_at)
414
+ VALUES (?, ?, 1, ?, ?, ?, ?)""",
415
+ (
416
+ question_type,
417
+ task_type,
418
+ 1 if positive else 0,
419
+ 0 if positive else 1,
420
+ info_gain,
421
+ now,
422
+ ),
423
+ )
424
+ else:
425
+ new_count = existing["asked_count"] + 1
426
+ new_pos = existing["positive_impact"] + (1 if positive else 0)
427
+ new_neg = existing["negative_impact"] + (0 if positive else 1)
428
+ # Running average of info gain
429
+ old_avg = existing["avg_info_gain"]
430
+ new_avg = old_avg + (info_gain - old_avg) / new_count
431
+ conn.execute(
432
+ """UPDATE questions SET
433
+ asked_count = ?, positive_impact = ?, negative_impact = ?,
434
+ avg_info_gain = ?, updated_at = ?
435
+ WHERE id = ?""",
436
+ (new_count, new_pos, new_neg, new_avg, now, existing["id"]),
437
+ )
438
+ conn.commit()
439
+
440
+ def get_question_stats(
441
+ self, task_type: str | None = None
442
+ ) -> list[dict[str, Any]]:
443
+ """Retrieve question effectiveness statistics.
444
+
445
+ Args:
446
+ task_type: Optional filter by task type.
447
+
448
+ Returns:
449
+ List of dicts with question stats.
450
+ """
451
+ if task_type is not None:
452
+ query = "SELECT * FROM questions WHERE task_type = ? ORDER BY avg_info_gain DESC"
453
+ params: tuple[Any, ...] = (task_type,)
454
+ else:
455
+ query = "SELECT * FROM questions ORDER BY avg_info_gain DESC"
456
+ params = ()
457
+
458
+ with self._connection() as conn:
459
+ rows = conn.execute(query, params).fetchall()
460
+
461
+ return [
462
+ {
463
+ "question_type": row["question_type"],
464
+ "task_type": row["task_type"],
465
+ "asked_count": row["asked_count"],
466
+ "positive_impact": row["positive_impact"],
467
+ "negative_impact": row["negative_impact"],
468
+ "avg_info_gain": row["avg_info_gain"],
469
+ "effectiveness": (
470
+ row["positive_impact"] / row["asked_count"]
471
+ if row["asked_count"] > 0
472
+ else 0.0
473
+ ),
474
+ }
475
+ for row in rows
476
+ ]
477
+
478
+ # -- learned weights (online SGD) ----------------------------------------
479
+
480
+ _DEFAULT_WEIGHTS = {
481
+ "specificity": 0.25,
482
+ "constraint_clarity": 0.20,
483
+ "context_completeness": 0.20,
484
+ "ambiguity": 0.20,
485
+ "format_spec": 0.15,
486
+ }
487
+
488
+ def save_learned_weights(
489
+ self,
490
+ weights: dict[str, float],
491
+ n_updates: int,
492
+ last_loss: float,
493
+ ) -> None:
494
+ """Persist learned scoring weights.
495
+
496
+ Args:
497
+ weights: Dimension-name → weight mapping (must sum to ~1.0).
498
+ n_updates: Cumulative number of SGD steps applied.
499
+ last_loss: MSE from the most recent SGD step.
500
+ """
501
+ now = datetime.now(timezone.utc).isoformat()
502
+ with self._connection() as conn:
503
+ conn.execute(
504
+ """INSERT INTO learned_weights (id, weights, n_updates, last_loss, updated_at)
505
+ VALUES (1, ?, ?, ?, ?)
506
+ ON CONFLICT(id) DO UPDATE SET
507
+ weights = excluded.weights,
508
+ n_updates = excluded.n_updates,
509
+ last_loss = excluded.last_loss,
510
+ updated_at = excluded.updated_at""",
511
+ (json.dumps(weights), n_updates, last_loss, now),
512
+ )
513
+ conn.commit()
514
+
515
+ def load_learned_weights(self) -> dict[str, float] | None:
516
+ """Return the learned weight dict, or ``None`` if never saved.
517
+
518
+ Returns:
519
+ Dict mapping dimension names to weights, or None.
520
+ """
521
+ with self._connection() as conn:
522
+ row = conn.execute(
523
+ "SELECT weights FROM learned_weights WHERE id = 1"
524
+ ).fetchone()
525
+ if row is None:
526
+ return None
527
+ return cast(dict[str, float], json.loads(row["weights"]))
528
+
529
+ def get_learned_weight_meta(self) -> dict[str, Any]:
530
+ """Return metadata about the current learned weights.
531
+
532
+ Returns:
533
+ Dict with keys ``n_updates``, ``last_loss``, ``updated_at``.
534
+ If no weights exist yet, returns zeroed defaults.
535
+ """
536
+ with self._connection() as conn:
537
+ row = conn.execute(
538
+ "SELECT n_updates, last_loss, updated_at FROM learned_weights WHERE id = 1"
539
+ ).fetchone()
540
+ if row is None:
541
+ return {"n_updates": 0, "last_loss": 0.0, "updated_at": None}
542
+ return dict(row)
543
+
544
+ # -- session management --------------------------------------------------
545
+
546
+ def create_session(
547
+ self,
548
+ session_id: str,
549
+ original_prompt: str,
550
+ task_type: str | None = None,
551
+ model_id: str | None = None,
552
+ ) -> None:
553
+ """Create a new elicitation session.
554
+
555
+ Args:
556
+ session_id: Unique session identifier.
557
+ original_prompt: The user's original prompt.
558
+ task_type: Optional task type classification.
559
+ model_id: Optional model being used.
560
+ """
561
+ now = datetime.now(timezone.utc).isoformat()
562
+ with self._connection() as conn:
563
+ conn.execute(
564
+ """INSERT INTO sessions
565
+ (session_id, original_prompt, task_type, model_id, created_at)
566
+ VALUES (?, ?, ?, ?, ?)""",
567
+ (session_id, original_prompt, task_type, model_id, now),
568
+ )
569
+ conn.commit()
570
+
571
+ def update_session(
572
+ self,
573
+ session_id: str,
574
+ *,
575
+ questions: list[dict[str, Any]] | None = None,
576
+ answers: dict[str, str] | None = None,
577
+ spec: dict[str, Any] | None = None,
578
+ final_score: float | None = None,
579
+ ) -> None:
580
+ """Update fields on an existing session.
581
+
582
+ Args:
583
+ session_id: Session to update.
584
+ questions: Updated questions list.
585
+ answers: Updated answers dict.
586
+ spec: The refined IntentSpec as a dict.
587
+ final_score: Final quality score achieved.
588
+ """
589
+ updates: list[str] = []
590
+ params: list[Any] = []
591
+
592
+ if questions is not None:
593
+ updates.append("questions_json = ?")
594
+ params.append(json.dumps(questions))
595
+ if answers is not None:
596
+ updates.append("answers_json = ?")
597
+ params.append(json.dumps(answers))
598
+ if spec is not None:
599
+ updates.append("spec_json = ?")
600
+ params.append(json.dumps(spec))
601
+ if final_score is not None:
602
+ updates.append("final_score = ?")
603
+ params.append(final_score)
604
+ updates.append("completed_at = ?")
605
+ params.append(datetime.now(timezone.utc).isoformat())
606
+
607
+ if not updates:
608
+ return
609
+
610
+ params.append(session_id)
611
+ sql = f"UPDATE sessions SET {', '.join(updates)} WHERE session_id = ?"
612
+
613
+ with self._connection() as conn:
614
+ conn.execute(sql, params)
615
+ conn.commit()
616
+
617
+ def get_session(self, session_id: str) -> dict[str, Any] | None:
618
+ """Retrieve a session by ID.
619
+
620
+ Args:
621
+ session_id: Session identifier.
622
+
623
+ Returns:
624
+ Session data as a dict, or ``None``.
625
+ """
626
+ with self._connection() as conn:
627
+ row = conn.execute(
628
+ "SELECT * FROM sessions WHERE session_id = ?", (session_id,)
629
+ ).fetchone()
630
+ if row is None:
631
+ return None
632
+ return {
633
+ "session_id": row["session_id"],
634
+ "original_prompt": row["original_prompt"],
635
+ "task_type": row["task_type"],
636
+ "model_id": row["model_id"],
637
+ "questions": json.loads(row["questions_json"]),
638
+ "answers": json.loads(row["answers_json"]),
639
+ "spec": json.loads(row["spec_json"]) if row["spec_json"] else None,
640
+ "final_score": row["final_score"],
641
+ "created_at": row["created_at"],
642
+ "completed_at": row["completed_at"],
643
+ }
644
+
645
+ # -- task management -----------------------------------------------------
646
+
647
+ def save_task(self, task_data: dict[str, Any]) -> None:
648
+ """Insert or update a task record.
649
+
650
+ Args:
651
+ task_data: Dict with keys matching the ``tasks`` table columns.
652
+ Must include ``id`` and ``title``.
653
+ """
654
+ now = datetime.now(timezone.utc).isoformat()
655
+ with self._connection() as conn:
656
+ conn.execute(
657
+ """INSERT INTO tasks
658
+ (id, parent_id, session_id, title, description, state,
659
+ dependencies, spec_json, result_json, metadata_json,
660
+ created_at, updated_at)
661
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
662
+ ON CONFLICT(id) DO UPDATE SET
663
+ state = excluded.state,
664
+ result_json = excluded.result_json,
665
+ metadata_json = excluded.metadata_json,
666
+ updated_at = excluded.updated_at""",
667
+ (
668
+ task_data["id"],
669
+ task_data.get("parent_id"),
670
+ task_data.get("session_id"),
671
+ task_data["title"],
672
+ task_data.get("description", ""),
673
+ task_data.get("state", "pending"),
674
+ json.dumps(task_data.get("dependencies", [])),
675
+ json.dumps(task_data.get("spec")) if task_data.get("spec") else None,
676
+ json.dumps(task_data.get("result")) if task_data.get("result") else None,
677
+ json.dumps(task_data.get("metadata", {})),
678
+ task_data.get("created_at", now),
679
+ now,
680
+ ),
681
+ )
682
+ conn.commit()
683
+
684
+ def get_task(self, task_id: str) -> dict[str, Any] | None:
685
+ """Retrieve a task by ID.
686
+
687
+ Args:
688
+ task_id: Task identifier.
689
+
690
+ Returns:
691
+ Task data as a dict, or ``None``.
692
+ """
693
+ with self._connection() as conn:
694
+ row = conn.execute(
695
+ "SELECT * FROM tasks WHERE id = ?", (task_id,)
696
+ ).fetchone()
697
+ if row is None:
698
+ return None
699
+ return self._row_to_task(row)
700
+
701
+ def get_tasks(
702
+ self,
703
+ session_id: str | None = None,
704
+ state: str | None = None,
705
+ limit: int = 100,
706
+ ) -> list[dict[str, Any]]:
707
+ """Query tasks with optional filters.
708
+
709
+ Args:
710
+ session_id: Filter by session.
711
+ state: Filter by state.
712
+ limit: Maximum rows to return.
713
+
714
+ Returns:
715
+ List of task dicts, most recently updated first.
716
+ """
717
+ clauses: list[str] = []
718
+ params: list[Any] = []
719
+ if session_id is not None:
720
+ clauses.append("session_id = ?")
721
+ params.append(session_id)
722
+ if state is not None:
723
+ clauses.append("state = ?")
724
+ params.append(state)
725
+
726
+ where = f"WHERE {' AND '.join(clauses)}" if clauses else ""
727
+ query = f"SELECT * FROM tasks {where} ORDER BY updated_at DESC LIMIT ?"
728
+ params.append(limit)
729
+
730
+ with self._connection() as conn:
731
+ rows = conn.execute(query, params).fetchall()
732
+ return [self._row_to_task(row) for row in rows]
733
+
734
+ def update_task_state(self, task_id: str, state: str) -> None:
735
+ """Update the state of a task.
736
+
737
+ Args:
738
+ task_id: Task identifier.
739
+ state: New state value.
740
+ """
741
+ now = datetime.now(timezone.utc).isoformat()
742
+ with self._connection() as conn:
743
+ conn.execute(
744
+ "UPDATE tasks SET state = ?, updated_at = ? WHERE id = ?",
745
+ (state, now, task_id),
746
+ )
747
+ conn.commit()
748
+
749
+ @staticmethod
750
+ def _row_to_task(row: sqlite3.Row) -> dict[str, Any]:
751
+ """Convert a database row to a task dict."""
752
+ return {
753
+ "id": row["id"],
754
+ "parent_id": row["parent_id"],
755
+ "session_id": row["session_id"],
756
+ "title": row["title"],
757
+ "description": row["description"],
758
+ "state": row["state"],
759
+ "dependencies": json.loads(row["dependencies"]),
760
+ "spec": json.loads(row["spec_json"]) if row["spec_json"] else None,
761
+ "result": json.loads(row["result_json"]) if row["result_json"] else None,
762
+ "metadata": json.loads(row["metadata_json"]),
763
+ "created_at": row["created_at"],
764
+ "updated_at": row["updated_at"],
765
+ }
766
+
767
+ # -- prompt history --------------------------------------------------------
768
+
769
+ def record_prompt(self, entry: dict[str, Any]) -> int:
770
+ """Append a prompt quality record to history.
771
+
772
+ Args:
773
+ entry: Dict with prompt analysis data including ``prompt_text``,
774
+ ``quality_score``, dimension scores, ``task_type``, etc.
775
+
776
+ Returns:
777
+ The auto-generated row ID.
778
+ """
779
+ now = datetime.now(timezone.utc).isoformat()
780
+ with self._connection() as conn:
781
+ cursor = conn.execute(
782
+ """INSERT INTO prompt_history
783
+ (timestamp, prompt_text, quality_score, specificity,
784
+ constraint_clarity, context_completeness, ambiguity,
785
+ format_spec, task_type, complexity, route_chosen,
786
+ word_count, grade, session_context)
787
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
788
+ (
789
+ now,
790
+ entry.get("prompt_text", ""),
791
+ entry.get("quality_score", 0.0),
792
+ entry.get("specificity", 0.0),
793
+ entry.get("constraint_clarity", 0.0),
794
+ entry.get("context_completeness", 0.0),
795
+ entry.get("ambiguity", 0.0),
796
+ entry.get("format_spec", 0.0),
797
+ entry.get("task_type", "general"),
798
+ entry.get("complexity", 0.0),
799
+ entry.get("route_chosen", "refine"),
800
+ entry.get("word_count", 0),
801
+ entry.get("grade", "C"),
802
+ entry.get("session_context", "default"),
803
+ ),
804
+ )
805
+ conn.commit()
806
+ return cursor.lastrowid or 0
807
+
808
+ def get_prompt_history(
809
+ self,
810
+ limit: int = 100,
811
+ session_context: str | None = None,
812
+ ) -> list[dict[str, Any]]:
813
+ """Retrieve prompt history records.
814
+
815
+ Args:
816
+ limit: Maximum rows to return.
817
+ session_context: Optional filter by session context.
818
+
819
+ Returns:
820
+ List of prompt history dicts, most recent first.
821
+ """
822
+ if session_context is not None:
823
+ query = (
824
+ "SELECT * FROM prompt_history WHERE session_context = ? "
825
+ "ORDER BY id DESC LIMIT ?"
826
+ )
827
+ params: tuple[Any, ...] = (session_context, limit)
828
+ else:
829
+ query = "SELECT * FROM prompt_history ORDER BY id DESC LIMIT ?"
830
+ params = (limit,)
831
+
832
+ with self._connection() as conn:
833
+ rows = conn.execute(query, params).fetchall()
834
+
835
+ return [
836
+ {
837
+ "id": row["id"],
838
+ "timestamp": row["timestamp"],
839
+ "prompt_text": row["prompt_text"],
840
+ "quality_score": row["quality_score"],
841
+ "specificity": row["specificity"],
842
+ "constraint_clarity": row["constraint_clarity"],
843
+ "context_completeness": row["context_completeness"],
844
+ "ambiguity": row["ambiguity"],
845
+ "format_spec": row["format_spec"],
846
+ "task_type": row["task_type"],
847
+ "complexity": row["complexity"],
848
+ "route_chosen": row["route_chosen"],
849
+ "word_count": row["word_count"],
850
+ "grade": row["grade"],
851
+ "session_context": row["session_context"],
852
+ }
853
+ for row in rows
854
+ ]
855
+
856
+ def get_prompt_stats(
857
+ self,
858
+ window: int = 50,
859
+ session_context: str | None = None,
860
+ ) -> dict[str, Any]:
861
+ """Compute aggregate prompt quality statistics.
862
+
863
+ Args:
864
+ window: Number of recent prompts to consider.
865
+ session_context: Optional filter by session context.
866
+
867
+ Returns:
868
+ Dict with stats: total, averages, trend, weak/strong areas.
869
+ """
870
+ history = self.get_prompt_history(limit=window, session_context=session_context)
871
+
872
+ if not history:
873
+ return {
874
+ "total_prompts": 0,
875
+ "avg_quality": 0.0,
876
+ "trend": "no_data",
877
+ "grade_distribution": {},
878
+ "learning_curve": [],
879
+ "weak_areas": [],
880
+ "strong_areas": [],
881
+ }
882
+
883
+ # Reverse to chronological order
884
+ history = list(reversed(history))
885
+
886
+ scores = [h["quality_score"] for h in history]
887
+ total = len(scores)
888
+
889
+ # Compute avg for recent vs older
890
+ avg_all = sum(scores) / total
891
+ recent_10 = scores[-10:] if total >= 10 else scores
892
+ avg_recent = sum(recent_10) / len(recent_10)
893
+
894
+ if total >= 10:
895
+ older_10 = scores[:10]
896
+ avg_older = sum(older_10) / len(older_10)
897
+ if avg_recent > avg_older + 0.05:
898
+ trend = "improving"
899
+ elif avg_recent < avg_older - 0.05:
900
+ trend = "declining"
901
+ else:
902
+ trend = "stable"
903
+ else:
904
+ trend = "insufficient_data"
905
+
906
+ # Grade distribution
907
+ grades: dict[str, int] = {}
908
+ for h in history:
909
+ g = h["grade"]
910
+ grades[g] = grades.get(g, 0) + 1
911
+
912
+ # Dimension averages
913
+ dims = {
914
+ "specificity": [h["specificity"] for h in history],
915
+ "constraint_clarity": [h["constraint_clarity"] for h in history],
916
+ "context_completeness": [h["context_completeness"] for h in history],
917
+ "format_spec": [h["format_spec"] for h in history],
918
+ }
919
+ dim_avgs = {k: sum(v) / len(v) for k, v in dims.items()}
920
+
921
+ # Weak/strong
922
+ sorted_dims = sorted(dim_avgs.items(), key=lambda x: x[1])
923
+ weak = [d[0] for d in sorted_dims[:2] if d[1] < 0.6]
924
+ strong = [d[0] for d in sorted_dims[-2:] if d[1] >= 0.6]
925
+
926
+ # Learning curve (group by chunks of 5)
927
+ curve: list[float] = []
928
+ chunk_size = max(1, total // 10) if total >= 10 else 1
929
+ for i in range(0, total, chunk_size):
930
+ chunk = scores[i: i + chunk_size]
931
+ curve.append(round(sum(chunk) / len(chunk), 3))
932
+
933
+ return {
934
+ "total_prompts": total,
935
+ "avg_quality": round(avg_all, 3),
936
+ "avg_quality_recent_10": round(avg_recent, 3),
937
+ "trend": trend,
938
+ "grade_distribution": grades,
939
+ "learning_curve": curve,
940
+ "dimension_averages": dim_avgs,
941
+ "weak_areas": weak,
942
+ "strong_areas": strong,
943
+ }
944
+
945
+ # -- serialization helpers -----------------------------------------------
946
+
947
+ @staticmethod
948
+ def _serialize_normal(p: NormalPrior) -> dict[str, Any]:
949
+ return {
950
+ "mean": p.mean,
951
+ "variance": p.variance,
952
+ "n_observations": p.n_observations,
953
+ "_m2": p._m2,
954
+ "decay": p.decay,
955
+ }
956
+
957
+ @staticmethod
958
+ def _deserialize_normal(d: dict[str, Any]) -> NormalPrior:
959
+ return NormalPrior(**d)
960
+
961
+ @staticmethod
962
+ def _serialize_beta(p: BetaPrior) -> dict[str, Any]:
963
+ return {"alpha": p.alpha, "beta": p.beta}
964
+
965
+ @staticmethod
966
+ def _deserialize_beta(d: dict[str, Any]) -> BetaPrior:
967
+ return BetaPrior(**d)
968
+
969
+ def _serialize_task_model_prior(self, prior: TaskModelPrior) -> dict[str, Any]:
970
+ """Serialise a :class:`TaskModelPrior` to a JSON-safe dict."""
971
+ iterations_data: dict[str, Any] = {}
972
+ for k, profile in prior.iterations.items():
973
+ iterations_data[str(k)] = {
974
+ "score": self._serialize_normal(profile.score),
975
+ "score_delta": self._serialize_normal(profile.score_delta),
976
+ "converge_prob": self._serialize_beta(profile.converge_prob),
977
+ "latency_ms": self._serialize_normal(profile.latency_ms),
978
+ }
979
+ return {
980
+ "task_type": prior.task_type,
981
+ "model_id": prior.model_id,
982
+ "created_at": prior.created_at,
983
+ "updated_at": prior.updated_at,
984
+ "total_calls": prior.total_calls,
985
+ "iterations": iterations_data,
986
+ "optimal_depth": self._serialize_normal(prior.optimal_depth),
987
+ "overall_converge_rate": self._serialize_beta(prior.overall_converge_rate),
988
+ "first_call_quality": self._serialize_normal(prior.first_call_quality),
989
+ }
990
+
991
+ def _deserialize_task_model_prior(self, data: dict[str, Any]) -> TaskModelPrior:
992
+ """Deserialise a JSON dict back into a :class:`TaskModelPrior`."""
993
+ iterations: dict[int, IterationProfile] = {}
994
+ for k_str, idata in data.get("iterations", {}).items():
995
+ iterations[int(k_str)] = IterationProfile(
996
+ score=self._deserialize_normal(idata["score"]),
997
+ score_delta=self._deserialize_normal(idata["score_delta"]),
998
+ converge_prob=self._deserialize_beta(idata["converge_prob"]),
999
+ latency_ms=self._deserialize_normal(idata["latency_ms"]),
1000
+ )
1001
+ return TaskModelPrior(
1002
+ task_type=data["task_type"],
1003
+ model_id=data["model_id"],
1004
+ created_at=data["created_at"],
1005
+ updated_at=data["updated_at"],
1006
+ total_calls=data["total_calls"],
1007
+ iterations=iterations,
1008
+ optimal_depth=self._deserialize_normal(data["optimal_depth"]),
1009
+ overall_converge_rate=self._deserialize_beta(data["overall_converge_rate"]),
1010
+ first_call_quality=self._deserialize_normal(data["first_call_quality"]),
1011
+ )
1012
+
1013
+ # -- plan persistence ----------------------------------------------------
1014
+
1015
+ def save_plan(self, plan_dict: dict[str, Any]) -> None:
1016
+ """Upsert a plan (full JSON blob) into the plans table.
1017
+
1018
+ Args:
1019
+ plan_dict: Output of ``Plan.to_dict()``, must include ``plan_id``.
1020
+ """
1021
+ now = datetime.now(timezone.utc).isoformat()
1022
+ plan_id = plan_dict["plan_id"]
1023
+ goal = plan_dict.get("goal", "")
1024
+ data = json.dumps(plan_dict)
1025
+ with self._connection() as conn:
1026
+ conn.execute(
1027
+ """INSERT INTO plans (plan_id, goal, data, created_at, updated_at)
1028
+ VALUES (?, ?, ?, ?, ?)
1029
+ ON CONFLICT(plan_id) DO UPDATE SET
1030
+ data = excluded.data,
1031
+ updated_at = excluded.updated_at""",
1032
+ (plan_id, goal, data, now, now),
1033
+ )
1034
+ conn.commit()
1035
+
1036
+ def load_plan(self, plan_id: str) -> dict[str, Any] | None:
1037
+ """Load a plan dict by plan_id.
1038
+
1039
+ Args:
1040
+ plan_id: The plan identifier.
1041
+
1042
+ Returns:
1043
+ Plan dict or None if not found.
1044
+ """
1045
+ with self._connection() as conn:
1046
+ row = conn.execute(
1047
+ "SELECT data FROM plans WHERE plan_id = ?", (plan_id,)
1048
+ ).fetchone()
1049
+ if row is None:
1050
+ return None
1051
+ return cast(dict[str, Any], json.loads(row["data"]))
1052
+
1053
+ def load_all_plans(self) -> list[dict[str, Any]]:
1054
+ """Load all persisted plans.
1055
+
1056
+ Returns:
1057
+ List of plan dicts, most recently updated first.
1058
+ """
1059
+ with self._connection() as conn:
1060
+ rows = conn.execute(
1061
+ "SELECT data FROM plans ORDER BY updated_at DESC"
1062
+ ).fetchall()
1063
+ return [json.loads(row["data"]) for row in rows]
1064
+
1065
+ def delete_plan(self, plan_id: str) -> bool:
1066
+ """Delete a plan from the store.
1067
+
1068
+ Args:
1069
+ plan_id: The plan identifier.
1070
+
1071
+ Returns:
1072
+ True if a row was deleted.
1073
+ """
1074
+ with self._connection() as conn:
1075
+ cursor = conn.execute(
1076
+ "DELETE FROM plans WHERE plan_id = ?", (plan_id,)
1077
+ )
1078
+ conn.commit()
1079
+ return cursor.rowcount > 0
1080
+
1081
+
1082
+ class SQLiteBackedPriors(AdaptivePriors):
1083
+ """Drop-in replacement for :class:`AdaptivePriors` backed by SQLite.
1084
+
1085
+ Extends the base class but replaces the JSON file persistence with
1086
+ :class:`LoopStore`. The in-memory prior dict is kept synchronised
1087
+ with the database.
1088
+
1089
+ Args:
1090
+ store: The :class:`LoopStore` to use for persistence.
1091
+ """
1092
+
1093
+ def __init__(self, store: LoopStore) -> None:
1094
+ # Bypass the parent __init__ to avoid JSON file loading
1095
+ self.store_path = None
1096
+ self._priors: dict[str, TaskModelPrior] = {}
1097
+ self._store = store
1098
+ self._load_from_store()
1099
+
1100
+ def _load_from_store(self) -> None:
1101
+ """Load all priors from the SQLite store into memory."""
1102
+ self._priors = self._store.load_all_priors()
1103
+
1104
+ def observe(self, observation: CallObservation) -> None:
1105
+ """Record an observation, updating both memory and database.
1106
+
1107
+ Args:
1108
+ observation: The observation to incorporate.
1109
+ """
1110
+ # Use parent observe to update in-memory priors
1111
+ super().observe(observation)
1112
+ # Also log the raw observation
1113
+ self._store.record_observation(observation)
1114
+ # Persist the updated prior
1115
+ key = self._key(observation.task_type, observation.model_id)
1116
+ if key in self._priors:
1117
+ self._store.save_prior(key, self._priors[key])
1118
+
1119
+ def _save(self) -> None:
1120
+ """Persist all priors to SQLite."""
1121
+ for key, prior in self._priors.items():
1122
+ self._store.save_prior(key, prior)
1123
+
1124
+ def _load(self) -> None:
1125
+ """Load all priors from SQLite."""
1126
+ self._load_from_store()