ursa-ai 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. ursa/__init__.py +3 -0
  2. ursa/agents/__init__.py +32 -0
  3. ursa/agents/acquisition_agents.py +812 -0
  4. ursa/agents/arxiv_agent.py +429 -0
  5. ursa/agents/base.py +728 -0
  6. ursa/agents/chat_agent.py +60 -0
  7. ursa/agents/code_review_agent.py +341 -0
  8. ursa/agents/execution_agent.py +915 -0
  9. ursa/agents/hypothesizer_agent.py +614 -0
  10. ursa/agents/lammps_agent.py +465 -0
  11. ursa/agents/mp_agent.py +204 -0
  12. ursa/agents/optimization_agent.py +410 -0
  13. ursa/agents/planning_agent.py +219 -0
  14. ursa/agents/rag_agent.py +304 -0
  15. ursa/agents/recall_agent.py +54 -0
  16. ursa/agents/websearch_agent.py +196 -0
  17. ursa/cli/__init__.py +363 -0
  18. ursa/cli/hitl.py +516 -0
  19. ursa/cli/hitl_api.py +75 -0
  20. ursa/observability/metrics_charts.py +1279 -0
  21. ursa/observability/metrics_io.py +11 -0
  22. ursa/observability/metrics_session.py +750 -0
  23. ursa/observability/pricing.json +97 -0
  24. ursa/observability/pricing.py +321 -0
  25. ursa/observability/timing.py +1466 -0
  26. ursa/prompt_library/__init__.py +0 -0
  27. ursa/prompt_library/code_review_prompts.py +51 -0
  28. ursa/prompt_library/execution_prompts.py +50 -0
  29. ursa/prompt_library/hypothesizer_prompts.py +17 -0
  30. ursa/prompt_library/literature_prompts.py +11 -0
  31. ursa/prompt_library/optimization_prompts.py +131 -0
  32. ursa/prompt_library/planning_prompts.py +79 -0
  33. ursa/prompt_library/websearch_prompts.py +131 -0
  34. ursa/tools/__init__.py +0 -0
  35. ursa/tools/feasibility_checker.py +114 -0
  36. ursa/tools/feasibility_tools.py +1075 -0
  37. ursa/tools/run_command.py +27 -0
  38. ursa/tools/write_code.py +42 -0
  39. ursa/util/__init__.py +0 -0
  40. ursa/util/diff_renderer.py +128 -0
  41. ursa/util/helperFunctions.py +142 -0
  42. ursa/util/logo_generator.py +625 -0
  43. ursa/util/memory_logger.py +183 -0
  44. ursa/util/optimization_schema.py +78 -0
  45. ursa/util/parse.py +405 -0
  46. ursa_ai-0.9.1.dist-info/METADATA +304 -0
  47. ursa_ai-0.9.1.dist-info/RECORD +51 -0
  48. ursa_ai-0.9.1.dist-info/WHEEL +5 -0
  49. ursa_ai-0.9.1.dist-info/entry_points.txt +2 -0
  50. ursa_ai-0.9.1.dist-info/licenses/LICENSE +8 -0
  51. ursa_ai-0.9.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1279 @@
1
+ # charts.py
2
+ from __future__ import annotations
3
+
4
+ from typing import Any, Dict, List, Tuple
5
+
6
+ import matplotlib
7
+
8
+ matplotlib.use("Agg") # safe for headless environments
9
+ import datetime as _dt
10
+
11
+ import matplotlib.pyplot as plt
12
+ import numpy as np
13
+ from scipy.stats import gaussian_kde # type: ignore
14
+
15
+ # Layout spec for compact charts (fractions of figure size)
16
+ _LAYOUT = dict(
17
+ header_y=1.1, # big header (agent : thread_id)
18
+ subtitle_y=1.0, # smaller subtitle (title/total)
19
+ ax_rect=(0.12, 0.310, 0.84, 0.58), # [left, bottom, width, height]
20
+ footer1_y=0.105, # run_id
21
+ footer2_y=0.050, # started → ended
22
+ )
23
+
24
+
25
+ def compute_llm_wall_seconds(payload: dict) -> float:
26
+ evs = payload.get("llm_events") or []
27
+ intervals = []
28
+ for ev in evs:
29
+ t0, t1 = ev.get("t_start"), ev.get("t_end")
30
+ if (
31
+ isinstance(t0, (int, float))
32
+ and isinstance(t1, (int, float))
33
+ and t1 >= t0
34
+ ):
35
+ intervals.append((t0, t1))
36
+ if not intervals:
37
+ return 0.0
38
+ intervals.sort()
39
+ wall = 0.0
40
+ cur_s, cur_e = intervals[0]
41
+ for s, e in intervals[1:]:
42
+ if s <= cur_e:
43
+ cur_e = max(cur_e, e)
44
+ else:
45
+ wall += cur_e - cur_s
46
+ cur_s, cur_e = s, e
47
+ wall += cur_e - cur_s
48
+ return float(wall)
49
+
50
+
51
+ def compute_attribution(payload: Dict[str, Any]) -> Dict[str, float]:
52
+ """
53
+ Returns a dict with totals useful for validation/printing:
54
+ total_s = graph:graph (or any graph:* fallback)
55
+ llm_total_s = sum of llm rows
56
+ tool_total_s = sum of tool rows
57
+ unattributed_s = max(0, total_s - (llm_total_s + tool_total_s))
58
+ overage_s = max(0, (llm_total_s + tool_total_s) - total_s)
59
+ """
60
+ tables = payload.get("tables") or {}
61
+ # total via our updated finder
62
+ total_s, _ = extract_time_breakdown(payload) # reuse finder logic
63
+ # but extract_time_breakdown returns (total, parts); we only need total
64
+ total_s = _find_graph_total_seconds(payload)
65
+
66
+ llm_total_s = sum(
67
+ float(r.get("total_s") or 0.0) for r in (tables.get("llm") or [])
68
+ )
69
+ tool_total_s = sum(
70
+ float(r.get("total_s") or 0.0) for r in (tables.get("tool") or [])
71
+ )
72
+ unattributed_s = max(0.0, total_s - (llm_total_s + tool_total_s))
73
+ overage_s = max(0.0, (llm_total_s + tool_total_s) - total_s)
74
+ return {
75
+ "total_s": total_s,
76
+ "llm_total_s": llm_total_s,
77
+ "tool_total_s": tool_total_s,
78
+ "unattributed_s": unattributed_s,
79
+ "overage_s": overage_s,
80
+ }
81
+
82
+
83
+ def _extract_context(payload: Dict[str, Any]) -> Dict[str, str]:
84
+ """Return a normalized context dict with agent/thread/run_id/started/ended."""
85
+ ctx = payload.get("context") or {}
86
+ return {
87
+ "agent": str(ctx.get("agent") or ""),
88
+ "thread_id": str(ctx.get("thread_id") or ""),
89
+ "run_id": str(ctx.get("run_id") or ""),
90
+ "started_at": str(ctx.get("started_at") or ""),
91
+ "ended_at": str(ctx.get("ended_at") or ""),
92
+ }
93
+
94
+
95
+ def _fmt_iso_pretty(ts: str) -> str:
96
+ """ISO8601 -> 'YYYY-MM-DD HH:MM:SS UTC' (best-effort; falls back to the original)."""
97
+ if not ts:
98
+ return ""
99
+ try:
100
+ dt = _dt.datetime.fromisoformat(ts.replace("Z", "+00:00"))
101
+ dt = dt.astimezone(_dt.timezone.utc)
102
+ return dt.strftime("%Y-%m-%d %H:%M:%S UTC")
103
+ except Exception:
104
+ return ts
105
+
106
+
107
+ def _find_graph_total_seconds(payload: Dict[str, Any]) -> float:
108
+ """
109
+ Total time = tables.runnable row where name == 'graph:graph'
110
+ Fallback to totals.graph_total_s if needed.
111
+ """
112
+ tables = payload.get("tables") or {}
113
+ runnable_rows = tables.get("runnable") or []
114
+ for row in runnable_rows:
115
+ if str(row.get("name", "")).startswith("graph:"):
116
+ return float(row.get("total_s") or 0.0)
117
+ totals = payload.get("totals") or {}
118
+ try:
119
+ return float(totals.get("graph_total_s") or 0.0)
120
+ except Exception:
121
+ return 0.0
122
+
123
+
124
+ def _aggregate_tools_seconds(
125
+ payload: Dict[str, Any],
126
+ ) -> List[Tuple[str, float]]:
127
+ tables = payload.get("tables") or {}
128
+ tool_rows = tables.get("tool") or []
129
+ out: List[Tuple[str, float]] = []
130
+ for row in tool_rows:
131
+ name = str(row.get("name") or "tool:unknown")
132
+ try:
133
+ total_s = float(row.get("total_s") or 0.0)
134
+ except Exception:
135
+ total_s = 0.0
136
+ out.append((f"tool:{name}", total_s))
137
+ return out
138
+
139
+
140
+ def _aggregate_llm_seconds(
141
+ payload: Dict[str, Any], *, group_llm: bool
142
+ ) -> List[Tuple[str, float]]:
143
+ tables = payload.get("tables") or {}
144
+ llm_rows = tables.get("llm") or []
145
+ if group_llm:
146
+ total = 0.0
147
+ for r in llm_rows:
148
+ try:
149
+ total += float(r.get("total_s") or 0.0)
150
+ except Exception:
151
+ pass
152
+ return [("llm:total", total)] if total > 0 else []
153
+ else:
154
+ out: List[Tuple[str, float]] = []
155
+ for row in llm_rows:
156
+ name = str(row.get("name") or "llm:unknown")
157
+ try:
158
+ total_s = float(row.get("total_s") or 0.0)
159
+ except Exception:
160
+ total_s = 0.0
161
+ out.append((name, total_s))
162
+ return out
163
+
164
+
165
+ def extract_time_breakdown(
166
+ payload: Dict[str, Any], *, group_llm: bool = False
167
+ ) -> Tuple[float, List[Tuple[str, float]]]:
168
+ """
169
+ Returns (total_seconds, parts), where parts is a list of (label, seconds).
170
+ parts = [each tool, each llm (or grouped), "other"] with "other" >= 0.
171
+ """
172
+ total = _find_graph_total_seconds(payload)
173
+ parts: List[Tuple[str, float]] = []
174
+ parts.extend(_aggregate_tools_seconds(payload))
175
+ parts.extend(_aggregate_llm_seconds(payload, group_llm=group_llm))
176
+
177
+ used = sum(v for _, v in parts)
178
+ other = max(0.0, total - used)
179
+ parts.append(("other", other))
180
+
181
+ # drop zero entries to keep charts tidy
182
+ parts = [(k, v) for k, v in parts if v > 0.0]
183
+ return total, parts
184
+
185
+
186
+ def plot_time_breakdown(
187
+ total: float,
188
+ parts: List[Tuple[str, float]],
189
+ out_path: str,
190
+ *,
191
+ title: str = "",
192
+ chart: str = "pie", # "pie" or "bar"
193
+ min_label_pct: float = 1.0,
194
+ context: Dict[str, Any] | None = None, # NEW
195
+ ) -> str:
196
+ labels = [k for k, _ in parts]
197
+ values = [v for _, v in parts]
198
+ overall = sum(values) or 1.0
199
+
200
+ # ----- build header/footer text from context -----
201
+ ctx = context or {}
202
+ agent = str(ctx.get("agent") or "")
203
+ thread_id = str(ctx.get("thread_id") or "")
204
+ run_id = str(ctx.get("run_id") or "")
205
+ started = _fmt_iso_pretty(str(ctx.get("started_at") or ""))
206
+ ended = _fmt_iso_pretty(str(ctx.get("ended_at") or ""))
207
+ header = " : ".join([p for p in [agent, thread_id] if p]) or ""
208
+ subtitle = title or f"Time Breakdown (total = {total:.3f}s)"
209
+
210
+ if chart == "bar":
211
+ fig = plt.figure(figsize=(8, 1.8))
212
+ ax = fig.add_axes([0.12, 0.30, 0.84, 0.56])
213
+
214
+ left = 0.0
215
+ for label, val in parts:
216
+ width = val / overall
217
+ ax.barh([0], [width], left=left, edgecolor="black")
218
+ pct = width * 100.0
219
+ if pct >= min_label_pct:
220
+ ax.text(
221
+ left + width / 2.0,
222
+ 0,
223
+ f"{label} ({pct:.1f}%)",
224
+ ha="center",
225
+ va="center",
226
+ )
227
+ left += width
228
+
229
+ ax.set_xlim(0, 1)
230
+ ax.set_yticks([])
231
+ ax.set_xlabel("Share of graph:graph wall time")
232
+
233
+ if header:
234
+ fig.text(0.5, 0.965, header, ha="center", va="top")
235
+ fig.text(
236
+ 0.5, 0.915, subtitle, ha="center", va="top", fontsize="small"
237
+ )
238
+ else:
239
+ fig.text(0.5, 0.945, subtitle, ha="center", va="top")
240
+
241
+ if run_id:
242
+ fig.text(
243
+ 0.5,
244
+ 0.10,
245
+ f"run_id: {run_id}",
246
+ ha="center",
247
+ va="center",
248
+ fontsize="x-small",
249
+ )
250
+ if started or ended:
251
+ fig.text(
252
+ 0.5,
253
+ 0.06,
254
+ f"{started} \u2192 {ended}",
255
+ ha="center",
256
+ va="center",
257
+ fontsize="x-small",
258
+ )
259
+
260
+ fig.savefig(out_path, dpi=150, bbox_inches="tight", pad_inches=0.02)
261
+ plt.close(fig)
262
+ return out_path
263
+
264
+ # --- pie ---
265
+ fig = plt.figure(figsize=(6, 6))
266
+ ax = fig.add_axes([0.08, 0.22, 0.84, 0.70]) # tighter
267
+
268
+ def _fmt(pct):
269
+ abs_val = (pct / 100.0) * overall
270
+ return f"{pct:.1f}%\n{abs_val:.3f}s"
271
+
272
+ ax.pie(values, labels=labels, autopct=_fmt, startangle=90)
273
+ ax.axis("equal")
274
+
275
+ if header:
276
+ fig.text(0.5, 0.965, header, ha="center", va="top")
277
+ fig.text(
278
+ 0.5,
279
+ 0.915,
280
+ (title or f"Time Breakdown (total = {total:.3f}s)"),
281
+ ha="center",
282
+ va="top",
283
+ fontsize="small",
284
+ )
285
+ else:
286
+ fig.text(
287
+ 0.5,
288
+ 0.945,
289
+ (title or f"Time Breakdown (total = {total:.3f}s)"),
290
+ ha="center",
291
+ va="top",
292
+ )
293
+
294
+ if run_id:
295
+ fig.text(
296
+ 0.5,
297
+ 0.10,
298
+ f"run_id: {run_id}",
299
+ ha="center",
300
+ va="center",
301
+ fontsize="x-small",
302
+ )
303
+ if started or ended:
304
+ fig.text(
305
+ 0.5,
306
+ 0.06,
307
+ f"{started} \u2192 {ended}",
308
+ ha="center",
309
+ va="center",
310
+ fontsize="x-small",
311
+ )
312
+
313
+ fig.savefig(out_path, dpi=150, bbox_inches="tight", pad_inches=0.02)
314
+ plt.close(fig)
315
+ return out_path
316
+
317
+
318
+ def plot_lollipop_time(
319
+ total: float,
320
+ parts: List[Tuple[str, float]],
321
+ out_path: str,
322
+ *,
323
+ title: str = "",
324
+ log_x: bool = True,
325
+ min_label_pct: float = 0.0, # you said you set default to 0.0
326
+ show_seconds: bool = True,
327
+ show_percent: bool = True,
328
+ exclude_zero: bool = True,
329
+ context: Dict[str, Any] | None = None, # NEW
330
+ ) -> str:
331
+ data = [(k, v) for (k, v) in parts if (v > 0 if exclude_zero else True)]
332
+ data.sort(key=lambda kv: kv[1])
333
+ labels = [k for k, _ in data]
334
+ values = [v for _, v in data]
335
+
336
+ # ----- context header/footer -----
337
+ ctx = context or {}
338
+ agent = str(ctx.get("agent") or "")
339
+ thread_id = str(ctx.get("thread_id") or "")
340
+ run_id = str(ctx.get("run_id") or "")
341
+ started = _fmt_iso_pretty(str(ctx.get("started_at") or ""))
342
+ ended = _fmt_iso_pretty(str(ctx.get("ended_at") or ""))
343
+ header = " : ".join([p for p in [agent, thread_id] if p]) or ""
344
+ subtitle = title or f"Time (seconds) by component (total = {total:.3f}s)"
345
+
346
+ # --- explicit layout: one Axes, exact placement ---
347
+ fig_h = max(2.2, 0.35 * max(1, len(values)))
348
+ fig = plt.figure(figsize=(8, fig_h))
349
+ ax = fig.add_axes(list(_LAYOUT["ax_rect"]))
350
+
351
+ # plot
352
+ y = range(len(values))
353
+ ax.hlines(y, xmin=0, xmax=values, linewidth=1)
354
+ ax.plot(values, y, "o")
355
+
356
+ if log_x:
357
+ vmin = min(values) if values else 0.0
358
+ if vmin <= 0:
359
+ vmin = min([v for v in values if v > 0] + [1e-6])
360
+ ax.set_xscale("log")
361
+ ax.set_xlim(
362
+ left=vmin * 0.8, right=(max(values) * 1.1 if values else 1.0)
363
+ )
364
+
365
+ # axes cosmetics
366
+ ax.set_yticks(list(y))
367
+ ax.set_yticklabels(labels)
368
+ ax.set_xlabel(
369
+ "Seconds (log scale)" if log_x else "Seconds", fontsize="small"
370
+ )
371
+ ax.tick_params(labelsize="small")
372
+
373
+ # header/subtitle (figure text so they don't push the axes)
374
+ if header:
375
+ fig.text(0.5, _LAYOUT["header_y"], header, ha="center", va="top")
376
+ fig.text(
377
+ 0.5,
378
+ _LAYOUT["subtitle_y"],
379
+ subtitle,
380
+ ha="center",
381
+ va="top",
382
+ fontsize="small",
383
+ )
384
+ else:
385
+ fig.text(
386
+ 0.5, (_LAYOUT["header_y"] - 0.02), subtitle, ha="center", va="top"
387
+ )
388
+
389
+ # annotate dots
390
+ for yi, val in zip(y, values):
391
+ pct = (val / total * 100.0) if total > 0 else 0.0
392
+ if pct >= min_label_pct:
393
+ bits = []
394
+ if show_percent:
395
+ bits.append(f"{pct:.2f}%")
396
+ if show_seconds:
397
+ bits.append(f"{val:.3f}s")
398
+ ax.text(
399
+ val,
400
+ yi,
401
+ " " + " ".join(bits),
402
+ va="center",
403
+ ha="left",
404
+ fontsize="small",
405
+ )
406
+
407
+ # compact footers
408
+ if run_id:
409
+ fig.text(
410
+ 0.5,
411
+ _LAYOUT["footer1_y"],
412
+ f"run_id: {run_id}",
413
+ ha="center",
414
+ va="center",
415
+ fontsize="x-small",
416
+ )
417
+ if started or ended:
418
+ fig.text(
419
+ 0.5,
420
+ _LAYOUT["footer2_y"],
421
+ f"{started} \u2192 {ended}",
422
+ ha="center",
423
+ va="center",
424
+ fontsize="x-small",
425
+ )
426
+
427
+ fig.savefig(out_path, dpi=150, bbox_inches="tight", pad_inches=0.02)
428
+ plt.close(fig)
429
+ return out_path
430
+
431
+
432
+ def extract_llm_token_stats(
433
+ payload: Dict[str, Any],
434
+ ) -> Tuple[Dict[str, int], Dict[str, List[int]]]:
435
+ """
436
+ Return (totals, samples) for LLM token usage from Telemetry payload.
437
+ - totals: sum across all LLM calls
438
+ - samples: list per call for KDE (one list per category)
439
+ Categories: input_tokens, output_tokens, reasoning_tokens, cached_tokens, total_tokens
440
+ """
441
+ events = payload.get("llm_events") or []
442
+ totals = {
443
+ "input_tokens": 0,
444
+ "output_tokens": 0,
445
+ "reasoning_tokens": 0,
446
+ "cached_tokens": 0,
447
+ "total_tokens": 0,
448
+ }
449
+ samples = {k: [] for k in totals.keys()}
450
+
451
+ for ev in events:
452
+ m = (ev.get("metrics") or {}).get("usage_rollup") or {}
453
+
454
+ # Normalize with safe int coercion
455
+ def _gi(key: str) -> int:
456
+ try:
457
+ v = m.get(key)
458
+ if v is None:
459
+ return 0
460
+ return int(float(v))
461
+ except Exception:
462
+ return 0
463
+
464
+ # Prefer explicit input/output; fall back to prompt/completion mirrors
465
+ it = _gi("input_tokens") or _gi("prompt_tokens")
466
+ ot = _gi("output_tokens") or _gi("completion_tokens")
467
+ rt = _gi("reasoning_tokens")
468
+ ct = _gi("cached_tokens")
469
+ tt = _gi("total_tokens")
470
+ # Ensure total is at least input+output (providers sometimes omit)
471
+ tt = max(tt, it + ot)
472
+
473
+ # Update
474
+ totals["input_tokens"] += it
475
+ totals["output_tokens"] += ot
476
+ totals["reasoning_tokens"] += rt
477
+ totals["cached_tokens"] += ct
478
+ totals["total_tokens"] += tt
479
+
480
+ samples["input_tokens"].append(it)
481
+ samples["output_tokens"].append(ot)
482
+ if rt:
483
+ samples["reasoning_tokens"].append(rt)
484
+ if ct:
485
+ samples["cached_tokens"].append(ct)
486
+ samples["total_tokens"].append(tt)
487
+
488
+ return totals, samples
489
+
490
+
491
+ def plot_token_totals_bar(
492
+ totals: Dict[str, int],
493
+ out_path: str,
494
+ *,
495
+ title: str = "",
496
+ context: Dict[str, Any] | None = None,
497
+ ) -> str:
498
+ """
499
+ Horizontal bar chart of token totals by category.
500
+ """
501
+
502
+ order = [
503
+ "input_tokens",
504
+ "output_tokens",
505
+ "reasoning_tokens",
506
+ "cached_tokens",
507
+ "total_tokens",
508
+ ]
509
+ labels = [lbl.replace("_", " ") for lbl in order]
510
+ values = [int(totals.get(k, 0)) for k in order]
511
+
512
+ ctx = context or {}
513
+ agent = str(ctx.get("agent") or "")
514
+ thread = str(ctx.get("thread_id") or "")
515
+ run_id = str(ctx.get("run_id") or "")
516
+ started = _fmt_iso_pretty(str(ctx.get("started_at") or ""))
517
+ ended = _fmt_iso_pretty(str(ctx.get("ended_at") or ""))
518
+
519
+ header = " : ".join([p for p in [agent, thread] if p]) or ""
520
+ subtitle = title or "LLM Token Totals by Category"
521
+
522
+ fig_h = max(2.2, 0.35 * len(labels))
523
+ fig = plt.figure(figsize=(8, fig_h))
524
+ ax = fig.add_axes(list(_LAYOUT["ax_rect"]))
525
+
526
+ y = list(range(len(labels)))
527
+ ax.barh(y, values, edgecolor="black")
528
+ ax.set_yticks(y)
529
+ ax.set_yticklabels(labels)
530
+ ax.invert_yaxis() # put input_tokens on top
531
+ ax.set_xlabel("Tokens", fontsize="small")
532
+ ax.tick_params(labelsize="small")
533
+
534
+ # annotate counts on bars
535
+ for yi, val in zip(y, values):
536
+ ax.text(val, yi, f" {val:,}", va="center", ha="left", fontsize="small")
537
+
538
+ if header:
539
+ fig.text(0.5, _LAYOUT["header_y"], header, ha="center", va="top")
540
+ fig.text(
541
+ 0.5,
542
+ _LAYOUT["subtitle_y"],
543
+ subtitle,
544
+ ha="center",
545
+ va="top",
546
+ fontsize="small",
547
+ )
548
+ else:
549
+ fig.text(
550
+ 0.5, (_LAYOUT["header_y"] - 0.02), subtitle, ha="center", va="top"
551
+ )
552
+
553
+ if run_id:
554
+ fig.text(
555
+ 0.5,
556
+ _LAYOUT["footer1_y"],
557
+ f"run_id: {run_id}",
558
+ ha="center",
559
+ va="center",
560
+ fontsize="x-small",
561
+ )
562
+ if started or ended:
563
+ fig.text(
564
+ 0.5,
565
+ _LAYOUT["footer2_y"],
566
+ f"{started} \u2192 {ended}",
567
+ ha="center",
568
+ va="center",
569
+ fontsize="x-small",
570
+ )
571
+
572
+ fig.savefig(out_path, dpi=150, bbox_inches="tight", pad_inches=0.02)
573
+ plt.close(fig)
574
+ return out_path
575
+
576
+
577
+ def plot_token_kde(
578
+ samples: Dict[str, List[int]],
579
+ out_path: str,
580
+ *,
581
+ title: str = "",
582
+ context: Dict[str, Any] | None = None,
583
+ log_x: bool = False,
584
+ bandwidth: float | None = None,
585
+ fill_alpha: float = 0.15,
586
+ line_alpha: float = 0.85,
587
+ ) -> str:
588
+ """
589
+ Overlay KDEs for input/output/reasoning/cached/total tokens.
590
+ - Uses scipy.stats.gaussian_kde if available; falls back to a Gaussian-smoothed histogram.
591
+ - No seaborn.
592
+ """
593
+
594
+ # categories & pretty labels (skip empty series automatically)
595
+ order = [
596
+ ("input_tokens", "input tokens"),
597
+ ("output_tokens", "output tokens"),
598
+ ("reasoning_tokens", "reasoning tokens"),
599
+ ("cached_tokens", "cached tokens"),
600
+ ("total_tokens", "total tokens"),
601
+ ]
602
+
603
+ # Gather non-empty arrays
604
+ series = []
605
+ for key, label in order:
606
+ arr = np.asarray(samples.get(key, []), dtype=float)
607
+ arr = arr[np.isfinite(arr)]
608
+ if arr.size >= 2: # need at least 2 for KDE
609
+ series.append((key, label, arr))
610
+
611
+ if not series:
612
+ # Nothing to plot; create an empty figure with a note
613
+ fig = plt.figure(figsize=(8, 2.0))
614
+ fig.text(0.5, 0.5, "No token samples to plot", ha="center", va="center")
615
+ fig.savefig(out_path, dpi=150, bbox_inches="tight", pad_inches=0.02)
616
+ plt.close(fig)
617
+ return out_path
618
+
619
+ # Context
620
+ ctx = context or {}
621
+ agent = str(ctx.get("agent") or "")
622
+ thread = str(ctx.get("thread_id") or "")
623
+ run_id = str(ctx.get("run_id") or "")
624
+ started = _fmt_iso_pretty(str(ctx.get("started_at") or ""))
625
+ ended = _fmt_iso_pretty(str(ctx.get("ended_at") or ""))
626
+
627
+ header = " : ".join([p for p in [agent, thread] if p]) or ""
628
+ subtitle = title or "LLM Token Usage — KDE"
629
+
630
+ # Build x-grid across all series
631
+ all_max = max(float(np.max(a)) for _, _, a in series)
632
+ x_min = 0.0
633
+ x_max = max(1.0, all_max * 1.05)
634
+ x = np.linspace(x_min, x_max, 600)
635
+
636
+ # Try scipy KDE; else fallback to simple Gaussian smoothing of hist density
637
+ def _kde(arr: np.ndarray) -> np.ndarray:
638
+ kde = gaussian_kde(arr, bw_method=bandwidth)
639
+ return kde.evaluate(x)
640
+
641
+ # Plot
642
+ fig = plt.figure(figsize=(8, 2.8))
643
+ ax = fig.add_axes(list(_LAYOUT["ax_rect"]))
644
+
645
+ for _, label, arr in series:
646
+ y = _kde(arr)
647
+ ax.plot(x, y, alpha=line_alpha, label=label)
648
+ ax.fill_between(x, 0, y, alpha=fill_alpha)
649
+
650
+ if log_x:
651
+ # Avoid log(0)
652
+ ax.set_xscale("log")
653
+ ax.set_xlim(left=max(1e-6, x_min + 1e-6), right=x_max)
654
+
655
+ ax.set_xlabel(
656
+ "Tokens" + (" (log scale)" if log_x else ""), fontsize="small"
657
+ )
658
+ ax.set_ylabel("Density", fontsize="small")
659
+ ax.tick_params(labelsize="small")
660
+ ax.legend(loc="upper right", fontsize="x-small", frameon=False)
661
+
662
+ if header:
663
+ fig.text(0.5, _LAYOUT["header_y"], header, ha="center", va="top")
664
+ fig.text(
665
+ 0.5,
666
+ _LAYOUT["subtitle_y"],
667
+ subtitle,
668
+ ha="center",
669
+ va="top",
670
+ fontsize="small",
671
+ )
672
+ else:
673
+ fig.text(
674
+ 0.5, (_LAYOUT["header_y"] - 0.02), subtitle, ha="center", va="top"
675
+ )
676
+
677
+ if run_id:
678
+ fig.text(
679
+ 0.5,
680
+ _LAYOUT["footer1_y"],
681
+ f"run_id: {run_id}",
682
+ ha="center",
683
+ va="center",
684
+ fontsize="x-small",
685
+ )
686
+ if started or ended:
687
+ fig.text(
688
+ 0.5,
689
+ _LAYOUT["footer2_y"],
690
+ f"{started} \u2192 {ended}",
691
+ ha="center",
692
+ va="center",
693
+ fontsize="x-small",
694
+ )
695
+
696
+ fig.savefig(out_path, dpi=150, bbox_inches="tight", pad_inches=0.02)
697
+ plt.close(fig)
698
+ return out_path
699
+
700
+
701
+ def plot_token_rates_bar(
702
+ totals: dict[str, int],
703
+ llm_seconds: float,
704
+ window_seconds: float,
705
+ out_path: str,
706
+ *,
707
+ title: str = "Tokens per second (two baselines)",
708
+ context: dict | None = None,
709
+ categories: list[str] | None = None,
710
+ ) -> str:
711
+ import matplotlib.pyplot as plt
712
+ import numpy as np
713
+
714
+ cats = categories or [
715
+ "input_tokens",
716
+ "output_tokens",
717
+ "reasoning_tokens",
718
+ "cached_tokens",
719
+ "total_tokens",
720
+ ]
721
+ labels = [c.replace("_", " ") for c in cats]
722
+ vals = [int(totals.get(c, 0) or 0) for c in cats]
723
+
724
+ def _rate(x: int, denom: float) -> float:
725
+ return (float(x) / denom) if denom and denom > 0 else 0.0
726
+
727
+ rates_llm = [_rate(x, llm_seconds) for x in vals]
728
+ rates_win = [_rate(x, window_seconds) for x in vals]
729
+
730
+ ctx = context or {}
731
+ agent = str(ctx.get("agent") or "")
732
+ thread = str(ctx.get("thread_id") or "")
733
+ run_id = str(ctx.get("run_id") or "")
734
+ # shorten very long run_ids in the footer to keep things readable
735
+ run_id_short = f"{run_id[:8]}" if run_id else ""
736
+ started = _fmt_iso_pretty(str(ctx.get("started_at") or ""))
737
+ ended = _fmt_iso_pretty(str(ctx.get("ended_at") or ""))
738
+
739
+ header = " : ".join([p for p in [agent, thread] if p]) or ""
740
+
741
+ # ---------------- layout: reserve dynamic space for 2–3 footer lines -------
742
+ show_warn = bool(window_seconds > 0 and llm_seconds > window_seconds)
743
+ footer_lines = [
744
+ f"LLM-active (sum): {llm_seconds:.3f}s • window: {window_seconds:.3f}s"
745
+ ]
746
+ if show_warn:
747
+ pf = llm_seconds / window_seconds
748
+ footer_lines.append(
749
+ f"Note: LLM sum exceeds window → parallel LLM work (~{pf:.2f}× overlap)."
750
+ )
751
+ if run_id_short or started or ended:
752
+ right = " ".join(
753
+ s
754
+ for s in [
755
+ f"run_id: {run_id_short}" if run_id_short else "",
756
+ f"{started} \u2192 {ended}" if (started or ended) else "",
757
+ ]
758
+ if s
759
+ )
760
+ footer_lines.append(right)
761
+
762
+ n_footer = len(footer_lines)
763
+ base_h = 2.8
764
+ h = base_h + 0.35 * max(0, n_footer - 2)
765
+
766
+ # widen the figure a bit to make room for an outside legend
767
+ fig = plt.figure(figsize=(10.0, h))
768
+
769
+ # Leave a right margin (legend will sit outside on the right)
770
+ ax_left, ax_width = 0.07, 0.75 # was 0.08, 0.84
771
+ ax_bottom = 0.24 + 0.05 * max(0, n_footer - 2)
772
+ ax_height = 0.58
773
+ ax = fig.add_axes([ax_left, ax_bottom, ax_width, ax_height])
774
+
775
+ x = np.arange(len(labels))
776
+ width = 0.38
777
+
778
+ bars1 = ax.bar(
779
+ x - width / 2,
780
+ rates_llm,
781
+ width,
782
+ label="per LLM-sec (sum)",
783
+ color="tab:purple",
784
+ edgecolor="black",
785
+ )
786
+ bars2 = ax.bar(
787
+ x + width / 2,
788
+ rates_win,
789
+ width,
790
+ label="per thread-sec",
791
+ color="tab:gray",
792
+ edgecolor="black",
793
+ )
794
+
795
+ ax.set_xticks(x)
796
+ ax.set_xticklabels(labels, rotation=0)
797
+ ax.set_ylabel("tokens / second", fontsize="small")
798
+ ax.tick_params(labelsize="small")
799
+ ax.legend(
800
+ loc="upper left",
801
+ bbox_to_anchor=(1.005, 1.0), # just outside the axes, top-right corner
802
+ borderaxespad=0.0,
803
+ frameon=False,
804
+ fontsize="small",
805
+ )
806
+
807
+ # annotate bars (top center)
808
+ def _annotate(bars):
809
+ for b in bars:
810
+ h = b.get_height()
811
+ ax.text(
812
+ b.get_x() + b.get_width() / 2,
813
+ h,
814
+ f"{h:.2f}",
815
+ ha="center",
816
+ va="bottom",
817
+ fontsize="x-small",
818
+ )
819
+
820
+ _annotate(bars1)
821
+ _annotate(bars2)
822
+
823
+ # header + subtitle
824
+ header_y = 0.985
825
+ subtitle_y = 0.955
826
+ if header:
827
+ fig.text(0.5, header_y, header, ha="center", va="top")
828
+ fig.text(
829
+ 0.5, subtitle_y, title, ha="center", va="top", fontsize="small"
830
+ )
831
+ else:
832
+ fig.text(0.5, (header_y - 0.015), title, ha="center", va="top")
833
+
834
+ # footer stack (unchanged)
835
+ y0, step = 0.09, 0.035
836
+ for i, line in enumerate(footer_lines):
837
+ fig.text(
838
+ 0.5,
839
+ y0 + i * step,
840
+ line,
841
+ ha="center",
842
+ va="center",
843
+ fontsize="x-small",
844
+ )
845
+
846
+ fig.savefig(out_path, dpi=150, bbox_inches="tight", pad_inches=0.03)
847
+ plt.close(fig)
848
+ return out_path
849
+
850
+
851
+ def plot_tokens_bar_by_model(
852
+ totals_by_model: Dict[str, Dict[str, int]],
853
+ out_path: str,
854
+ *,
855
+ title: str = "LLM Token Totals by Category — by model",
856
+ context: dict | None = None,
857
+ ) -> str:
858
+ import matplotlib.pyplot as plt
859
+
860
+ cats = [
861
+ "input_tokens",
862
+ "output_tokens",
863
+ "reasoning_tokens",
864
+ "cached_tokens",
865
+ "total_tokens",
866
+ ]
867
+ labels = [c.replace("_", " ") for c in cats]
868
+ models = sorted(totals_by_model.keys())
869
+ if not models:
870
+ raise ValueError("No models to plot.")
871
+
872
+ rows = len(models)
873
+ fig_h = 2.2 + 2.2 * rows
874
+ fig, axes = plt.subplots(rows, 1, figsize=(12, fig_h), sharex=False)
875
+ if rows == 1:
876
+ axes = [axes]
877
+
878
+ for i, model in enumerate(models):
879
+ ax = axes[i]
880
+ vals = [
881
+ int(totals_by_model.get(model, {}).get(k, 0) or 0) for k in cats
882
+ ]
883
+ bars = ax.barh(labels, vals, edgecolor="black")
884
+ for b, v in zip(bars, vals):
885
+ ax.text(
886
+ b.get_width(),
887
+ b.get_y() + b.get_height() / 2,
888
+ f" {v:,}",
889
+ va="center",
890
+ ha="left",
891
+ fontsize=9,
892
+ )
893
+ ax.set_title(model, loc="left", fontsize=11)
894
+ ax.set_xlabel("Tokens")
895
+
896
+ # Header/subtitle/footer
897
+ if context:
898
+ header = (
899
+ f"{context.get('agent', '')}: {context.get('thread_id', '')}".strip(
900
+ " :"
901
+ )
902
+ )
903
+ if header:
904
+ fig.suptitle(header, fontsize=14, y=0.98)
905
+ if title:
906
+ fig.text(0.5, 0.94, title, ha="center", fontsize=12)
907
+ if context:
908
+ # Hide global start→end for SUPER aggregates (not meaningful across overlapping threads)
909
+ if (context.get("agent") or "").upper() != "SUPER":
910
+ s, e = context.get("started_at"), context.get("ended_at")
911
+ if s and e:
912
+ fig.text(0.5, 0.02, f"{s} → {e}", ha="center", fontsize=9)
913
+
914
+ fig.tight_layout(rect=(0.02, 0.06, 1, 0.92))
915
+ fig.savefig(out_path, dpi=150, bbox_inches="tight")
916
+ plt.close(fig)
917
+ return out_path
918
+
919
+
920
+ def plot_token_rates_by_model(
921
+ totals_by_model: Dict[str, Dict[str, int]],
922
+ llm_seconds_by_model: Dict[str, float],
923
+ window_seconds: float,
924
+ out_path: str,
925
+ *,
926
+ title: str = "Tokens per second — by model",
927
+ context: dict | None = None,
928
+ ) -> str:
929
+ import matplotlib.pyplot as plt
930
+ import numpy as np
931
+
932
+ cats = [
933
+ "input_tokens",
934
+ "output_tokens",
935
+ "reasoning_tokens",
936
+ "cached_tokens",
937
+ "total_tokens",
938
+ ]
939
+ xlabels = [c.replace("_", " ") for c in cats]
940
+ models = sorted(totals_by_model.keys())
941
+ if not models:
942
+ raise ValueError("No models to plot.")
943
+
944
+ rows = len(models)
945
+ fig_h = 2.4 + 2.4 * rows
946
+ fig, axes = plt.subplots(rows, 1, figsize=(12, fig_h), sharex=True)
947
+ if rows == 1:
948
+ axes = [axes]
949
+
950
+ x = np.arange(len(cats))
951
+ w = 0.38
952
+ denom_thread = max(1e-9, float(window_seconds or 0.0))
953
+
954
+ for i, model in enumerate(models):
955
+ ax = axes[i]
956
+ totals = [
957
+ int(totals_by_model.get(model, {}).get(k, 0) or 0) for k in cats
958
+ ]
959
+ denom_llm = max(
960
+ 1e-9, float(llm_seconds_by_model.get(model, 0.0) or 0.0)
961
+ )
962
+
963
+ per_llm = [v / denom_llm for v in totals]
964
+ per_thread = [v / denom_thread for v in totals]
965
+
966
+ b1 = ax.bar(
967
+ x - w / 2,
968
+ per_llm,
969
+ width=w,
970
+ label="per LLM-sec (sum)",
971
+ edgecolor="black",
972
+ )
973
+ b2 = ax.bar(
974
+ x + w / 2,
975
+ per_thread,
976
+ width=w,
977
+ label="per thread-sec",
978
+ edgecolor="black",
979
+ )
980
+
981
+ for bx in (b1, b2):
982
+ for rect in bx:
983
+ h = rect.get_height()
984
+ if h > 0:
985
+ ax.text(
986
+ rect.get_x() + rect.get_width() / 2,
987
+ h,
988
+ f"{h:,.2f}",
989
+ ha="center",
990
+ va="bottom",
991
+ fontsize=8,
992
+ )
993
+
994
+ ax.set_xticks(x, xlabels)
995
+ ax.set_ylabel("tokens / second")
996
+ ax.set_title(model, loc="left", fontsize=11)
997
+ if i == 0:
998
+ ax.legend(loc="upper right")
999
+
1000
+ # Header/subtitle/footer
1001
+ if context:
1002
+ header = (
1003
+ f"{context.get('agent', '')}: {context.get('thread_id', '')}".strip(
1004
+ " :"
1005
+ )
1006
+ )
1007
+ if header:
1008
+ fig.suptitle(header, fontsize=14, y=0.98)
1009
+ if title:
1010
+ fig.text(0.5, 0.94, title, ha="center", fontsize=12)
1011
+
1012
+ if context:
1013
+ is_super = str(context.get("agent") or "") == "SUPER"
1014
+
1015
+ s, e = context.get("started_at"), context.get("ended_at")
1016
+ llm_sum = sum(
1017
+ float(llm_seconds_by_model.get(m, 0.0) or 0.0) for m in models
1018
+ )
1019
+
1020
+ # Only show overlap note when we are actually showing a single window denominator
1021
+ overlap_note = ""
1022
+ if (
1023
+ (not is_super)
1024
+ and window_seconds > 0
1025
+ and llm_sum > window_seconds * 1.05
1026
+ ):
1027
+ overlap_note = f" Note: LLM sum exceeds window → parallel LLM work (~{llm_sum / window_seconds:.2f}× overlap)."
1028
+
1029
+ # Hide the global start→end line for SUPER aggregates
1030
+ if (not is_super) and s and e:
1031
+ fig.text(0.5, 0.04, f"{s} → {e}", ha="center", fontsize=9)
1032
+
1033
+ # For SUPER, drop the “• window: …” part entirely
1034
+ if is_super:
1035
+ footer = f"LLM-active (sum across models): {llm_sum:.3f}s"
1036
+ else:
1037
+ footer = f"LLM-active (sum across models): {llm_sum:.3f}s • window: {window_seconds:.3f}s{overlap_note}"
1038
+
1039
+ fig.text(0.5, 0.02, footer, ha="center", fontsize=9)
1040
+
1041
+ fig.tight_layout(rect=(0.02, 0.08, 1, 0.92))
1042
+ fig.savefig(out_path, dpi=150, bbox_inches="tight")
1043
+ plt.close(fig)
1044
+ return out_path
1045
+
1046
+
1047
+ def plot_tokens_by_agent_stacked(
1048
+ totals_by_agent: Dict[str, Dict[str, int]],
1049
+ out_path: str,
1050
+ *,
1051
+ title: str = "LLM Token Totals by Agent (thread)",
1052
+ footer_lines: List[str] | None = None,
1053
+ ) -> str:
1054
+ import matplotlib.pyplot as plt
1055
+ import numpy as np
1056
+
1057
+ cats = [
1058
+ "input_tokens",
1059
+ "output_tokens",
1060
+ "reasoning_tokens",
1061
+ "cached_tokens",
1062
+ ]
1063
+ pretty = {
1064
+ "input_tokens": "input tokens",
1065
+ "output_tokens": "output tokens",
1066
+ "reasoning_tokens": "reasoning tokens",
1067
+ "cached_tokens": "cached tokens",
1068
+ }
1069
+
1070
+ if not totals_by_agent:
1071
+ fig = plt.figure(figsize=(10, 2.0))
1072
+ fig.text(0.5, 0.5, "No data", ha="center", va="center")
1073
+ fig.savefig(out_path, dpi=150, bbox_inches="tight")
1074
+ plt.close(fig)
1075
+ return out_path
1076
+
1077
+ agents = sorted(
1078
+ totals_by_agent.keys(),
1079
+ key=lambda a: sum(int(totals_by_agent[a].get(k, 0) or 0) for k in cats),
1080
+ reverse=True,
1081
+ )
1082
+ x = np.arange(len(agents))
1083
+ width = 0.65
1084
+
1085
+ # Build stacked series
1086
+ series = []
1087
+ for k in cats:
1088
+ series.append([
1089
+ int(totals_by_agent.get(a, {}).get(k, 0) or 0) for a in agents
1090
+ ])
1091
+
1092
+ totals_per_agent = [sum(vals) for vals in zip(*series)]
1093
+ max_total = max(totals_per_agent) if totals_per_agent else 0
1094
+
1095
+ fig_h = 3.0 + 0.18 * max(0, len(agents) - 6)
1096
+ fig = plt.figure(figsize=(max(10, 1.2 * len(agents)), fig_h))
1097
+ ax = fig.add_axes([0.08, 0.28, 0.78, 0.60]) # leave right margin for legend
1098
+
1099
+ bottoms = np.zeros(len(agents), dtype=float)
1100
+ bars_all = []
1101
+ for k, vals in zip(cats, series):
1102
+ b = ax.bar(
1103
+ x,
1104
+ vals,
1105
+ width=width,
1106
+ bottom=bottoms,
1107
+ edgecolor="black",
1108
+ label=pretty[k],
1109
+ )
1110
+ bars_all.append(b)
1111
+ bottoms = bottoms + np.array(vals, dtype=float)
1112
+
1113
+ # Axis & labels
1114
+ ax.set_xticks(x)
1115
+ ax.set_xticklabels(agents, rotation=28, ha="right")
1116
+ ax.set_ylabel("Tokens", fontsize="small")
1117
+ ax.tick_params(labelsize="small")
1118
+ ax.set_ylim(0, max_total * 1.08 if max_total > 0 else 1.0)
1119
+
1120
+ # Annotate totals on top of each stack (only if > 0)
1121
+ for xi, tot in enumerate(totals_per_agent):
1122
+ if tot > 0:
1123
+ ax.text(
1124
+ xi,
1125
+ tot,
1126
+ f"{tot:,}",
1127
+ ha="center",
1128
+ va="bottom",
1129
+ fontsize="x-small",
1130
+ )
1131
+
1132
+ # Legend outside
1133
+ ax.legend(
1134
+ loc="upper left",
1135
+ bbox_to_anchor=(1.005, 1.0),
1136
+ frameon=False,
1137
+ fontsize="small",
1138
+ )
1139
+
1140
+ # Title / footer
1141
+ fig.suptitle(title, y=0.97)
1142
+ y0 = 0.08
1143
+ if footer_lines:
1144
+ for i, line in enumerate(footer_lines):
1145
+ fig.text(
1146
+ 0.5,
1147
+ y0 + i * 0.035,
1148
+ line,
1149
+ ha="center",
1150
+ va="center",
1151
+ fontsize="x-small",
1152
+ )
1153
+
1154
+ fig.savefig(out_path, dpi=150, bbox_inches="tight")
1155
+ plt.close(fig)
1156
+ return out_path
1157
+
1158
+
1159
+ def plot_tps_by_agent_grouped(
1160
+ totals_by_agent: Dict[str, Dict[str, int]],
1161
+ llm_secs_by_agent: Dict[str, float],
1162
+ thread_window_seconds: float,
1163
+ out_path: str,
1164
+ *,
1165
+ title: str = "Tokens per second by Agent (thread)",
1166
+ footer_lines: List[str] | None = None,
1167
+ ) -> str:
1168
+ import matplotlib.pyplot as plt
1169
+ import numpy as np
1170
+
1171
+ cats = [
1172
+ "input_tokens",
1173
+ "output_tokens",
1174
+ "reasoning_tokens",
1175
+ "cached_tokens",
1176
+ ]
1177
+
1178
+ if not totals_by_agent:
1179
+ fig = plt.figure(figsize=(10, 2.0))
1180
+ fig.text(0.5, 0.5, "No data", ha="center", va="center")
1181
+ fig.savefig(out_path, dpi=150, bbox_inches="tight")
1182
+ plt.close(fig)
1183
+ return out_path
1184
+
1185
+ # Order agents by total tokens desc
1186
+ agents = sorted(
1187
+ totals_by_agent.keys(),
1188
+ key=lambda a: sum(int(totals_by_agent[a].get(k, 0) or 0) for k in cats),
1189
+ reverse=True,
1190
+ )
1191
+
1192
+ def _rate_sum(agent: str, denom: float) -> List[float]:
1193
+ denom = float(denom or 0.0)
1194
+ vals = [
1195
+ int(totals_by_agent.get(agent, {}).get(k, 0) or 0) for k in cats
1196
+ ]
1197
+ if denom <= 0:
1198
+ return [0.0 for _ in vals]
1199
+ return [v / denom for v in vals]
1200
+
1201
+ per_llm = [_rate_sum(a, llm_secs_by_agent.get(a, 0.0)) for a in agents]
1202
+ per_thr = [_rate_sum(a, thread_window_seconds) for a in agents]
1203
+
1204
+ # Prepare bars
1205
+ n_agents = len(agents)
1206
+ x = np.arange(n_agents)
1207
+ width = 0.38
1208
+
1209
+ fig_h = 3.0 + 0.18 * max(0, n_agents - 6)
1210
+ fig = plt.figure(figsize=(max(12, 1.3 * n_agents), fig_h))
1211
+ ax = fig.add_axes([0.08, 0.28, 0.78, 0.60])
1212
+
1213
+ # For grouped by category, build sums per agent for each baseline
1214
+ # We’ll draw two bars per agent per category by offsetting x positions.
1215
+ # But for readability, we instead draw two *stacks* per agent: one for LLM-sec and one for thread-sec.
1216
+ # Build totals per agent for each baseline across categories.
1217
+ # Here we keep the same layout as your other TPS chart: two bars per category,
1218
+ # but along the agent axis it’s clearer to sum categories, so we’ll show TOTAL only.
1219
+ # If you want per-category grouped bars per agent, ping me and I’ll switch layout.
1220
+
1221
+ # -> Simpler: show TOTAL tokens/sec per agent (two baselines).
1222
+ per_llm_total = [sum(vals) for vals in per_llm]
1223
+ per_thr_total = [sum(vals) for vals in per_thr]
1224
+
1225
+ b1 = ax.bar(
1226
+ x - width / 2,
1227
+ per_llm_total,
1228
+ width=width,
1229
+ label="per LLM-sec (sum)",
1230
+ edgecolor="black",
1231
+ )
1232
+ b2 = ax.bar(
1233
+ x + width / 2,
1234
+ per_thr_total,
1235
+ width=width,
1236
+ label="per thread-sec",
1237
+ edgecolor="black",
1238
+ )
1239
+
1240
+ for bars in (b1, b2):
1241
+ for rect in bars:
1242
+ h = rect.get_height()
1243
+ if h > 0:
1244
+ ax.text(
1245
+ rect.get_x() + rect.get_width() / 2,
1246
+ h,
1247
+ f"{h:,.2f}",
1248
+ ha="center",
1249
+ va="bottom",
1250
+ fontsize="x-small",
1251
+ )
1252
+
1253
+ ax.set_xticks(x)
1254
+ ax.set_xticklabels(agents, rotation=28, ha="right")
1255
+ ax.set_ylabel("tokens / second", fontsize="small")
1256
+ ax.tick_params(labelsize="small")
1257
+ ax.legend(
1258
+ loc="upper left",
1259
+ bbox_to_anchor=(1.005, 1.0),
1260
+ frameon=False,
1261
+ fontsize="small",
1262
+ )
1263
+
1264
+ fig.suptitle(title, y=0.97)
1265
+ y0 = 0.08
1266
+ if footer_lines:
1267
+ for i, line in enumerate(footer_lines):
1268
+ fig.text(
1269
+ 0.5,
1270
+ y0 + i * 0.035,
1271
+ line,
1272
+ ha="center",
1273
+ va="center",
1274
+ fontsize="x-small",
1275
+ )
1276
+
1277
+ fig.savefig(out_path, dpi=150, bbox_inches="tight")
1278
+ plt.close(fig)
1279
+ return out_path