ursa-ai 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. ursa/__init__.py +3 -0
  2. ursa/agents/__init__.py +32 -0
  3. ursa/agents/acquisition_agents.py +812 -0
  4. ursa/agents/arxiv_agent.py +429 -0
  5. ursa/agents/base.py +728 -0
  6. ursa/agents/chat_agent.py +60 -0
  7. ursa/agents/code_review_agent.py +341 -0
  8. ursa/agents/execution_agent.py +915 -0
  9. ursa/agents/hypothesizer_agent.py +614 -0
  10. ursa/agents/lammps_agent.py +465 -0
  11. ursa/agents/mp_agent.py +204 -0
  12. ursa/agents/optimization_agent.py +410 -0
  13. ursa/agents/planning_agent.py +219 -0
  14. ursa/agents/rag_agent.py +304 -0
  15. ursa/agents/recall_agent.py +54 -0
  16. ursa/agents/websearch_agent.py +196 -0
  17. ursa/cli/__init__.py +363 -0
  18. ursa/cli/hitl.py +516 -0
  19. ursa/cli/hitl_api.py +75 -0
  20. ursa/observability/metrics_charts.py +1279 -0
  21. ursa/observability/metrics_io.py +11 -0
  22. ursa/observability/metrics_session.py +750 -0
  23. ursa/observability/pricing.json +97 -0
  24. ursa/observability/pricing.py +321 -0
  25. ursa/observability/timing.py +1466 -0
  26. ursa/prompt_library/__init__.py +0 -0
  27. ursa/prompt_library/code_review_prompts.py +51 -0
  28. ursa/prompt_library/execution_prompts.py +50 -0
  29. ursa/prompt_library/hypothesizer_prompts.py +17 -0
  30. ursa/prompt_library/literature_prompts.py +11 -0
  31. ursa/prompt_library/optimization_prompts.py +131 -0
  32. ursa/prompt_library/planning_prompts.py +79 -0
  33. ursa/prompt_library/websearch_prompts.py +131 -0
  34. ursa/tools/__init__.py +0 -0
  35. ursa/tools/feasibility_checker.py +114 -0
  36. ursa/tools/feasibility_tools.py +1075 -0
  37. ursa/tools/run_command.py +27 -0
  38. ursa/tools/write_code.py +42 -0
  39. ursa/util/__init__.py +0 -0
  40. ursa/util/diff_renderer.py +128 -0
  41. ursa/util/helperFunctions.py +142 -0
  42. ursa/util/logo_generator.py +625 -0
  43. ursa/util/memory_logger.py +183 -0
  44. ursa/util/optimization_schema.py +78 -0
  45. ursa/util/parse.py +405 -0
  46. ursa_ai-0.9.1.dist-info/METADATA +304 -0
  47. ursa_ai-0.9.1.dist-info/RECORD +51 -0
  48. ursa_ai-0.9.1.dist-info/WHEEL +5 -0
  49. ursa_ai-0.9.1.dist-info/entry_points.txt +2 -0
  50. ursa_ai-0.9.1.dist-info/licenses/LICENSE +8 -0
  51. ursa_ai-0.9.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,750 @@
1
+ # metrics_session.py
2
+ from __future__ import annotations
3
+
4
+ import json
5
+ import os
6
+ import re
7
+ from dataclasses import dataclass
8
+ from datetime import datetime, timezone
9
+ from typing import Dict, List, Tuple
10
+
11
+ import matplotlib
12
+
13
+ matplotlib.use("Agg")
14
+ from collections import defaultdict
15
+
16
+ import matplotlib.dates as mdates
17
+ import matplotlib.pyplot as plt
18
+ import pandas as pd
19
+
20
+ from ursa.observability.metrics_charts import (
21
+ compute_attribution,
22
+ extract_llm_token_stats, # reuse per-file tokens
23
+ extract_time_breakdown, # reuse per-file extraction
24
+ )
25
+ from ursa.observability.metrics_io import load_metrics
26
+
27
+
28
+ def _dt(x: str) -> datetime:
29
+ return datetime.fromisoformat(str(x).replace("Z", "+00:00"))
30
+
31
+
32
+ # Compact, predictable layout
33
+ _LAYOUT = dict(
34
+ header_y=0.965,
35
+ subtitle_y=0.915,
36
+ legend_y=0.885, # legend sits just under the subtitle
37
+ ax_rect=(0.10, 0.28, 0.86, 0.58), # [left, bottom, width, height]
38
+ footer1_y=0.105,
39
+ footer2_y=0.070,
40
+ )
41
+
42
+ ISO_RE = re.compile(r"Z$")
43
+
44
+
45
+ def _parse_iso(ts: str | None) -> datetime | None:
46
+ if not ts:
47
+ return None
48
+ try:
49
+ return datetime.fromisoformat(ISO_RE.sub("+00:00", ts)).astimezone(
50
+ timezone.utc
51
+ )
52
+ except Exception:
53
+ return None
54
+
55
+
56
+ def _fmt_iso_pretty(ts: str | None) -> str:
57
+ dt = _parse_iso(ts)
58
+ return dt.strftime("%Y-%m-%d %H:%M:%S UTC") if dt else (ts or "")
59
+
60
+
61
+ @dataclass
62
+ class RunRecord:
63
+ path: str
64
+ agent: str
65
+ thread_id: str
66
+ run_id: str
67
+ started_at: datetime
68
+ ended_at: datetime
69
+
70
+ @property
71
+ def duration_s(self) -> float:
72
+ return max(0.0, (self.ended_at - self.started_at).total_seconds())
73
+
74
+
75
+ # -------------------------------
76
+ # Directory scan
77
+ # -------------------------------
78
+ def scan_directory_for_threads(dir_path: str) -> Dict[str, List[RunRecord]]:
79
+ """
80
+ Scan a directory for metrics JSONs and group them by thread_id.
81
+ Returns: {thread_id: [RunRecord, ...]}
82
+ """
83
+ sessions: Dict[str, List[RunRecord]] = {}
84
+ for name in sorted(os.listdir(dir_path)):
85
+ if not name.lower().endswith(".json"):
86
+ continue
87
+ fp = os.path.join(dir_path, name)
88
+ try:
89
+ with open(fp, "r", encoding="utf-8") as f:
90
+ payload = json.load(f)
91
+ ctx = payload.get("context") or {}
92
+ agent = str(ctx.get("agent") or "")
93
+ thread_id = str(ctx.get("thread_id") or "")
94
+ run_id = str(ctx.get("run_id") or "")
95
+ s = _parse_iso(ctx.get("started_at"))
96
+ e = _parse_iso(ctx.get("ended_at"))
97
+ if not (thread_id and agent and run_id and s and e):
98
+ continue
99
+ rec = RunRecord(
100
+ path=fp,
101
+ agent=agent,
102
+ thread_id=thread_id,
103
+ run_id=run_id,
104
+ started_at=s,
105
+ ended_at=e,
106
+ )
107
+ sessions.setdefault(thread_id, []).append(rec)
108
+ except Exception:
109
+ continue
110
+
111
+ for tid, runs in sessions.items():
112
+ runs.sort(key=lambda r: r.started_at)
113
+ return sessions
114
+
115
+
116
+ def list_threads_summary(
117
+ sessions: Dict[str, List[RunRecord]],
118
+ ) -> List[Tuple[str, int]]:
119
+ """Return [(thread_id, count)] sorted by count desc, then thread_id."""
120
+ out = [(tid, len(runs)) for tid, runs in sessions.items()]
121
+ out.sort(key=lambda t: (-t[1], t[0]))
122
+ return out
123
+
124
+
125
+ # -------------------------------
126
+ # Timeline plot
127
+ # -------------------------------
128
+
129
+
130
+ def _abbrev_agent(agent: str) -> str:
131
+ # Common cleanup: drop trailing 'Agent', trim to a tidy length
132
+ base = agent[:-5] if agent.endswith("Agent") else agent
133
+ base = base.strip()
134
+ return base if len(base) <= 10 else (base[:9] + "…")
135
+
136
+
137
+ def _label_for_run(r: RunRecord, seconds_precision: int = 1) -> str:
138
+ return f"{_abbrev_agent(r.agent)}:{r.run_id[:6]} ({r.duration_s:.{seconds_precision}f}s)"
139
+
140
+
141
+ def plot_thread_timeline(
142
+ runs: List[RunRecord],
143
+ out_path: str,
144
+ *,
145
+ title: str = "Agent runs timeline",
146
+ stacked_rows: bool = False, # False = single lane (“flow”); True = one row per run
147
+ min_label_sec: float = 0.50, # skip labels for bars shorter than this
148
+ ) -> str:
149
+ """
150
+ Draw a time timeline for a single thread_id across multiple agent runs.
151
+
152
+ Fixes:
153
+ • Smarter labels: shortened text + alternating vertical offset to avoid overlap
154
+ • Exact bounds: xlim is precisely first start → last end
155
+ • Legend outside bars: compact legend drawn above the axes
156
+ """
157
+ if not runs:
158
+ raise ValueError("No runs provided for timeline")
159
+
160
+ # timeline bounds (exact)
161
+ t0 = min(r.started_at for r in runs)
162
+ t1 = max(r.ended_at for r in runs)
163
+
164
+ thread_id = runs[0].thread_id
165
+ header = f"Thread: {thread_id}"
166
+ subtitle = title
167
+ footer1 = f"runs: {len(runs)}"
168
+ footer2 = f"{t0.strftime('%Y-%m-%d %H:%M:%S UTC')} \u2192 {t1.strftime('%Y-%m-%d %H:%M:%S UTC')}"
169
+
170
+ # Map agents to stable colors
171
+ agents = [r.agent for r in runs]
172
+ uniq_agents: List[str] = []
173
+ for a in agents:
174
+ if a not in uniq_agents:
175
+ uniq_agents.append(a)
176
+ color_cycle = plt.rcParams["axes.prop_cycle"].by_key().get("color", [])
177
+ color_map = {
178
+ a: color_cycle[i % max(1, len(color_cycle))]
179
+ for i, a in enumerate(uniq_agents)
180
+ }
181
+
182
+ if stacked_rows:
183
+ fig_h = max(2.5, 0.35 * len(runs))
184
+ fig = plt.figure(figsize=(10, fig_h))
185
+ ax = fig.add_axes(list(_LAYOUT["ax_rect"]))
186
+ bar_h = 0.8
187
+ y_positions = list(range(len(runs)))
188
+ for yi, r in zip(y_positions, runs):
189
+ xs = mdates.date2num(r.started_at)
190
+ w = mdates.date2num(r.ended_at) - xs
191
+ ax.barh(
192
+ yi,
193
+ w,
194
+ left=xs,
195
+ height=bar_h,
196
+ edgecolor="black",
197
+ color=color_map.get(r.agent),
198
+ )
199
+ if r.duration_s >= min_label_sec:
200
+ ax.text(
201
+ xs + w / 2,
202
+ yi,
203
+ _label_for_run(r, 1),
204
+ ha="center",
205
+ va="center",
206
+ fontsize="small",
207
+ )
208
+ ax.set_yticks(y_positions)
209
+ ax.set_yticklabels([
210
+ f"{_abbrev_agent(r.agent)}:{r.run_id[:6]}" for r in runs
211
+ ])
212
+ ax.invert_yaxis()
213
+ else:
214
+ # Single lane “flow” with alternating label bands
215
+ fig = plt.figure(figsize=(10, 2.4))
216
+ ax = fig.add_axes(list(_LAYOUT["ax_rect"]))
217
+ lane_y, lane_h = 0.0, 1.0
218
+
219
+ segs, colors = [], []
220
+ for r in runs:
221
+ xs = mdates.date2num(r.started_at)
222
+ w = mdates.date2num(r.ended_at) - xs
223
+ segs.append((xs, w))
224
+ colors.append(color_map.get(r.agent))
225
+ ax.broken_barh(
226
+ segs, (lane_y, lane_h), facecolors=colors, edgecolor="black"
227
+ )
228
+
229
+ # Alternate label positions: upper/lower halves of the lane
230
+ upper_y = lane_y + lane_h * 0.72
231
+ lower_y = lane_y + lane_h * 0.28
232
+ for idx, ((xs, w), r) in enumerate(zip(segs, runs)):
233
+ if r.duration_s < min_label_sec:
234
+ continue
235
+ y = upper_y if (idx % 2 == 0) else lower_y
236
+ ax.text(
237
+ xs + w / 2,
238
+ y,
239
+ _label_for_run(r, 1),
240
+ ha="center",
241
+ va="center",
242
+ fontsize="small",
243
+ )
244
+
245
+ ax.set_yticks([])
246
+
247
+ # x axis formatting (exact bounds, no margins)
248
+ locator = mdates.AutoDateLocator()
249
+ formatter = mdates.ConciseDateFormatter(locator, show_offset=False)
250
+ ax.xaxis.set_major_locator(locator)
251
+ ax.xaxis.set_major_formatter(formatter)
252
+ ax.set_xlim(mdates.date2num(t0), mdates.date2num(t1))
253
+ ax.margins(x=0) # ensure bars touch the exact edges
254
+ ax.set_xlabel("Time (UTC)", fontsize="small")
255
+ ax.tick_params(labelsize="small")
256
+
257
+ # Header / subtitle
258
+ fig.text(0.5, _LAYOUT["header_y"], header, ha="center", va="top")
259
+ fig.text(
260
+ 0.5,
261
+ _LAYOUT["subtitle_y"],
262
+ subtitle,
263
+ ha="center",
264
+ va="top",
265
+ fontsize="small",
266
+ )
267
+
268
+ # Legend (outside the axes, above the chart)
269
+ if len(uniq_agents) > 0:
270
+ handles = [
271
+ plt.Line2D([0], [0], color=color_map[a], lw=6) for a in uniq_agents
272
+ ]
273
+ fig.legend(
274
+ handles,
275
+ [_abbrev_agent(a) for a in uniq_agents],
276
+ loc="center",
277
+ bbox_to_anchor=(0.5, _LAYOUT["legend_y"]),
278
+ ncol=max(1, min(4, len(uniq_agents))),
279
+ frameon=False,
280
+ fontsize="x-small",
281
+ )
282
+
283
+ # Footers
284
+ fig.text(
285
+ 0.5,
286
+ _LAYOUT["footer1_y"],
287
+ footer1,
288
+ ha="center",
289
+ va="center",
290
+ fontsize="x-small",
291
+ )
292
+ fig.text(
293
+ 0.5,
294
+ _LAYOUT["footer2_y"],
295
+ footer2,
296
+ ha="center",
297
+ va="center",
298
+ fontsize="x-small",
299
+ )
300
+
301
+ fig.savefig(out_path, dpi=150, bbox_inches="tight", pad_inches=0.02)
302
+ plt.close(fig)
303
+ return out_path
304
+
305
+
306
+ def runs_to_dataframe(runs: list[RunRecord]) -> pd.DataFrame:
307
+ rows = [
308
+ {
309
+ "thread_id": r.thread_id,
310
+ "agent": r.agent,
311
+ "run_id": r.run_id,
312
+ "label": f"{r.agent}:{r.run_id[:8]}",
313
+ "start": r.started_at,
314
+ "end": r.ended_at,
315
+ "duration_s": r.duration_s,
316
+ "started_at": r.started_at.strftime("%Y-%m-%d %H:%M:%S UTC"),
317
+ "ended_at": r.ended_at.strftime("%Y-%m-%d %H:%M:%S UTC"),
318
+ }
319
+ for r in runs
320
+ ]
321
+ return pd.DataFrame(rows).sort_values("start").reset_index(drop=True)
322
+
323
+
324
+ def plot_thread_timeline_interactive(
325
+ runs: list[RunRecord],
326
+ out_html: str,
327
+ *,
328
+ group_by: str = "agent", # "agent" (few lanes) or "run" (one lane per run)
329
+ ) -> str:
330
+ import plotly.express as px
331
+
332
+ if not runs:
333
+ raise ValueError("No runs provided")
334
+ df = runs_to_dataframe(runs)
335
+ thread_id = runs[0].thread_id
336
+
337
+ ycol = "agent" if group_by == "agent" else "label"
338
+ fig = px.timeline(
339
+ df,
340
+ x_start="start",
341
+ x_end="end",
342
+ y=ycol,
343
+ color="agent",
344
+ hover_data={
345
+ "agent": True,
346
+ "run_id": True,
347
+ "duration_s": ":.2f",
348
+ "started_at": True,
349
+ "ended_at": True,
350
+ "label": False,
351
+ "start": False,
352
+ "end": False,
353
+ },
354
+ title=f"Thread: {thread_id} — Agent runs timeline (interactive)",
355
+ )
356
+ fig.update_yaxes(autorange="reversed")
357
+ fig.update_layout(
358
+ legend_title_text="Agent",
359
+ margin=dict(l=40, r=20, t=60, b=40),
360
+ hoverlabel=dict(namelength=-1),
361
+ )
362
+ # lock bounds to exact workflow window
363
+ fig.update_xaxes(range=[df["start"].min(), df["end"].max()])
364
+ fig.write_html(out_html, include_plotlyjs="cdn", full_html=True)
365
+ return out_html
366
+
367
+
368
+ def export_thread_csv(runs: list[RunRecord], out_csv: str) -> str:
369
+ df = runs_to_dataframe(runs)
370
+ df.to_csv(out_csv, index=False)
371
+ return out_csv
372
+
373
+
374
+ def aggregate_thread_context(runs: list[RunRecord]) -> dict:
375
+ """Build a context dict for charts at the thread level."""
376
+ if not runs:
377
+ return {}
378
+ t0 = min(r.started_at for r in runs).astimezone(timezone.utc)
379
+ t1 = max(r.ended_at for r in runs).astimezone(timezone.utc)
380
+ thread_id = runs[0].thread_id
381
+ # We intentionally set agent="Thread" so chart headers read "Thread : <id>"
382
+ return {
383
+ "agent": "Thread",
384
+ "thread_id": thread_id,
385
+ "run_id": "",
386
+ "started_at": t0.isoformat(),
387
+ "ended_at": t1.isoformat(),
388
+ }
389
+
390
+
391
+ def extract_thread_time_breakdown(
392
+ runs: list[RunRecord],
393
+ *,
394
+ group_llm: bool = False,
395
+ ) -> tuple[float, list[tuple[str, float]], dict]:
396
+ """
397
+ Sum graph totals and parts (llm/tools/other) across all runs of a thread.
398
+ Returns (total_seconds, parts_list, context).
399
+ """
400
+ total_sum = 0.0
401
+ parts_acc: dict[str, float] = {}
402
+ for r in runs:
403
+ payload = load_metrics(r.path)
404
+ total_i, parts_i = extract_time_breakdown(payload, group_llm=group_llm)
405
+ total_sum += float(total_i or 0.0)
406
+ for label, sec in parts_i:
407
+ parts_acc[label] = parts_acc.get(label, 0.0) + float(sec or 0.0)
408
+
409
+ # stable order: by seconds desc
410
+ parts = sorted(parts_acc.items(), key=lambda kv: kv[1], reverse=True)
411
+ ctx = aggregate_thread_context(runs)
412
+ return total_sum, parts, ctx
413
+
414
+
415
+ def extract_thread_token_stats(
416
+ runs: list[RunRecord],
417
+ ) -> tuple[dict[str, int], dict[str, list[int]], dict]:
418
+ """
419
+ Merge token totals and concatenate per-call samples across all runs of a thread.
420
+ Returns (totals, samples, context).
421
+ """
422
+ totals = {
423
+ "input_tokens": 0,
424
+ "output_tokens": 0,
425
+ "reasoning_tokens": 0,
426
+ "cached_tokens": 0,
427
+ "total_tokens": 0,
428
+ }
429
+ samples = {k: [] for k in totals.keys()}
430
+
431
+ for r in runs:
432
+ payload = load_metrics(r.path)
433
+ t_i, s_i = extract_llm_token_stats(payload)
434
+ # sum totals
435
+ for k in totals:
436
+ totals[k] += int(t_i.get(k, 0) or 0)
437
+ # concat samples
438
+ for k in samples:
439
+ samples[k].extend(list(s_i.get(k, []) or []))
440
+
441
+ ctx = aggregate_thread_context(runs)
442
+ return totals, samples, ctx
443
+
444
+
445
+ def compute_thread_time_bases(runs: list[RunRecord]) -> tuple[float, float]:
446
+ """
447
+ Return (llm_active_seconds, thread_elapsed_seconds) for a thread.
448
+ - llm_active_seconds: sum of LLM total_s across all runs
449
+ - thread_elapsed_seconds: (max ended_at - min started_at)
450
+ """
451
+ if not runs:
452
+ return (0.0, 0.0)
453
+ llm = 0.0
454
+ for r in runs:
455
+ payload = load_metrics(r.path)
456
+ att = compute_attribution(payload)
457
+ llm += float(att.get("llm_total_s", 0.0) or 0.0)
458
+ start = min(r.started_at for r in runs).astimezone(timezone.utc)
459
+ end = max(r.ended_at for r in runs).astimezone(timezone.utc)
460
+ elapsed = max(0.0, (end - start).total_seconds())
461
+ return (llm, elapsed)
462
+
463
+
464
+ def extract_run_tokens_by_model(
465
+ payload: dict,
466
+ ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, float]]:
467
+ """
468
+ Returns:
469
+ tokens_by_model: {model: {input_tokens,...,total_tokens}}
470
+ seconds_by_model: {model: llm_active_seconds_sum}
471
+ """
472
+ tokens_by_model: Dict[str, Dict[str, int]] = defaultdict(
473
+ lambda: defaultdict(int)
474
+ )
475
+ seconds_by_model: Dict[str, float] = defaultdict(float)
476
+
477
+ # 1) tokens by model from llm_events
478
+ for ev in payload.get("llm_events") or []:
479
+ name = ev.get("name") or ""
480
+ model = (
481
+ name.split("llm:", 1)[-1]
482
+ if name.startswith("llm:")
483
+ else (name or "unknown")
484
+ )
485
+ roll = ((ev.get("metrics") or {}).get("usage_rollup")) or {}
486
+ for k in (
487
+ "input_tokens",
488
+ "output_tokens",
489
+ "reasoning_tokens",
490
+ "cached_tokens",
491
+ "total_tokens",
492
+ ):
493
+ try:
494
+ tokens_by_model[model][k] += int(float(roll.get(k, 0) or 0))
495
+ except Exception:
496
+ pass
497
+
498
+ # 2) llm-active seconds by model from tables.llm
499
+ for row in (payload.get("tables") or {}).get("llm") or []:
500
+ n = row.get("name") or ""
501
+ model = (
502
+ n.split("llm:", 1)[-1] if n.startswith("llm:") else (n or "unknown")
503
+ )
504
+ try:
505
+ seconds_by_model[model] += float(row.get("total_s") or 0.0)
506
+ except Exception:
507
+ pass
508
+
509
+ # Cast out of defaultdicts
510
+ return {m: dict(d) for m, d in tokens_by_model.items()}, dict(
511
+ seconds_by_model
512
+ )
513
+
514
+
515
+ def extract_thread_tokens_by_model(
516
+ runs,
517
+ ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, float], float, dict]:
518
+ """
519
+ Aggregate by LLM across all runs in a single thread.
520
+ Returns:
521
+ tokens_by_model, seconds_by_model, window_seconds, context
522
+ """
523
+ tokens_by_model: Dict[str, Dict[str, int]] = defaultdict(
524
+ lambda: defaultdict(int)
525
+ )
526
+ seconds_by_model: Dict[str, float] = defaultdict(float)
527
+
528
+ min_start, max_end = None, None
529
+ for r in runs:
530
+ payload = load_metrics(r.path)
531
+ t_by_model, s_by_model = extract_run_tokens_by_model(payload)
532
+
533
+ for m, d in t_by_model.items():
534
+ for k, v in d.items():
535
+ tokens_by_model[m][k] += int(v or 0)
536
+ for m, secs in s_by_model.items():
537
+ seconds_by_model[m] += float(secs or 0.0)
538
+
539
+ ctx = payload.get("context") or {}
540
+ s, e = ctx.get("started_at"), ctx.get("ended_at")
541
+ if s:
542
+ ds = _dt(s)
543
+ min_start = (
544
+ ds if (min_start is None or ds < min_start) else min_start
545
+ )
546
+ if e:
547
+ de = _dt(e)
548
+ max_end = de if (max_end is None or de > max_end) else max_end
549
+
550
+ window_seconds = (
551
+ (max_end - min_start).total_seconds()
552
+ if (min_start and max_end)
553
+ else 0.0
554
+ )
555
+ ctx_out = {
556
+ "agent": "Thread",
557
+ "thread_id": runs[0].thread_id if runs else "",
558
+ "run_id": "",
559
+ "started_at": min_start.isoformat() if min_start else "",
560
+ "ended_at": max_end.isoformat() if max_end else "",
561
+ }
562
+ return (
563
+ {m: dict(tokens_by_model[m]) for m in tokens_by_model},
564
+ dict(seconds_by_model),
565
+ window_seconds,
566
+ ctx_out,
567
+ )
568
+
569
+
570
+ def aggregate_super_tokens_by_model(
571
+ thread_dirs: List[str],
572
+ ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, float], float, dict]:
573
+ """
574
+ Walk all thread subdirs and aggregate per-LLM totals and LLM-active seconds.
575
+ Returns:
576
+ tokens_by_model, seconds_by_model, window_seconds, context
577
+ """
578
+ tokens_by_model: Dict[str, Dict[str, int]] = defaultdict(
579
+ lambda: defaultdict(int)
580
+ )
581
+ seconds_by_model: Dict[str, float] = defaultdict(float)
582
+
583
+ min_start, max_end = None, None
584
+
585
+ for d in thread_dirs:
586
+ sessions = scan_directory_for_threads(d)
587
+ for _tid, runs in (sessions or {}).items():
588
+ for r in runs:
589
+ payload = load_metrics(r.path)
590
+ t_by_model, s_by_model = extract_run_tokens_by_model(payload)
591
+
592
+ for m, dct in t_by_model.items():
593
+ for k, v in dct.items():
594
+ tokens_by_model[m][k] += int(v or 0)
595
+ for m, secs in s_by_model.items():
596
+ seconds_by_model[m] += float(secs or 0.0)
597
+
598
+ ctx = payload.get("context") or {}
599
+ s, e = ctx.get("started_at"), ctx.get("ended_at")
600
+ if s:
601
+ ds = _dt(s)
602
+ min_start = (
603
+ ds
604
+ if (min_start is None or ds < min_start)
605
+ else min_start
606
+ )
607
+ if e:
608
+ de = _dt(e)
609
+ max_end = (
610
+ de if (max_end is None or de > max_end) else max_end
611
+ )
612
+
613
+ window_seconds = (
614
+ (max_end - min_start).total_seconds()
615
+ if (min_start and max_end)
616
+ else 0.0
617
+ )
618
+ ctx_out = {
619
+ "agent": "SUPER",
620
+ "thread_id": f"{len(thread_dirs)} thread dirs",
621
+ "run_id": "",
622
+ "started_at": min_start.isoformat() if min_start else "",
623
+ "ended_at": max_end.isoformat() if max_end else "",
624
+ }
625
+ return (
626
+ {m: dict(tokens_by_model[m]) for m in tokens_by_model},
627
+ dict(seconds_by_model),
628
+ window_seconds,
629
+ ctx_out,
630
+ )
631
+
632
+
633
+ def extract_thread_token_stats_by_agent(
634
+ runs: List[RunRecord],
635
+ ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, float], float]:
636
+ """
637
+ Aggregate token totals and LLM-active seconds by AGENT for a single thread.
638
+ Returns:
639
+ (totals_by_agent, llm_secs_by_agent, thread_window_seconds)
640
+ - totals_by_agent: {agent: {input_tokens, output_tokens, reasoning_tokens, cached_tokens, total_tokens}}
641
+ - llm_secs_by_agent: {agent: seconds}
642
+ - thread_window_seconds: (max end - min start) across all runs in this thread
643
+ """
644
+ totals_by_agent: Dict[str, Dict[str, int]] = defaultdict(
645
+ lambda: defaultdict(int)
646
+ )
647
+ llm_secs_by_agent: Dict[str, float] = defaultdict(float)
648
+
649
+ if not runs:
650
+ return {}, {}, 0.0
651
+
652
+ # compute thread window
653
+ min_start = min(r.started_at for r in runs)
654
+ max_end = max(r.ended_at for r in runs)
655
+ thread_window_seconds = max(0.0, (max_end - min_start).total_seconds())
656
+
657
+ for r in runs:
658
+ payload = load_metrics(r.path)
659
+
660
+ # tokens for this run -> bucket by r.agent
661
+ run_totals, _ = extract_llm_token_stats(payload)
662
+ for k, v in run_totals.items():
663
+ try:
664
+ totals_by_agent[r.agent][k] += int(v or 0)
665
+ except Exception:
666
+ pass
667
+
668
+ # llm-active seconds for this run -> bucket by r.agent
669
+ try:
670
+ att = compute_attribution(payload)
671
+ llm_secs_by_agent[r.agent] += float(
672
+ att.get("llm_total_s", 0.0) or 0.0
673
+ )
674
+ except Exception:
675
+ pass
676
+
677
+ # cast out of defaultdicts
678
+ return (
679
+ {a: dict(d) for a, d in totals_by_agent.items()},
680
+ dict(llm_secs_by_agent),
681
+ float(thread_window_seconds),
682
+ )
683
+
684
+
685
+ def aggregate_super_token_stats_by_agent(
686
+ sessions: Dict[str, List[RunRecord]],
687
+ ) -> Tuple[
688
+ Dict[str, Dict[str, int]], Dict[str, float], float, Dict[str, float]
689
+ ]:
690
+ """
691
+ Aggregate across ALL threads in a directory (non-recursive).
692
+ Input:
693
+ sessions: {thread_id: [RunRecord, ...]} (from scan_directory_for_threads)
694
+ Returns:
695
+ totals_by_agent, llm_secs_by_agent, sum_thread_secs, summary
696
+ - totals_by_agent: {agent: {...token categories...}}
697
+ - llm_secs_by_agent: {agent: seconds}
698
+ - sum_thread_secs: sum of each thread's (max end - min start)
699
+ - summary: {"n_threads": ..., "n_runs": ..., "sum_thread_secs": ...}
700
+ """
701
+ totals_by_agent: Dict[str, Dict[str, int]] = defaultdict(
702
+ lambda: defaultdict(int)
703
+ )
704
+ llm_secs_by_agent: Dict[str, float] = defaultdict(float)
705
+ sum_thread_secs = 0.0
706
+ n_runs = 0
707
+
708
+ for _tid, runs in (sessions or {}).items():
709
+ if not runs:
710
+ continue
711
+ n_runs += len(runs)
712
+
713
+ # thread window for this thread
714
+ t_min = min(r.started_at for r in runs)
715
+ t_max = max(r.ended_at for r in runs)
716
+ sum_thread_secs += max(0.0, (t_max - t_min).total_seconds())
717
+
718
+ # per-run aggregation by agent
719
+ for r in runs:
720
+ payload = load_metrics(r.path)
721
+
722
+ # tokens
723
+ run_totals, _ = extract_llm_token_stats(payload)
724
+ for k, v in run_totals.items():
725
+ try:
726
+ totals_by_agent[r.agent][k] += int(v or 0)
727
+ except Exception:
728
+ pass
729
+
730
+ # llm seconds
731
+ try:
732
+ att = compute_attribution(payload)
733
+ llm_secs_by_agent[r.agent] += float(
734
+ att.get("llm_total_s", 0.0) or 0.0
735
+ )
736
+ except Exception:
737
+ pass
738
+
739
+ summary = {
740
+ "n_threads": float(len(sessions)),
741
+ "n_runs": float(n_runs),
742
+ "sum_thread_secs": float(sum_thread_secs),
743
+ }
744
+
745
+ return (
746
+ {a: dict(d) for a, d in totals_by_agent.items()},
747
+ dict(llm_secs_by_agent),
748
+ float(sum_thread_secs),
749
+ summary,
750
+ )