coding-agent-roi 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,542 @@
1
+ """SQLite storage for interactions, with upsert + topic aggregation.
2
+
3
+ Local-first: a single SQLite file holds every collected interaction. Writes are
4
+ idempotent on the interaction ``id`` so re-running ingest never double-counts.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from collections.abc import Iterable
10
+ from dataclasses import dataclass
11
+ from datetime import datetime
12
+ from pathlib import Path
13
+ from typing import Any
14
+
15
+ from sqlalchemy import String, create_engine, func, select
16
+ from sqlalchemy.dialects.sqlite import insert as sqlite_insert
17
+ from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column
18
+
19
+ from agent_roi.core.models import (
20
+ Interaction,
21
+ InteractionView,
22
+ Rollup,
23
+ SessionDetail,
24
+ SessionSummary,
25
+ TimeSeriesBundle,
26
+ TimeSeriesPoint,
27
+ TimeSeriesSplitRow,
28
+ TopicBreakdown,
29
+ )
30
+ from agent_roi.core.pricing import cost_of
31
+
32
+
33
+ @dataclass
34
+ class UnclassifiedSession:
35
+ """A session awaiting classification, with a combined summary to label."""
36
+
37
+ session_id: str
38
+ project: str
39
+ summary: str
40
+
41
+
42
+ class Base(DeclarativeBase):
43
+ pass
44
+
45
+
46
+ class InteractionRow(Base):
47
+ __tablename__ = "interactions"
48
+
49
+ id: Mapped[str] = mapped_column(String, primary_key=True)
50
+ tool: Mapped[str] = mapped_column(String, index=True)
51
+ session_id: Mapped[str] = mapped_column(String, index=True)
52
+ timestamp: Mapped[datetime] = mapped_column(index=True)
53
+ model: Mapped[str] = mapped_column(String, index=True)
54
+ input_tokens: Mapped[int] = mapped_column(default=0)
55
+ output_tokens: Mapped[int] = mapped_column(default=0)
56
+ cache_read_tokens: Mapped[int] = mapped_column(default=0)
57
+ cache_write_tokens: Mapped[int] = mapped_column(default=0)
58
+ cwd: Mapped[str] = mapped_column(String, default="")
59
+ project: Mapped[str] = mapped_column(String, default="", index=True)
60
+ summary: Mapped[str] = mapped_column(String, default="")
61
+ topic: Mapped[str | None] = mapped_column(String, nullable=True, index=True)
62
+ cost_usd: Mapped[float] = mapped_column(default=0.0)
63
+ estimated: Mapped[bool] = mapped_column(default=False)
64
+
65
+
66
+ class Database:
67
+ def __init__(self, path: Path) -> None:
68
+ path.parent.mkdir(parents=True, exist_ok=True)
69
+ self.engine = create_engine(f"sqlite:///{path}")
70
+ Base.metadata.create_all(self.engine)
71
+ self._migrate()
72
+
73
+ def _migrate(self) -> None:
74
+ """Add columns introduced after a database was first created.
75
+
76
+ ``create_all`` only creates missing *tables*, never missing *columns*, so
77
+ a database from an older version is missing columns added later. We patch
78
+ them in with ``ALTER TABLE`` (SQLite supports adding columns cheaply).
79
+ """
80
+ expected = {
81
+ "estimated": "BOOLEAN DEFAULT 0",
82
+ "cwd": "TEXT DEFAULT ''",
83
+ "project": "TEXT DEFAULT ''",
84
+ }
85
+ with self.engine.begin() as conn:
86
+ rows = conn.exec_driver_sql("PRAGMA table_info(interactions)").fetchall()
87
+ existing = {row[1] for row in rows}
88
+ for column, ddl in expected.items():
89
+ if column not in existing:
90
+ conn.exec_driver_sql(
91
+ f"ALTER TABLE interactions ADD COLUMN {column} {ddl}"
92
+ )
93
+
94
+ def upsert_many(self, interactions: Iterable[Interaction]) -> int:
95
+ """Insert or update interactions. Returns the number processed.
96
+
97
+ Existing rows keep their ``topic`` unless the incoming row has one, so a
98
+ re-ingest does not wipe classifications.
99
+ """
100
+ count = 0
101
+ with Session(self.engine) as session:
102
+ for itx in interactions:
103
+ values = {
104
+ "id": itx.id,
105
+ "tool": itx.tool.value,
106
+ "session_id": itx.session_id,
107
+ "timestamp": itx.timestamp,
108
+ "model": itx.model,
109
+ "input_tokens": itx.input_tokens,
110
+ "output_tokens": itx.output_tokens,
111
+ "cache_read_tokens": itx.cache_read_tokens,
112
+ "cache_write_tokens": itx.cache_write_tokens,
113
+ "cwd": itx.cwd,
114
+ "project": itx.project,
115
+ "summary": itx.summary,
116
+ "topic": itx.topic,
117
+ "cost_usd": cost_of(itx),
118
+ "estimated": itx.estimated,
119
+ }
120
+ stmt = sqlite_insert(InteractionRow).values(**values)
121
+ update_cols = {k: v for k, v in values.items() if k not in ("id", "topic")}
122
+ stmt = stmt.on_conflict_do_update(index_elements=["id"], set_=update_cols)
123
+ session.execute(stmt)
124
+ count += 1
125
+ session.commit()
126
+ return count
127
+
128
+ def unclassified(self, limit: int | None = None) -> list[InteractionRow]:
129
+ with Session(self.engine) as session:
130
+ stmt = select(InteractionRow).where(InteractionRow.topic.is_(None))
131
+ if limit is not None:
132
+ stmt = stmt.limit(limit)
133
+ return list(session.scalars(stmt))
134
+
135
+ def set_topic(self, interaction_id: str, topic: str) -> None:
136
+ with Session(self.engine) as session:
137
+ row = session.get(InteractionRow, interaction_id)
138
+ if row is not None:
139
+ row.topic = topic
140
+ session.commit()
141
+
142
+ def unclassified_sessions(self, limit: int | None = None) -> list[UnclassifiedSession]:
143
+ """Sessions with at least one unclassified interaction (topic IS NULL)."""
144
+ return self._session_docs(only_unclassified=True, limit=limit)
145
+
146
+ def all_sessions(self, limit: int | None = None) -> list[UnclassifiedSession]:
147
+ """Every session, classified or not.
148
+
149
+ Semantic clustering works best when it sees all sessions together, so the
150
+ classifier re-labels the whole corpus rather than only new sessions.
151
+ """
152
+ return self._session_docs(only_unclassified=False, limit=limit)
153
+
154
+ def _session_docs(
155
+ self, only_unclassified: bool, limit: int | None
156
+ ) -> list[UnclassifiedSession]:
157
+ with Session(self.engine) as session:
158
+ stmt = select(InteractionRow).order_by(
159
+ InteractionRow.session_id, InteractionRow.timestamp
160
+ )
161
+ if only_unclassified:
162
+ stmt = stmt.where(InteractionRow.topic.is_(None))
163
+ rows = list(session.scalars(stmt))
164
+
165
+ by_session: dict[str, list[InteractionRow]] = {}
166
+ for row in rows:
167
+ by_session.setdefault(row.session_id, []).append(row)
168
+
169
+ result: list[UnclassifiedSession] = []
170
+ for session_id, items in by_session.items():
171
+ # Build a compact summary from the most informative snippets, capped.
172
+ snippets: list[str] = []
173
+ for it in items:
174
+ if it.summary and it.summary not in snippets:
175
+ snippets.append(it.summary)
176
+ if len(snippets) >= 8:
177
+ break
178
+ project = next((it.project for it in items if it.project), "")
179
+ result.append(
180
+ UnclassifiedSession(
181
+ session_id=session_id,
182
+ project=project,
183
+ summary="\n".join(snippets)[:2000],
184
+ )
185
+ )
186
+ if limit is not None and len(result) >= limit:
187
+ break
188
+ return result
189
+
190
+ def clear_topics(self) -> None:
191
+ """Reset every interaction's topic so the next classify re-labels all."""
192
+ with Session(self.engine) as session:
193
+ for row in session.scalars(select(InteractionRow)):
194
+ row.topic = None
195
+ session.commit()
196
+
197
+ def set_session_topic(self, session_id: str, topic: str) -> int:
198
+ """Apply a topic to every interaction in a session. Returns rows updated."""
199
+ with Session(self.engine) as session:
200
+ rows = list(
201
+ session.scalars(
202
+ select(InteractionRow).where(InteractionRow.session_id == session_id)
203
+ )
204
+ )
205
+ for row in rows:
206
+ row.topic = topic
207
+ session.commit()
208
+ return len(rows)
209
+
210
+ # Columns that can be used as a grouping dimension.
211
+ _DIMENSIONS = {
212
+ "topic": func.coalesce(InteractionRow.topic, "uncategorized"),
213
+ "tool": InteractionRow.tool,
214
+ "model": InteractionRow.model,
215
+ "project": func.coalesce(func.nullif(InteractionRow.project, ""), "unknown"),
216
+ }
217
+
218
+ def rollup(
219
+ self,
220
+ dimension: str = "topic",
221
+ start: datetime | None = None,
222
+ end: datetime | None = None,
223
+ ) -> list[Rollup]:
224
+ """Aggregate usage and cost grouped by ``dimension`` over an optional
225
+ time window. ``dimension`` is one of 'topic', 'tool', 'model'."""
226
+ if dimension not in self._DIMENSIONS:
227
+ raise ValueError(f"Unknown dimension: {dimension!r}")
228
+ key_col = self._DIMENSIONS[dimension]
229
+
230
+ with Session(self.engine) as session:
231
+ stmt: Any = select(
232
+ key_col.label("key"),
233
+ func.count().label("interactions"),
234
+ func.sum(InteractionRow.input_tokens),
235
+ func.sum(InteractionRow.output_tokens),
236
+ func.sum(InteractionRow.cache_read_tokens),
237
+ func.sum(InteractionRow.cache_write_tokens),
238
+ func.sum(InteractionRow.cost_usd),
239
+ func.max(InteractionRow.estimated),
240
+ )
241
+ stmt = _apply_window(stmt, start, end)
242
+ stmt = stmt.group_by(key_col).order_by(func.sum(InteractionRow.cost_usd).desc())
243
+ return [_row_to_rollup(row) for row in session.execute(stmt)]
244
+
245
+ def topic_breakdown(
246
+ self,
247
+ topic: str,
248
+ start: datetime | None = None,
249
+ end: datetime | None = None,
250
+ ) -> TopicBreakdown:
251
+ """For one topic, return its total plus a split by tool and by model."""
252
+ is_uncat = topic == "uncategorized"
253
+ topic_filter = (
254
+ InteractionRow.topic.is_(None) if is_uncat else (InteractionRow.topic == topic)
255
+ )
256
+
257
+ def grouped(key_col: Any) -> list[Rollup]:
258
+ with Session(self.engine) as session:
259
+ stmt: Any = select(
260
+ key_col.label("key"),
261
+ func.count(),
262
+ func.sum(InteractionRow.input_tokens),
263
+ func.sum(InteractionRow.output_tokens),
264
+ func.sum(InteractionRow.cache_read_tokens),
265
+ func.sum(InteractionRow.cache_write_tokens),
266
+ func.sum(InteractionRow.cost_usd),
267
+ func.max(InteractionRow.estimated),
268
+ ).where(topic_filter)
269
+ stmt = _apply_window(stmt, start, end)
270
+ stmt = stmt.group_by(key_col).order_by(func.sum(InteractionRow.cost_usd).desc())
271
+ return [_row_to_rollup(row) for row in session.execute(stmt)]
272
+
273
+ by_tool = grouped(InteractionRow.tool)
274
+ by_model = grouped(InteractionRow.model)
275
+ total = _sum_rollups(topic, by_tool)
276
+ return TopicBreakdown(topic=topic, total=total, by_tool=by_tool, by_model=by_model)
277
+
278
+ def sessions(
279
+ self,
280
+ topic: str | None = None,
281
+ start: datetime | None = None,
282
+ end: datetime | None = None,
283
+ limit: int | None = None,
284
+ ) -> list[SessionSummary]:
285
+ """Aggregate interactions into per-session rows (optionally one topic).
286
+
287
+ This is the middle of the topic -> session -> interaction drill-down: each
288
+ row shows how a single session spent tokens, across which tools/models.
289
+ """
290
+ with Session(self.engine) as session:
291
+ stmt: Any = select(
292
+ InteractionRow.session_id,
293
+ func.coalesce(InteractionRow.topic, "uncategorized"),
294
+ func.max(InteractionRow.project),
295
+ func.group_concat(InteractionRow.tool.distinct()),
296
+ func.group_concat(InteractionRow.model.distinct()),
297
+ func.min(InteractionRow.timestamp),
298
+ func.max(InteractionRow.timestamp),
299
+ func.count(),
300
+ func.sum(InteractionRow.input_tokens),
301
+ func.sum(InteractionRow.output_tokens),
302
+ func.sum(InteractionRow.cache_read_tokens),
303
+ func.sum(InteractionRow.cache_write_tokens),
304
+ func.sum(InteractionRow.cost_usd),
305
+ func.max(InteractionRow.estimated),
306
+ )
307
+ stmt = _apply_window(stmt, start, end)
308
+ if topic is not None:
309
+ stmt = stmt.where(_topic_filter(topic))
310
+ stmt = stmt.group_by(InteractionRow.session_id).order_by(
311
+ func.sum(InteractionRow.cost_usd).desc()
312
+ )
313
+ if limit is not None:
314
+ stmt = stmt.limit(limit)
315
+ return [_row_to_session(row) for row in session.execute(stmt)]
316
+
317
+ def timeseries(
318
+ self,
319
+ start: datetime | None = None,
320
+ end: datetime | None = None,
321
+ granularity: str = "day",
322
+ top_series: int = 8,
323
+ ) -> TimeSeriesBundle:
324
+ """Token/cost buckets plus splits by tool and model."""
325
+ bucket = _timeseries_bucket(granularity)
326
+ totals = self._timeseries_totals(start, end, bucket)
327
+ by_tool, tool_keys = self._timeseries_split(
328
+ InteractionRow.tool, start, end, top_series, bucket
329
+ )
330
+ by_model, model_keys = self._timeseries_split(
331
+ InteractionRow.model, start, end, top_series, bucket
332
+ )
333
+ return TimeSeriesBundle(
334
+ totals=totals,
335
+ by_tool=by_tool,
336
+ by_model=by_model,
337
+ tool_keys=tool_keys,
338
+ model_keys=model_keys,
339
+ )
340
+
341
+ def _timeseries_totals(
342
+ self,
343
+ start: datetime | None,
344
+ end: datetime | None,
345
+ bucket: Any,
346
+ ) -> list[TimeSeriesPoint]:
347
+ period = bucket.label("period")
348
+ with Session(self.engine) as session:
349
+ stmt: Any = select(
350
+ period,
351
+ func.count(),
352
+ func.sum(InteractionRow.input_tokens),
353
+ func.sum(InteractionRow.output_tokens),
354
+ func.sum(InteractionRow.cache_read_tokens),
355
+ func.sum(InteractionRow.cache_write_tokens),
356
+ func.sum(InteractionRow.cost_usd),
357
+ )
358
+ stmt = _apply_window(stmt, start, end)
359
+ stmt = stmt.group_by(period).order_by(period)
360
+ return [
361
+ TimeSeriesPoint(
362
+ date=str(row[0]),
363
+ interactions=row[1],
364
+ input_tokens=row[2] or 0,
365
+ output_tokens=row[3] or 0,
366
+ cache_read_tokens=row[4] or 0,
367
+ cache_write_tokens=row[5] or 0,
368
+ cost_usd=row[6] or 0.0,
369
+ )
370
+ for row in session.execute(stmt)
371
+ ]
372
+
373
+ def _timeseries_split(
374
+ self,
375
+ key_col: Any,
376
+ start: datetime | None,
377
+ end: datetime | None,
378
+ top: int,
379
+ bucket: Any,
380
+ ) -> tuple[list[TimeSeriesSplitRow], list[str]]:
381
+ period = bucket.label("period")
382
+ with Session(self.engine) as session:
383
+ stmt: Any = select(
384
+ period,
385
+ key_col.label("series_key"),
386
+ func.count(),
387
+ func.sum(InteractionRow.input_tokens),
388
+ func.sum(InteractionRow.output_tokens),
389
+ func.sum(InteractionRow.cache_read_tokens),
390
+ func.sum(InteractionRow.cache_write_tokens),
391
+ func.sum(InteractionRow.cost_usd),
392
+ )
393
+ stmt = _apply_window(stmt, start, end)
394
+ stmt = stmt.group_by(period, key_col).order_by(period)
395
+ raw = list(session.execute(stmt))
396
+
397
+ totals_by_key: dict[str, int] = {}
398
+ by_day: dict[str, dict[str, int]] = {}
399
+ meta: dict[str, tuple[int, float]] = {}
400
+ for row in raw:
401
+ d, key = str(row[0]), str(row[1])
402
+ tokens = (row[3] or 0) + (row[4] or 0) + (row[5] or 0) + (row[6] or 0)
403
+ totals_by_key[key] = totals_by_key.get(key, 0) + tokens
404
+ bucket = by_day.setdefault(d, {})
405
+ bucket[key] = bucket.get(key, 0) + tokens
406
+ prev = meta.get(d, (0, 0.0))
407
+ meta[d] = (prev[0] + row[2], prev[1] + (row[7] or 0.0))
408
+
409
+ ranked = sorted(totals_by_key, key=lambda k: totals_by_key[k], reverse=True)
410
+ keep = ranked[:top]
411
+ other_label = "other"
412
+ if len(ranked) > top:
413
+ keep = [*keep, other_label]
414
+
415
+ rows: list[TimeSeriesSplitRow] = []
416
+ for d in sorted(by_day):
417
+ values: dict[str, int] = {}
418
+ overflow = 0
419
+ for key, tokens in by_day[d].items():
420
+ if key in keep and key != other_label:
421
+ values[key] = values.get(key, 0) + tokens
422
+ elif other_label in keep:
423
+ overflow += tokens
424
+ if overflow:
425
+ values[other_label] = overflow
426
+ interactions, cost = meta.get(d, (0, 0.0))
427
+ rows.append(
428
+ TimeSeriesSplitRow(
429
+ date=d,
430
+ values=values,
431
+ interactions=interactions,
432
+ cost_usd=cost,
433
+ )
434
+ )
435
+ return rows, keep
436
+
437
+ def session_detail(self, session_id: str) -> SessionDetail | None:
438
+ """A session's aggregate plus its individual interactions, newest first."""
439
+ summaries = [s for s in self.sessions() if s.session_id == session_id]
440
+ if not summaries:
441
+ return None
442
+ with Session(self.engine) as session:
443
+ rows = list(
444
+ session.scalars(
445
+ select(InteractionRow)
446
+ .where(InteractionRow.session_id == session_id)
447
+ .order_by(InteractionRow.timestamp.desc())
448
+ )
449
+ )
450
+ interactions = [
451
+ InteractionView(
452
+ id=r.id,
453
+ tool=r.tool,
454
+ model=r.model,
455
+ timestamp=r.timestamp,
456
+ input_tokens=r.input_tokens,
457
+ output_tokens=r.output_tokens,
458
+ cache_read_tokens=r.cache_read_tokens,
459
+ cache_write_tokens=r.cache_write_tokens,
460
+ cost_usd=r.cost_usd,
461
+ estimated=r.estimated,
462
+ summary=r.summary,
463
+ )
464
+ for r in rows
465
+ ]
466
+ return SessionDetail(session=summaries[0], interactions=interactions)
467
+
468
+
469
+ def _topic_filter(topic: str) -> Any:
470
+ if topic == "uncategorized":
471
+ return InteractionRow.topic.is_(None)
472
+ return InteractionRow.topic == topic
473
+
474
+
475
+ def _split_concat(value: Any) -> list[str]:
476
+ """Split SQLite group_concat output into a sorted, de-duplicated list."""
477
+ if not value:
478
+ return []
479
+ return sorted({part for part in str(value).split(",") if part})
480
+
481
+
482
+ def _row_to_session(row: Any) -> SessionSummary:
483
+ return SessionSummary(
484
+ session_id=str(row[0]),
485
+ topic=str(row[1]),
486
+ project=str(row[2] or ""),
487
+ tools=_split_concat(row[3]),
488
+ models=_split_concat(row[4]),
489
+ started=row[5],
490
+ ended=row[6],
491
+ interactions=row[7],
492
+ input_tokens=row[8] or 0,
493
+ output_tokens=row[9] or 0,
494
+ cache_read_tokens=row[10] or 0,
495
+ cache_write_tokens=row[11] or 0,
496
+ cost_usd=row[12] or 0.0,
497
+ estimated=bool(row[13]),
498
+ )
499
+
500
+
501
+ def _timeseries_bucket(granularity: str) -> Any:
502
+ if granularity == "week":
503
+ return func.strftime("%Y-W%W", InteractionRow.timestamp)
504
+ if granularity == "month":
505
+ return func.strftime("%Y-%m", InteractionRow.timestamp)
506
+ if granularity != "day":
507
+ raise ValueError(f"Unknown granularity: {granularity!r}")
508
+ return func.strftime("%Y-%m-%d", InteractionRow.timestamp)
509
+
510
+
511
+ def _apply_window(stmt: Any, start: datetime | None, end: datetime | None) -> Any:
512
+ if start is not None:
513
+ stmt = stmt.where(InteractionRow.timestamp >= start)
514
+ if end is not None:
515
+ stmt = stmt.where(InteractionRow.timestamp < end)
516
+ return stmt
517
+
518
+
519
+ def _row_to_rollup(row: Any) -> Rollup:
520
+ return Rollup(
521
+ key=str(row[0]),
522
+ interactions=row[1],
523
+ input_tokens=row[2] or 0,
524
+ output_tokens=row[3] or 0,
525
+ cache_read_tokens=row[4] or 0,
526
+ cache_write_tokens=row[5] or 0,
527
+ cost_usd=row[6] or 0.0,
528
+ estimated=bool(row[7]),
529
+ )
530
+
531
+
532
+ def _sum_rollups(key: str, rollups: list[Rollup]) -> Rollup:
533
+ return Rollup(
534
+ key=key,
535
+ interactions=sum(r.interactions for r in rollups),
536
+ input_tokens=sum(r.input_tokens for r in rollups),
537
+ output_tokens=sum(r.output_tokens for r in rollups),
538
+ cache_read_tokens=sum(r.cache_read_tokens for r in rollups),
539
+ cache_write_tokens=sum(r.cache_write_tokens for r in rollups),
540
+ cost_usd=sum(r.cost_usd for r in rollups),
541
+ estimated=any(r.estimated for r in rollups),
542
+ )