spooling 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spooling/ingest.py ADDED
@@ -0,0 +1,496 @@
1
+ """Ingestion pipeline - parse AI coding sessions from multiple providers and store in pgvector."""
2
+
3
+ import json
4
+ from pathlib import Path
5
+
6
+ from rich.console import Console
7
+ from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn
8
+
9
+ from spooling.config import MODEL_PRICING, DEFAULT_PRICING
10
+ from spooling.db import get_connection
11
+ from spooling.parser import ParsedSession
12
+ from spooling.embeddings import embed_texts, chunk_text
13
+ from spooling.providers import get_provider, get_all_providers
14
+ from spooling.tracing import Trace, compute_trace_metrics
15
+
16
+
17
+ def _scrub(s):
18
+ """Strip NUL bytes PostgreSQL text cols can't hold."""
19
+ if s is None:
20
+ return None
21
+ if isinstance(s, str):
22
+ return s.replace("\x00", "")
23
+ return s
24
+
25
+ console = Console()
26
+
27
+
28
+ def _estimate_cost(
29
+ input_tokens: int,
30
+ output_tokens: int,
31
+ model: str | None,
32
+ provider_id: str | None = None,
33
+ ) -> float:
34
+ """Session-level cost using the LiteLLM-backed rate table.
35
+
36
+ Routes through ``spooling.pricing.get_rates`` so non-default providers
37
+ (Gemini, GPT, etc.) get real per-model rates instead of falling
38
+ through the fallback default. Passing
39
+ ``provider_id`` lets the pricing layer substitute that provider's
40
+ default model when the parser couldn't capture one (e.g. Kiro's
41
+ ``auto``, or Copilot sessions that don't expose the model).
42
+ """
43
+ from spooling.pricing import get_rates
44
+ rates = get_rates(model, provider_id=provider_id)
45
+ return rates.cost(input_tokens=input_tokens, output_tokens=output_tokens)
46
+
47
+
48
+ def _get_synced_files(conn) -> dict[str, int]:
49
+ """Get map of file_path -> last_size for already-synced files."""
50
+ rows = conn.execute("SELECT file_path, last_size FROM sync_state").fetchall()
51
+ return {r["file_path"]: r["last_size"] for r in rows}
52
+
53
+
54
+ def _mark_synced(conn, file_path: str, size: int, provider_id: str = "jsonl-session"):
55
+ conn.execute(
56
+ "INSERT INTO sync_state (file_path, last_size, provider_id) VALUES (%s, %s, %s) "
57
+ "ON CONFLICT (file_path) DO UPDATE SET last_size = %s, provider_id = %s, last_synced_at = now()",
58
+ (file_path, size, provider_id, size, provider_id),
59
+ )
60
+
61
+
62
+ def _store_session(conn, session: ParsedSession):
63
+ """Store a parsed session and its messages."""
64
+ cost = _estimate_cost(
65
+ session.estimated_input_tokens,
66
+ session.estimated_output_tokens,
67
+ session.model,
68
+ provider_id=session.provider_id,
69
+ )
70
+
71
+ conn.execute(
72
+ """INSERT INTO sessions (id, provider_id, project, cwd, git_branch, started_at, ended_at,
73
+ message_count, tool_call_count, estimated_input_tokens, estimated_output_tokens,
74
+ estimated_cost_usd, agent_version, model, title)
75
+ VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
76
+ ON CONFLICT (id) DO UPDATE SET
77
+ provider_id = EXCLUDED.provider_id,
78
+ message_count = EXCLUDED.message_count,
79
+ tool_call_count = EXCLUDED.tool_call_count,
80
+ estimated_input_tokens = EXCLUDED.estimated_input_tokens,
81
+ estimated_output_tokens = EXCLUDED.estimated_output_tokens,
82
+ estimated_cost_usd = EXCLUDED.estimated_cost_usd,
83
+ ended_at = EXCLUDED.ended_at,
84
+ model = EXCLUDED.model,
85
+ title = EXCLUDED.title""",
86
+ (
87
+ session.session_id, session.provider_id, _scrub(session.project), _scrub(session.cwd),
88
+ _scrub(session.git_branch), session.started_at, session.ended_at, session.message_count,
89
+ session.tool_call_count, session.estimated_input_tokens,
90
+ session.estimated_output_tokens, cost, _scrub(session.agent_version),
91
+ _scrub(session.model), _scrub(session.title),
92
+ ),
93
+ )
94
+
95
+ # Upsert messages
96
+ for msg in session.messages:
97
+ if not msg.uuid:
98
+ continue
99
+ conn.execute(
100
+ """INSERT INTO messages (id, session_id, role, content, timestamp, tools_used,
101
+ cwd, git_branch, estimated_tokens)
102
+ VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
103
+ ON CONFLICT (id) DO NOTHING""",
104
+ (
105
+ msg.uuid, session.session_id, msg.role, _scrub(msg.content),
106
+ msg.timestamp, json.dumps(msg.tools_used), _scrub(msg.cwd),
107
+ _scrub(msg.git_branch), msg.estimated_tokens,
108
+ ),
109
+ )
110
+
111
+ # Store tool calls (with rich details when available).
112
+ if getattr(msg, "tool_details", None):
113
+ for td in msg.tool_details:
114
+ conn.execute(
115
+ """INSERT INTO tool_calls (session_id, message_id, tool_name, tool_input, tool_result_preview, timestamp)
116
+ VALUES (%s, %s, %s, %s, %s, %s)""",
117
+ (
118
+ session.session_id, msg.uuid, _scrub(td.name),
119
+ _scrub(td.input_summary),
120
+ _scrub(td.result_preview) or None,
121
+ msg.timestamp,
122
+ ),
123
+ )
124
+ else:
125
+ for tool_name in msg.tools_used:
126
+ conn.execute(
127
+ """INSERT INTO tool_calls (session_id, message_id, tool_name, timestamp)
128
+ VALUES (%s, %s, %s, %s)""",
129
+ (session.session_id, msg.uuid, _scrub(tool_name), msg.timestamp),
130
+ )
131
+
132
+
133
+ def _store_trace(conn, trace: Trace):
134
+ """Persist a Trace + its spans + span_events.
135
+
136
+ Idempotent on trace_id: we DELETE existing spans/events for this trace
137
+ (FK CASCADE handles span_events + evals rows), then re-insert. The
138
+ traces row is upserted so aggregate metrics stay fresh.
139
+ """
140
+ if trace is None or not trace.spans:
141
+ return
142
+
143
+ m = compute_trace_metrics(trace)
144
+
145
+ # Delete existing spans (cascades to span_events and evals).
146
+ conn.execute("DELETE FROM spans WHERE trace_id = %s", (trace.id,))
147
+
148
+ # Upsert the trace row.
149
+ conn.execute(
150
+ """INSERT INTO traces (
151
+ id, session_id, provider_id, project, title,
152
+ started_at, ended_at, duration_ms,
153
+ span_count, agent_count, tool_count, llm_count, error_count,
154
+ total_input_tokens, total_output_tokens,
155
+ total_cache_read_tokens, total_cache_write_tokens,
156
+ total_cost_usd, cwd, git_branch, model,
157
+ vendor_count, top_vendors, attrs
158
+ ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s,
159
+ %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
160
+ ON CONFLICT (id) DO UPDATE SET
161
+ session_id = EXCLUDED.session_id,
162
+ provider_id = EXCLUDED.provider_id,
163
+ project = EXCLUDED.project,
164
+ title = EXCLUDED.title,
165
+ started_at = EXCLUDED.started_at,
166
+ ended_at = EXCLUDED.ended_at,
167
+ duration_ms = EXCLUDED.duration_ms,
168
+ span_count = EXCLUDED.span_count,
169
+ agent_count = EXCLUDED.agent_count,
170
+ tool_count = EXCLUDED.tool_count,
171
+ llm_count = EXCLUDED.llm_count,
172
+ error_count = EXCLUDED.error_count,
173
+ total_input_tokens = EXCLUDED.total_input_tokens,
174
+ total_output_tokens = EXCLUDED.total_output_tokens,
175
+ total_cache_read_tokens = EXCLUDED.total_cache_read_tokens,
176
+ total_cache_write_tokens = EXCLUDED.total_cache_write_tokens,
177
+ total_cost_usd = EXCLUDED.total_cost_usd,
178
+ cwd = EXCLUDED.cwd,
179
+ git_branch = EXCLUDED.git_branch,
180
+ model = EXCLUDED.model,
181
+ vendor_count = EXCLUDED.vendor_count,
182
+ top_vendors = EXCLUDED.top_vendors,
183
+ attrs = EXCLUDED.attrs
184
+ """,
185
+ (
186
+ trace.id, trace.session_id, trace.provider_id,
187
+ _scrub(trace.project), _scrub(trace.title),
188
+ trace.started_at, trace.ended_at, trace.duration_ms,
189
+ m["span_count"], m["agent_count"], m["tool_count"], m["llm_count"], m["error_count"],
190
+ m["input_tokens"], m["output_tokens"],
191
+ m["cache_read_tokens"], m["cache_write_tokens"],
192
+ m["cost_usd"], _scrub(trace.cwd), _scrub(trace.git_branch), _scrub(trace.model),
193
+ m["vendor_count"], json.dumps(m["top_vendors"]),
194
+ json.dumps(trace.attrs or {}, default=str),
195
+ ),
196
+ )
197
+
198
+ # Insert spans in sequence order so parent rows always exist first.
199
+ for span in sorted(trace.spans, key=lambda s: s.sequence):
200
+ conn.execute(
201
+ """INSERT INTO spans (
202
+ id, trace_id, parent_id, kind, name, status,
203
+ started_at, ended_at, duration_ms, depth, sequence,
204
+ input_tokens, output_tokens, cache_read_tokens, cache_write_tokens,
205
+ cost_usd, model, tool_name, tool_input, tool_output, tool_is_error,
206
+ agent_type, agent_prompt, vendor, category, attrs
207
+ ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s,
208
+ %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
209
+ """,
210
+ (
211
+ span.id, span.trace_id, span.parent_id, span.kind.value, _scrub(span.name), span.status.value,
212
+ span.started_at, span.ended_at, span.duration_ms, span.depth, span.sequence,
213
+ span.input_tokens, span.output_tokens, span.cache_read_tokens, span.cache_write_tokens,
214
+ span.cost_usd, span.model, _scrub(span.tool_name),
215
+ json.dumps(span.tool_input, default=str) if span.tool_input is not None else None,
216
+ _scrub(span.tool_output), span.tool_is_error,
217
+ _scrub(span.agent_type), _scrub(span.agent_prompt),
218
+ _scrub(span.vendor), _scrub(span.category),
219
+ json.dumps(span.attrs or {}, default=str),
220
+ ),
221
+ )
222
+
223
+ for ev in span.events:
224
+ conn.execute(
225
+ """INSERT INTO span_events (span_id, trace_id, name, timestamp, attrs)
226
+ VALUES (%s, %s, %s, %s, %s)""",
227
+ (span.id, trace.id, ev.name, ev.timestamp, json.dumps(ev.attrs or {})),
228
+ )
229
+
230
+ # After spans land, sync the session's headline cost to the sum of its
231
+ # llm_call span costs. Span cost is computed from real per-turn usage
232
+ # (input/output/cache tokens) priced via the LiteLLM rate table, and is
233
+ # the right number for "what would this workload cost on the API at list
234
+ # price." Only override when we actually have real LLM call cost data;
235
+ # providers without trace-level usage (Gemini Code Assist webview,
236
+ # Kiro, etc., where llm_call costs come from message-char estimates)
237
+ # still benefit, but sessions with zero llm_call cost keep their
238
+ # existing chars/4 estimate untouched.
239
+ if trace.session_id:
240
+ conn.execute(
241
+ """UPDATE sessions ss
242
+ SET estimated_cost_usd = sub.span_cost
243
+ FROM (
244
+ SELECT t.session_id, SUM(s.cost_usd)::numeric(10, 4) AS span_cost
245
+ FROM spans s JOIN traces t ON s.trace_id = t.id
246
+ WHERE t.session_id = %s
247
+ AND s.kind = 'llm_call'
248
+ AND s.cost_usd IS NOT NULL
249
+ AND s.cost_usd > 0
250
+ AND s.model IS NOT NULL
251
+ AND s.model <> '<synthetic>'
252
+ GROUP BY t.session_id
253
+ ) sub
254
+ WHERE ss.id = sub.session_id AND sub.span_cost > 0""",
255
+ (trace.session_id,),
256
+ )
257
+
258
+
259
+ def _embed_session(conn, session: ParsedSession):
260
+ """Chunk and embed session messages into pgvector."""
261
+ # Delete existing chunks for this session (re-embed on update)
262
+ conn.execute("DELETE FROM chunks WHERE session_id = %s", (session.session_id,))
263
+
264
+ all_chunks = []
265
+ chunk_meta = []
266
+
267
+ for msg in session.messages:
268
+ if not msg.content.strip():
269
+ continue
270
+ chunks = chunk_text(msg.content)
271
+ for chunk in chunks:
272
+ all_chunks.append(chunk)
273
+ chunk_meta.append({
274
+ "session_id": session.session_id,
275
+ "message_id": msg.uuid,
276
+ "role": msg.role,
277
+ "project": session.project,
278
+ "timestamp": msg.timestamp,
279
+ })
280
+
281
+ if not all_chunks:
282
+ return 0
283
+
284
+ # Batch embed
285
+ vectors = embed_texts(all_chunks)
286
+
287
+ for chunk, vec, meta in zip(all_chunks, vectors, chunk_meta):
288
+ conn.execute(
289
+ """INSERT INTO chunks (session_id, message_id, content, role, project, timestamp, embedding)
290
+ VALUES (%s, %s, %s, %s, %s, %s, %s::vector)""",
291
+ (
292
+ meta["session_id"], meta["message_id"], chunk,
293
+ meta["role"], meta["project"], meta["timestamp"],
294
+ str(vec),
295
+ ),
296
+ )
297
+
298
+ return len(all_chunks)
299
+
300
+
301
+ def _get_connected_providers(conn) -> list[dict]:
302
+ """Get all connected providers from the database."""
303
+ rows = conn.execute(
304
+ "SELECT id, type, data_path, config FROM providers WHERE status = 'connected' AND type != 'agent'"
305
+ ).fetchall()
306
+ return [dict(r) for r in rows]
307
+
308
+
309
+ def _sync_remote_provider(conn, provider, prov_info: dict, embed: bool) -> tuple[int, int, int]:
310
+ """Sync one remote provider via iter_sessions; persist cursor back to config."""
311
+ config = prov_info.get("config") or {}
312
+ if isinstance(config, str):
313
+ config = json.loads(config)
314
+ state = config.get("sync_state") or {}
315
+
316
+ total_sessions = 0
317
+ total_messages = 0
318
+ total_chunks = 0
319
+ last_marker = None
320
+
321
+ try:
322
+ for session, marker in provider.iter_sessions(config=config, state=state):
323
+ _store_session(conn, session)
324
+ if session.trace is not None:
325
+ _store_trace(conn, session.trace)
326
+ total_messages += session.message_count
327
+ total_sessions += 1
328
+ if embed:
329
+ total_chunks += _embed_session(conn, session)
330
+ last_marker = marker
331
+ # Advance the cursor opportunistically so a crash mid-sync
332
+ # still makes progress on the next run.
333
+ if marker.get("kind") == "remote" and marker.get("cursor"):
334
+ state["updated_after"] = marker["cursor"]
335
+ config["sync_state"] = state
336
+ conn.execute(
337
+ "UPDATE providers SET config = %s WHERE id = %s",
338
+ (json.dumps(config), prov_info["id"]),
339
+ )
340
+ conn.commit()
341
+ except Exception as e:
342
+ console.print(f"[red]Remote sync failed for {provider.name}: {e}[/red]")
343
+
344
+ return total_sessions, total_messages, total_chunks
345
+
346
+
347
+ def sync(embed: bool = True, provider_filter: str | None = None):
348
+ """Sync sessions from all connected providers to the database."""
349
+ conn = get_connection()
350
+ connected = _get_connected_providers(conn)
351
+
352
+ # Always supplement the DB-connected list with any provider that
353
+ # `spooling init` would detect as locally available but doesn't yet have a
354
+ # row in the providers table. Skip providers that DO have a row but
355
+ # aren't 'connected' (user explicitly disabled). Lets a fresh install
356
+ # of e.g. Kiro start syncing on the next `spooling sync` without forcing
357
+ # the user back through the Connections page.
358
+ known_types = {row["type"] for row in conn.execute("SELECT type FROM providers").fetchall()}
359
+ from spooling.providers import get_all_providers
360
+ for type_id, prov in get_all_providers().items():
361
+ if type_id in known_types:
362
+ continue
363
+ if prov.is_available():
364
+ connected.append({
365
+ "id": type_id,
366
+ "type": type_id,
367
+ "data_path": str(prov.resolved_data_path()),
368
+ })
369
+
370
+ if provider_filter:
371
+ connected = [p for p in connected if p["type"] == provider_filter]
372
+
373
+ if not connected:
374
+ console.print("[yellow]No providers to sync. Run 'spooling init' to see what's detected locally, or connect one via the UI.[/yellow]")
375
+ conn.close()
376
+ return
377
+
378
+ synced = _get_synced_files(conn)
379
+ grand_total_sessions = 0
380
+ grand_total_messages = 0
381
+ grand_total_chunks = 0
382
+
383
+ for prov_info in connected:
384
+ provider = get_provider(prov_info["type"])
385
+ if not provider:
386
+ console.print(f"[yellow]Unknown provider type: {prov_info['type']}[/yellow]")
387
+ continue
388
+
389
+ # Remote providers (GitLab, etc.) don't have files on disk —
390
+ # delegate to iter_sessions and persist the cursor cleanly.
391
+ if provider.is_remote:
392
+ console.print(f"[bold]{provider.name}:[/bold] Fetching from API...")
393
+ ns, nm, nc = _sync_remote_provider(conn, provider, prov_info, embed)
394
+ if ns:
395
+ console.print(
396
+ f" [green]Synced {ns} sessions, {nm} messages, "
397
+ f"{nc} chunks embedded.[/green]"
398
+ )
399
+ conn.execute(
400
+ """UPDATE providers SET
401
+ session_count = (SELECT COUNT(*) FROM sessions WHERE provider_id = %s),
402
+ last_synced_at = now()
403
+ WHERE id = %s""",
404
+ (prov_info["id"], prov_info["id"]),
405
+ )
406
+ conn.commit()
407
+ grand_total_sessions += ns
408
+ grand_total_messages += nm
409
+ grand_total_chunks += nc
410
+ continue
411
+
412
+ # Use custom data_path if set, otherwise default
413
+ data_path = None
414
+ if prov_info.get("data_path"):
415
+ expanded = Path(prov_info["data_path"]).expanduser()
416
+ if expanded.exists():
417
+ data_path = expanded
418
+
419
+ files = provider.discover_session_files(data_path)
420
+ if not files:
421
+ continue
422
+
423
+ # Filter to new or changed files
424
+ to_process = []
425
+ for f in files:
426
+ size = f.stat().st_size
427
+ if str(f) not in synced or synced[str(f)] != size:
428
+ to_process.append(f)
429
+
430
+ if not to_process:
431
+ continue
432
+
433
+ console.print(
434
+ f"[bold]{provider.name}:[/bold] Found {len(to_process)} new/updated session files."
435
+ )
436
+
437
+ total_messages = 0
438
+ total_chunks = 0
439
+ total_sessions = 0
440
+
441
+ with Progress(
442
+ SpinnerColumn(),
443
+ TextColumn("[progress.description]{task.description}"),
444
+ BarColumn(),
445
+ TextColumn("{task.completed}/{task.total}"),
446
+ console=console,
447
+ ) as progress:
448
+ task = progress.add_task(f"Syncing {provider.name}...", total=len(to_process))
449
+
450
+ for f in to_process:
451
+ sessions = provider.parse_session_file(f)
452
+ for session in sessions:
453
+ _store_session(conn, session)
454
+ if session.trace is not None:
455
+ _store_trace(conn, session.trace)
456
+ total_messages += session.message_count
457
+ total_sessions += 1
458
+
459
+ if embed:
460
+ chunks = _embed_session(conn, session)
461
+ total_chunks += chunks
462
+
463
+ _mark_synced(conn, str(f), f.stat().st_size, prov_info["type"])
464
+ conn.commit()
465
+ progress.advance(task)
466
+
467
+ # Update provider stats
468
+ conn.execute(
469
+ """UPDATE providers SET
470
+ session_count = (SELECT COUNT(*) FROM sessions WHERE provider_id = %s),
471
+ last_synced_at = now()
472
+ WHERE id = %s""",
473
+ (prov_info["id"], prov_info["id"]),
474
+ )
475
+ conn.commit()
476
+
477
+ grand_total_sessions += total_sessions
478
+ grand_total_messages += total_messages
479
+ grand_total_chunks += total_chunks
480
+
481
+ console.print(
482
+ f" [green]Synced {total_sessions} sessions, "
483
+ f"{total_messages} messages, "
484
+ f"{total_chunks} chunks embedded.[/green]"
485
+ )
486
+
487
+ conn.close()
488
+
489
+ if grand_total_sessions == 0:
490
+ console.print("[green]All sessions already synced.[/green]")
491
+ else:
492
+ console.print(
493
+ f"\n[green]Total: {grand_total_sessions} sessions, "
494
+ f"{grand_total_messages} messages, "
495
+ f"{grand_total_chunks} chunks embedded.[/green]"
496
+ )