nexo-brain 3.0.1 → 3.0.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -25,7 +25,7 @@ from client_preferences import (
25
25
  resolve_client_runtime_profile,
26
26
  )
27
27
  from cron_recovery import should_run_at_load
28
- from doctor.models import DoctorCheck
28
+ from doctor.models import DoctorCheck, safe_check
29
29
 
30
30
  NEXO_HOME = Path(os.environ.get("NEXO_HOME", str(Path.home() / ".nexo")))
31
31
  NEXO_CODE = Path(os.environ.get("NEXO_CODE", str(Path(__file__).resolve().parents[2])))
@@ -406,21 +406,22 @@ def _load_active_conditioned_learnings() -> list[dict]:
406
406
  import sqlite3
407
407
 
408
408
  conn = sqlite3.connect(str(db_path), timeout=2)
409
- conn.row_factory = sqlite3.Row
410
- table = conn.execute(
411
- "SELECT name FROM sqlite_master WHERE type='table' AND name='learnings'"
412
- ).fetchone()
413
- if not table:
409
+ try:
410
+ conn.row_factory = sqlite3.Row
411
+ table = conn.execute(
412
+ "SELECT name FROM sqlite_master WHERE type='table' AND name='learnings'"
413
+ ).fetchone()
414
+ if not table:
415
+ return []
416
+ rows = conn.execute(
417
+ """SELECT id, title, applies_to
418
+ FROM learnings
419
+ WHERE status = 'active' AND COALESCE(applies_to, '') != ''
420
+ ORDER BY updated_at DESC, id DESC"""
421
+ ).fetchall()
422
+ return [dict(row) for row in rows]
423
+ finally:
414
424
  conn.close()
415
- return []
416
- rows = conn.execute(
417
- """SELECT id, title, applies_to
418
- FROM learnings
419
- WHERE status = 'active' AND COALESCE(applies_to, '') != ''
420
- ORDER BY updated_at DESC, id DESC"""
421
- ).fetchall()
422
- conn.close()
423
- return [dict(row) for row in rows]
424
425
  except Exception:
425
426
  return []
426
427
 
@@ -595,22 +596,23 @@ def _open_protocol_debt_summary(*debt_types: str) -> dict:
595
596
 
596
597
  try:
597
598
  conn = sqlite3.connect(str(db_path), timeout=2)
598
- conn.row_factory = sqlite3.Row
599
- table = conn.execute(
600
- "SELECT name FROM sqlite_master WHERE type='table' AND name='protocol_debt'"
601
- ).fetchone()
602
- if not table:
599
+ try:
600
+ conn.row_factory = sqlite3.Row
601
+ table = conn.execute(
602
+ "SELECT name FROM sqlite_master WHERE type='table' AND name='protocol_debt'"
603
+ ).fetchone()
604
+ if not table:
605
+ return summary
606
+ placeholders = ",".join("?" for _ in debt_types)
607
+ rows = conn.execute(
608
+ f"""SELECT debt_type, COUNT(*) AS total
609
+ FROM protocol_debt
610
+ WHERE status = 'open' AND debt_type IN ({placeholders})
611
+ GROUP BY debt_type""",
612
+ tuple(debt_types),
613
+ ).fetchall()
614
+ finally:
603
615
  conn.close()
604
- return summary
605
- placeholders = ",".join("?" for _ in debt_types)
606
- rows = conn.execute(
607
- f"""SELECT debt_type, COUNT(*) AS total
608
- FROM protocol_debt
609
- WHERE status = 'open' AND debt_type IN ({placeholders})
610
- GROUP BY debt_type""",
611
- tuple(debt_types),
612
- ).fetchall()
613
- conn.close()
614
616
  except Exception:
615
617
  return summary
616
618
 
@@ -1172,14 +1174,16 @@ def check_stale_sessions() -> DoctorCheck:
1172
1174
  summary="No DB to check sessions",
1173
1175
  )
1174
1176
  conn = sqlite3.connect(str(db_path), timeout=2)
1175
- conn.row_factory = sqlite3.Row
1176
- cutoff = time.time() - 7200
1177
- day_ago = time.time() - 86400
1178
- rows = conn.execute(
1179
- "SELECT COUNT(*) as cnt FROM sessions WHERE last_update_epoch < ? AND last_update_epoch > ?",
1180
- (cutoff, day_ago),
1181
- ).fetchone()
1182
- conn.close()
1177
+ try:
1178
+ conn.row_factory = sqlite3.Row
1179
+ cutoff = time.time() - 7200
1180
+ day_ago = time.time() - 86400
1181
+ rows = conn.execute(
1182
+ "SELECT COUNT(*) as cnt FROM sessions WHERE last_update_epoch < ? AND last_update_epoch > ?",
1183
+ (cutoff, day_ago),
1184
+ ).fetchone()
1185
+ finally:
1186
+ conn.close()
1183
1187
  count = rows["cnt"] if rows else 0
1184
1188
  if count > 0:
1185
1189
  return DoctorCheck(
@@ -1221,24 +1225,25 @@ def check_cron_freshness() -> DoctorCheck:
1221
1225
  summary="No DB to check cron runs",
1222
1226
  )
1223
1227
  conn = sqlite3.connect(str(db_path), timeout=2)
1224
- # Check if cron_runs table exists
1225
- tables = conn.execute(
1226
- "SELECT name FROM sqlite_master WHERE type='table' AND name='cron_runs'"
1227
- ).fetchone()
1228
- if not tables:
1228
+ try:
1229
+ # Check if cron_runs table exists
1230
+ tables = conn.execute(
1231
+ "SELECT name FROM sqlite_master WHERE type='table' AND name='cron_runs'"
1232
+ ).fetchone()
1233
+ if not tables:
1234
+ return DoctorCheck(
1235
+ id="runtime.cron_freshness",
1236
+ tier="runtime",
1237
+ status="healthy",
1238
+ severity="info",
1239
+ summary="No cron_runs table yet",
1240
+ )
1241
+ # Latest run per cron
1242
+ rows = conn.execute(
1243
+ "SELECT cron_id, MAX(started_at) as last_run FROM cron_runs GROUP BY cron_id"
1244
+ ).fetchall()
1245
+ finally:
1229
1246
  conn.close()
1230
- return DoctorCheck(
1231
- id="runtime.cron_freshness",
1232
- tier="runtime",
1233
- status="healthy",
1234
- severity="info",
1235
- summary="No cron_runs table yet",
1236
- )
1237
- # Latest run per cron
1238
- rows = conn.execute(
1239
- "SELECT cron_id, MAX(started_at) as last_run FROM cron_runs GROUP BY cron_id"
1240
- ).fetchall()
1241
- conn.close()
1242
1247
 
1243
1248
  stale = []
1244
1249
  expectations = _cron_expectations()
@@ -2167,32 +2172,36 @@ def check_protocol_compliance() -> DoctorCheck:
2167
2172
  db_path = NEXO_HOME / "data" / "nexo.db"
2168
2173
  if db_path.is_file():
2169
2174
  conn = sqlite3.connect(str(db_path), timeout=2)
2170
- conn.row_factory = sqlite3.Row
2171
- tables = {
2172
- row["name"]
2173
- for row in conn.execute(
2174
- "SELECT name FROM sqlite_master WHERE type='table' AND name IN ('protocol_tasks', 'protocol_debt')"
2175
- ).fetchall()
2176
- }
2177
- if {"protocol_tasks", "protocol_debt"}.issubset(tables):
2178
- window = "-7 days"
2179
- tasks = conn.execute(
2180
- """SELECT * FROM protocol_tasks
2181
- WHERE opened_at >= datetime('now', ?)
2182
- ORDER BY opened_at DESC""",
2183
- (window,),
2184
- ).fetchall()
2185
- debt_rows = conn.execute(
2186
- """SELECT severity, debt_type, COUNT(*) AS total
2187
- FROM protocol_debt
2188
- WHERE status = 'open' AND created_at >= datetime('now', ?)
2189
- GROUP BY severity, debt_type
2190
- ORDER BY total DESC, debt_type ASC""",
2191
- (window,),
2192
- ).fetchall()
2175
+ try:
2176
+ conn.row_factory = sqlite3.Row
2177
+ tables = {
2178
+ row["name"]
2179
+ for row in conn.execute(
2180
+ "SELECT name FROM sqlite_master WHERE type='table' AND name IN ('protocol_tasks', 'protocol_debt')"
2181
+ ).fetchall()
2182
+ }
2183
+ tasks = None
2184
+ debt_rows = None
2185
+ if {"protocol_tasks", "protocol_debt"}.issubset(tables):
2186
+ window = "-7 days"
2187
+ tasks = conn.execute(
2188
+ """SELECT * FROM protocol_tasks
2189
+ WHERE opened_at >= datetime('now', ?)
2190
+ ORDER BY opened_at DESC""",
2191
+ (window,),
2192
+ ).fetchall()
2193
+ debt_rows = conn.execute(
2194
+ """SELECT severity, debt_type, COUNT(*) AS total
2195
+ FROM protocol_debt
2196
+ WHERE status = 'open' AND created_at >= datetime('now', ?)
2197
+ GROUP BY severity, debt_type
2198
+ ORDER BY total DESC, debt_type ASC""",
2199
+ (window,),
2200
+ ).fetchall()
2201
+ finally:
2193
2202
  conn.close()
2194
2203
 
2195
- if tasks or debt_rows:
2204
+ if tasks is not None and debt_rows is not None and (tasks or debt_rows):
2196
2205
  closed_tasks = [row for row in tasks if row["status"] != "open"]
2197
2206
  verify_required = [row for row in closed_tasks if row["must_verify"] and row["status"] == "done"]
2198
2207
  verify_ok = [row for row in verify_required if (row["close_evidence"] or "").strip()]
@@ -2410,11 +2419,13 @@ def check_state_watchers() -> DoctorCheck:
2410
2419
  if db_path.is_file():
2411
2420
  try:
2412
2421
  conn = sqlite3.connect(str(db_path))
2413
- row = conn.execute(
2414
- "SELECT COUNT(*) FROM state_watchers WHERE status = 'active'"
2415
- ).fetchone()
2416
- conn.close()
2417
- active_watchers = int(row[0] or 0) if row else 0
2422
+ try:
2423
+ row = conn.execute(
2424
+ "SELECT COUNT(*) FROM state_watchers WHERE status = 'active'"
2425
+ ).fetchone()
2426
+ active_watchers = int(row[0] or 0) if row else 0
2427
+ finally:
2428
+ conn.close()
2418
2429
  except Exception:
2419
2430
  active_watchers = 0
2420
2431
 
@@ -2518,37 +2529,38 @@ def check_automation_telemetry(days: int = 7) -> DoctorCheck:
2518
2529
 
2519
2530
  try:
2520
2531
  conn = sqlite3.connect(str(db_path), timeout=2)
2521
- conn.row_factory = sqlite3.Row
2522
- table = conn.execute(
2523
- "SELECT name FROM sqlite_master WHERE type='table' AND name='automation_runs'"
2524
- ).fetchone()
2525
- if not table:
2526
- conn.close()
2527
- return DoctorCheck(
2528
- id="runtime.automation_telemetry",
2529
- tier="runtime",
2530
- status="degraded",
2531
- severity="warn",
2532
- summary="Automation telemetry schema is missing",
2533
- evidence=["table automation_runs not found"],
2534
- repair_plan=["Run NEXO migrations before trusting automation cost/parity metrics"],
2535
- escalation_prompt="Shared automation runs are happening without the telemetry table that release metrics depend on.",
2536
- )
2532
+ try:
2533
+ conn.row_factory = sqlite3.Row
2534
+ table = conn.execute(
2535
+ "SELECT name FROM sqlite_master WHERE type='table' AND name='automation_runs'"
2536
+ ).fetchone()
2537
+ if not table:
2538
+ return DoctorCheck(
2539
+ id="runtime.automation_telemetry",
2540
+ tier="runtime",
2541
+ status="degraded",
2542
+ severity="warn",
2543
+ summary="Automation telemetry schema is missing",
2544
+ evidence=["table automation_runs not found"],
2545
+ repair_plan=["Run NEXO migrations before trusting automation cost/parity metrics"],
2546
+ escalation_prompt="Shared automation runs are happening without the telemetry table that release metrics depend on.",
2547
+ )
2537
2548
 
2538
- row = conn.execute(
2539
- """
2540
- SELECT
2541
- COUNT(*) AS runs,
2542
- SUM(CASE WHEN (input_tokens + cached_input_tokens + output_tokens) > 0 THEN 1 ELSE 0 END) AS usage_runs,
2543
- SUM(CASE WHEN total_cost_usd IS NOT NULL THEN 1 ELSE 0 END) AS cost_runs,
2544
- SUM(CASE WHEN cost_source = 'pricing_unavailable' THEN 1 ELSE 0 END) AS pricing_gaps,
2545
- GROUP_CONCAT(DISTINCT backend) AS backends
2546
- FROM automation_runs
2547
- WHERE created_at >= datetime('now', ?)
2548
- """,
2549
- (f"-{days} days",),
2550
- ).fetchone()
2551
- conn.close()
2549
+ row = conn.execute(
2550
+ """
2551
+ SELECT
2552
+ COUNT(*) AS runs,
2553
+ SUM(CASE WHEN (input_tokens + cached_input_tokens + output_tokens) > 0 THEN 1 ELSE 0 END) AS usage_runs,
2554
+ SUM(CASE WHEN total_cost_usd IS NOT NULL THEN 1 ELSE 0 END) AS cost_runs,
2555
+ SUM(CASE WHEN cost_source = 'pricing_unavailable' THEN 1 ELSE 0 END) AS pricing_gaps,
2556
+ GROUP_CONCAT(DISTINCT backend) AS backends
2557
+ FROM automation_runs
2558
+ WHERE created_at >= datetime('now', ?)
2559
+ """,
2560
+ (f"-{days} days",),
2561
+ ).fetchone()
2562
+ finally:
2563
+ conn.close()
2552
2564
  except Exception as exc:
2553
2565
  return DoctorCheck(
2554
2566
  id="runtime.automation_telemetry",
@@ -2623,22 +2635,22 @@ def check_automation_telemetry(days: int = 7) -> DoctorCheck:
2623
2635
  def run_runtime_checks(fix: bool = False) -> list[DoctorCheck]:
2624
2636
  """Run all runtime-tier checks. Read-only by default."""
2625
2637
  return [
2626
- check_immune_status(),
2627
- check_watchdog_status(),
2628
- check_stale_sessions(),
2629
- check_cron_freshness(),
2630
- check_client_backend_preferences(),
2631
- check_client_bootstrap_parity(fix=fix),
2632
- check_codex_session_parity(),
2633
- check_codex_conditioned_file_discipline(),
2634
- check_claude_desktop_shared_brain(),
2635
- check_transcript_source_parity(),
2636
- check_client_assumption_regressions(),
2637
- check_protocol_compliance(),
2638
- check_automation_telemetry(),
2639
- check_state_watchers(),
2640
- check_release_artifact_sync(),
2641
- check_launchagent_integrity(fix=fix),
2642
- check_personal_script_registry(fix=fix),
2643
- check_skill_health(fix=fix),
2638
+ safe_check(check_immune_status),
2639
+ safe_check(check_watchdog_status),
2640
+ safe_check(check_stale_sessions),
2641
+ safe_check(check_cron_freshness),
2642
+ safe_check(check_client_backend_preferences),
2643
+ safe_check(check_client_bootstrap_parity, fix=fix),
2644
+ safe_check(check_codex_session_parity),
2645
+ safe_check(check_codex_conditioned_file_discipline),
2646
+ safe_check(check_claude_desktop_shared_brain),
2647
+ safe_check(check_transcript_source_parity),
2648
+ safe_check(check_client_assumption_regressions),
2649
+ safe_check(check_protocol_compliance),
2650
+ safe_check(check_automation_telemetry),
2651
+ safe_check(check_state_watchers),
2652
+ safe_check(check_release_artifact_sync),
2653
+ safe_check(check_launchagent_integrity, fix=fix),
2654
+ safe_check(check_personal_script_registry, fix=fix),
2655
+ safe_check(check_skill_health, fix=fix),
2644
2656
  ]
@@ -120,52 +120,54 @@ def save_objective(obj: dict):
120
120
  def get_week_data(db_path: str) -> dict:
121
121
  """Gather last 7 days of learnings, decisions, changes, diaries."""
122
122
  conn = sqlite3.connect(db_path, timeout=10)
123
- conn.row_factory = sqlite3.Row
124
- cutoff_epoch = time.time() - 7 * 86400
125
- cutoff_date = (date.today() - timedelta(days=7)).isoformat()
126
-
127
- data = {}
128
-
129
- rows = conn.execute(
130
- "SELECT category, title, content FROM learnings WHERE created_at > ? ORDER BY created_at DESC LIMIT 50",
131
- (cutoff_epoch,)
132
- ).fetchall()
133
- data["learnings"] = [dict(r) for r in rows]
134
-
135
- rows = conn.execute(
136
- "SELECT domain, decision, alternatives, based_on, confidence, outcome FROM decisions "
137
- "WHERE created_at > ? ORDER BY created_at DESC LIMIT 20",
138
- (cutoff_date,)
139
- ).fetchall()
140
- data["decisions"] = [dict(r) for r in rows]
141
-
142
- rows = conn.execute(
143
- "SELECT files, what_changed, why, affects, risks FROM change_log "
144
- "WHERE created_at > ? ORDER BY created_at DESC LIMIT 30",
145
- (cutoff_date,)
146
- ).fetchall()
147
- data["changes"] = [dict(r) for r in rows]
148
-
149
- rows = conn.execute(
150
- "SELECT summary, decisions as diary_decisions, pending, mental_state, domain, user_signals "
151
- "FROM session_diary WHERE created_at > ? ORDER BY created_at DESC LIMIT 20",
152
- (cutoff_date,)
153
- ).fetchall()
154
- data["diaries"] = [dict(r) for r in rows]
155
-
156
- rows = conn.execute(
157
- "SELECT * FROM evolution_log ORDER BY id DESC LIMIT 20"
158
- ).fetchall()
159
- data["evolution_history"] = [dict(r) for r in rows]
160
-
161
- rows = conn.execute(
162
- "SELECT dimension, score, delta, measured_at FROM evolution_metrics "
163
- "WHERE id IN (SELECT MAX(id) FROM evolution_metrics GROUP BY dimension)"
164
- ).fetchall()
165
- data["current_metrics"] = {r["dimension"]: dict(r) for r in rows}
166
-
167
- conn.close()
168
- return data
123
+ try:
124
+ conn.row_factory = sqlite3.Row
125
+ cutoff_epoch = time.time() - 7 * 86400
126
+ cutoff_date = (date.today() - timedelta(days=7)).isoformat()
127
+
128
+ data = {}
129
+
130
+ rows = conn.execute(
131
+ "SELECT category, title, content FROM learnings WHERE created_at > ? ORDER BY created_at DESC LIMIT 50",
132
+ (cutoff_epoch,)
133
+ ).fetchall()
134
+ data["learnings"] = [dict(r) for r in rows]
135
+
136
+ rows = conn.execute(
137
+ "SELECT domain, decision, alternatives, based_on, confidence, outcome FROM decisions "
138
+ "WHERE created_at > ? ORDER BY created_at DESC LIMIT 20",
139
+ (cutoff_date,)
140
+ ).fetchall()
141
+ data["decisions"] = [dict(r) for r in rows]
142
+
143
+ rows = conn.execute(
144
+ "SELECT files, what_changed, why, affects, risks FROM change_log "
145
+ "WHERE created_at > ? ORDER BY created_at DESC LIMIT 30",
146
+ (cutoff_date,)
147
+ ).fetchall()
148
+ data["changes"] = [dict(r) for r in rows]
149
+
150
+ rows = conn.execute(
151
+ "SELECT summary, decisions as diary_decisions, pending, mental_state, domain, user_signals "
152
+ "FROM session_diary WHERE created_at > ? ORDER BY created_at DESC LIMIT 20",
153
+ (cutoff_date,)
154
+ ).fetchall()
155
+ data["diaries"] = [dict(r) for r in rows]
156
+
157
+ rows = conn.execute(
158
+ "SELECT * FROM evolution_log ORDER BY id DESC LIMIT 20"
159
+ ).fetchall()
160
+ data["evolution_history"] = [dict(r) for r in rows]
161
+
162
+ rows = conn.execute(
163
+ "SELECT dimension, score, delta, measured_at FROM evolution_metrics "
164
+ "WHERE id IN (SELECT MAX(id) FROM evolution_metrics GROUP BY dimension)"
165
+ ).fetchall()
166
+ data["current_metrics"] = {r["dimension"]: dict(r) for r in rows}
167
+
168
+ return data
169
+ finally:
170
+ conn.close()
169
171
 
170
172
 
171
173
  def create_snapshot(files_to_backup: list) -> str:
@@ -147,25 +147,27 @@ def backfill_decisions() -> int:
147
147
  def backfill_somatic() -> int:
148
148
  """Read somatic_markers from cognitive.db → create file/area nodes with risk."""
149
149
  cdb = _cognitive_db()
150
- rows = cdb.execute(
151
- "SELECT target, target_type, risk_score, incident_count FROM somatic_markers"
152
- ).fetchall()
153
- count = 0
154
- for row in rows:
155
- target_type = row["target_type"] or "file"
156
- node_ref = f"{target_type}:{row['target']}"
157
- kg.upsert_node(
158
- node_type=target_type,
159
- node_ref=node_ref,
160
- label=os.path.basename(row["target"]) or row["target"],
161
- properties={
162
- "risk_score": row["risk_score"],
163
- "incident_count": row["incident_count"],
164
- },
165
- )
166
- count += 1
167
- cdb.close()
168
- return count
150
+ try:
151
+ rows = cdb.execute(
152
+ "SELECT target, target_type, risk_score, incident_count FROM somatic_markers"
153
+ ).fetchall()
154
+ count = 0
155
+ for row in rows:
156
+ target_type = row["target_type"] or "file"
157
+ node_ref = f"{target_type}:{row['target']}"
158
+ kg.upsert_node(
159
+ node_type=target_type,
160
+ node_ref=node_ref,
161
+ label=os.path.basename(row["target"]) or row["target"],
162
+ properties={
163
+ "risk_score": row["risk_score"],
164
+ "incident_count": row["incident_count"],
165
+ },
166
+ )
167
+ count += 1
168
+ return count
169
+ finally:
170
+ cdb.close()
169
171
 
170
172
 
171
173
  def run_full_backfill() -> dict:
@@ -1,7 +1,7 @@
1
1
  """Opportunistic maintenance — run overdue tasks on MCP startup."""
2
2
 
3
3
  import time
4
- from datetime import datetime
4
+ from datetime import datetime, timezone
5
5
  from db import get_db
6
6
 
7
7
 
@@ -16,7 +16,7 @@ def check_and_run_overdue():
16
16
  if last_run:
17
17
  try:
18
18
  last_dt = datetime.strptime(last_run, "%Y-%m-%dT%H:%M:%S")
19
- hours_since = (datetime.now(datetime.timezone.utc).replace(tzinfo=None) - last_dt).total_seconds() / 3600
19
+ hours_since = (datetime.now(timezone.utc).replace(tzinfo=None) - last_dt).total_seconds() / 3600
20
20
  if hours_since < interval:
21
21
  continue
22
22
  except (ValueError, TypeError):
@@ -28,7 +28,7 @@ def check_and_run_overdue():
28
28
  conn.execute(
29
29
  "UPDATE maintenance_schedule SET last_run_at = ?, last_duration_ms = ?, "
30
30
  "run_count = run_count + 1 WHERE task_name = ?",
31
- (datetime.now(datetime.timezone.utc).replace(tzinfo=None).strftime("%Y-%m-%dT%H:%M:%S"), duration_ms, task))
31
+ (datetime.now(timezone.utc).replace(tzinfo=None).strftime("%Y-%m-%dT%H:%M:%S"), duration_ms, task))
32
32
  conn.commit()
33
33
  ran.append({"task": task, "duration_ms": duration_ms})
34
34
  except Exception as e:
@@ -30,15 +30,17 @@ MODELS = {
30
30
  def verify():
31
31
  """Check current embedding dimensions in the database."""
32
32
  conn = sqlite3.connect(DB_PATH)
33
- for table in ["stm_memories", "ltm_memories"]:
34
- count = conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone()[0]
35
- if count == 0:
36
- print(f" {table}: {count} rows (empty)")
37
- continue
38
- row = conn.execute(f"SELECT embedding FROM {table} LIMIT 1").fetchone()
39
- vec = np.frombuffer(row[0], dtype=np.float32)
40
- print(f" {table}: {count} rows, embedding dim = {len(vec)}")
41
- conn.close()
33
+ try:
34
+ for table in ["stm_memories", "ltm_memories"]:
35
+ count = conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone()[0]
36
+ if count == 0:
37
+ print(f" {table}: {count} rows (empty)")
38
+ continue
39
+ row = conn.execute(f"SELECT embedding FROM {table} LIMIT 1").fetchone()
40
+ vec = np.frombuffer(row[0], dtype=np.float32)
41
+ print(f" {table}: {count} rows, embedding dim = {len(vec)}")
42
+ finally:
43
+ conn.close()
42
44
 
43
45
 
44
46
  def upgrade():
@@ -62,31 +64,31 @@ def upgrade():
62
64
  model = TextEmbedding(model_name)
63
65
 
64
66
  conn = sqlite3.connect(DB_PATH)
65
-
66
- for table in ["stm_memories", "ltm_memories"]:
67
- rows = conn.execute(f"SELECT id, content FROM {table}").fetchall()
68
- if not rows:
69
- print(f"\n{table}: empty, skipping")
70
- continue
71
-
72
- print(f"\n{table}: re-embedding {len(rows)} memories...")
73
- t0 = time.time()
74
-
75
- # Batch embed for speed
76
- contents = [r[1] for r in rows]
77
- ids = [r[0] for r in rows]
78
-
79
- embeddings = list(model.embed(contents))
80
-
81
- for mem_id, emb in zip(ids, embeddings):
82
- blob = np.array(emb, dtype=np.float32).tobytes()
83
- conn.execute(f"UPDATE {table} SET embedding = ? WHERE id = ?", (blob, mem_id))
84
-
85
- conn.commit()
86
- elapsed = time.time() - t0
87
- print(f" Done: {len(rows)} memories in {elapsed:.1f}s ({elapsed/len(rows)*1000:.0f}ms/memory)")
88
-
89
- conn.close()
67
+ try:
68
+ for table in ["stm_memories", "ltm_memories"]:
69
+ rows = conn.execute(f"SELECT id, content FROM {table}").fetchall()
70
+ if not rows:
71
+ print(f"\n{table}: empty, skipping")
72
+ continue
73
+
74
+ print(f"\n{table}: re-embedding {len(rows)} memories...")
75
+ t0 = time.time()
76
+
77
+ # Batch embed for speed
78
+ contents = [r[1] for r in rows]
79
+ ids = [r[0] for r in rows]
80
+
81
+ embeddings = list(model.embed(contents))
82
+
83
+ for mem_id, emb in zip(ids, embeddings):
84
+ blob = np.array(emb, dtype=np.float32).tobytes()
85
+ conn.execute(f"UPDATE {table} SET embedding = ? WHERE id = ?", (blob, mem_id))
86
+
87
+ conn.commit()
88
+ elapsed = time.time() - t0
89
+ print(f" Done: {len(rows)} memories in {elapsed:.1f}s ({elapsed/len(rows)*1000:.0f}ms/memory)")
90
+ finally:
91
+ conn.close()
90
92
 
91
93
  print("\nAfter upgrade:")
92
94
  verify()