ursa-ai 0.4.2__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ursa-ai might be problematic. Click here for more details.

@@ -0,0 +1,1441 @@
1
+ # ursa/observability/timing.py
2
+ from __future__ import annotations
3
+
4
+ import collections
5
+ import datetime
6
+ import importlib
7
+ import json
8
+ import os
9
+ import re
10
+ import time
11
+ import uuid
12
+ from collections import defaultdict
13
+ from dataclasses import dataclass, field
14
+ from functools import wraps
15
+ from threading import Lock
16
+ from typing import Any, Callable, Dict, Iterable, List, Tuple
17
+
18
+ from langchain_core.callbacks import BaseCallbackHandler
19
+ from rich import get_console
20
+ from rich.box import HEAVY
21
+ from rich.console import Group
22
+ from rich.panel import Panel
23
+ from rich.rule import Rule
24
+ from rich.table import Table
25
+ from rich.text import Text
26
+
27
+ NAME_W, COUNT_W, TOTAL_W, AVG_W, MAX_W = 30, 7, 12, 12, 12
28
+ COL_PAD = (0, 1) # top/bottom, left/right padding in the Rich table cells
29
+
30
+
31
+ def _get_pricing_module():
32
+ candidates = (
33
+ "ursa.observability.pricing",
34
+ "ursa.observability.llm_pricing",
35
+ "pricing",
36
+ "llm_pricing",
37
+ )
38
+ for name in candidates:
39
+ try:
40
+ mod = importlib.import_module(name)
41
+ if hasattr(mod, "load_registry") and hasattr(mod, "price_payload"):
42
+ return mod
43
+ except Exception:
44
+ continue
45
+ return None
46
+
47
+
48
+ def _to_snake(s: str) -> str:
49
+ s = str(s)
50
+ s = re.sub(r"(?<!^)(?=[A-Z])", "_", s) # CamelCase -> snake_case
51
+ s = s.replace("-", "_").replace(" ", "_")
52
+ return s.lower()
53
+
54
+
55
+ _SESSIONS: dict[str, "SessionRollup"] = {}
56
+
57
+
58
+ @dataclass
59
+ class _Bucket:
60
+ count: int = 0
61
+ total_ms: float = 0.0
62
+ max_ms: float = 0.0
63
+
64
+ def add(self, count: int, total_s: float, max_ms: float):
65
+ self.count += int(count or 0)
66
+ self.total_ms += float(total_s or 0.0) * 1000.0
67
+ self.max_ms = max(self.max_ms, float(max_ms or 0.0))
68
+
69
+ def as_row(self, name: str):
70
+ avg_ms = (self.total_ms / self.count) if self.count else 0.0
71
+ return (name, self.count, self.total_ms / 1000.0, avg_ms, self.max_ms)
72
+
73
+
74
+ @dataclass
75
+ class SessionRollup:
76
+ thread_id: str
77
+ runs: int = 0
78
+ agents: set = field(default_factory=set)
79
+
80
+ # times/costs
81
+ wall_sum_s: float = 0.0 # sum of each run's wall time
82
+ llm_total_s: float = 0.0 # derived from llm buckets
83
+ tool_total_s: float = 0.0 # derived from tool buckets
84
+ cost_total_usd: float = 0.0
85
+
86
+ # breakdowns
87
+ runnable_by_name: dict[str, _Bucket] = field(
88
+ default_factory=lambda: defaultdict(_Bucket)
89
+ )
90
+ tool_by_name: dict[str, _Bucket] = field(
91
+ default_factory=lambda: defaultdict(_Bucket)
92
+ )
93
+ llm_by_name: dict[str, _Bucket] = field(
94
+ default_factory=lambda: defaultdict(_Bucket)
95
+ )
96
+ cost_by_model_usd: dict[str, float] = field(
97
+ default_factory=lambda: defaultdict(float)
98
+ )
99
+
100
+ # temporal bounds
101
+ started_at: str | None = None
102
+ ended_at: str | None = None
103
+
104
+ def ingest(self, payload: dict) -> None:
105
+ from datetime import datetime
106
+
107
+ def p(ts):
108
+ try:
109
+ return datetime.fromisoformat((ts or "").replace("Z", "+00:00"))
110
+ except Exception:
111
+ return None
112
+
113
+ ctx = payload.get("context") or {}
114
+ agent = ctx.get("agent") or "agent"
115
+ s_iso, e_iso = ctx.get("started_at"), ctx.get("ended_at")
116
+ s_dt, e_dt = p(s_iso), p(e_iso)
117
+
118
+ self.runs += 1
119
+ self.agents.add(agent)
120
+
121
+ # wall time: sum of each run; keep overall min/max for elapsed
122
+ if s_dt and e_dt:
123
+ self.wall_sum_s += max(0.0, (e_dt - s_dt).total_seconds())
124
+ if not self.started_at or (
125
+ p(self.started_at) and s_dt < p(self.started_at)
126
+ ):
127
+ self.started_at = s_iso
128
+ if not self.ended_at or (
129
+ p(self.ended_at) and e_dt > p(self.ended_at)
130
+ ):
131
+ self.ended_at = e_iso
132
+
133
+ # aggregate tables
134
+ tables = payload.get("tables") or {}
135
+ for row in tables.get("runnable") or []:
136
+ self.runnable_by_name[row["name"]].add(
137
+ row["count"], row["total_s"], row["max_ms"]
138
+ )
139
+ for row in tables.get("tool") or []:
140
+ self.tool_by_name[row["name"]].add(
141
+ row["count"], row["total_s"], row["max_ms"]
142
+ )
143
+ for row in tables.get("llm") or []:
144
+ self.llm_by_name[row["name"]].add(
145
+ row["count"], row["total_s"], row["max_ms"]
146
+ )
147
+
148
+ # recompute llm/tool totals from buckets
149
+ self.llm_total_s = (
150
+ sum(b.total_ms for b in self.llm_by_name.values()) / 1000.0
151
+ )
152
+ self.tool_total_s = (
153
+ sum(b.total_ms for b in self.tool_by_name.values()) / 1000.0
154
+ )
155
+
156
+ # costs (if priced)
157
+ costs = payload.get("costs") or {}
158
+ self.cost_total_usd += float(costs.get("total_usd") or 0.0)
159
+ for model, amt in (costs.get("by_model_usd") or {}).items():
160
+ try:
161
+ self.cost_by_model_usd[model] += float(amt)
162
+ except Exception:
163
+ pass
164
+
165
+
166
+ def _session_ingest(payload: dict) -> None:
167
+ tid = (payload.get("context") or {}).get("thread_id")
168
+ if not tid:
169
+ return
170
+ _SESSIONS.setdefault(tid, SessionRollup(thread_id=tid)).ingest(payload)
171
+
172
+
173
+ def _rows_from_bucket_map(
174
+ d: dict[str, _Bucket],
175
+ ) -> list[tuple[str, int, float, float, float]]:
176
+ rows = [b.as_row(name) for name, b in d.items()]
177
+ rows.sort(key=lambda r: r[2], reverse=True) # sort by total(s)
178
+ return rows
179
+
180
+
181
+ def render_session_summary(thread_id: str):
182
+ roll = _SESSIONS.get(thread_id)
183
+ console = get_console()
184
+ if not roll:
185
+ msg = f"No session data for thread_id '{thread_id}'."
186
+ console.print(
187
+ Panel(msg, title="[bold]Session Summary[/]", border_style="red")
188
+ )
189
+ return msg
190
+
191
+ # header
192
+ header_lines = []
193
+ agents_list = ", ".join(sorted(roll.agents)) or "—"
194
+ header_lines.append(
195
+ f"[bold magenta]Session[/] • thread [bold]{thread_id}[/] [dim]• runs {roll.runs} • agents {len(roll.agents)}[/]"
196
+ )
197
+ # both elapsed window and sum of runs
198
+ elapsed = None
199
+ if roll.started_at and roll.ended_at:
200
+ header_lines.append(f"[dim]{roll.started_at} → {roll.ended_at}[/dim]")
201
+ # display elapsed in panel footer text (computing here)
202
+ try:
203
+ from datetime import datetime
204
+
205
+ s, e = (
206
+ datetime.fromisoformat(roll.started_at.replace("Z", "+00:00")),
207
+ datetime.fromisoformat(roll.ended_at.replace("Z", "+00:00")),
208
+ )
209
+ elapsed = max(0.0, (e - s).total_seconds())
210
+ except Exception:
211
+ elapsed = None
212
+ if elapsed is not None:
213
+ header_lines[-1] += (
214
+ f" [bold]wall (elapsed)[/]: {elapsed:,.2f}s [bold]wall (sum)[/]: {roll.wall_sum_s:,.2f}s"
215
+ )
216
+ else:
217
+ header_lines.append(f"[bold]wall (sum)[/]: {roll.wall_sum_s:,.2f}s")
218
+
219
+ # combined tables (aligned widths)
220
+ t_nodes = _mk_table(
221
+ "Per-Node / Runnable Timing (session)",
222
+ _rows_from_bucket_map(roll.runnable_by_name),
223
+ )
224
+ t_tools = _mk_table(
225
+ "Per-Tool Timing (session)", _rows_from_bucket_map(roll.tool_by_name)
226
+ )
227
+ t_llms = _mk_table(
228
+ "Per-LLM Timing (session)", _rows_from_bucket_map(roll.llm_by_name)
229
+ )
230
+
231
+ # cost-by-model table (aligned with a smaller schema)
232
+ from rich.table import Table
233
+
234
+ t_cost = Table(
235
+ title="Cost by Model (USD)",
236
+ title_style="bold white",
237
+ box=HEAVY,
238
+ expand=False,
239
+ pad_edge=False,
240
+ header_style="bold",
241
+ padding=COL_PAD,
242
+ )
243
+ t_cost.add_column(
244
+ "Model",
245
+ style="cyan",
246
+ no_wrap=True,
247
+ width=NAME_W,
248
+ min_width=NAME_W,
249
+ max_width=NAME_W,
250
+ )
251
+ t_cost.add_column(
252
+ "Cost",
253
+ justify="right",
254
+ width=TOTAL_W,
255
+ min_width=TOTAL_W,
256
+ max_width=TOTAL_W,
257
+ )
258
+ if roll.cost_by_model_usd:
259
+ for model, amt in sorted(
260
+ roll.cost_by_model_usd.items(), key=lambda kv: kv[1], reverse=True
261
+ ):
262
+ t_cost.add_row(model, f"${amt:,.6f}")
263
+ else:
264
+ t_cost.add_row("—", "$0.000000")
265
+
266
+ # attribution block
267
+ attrib = [
268
+ "[bold]Session Totals[/]",
269
+ f" LLM total: {roll.llm_total_s:,.2f}s",
270
+ f" Tool total: {roll.tool_total_s:,.2f}s",
271
+ (f" Wall (elapsed): {elapsed:,.2f}s" if elapsed is not None else None),
272
+ f" Wall (sum): {roll.wall_sum_s:,.2f}s",
273
+ f"[bold]Cost total:[/] [bold green]${roll.cost_total_usd:,.6f}[/]",
274
+ f"[dim]Agents:[/] {agents_list}",
275
+ ]
276
+ attrib = [a for a in attrib if a is not None]
277
+
278
+ renderables = [
279
+ Text.from_markup("\n".join(header_lines)),
280
+ Rule(),
281
+ t_nodes,
282
+ t_tools,
283
+ t_llms,
284
+ Rule(),
285
+ t_cost,
286
+ Rule(),
287
+ Text.from_markup("\n".join(attrib)),
288
+ ]
289
+ panel = Panel.fit(
290
+ Group(*renderables),
291
+ title=f"[bold white]Session Summary[/] • [cyan]{thread_id}[/]",
292
+ border_style="bright_magenta",
293
+ padding=(1, 2),
294
+ box=HEAVY,
295
+ )
296
+ console.print(panel)
297
+
298
+
299
+ # ---------------------------
300
+ # Aggregators
301
+ # ---------------------------
302
+
303
+
304
+ @dataclass
305
+ class _Agg:
306
+ # list of (name, elapsed_ms, ok)
307
+ records: List[Tuple[str, float, bool]] = field(default_factory=list)
308
+ _lock: Lock = field(default_factory=Lock, repr=False)
309
+
310
+ def add(self, name: str, elapsed_ms: float, ok: bool) -> None:
311
+ with self._lock:
312
+ self.records.append((name, elapsed_ms, ok))
313
+
314
+ def buckets(self) -> List[Tuple[str, int, float, float, float]]:
315
+ # -> [(name, count, total_secs, avg_ms, max_ms)]
316
+ by_name: Dict[str, List[float]] = defaultdict(list)
317
+ with self._lock:
318
+ for name, ms, _ok in self.records:
319
+ by_name[name].append(ms)
320
+ rows = []
321
+ for name, times in by_name.items():
322
+ total_ms = sum(times)
323
+ rows.append((
324
+ name,
325
+ len(times),
326
+ total_ms / 1000.0,
327
+ total_ms / len(times),
328
+ max(times),
329
+ ))
330
+ rows.sort(key=lambda r: r[2], reverse=True) # by total seconds
331
+ return rows
332
+
333
+
334
+ # ---------------------------------
335
+ # Callback Handlers
336
+ # ---------------------------------
337
+
338
+
339
+ class PerToolTimer(BaseCallbackHandler):
340
+ """Times each tool call via callbacks; robust to decorator order."""
341
+
342
+ def __init__(self, agg: _Agg | None = None):
343
+ self.agg = agg or _Agg()
344
+ self._starts: Dict[Any, Tuple[str, float]] = {}
345
+
346
+ def _name(self, serialized) -> str:
347
+ # serialized can be None, dict, str, or contain nested ids
348
+ if isinstance(serialized, dict):
349
+ name = serialized.get("name")
350
+ if not name:
351
+ sid = serialized.get("id")
352
+ if isinstance(sid, dict):
353
+ name = sid.get("name") or sid.get("id")
354
+ elif isinstance(sid, (list, tuple)):
355
+ name = "/".join(map(str, sid))
356
+ elif sid is not None:
357
+ name = str(sid)
358
+ return name or "unknown_tool"
359
+ if isinstance(serialized, str):
360
+ return serialized
361
+ return "unknown_tool"
362
+
363
+ def on_tool_start(self, serialized, input_str, *, run_id, **kwargs):
364
+ name = self._name(serialized)
365
+ self._starts[run_id] = (name, time.perf_counter())
366
+
367
+ def on_tool_end(self, output, *, run_id, **kwargs):
368
+ name, t0 = self._starts.pop(
369
+ run_id, ("unknown_tool", time.perf_counter())
370
+ )
371
+ self.agg.add(name, (time.perf_counter() - t0) * 1000.0, True)
372
+
373
+ def on_tool_error(self, error, *, run_id, **kwargs):
374
+ name, t0 = self._starts.pop(
375
+ run_id, ("unknown_tool", time.perf_counter())
376
+ )
377
+ self.agg.add(name, (time.perf_counter() - t0) * 1000.0, False)
378
+
379
+
380
+ class PerRunnableTimer(BaseCallbackHandler):
381
+ """
382
+ Times LangChain/LangGraph runnables (chains, graphs, nodes). You’ll usually
383
+ see node names in `serialized.get('name')` or `serialized.get('id')`.
384
+ """
385
+
386
+ def __init__(self, agg: _Agg | None = None):
387
+ self.agg = agg or _Agg()
388
+ self._starts: Dict[Any, Tuple[str, float]] = {}
389
+
390
+ def _name(self, serialized) -> str:
391
+ # serialized can be None, dict, str, or contain nested ids
392
+ if isinstance(serialized, dict):
393
+ name = serialized.get("name")
394
+ if not name:
395
+ sid = serialized.get("id")
396
+ if isinstance(sid, dict):
397
+ name = sid.get("name") or sid.get("id")
398
+ elif isinstance(sid, (list, tuple)):
399
+ name = "/".join(map(str, sid))
400
+ elif sid is not None:
401
+ name = str(sid)
402
+ return name or "runnable"
403
+ if isinstance(serialized, str):
404
+ return serialized
405
+ return "runnable"
406
+
407
+ # Chains/graphs/nodes map onto these events:
408
+ def on_chain_start(
409
+ self,
410
+ serialized,
411
+ inputs,
412
+ *,
413
+ run_id,
414
+ parent_run_id=None,
415
+ tags=None,
416
+ metadata=None,
417
+ **kwargs,
418
+ ):
419
+ base_name = self._name(serialized)
420
+
421
+ # Root span (keep)
422
+ if parent_run_id is None:
423
+ name = base_name
424
+ if name == "runnable" and tags:
425
+ name = tags[-1] # e.g., "graph"
426
+ name = f"graph:{name}"
427
+ self._starts[run_id] = (name, time.perf_counter())
428
+ return
429
+
430
+ # ---- Child span (graph node) ----
431
+ md = metadata if isinstance(metadata, dict) else {}
432
+
433
+ # Only keep spans that our wrapper marked with a namespace.
434
+ # This filters out internal 'graph:step:N:<node>' duplicates.
435
+ ns = md.get("ursa_ns")
436
+ if not ns:
437
+ return # ignore un-namespaced child spans
438
+
439
+ # node base name (prefer explicit metadata)
440
+ node_base = (
441
+ md.get("langgraph_node")
442
+ or md.get("node_name")
443
+ or md.get("langgraph:node")
444
+ or base_name
445
+ )
446
+
447
+ # canonicalize "graph:step:N:<node>" → "<node>"
448
+ if isinstance(node_base, str) and node_base.startswith("graph:step:"):
449
+ # split on last colon so "graph:step:N:<node>" → "<node>"
450
+ parts = node_base.split(":", 3)
451
+ if len(parts) == 4:
452
+ node_base = parts[3]
453
+
454
+ # namespace + snake casing for safety
455
+ def _to_snake(s: str) -> str:
456
+ import re
457
+
458
+ s = str(s)
459
+ s = re.sub(r"(?<!^)(?=[A-Z])", "_", s)
460
+ s = s.replace("-", "_").replace(" ", "_")
461
+ return s.lower()
462
+
463
+ ns = _to_snake(ns)
464
+ qualified = f"{ns}:{node_base}"
465
+ name = f"node:{qualified}"
466
+
467
+ self._starts[run_id] = (name, time.perf_counter())
468
+
469
+ def on_chain_end(self, outputs, *, run_id, **kwargs):
470
+ name, t0 = self._starts.pop(run_id, ("runnable", time.perf_counter()))
471
+ self.agg.add(name, (time.perf_counter() - t0) * 1000.0, True)
472
+
473
+ def on_chain_error(self, error, *, run_id, **kwargs):
474
+ name, t0 = self._starts.pop(run_id, ("runnable", time.perf_counter()))
475
+ self.agg.add(name, (time.perf_counter() - t0) * 1000.0, False)
476
+
477
+
478
+ def _to_int(x, default=0):
479
+ try:
480
+ if isinstance(x, (int,)):
481
+ return int(x)
482
+ if isinstance(x, float):
483
+ return int(x)
484
+ if isinstance(x, str):
485
+ # handles "340" or "340.0"
486
+ return int(float(x))
487
+ except Exception:
488
+ pass
489
+ return default
490
+
491
+
492
+ def _acc_from(d: dict, roll: dict):
493
+ # Map whatever keys exist into our canonical fields
494
+ it = _to_int(d.get("input_tokens", d.get("prompt_tokens")))
495
+ ot = _to_int(d.get("output_tokens", d.get("completion_tokens")))
496
+ tt = _to_int(d.get("total_tokens", it + ot))
497
+
498
+ roll["input_tokens"] += it
499
+ roll["output_tokens"] += ot
500
+ roll["total_tokens"] += tt
501
+
502
+ # Keep prompt/completion mirrors too
503
+ roll["prompt_tokens"] += _to_int(d.get("prompt_tokens", it))
504
+ roll["completion_tokens"] += _to_int(d.get("completion_tokens", ot))
505
+
506
+ # extras / synonyms
507
+ # reasoning
508
+ roll["reasoning_tokens"] += _to_int(
509
+ d.get("reasoning_tokens")
510
+ or (d.get("completion_tokens_details") or {}).get("reasoning_tokens")
511
+ )
512
+ # cached
513
+ cached = (
514
+ d.get("cached_tokens")
515
+ or d.get("cached_input_tokens")
516
+ or (d.get("prompt_tokens_details") or {}).get("cached_tokens")
517
+ or d.get("prompt_cache_hits")
518
+ )
519
+ roll["cached_tokens"] += _to_int(cached)
520
+
521
+ # costs if exposed (keep as floats)
522
+ for k in ("input_cost", "output_cost", "total_cost"):
523
+ v = d.get(k)
524
+ if v is not None:
525
+ try:
526
+ roll[k] += float(v)
527
+ except Exception:
528
+ pass
529
+
530
+
531
+ def _maybe_add_extras(d: dict, roll: dict):
532
+ if not isinstance(d, dict):
533
+ return
534
+ # reasoning
535
+ rt = d.get("reasoning_tokens") or (
536
+ d.get("completion_tokens_details") or {}
537
+ ).get("reasoning_tokens")
538
+ roll["reasoning_tokens"] += _to_int(rt)
539
+ # cached
540
+ cached = (
541
+ d.get("cached_tokens")
542
+ or d.get("cached_input_tokens")
543
+ or (d.get("prompt_tokens_details") or {}).get("cached_tokens")
544
+ or d.get("prompt_cache_hits")
545
+ )
546
+ roll["cached_tokens"] += _to_int(cached)
547
+
548
+
549
+ class PerLLMTimer(BaseCallbackHandler):
550
+ """Times LLM calls (chat/completions) and captures usage/metrics."""
551
+
552
+ def __init__(self, agg: _Agg | None = None, keep_max: int = 1000):
553
+ self.agg = agg or _Agg()
554
+ self._starts: Dict[Any, Tuple[str, float, list, dict]] = {}
555
+ self.samples: collections.deque = collections.deque(maxlen=keep_max)
556
+
557
+ def _name(self, serialized, metadata, tags) -> str:
558
+ model = (metadata or {}).get("model")
559
+ if model:
560
+ return f"llm:{model}"
561
+ if isinstance(serialized, dict):
562
+ name = serialized.get("name")
563
+ if not name:
564
+ sid = serialized.get("id")
565
+ if isinstance(sid, dict):
566
+ name = sid.get("name") or sid.get("id")
567
+ elif isinstance(sid, (list, tuple)):
568
+ name = "/".join(map(str, sid))
569
+ elif sid is not None:
570
+ name = str(sid)
571
+ return f"llm:{name or 'unknown'}"
572
+ if isinstance(serialized, str):
573
+ return f"llm:{serialized}"
574
+ if tags:
575
+ return f"llm:{tags[-1]}"
576
+ return "llm:unknown"
577
+
578
+ def on_llm_start(
579
+ self, serialized, prompts, *, run_id, tags=None, metadata=None, **kwargs
580
+ ):
581
+ name = self._name(serialized, metadata, tags)
582
+ self._starts[run_id] = (
583
+ name,
584
+ time.perf_counter(),
585
+ tags or [],
586
+ metadata or {},
587
+ )
588
+
589
+ def _extract_metrics(self, response) -> dict:
590
+ """
591
+ Aggregate usage/metadata from multiple providers into a consistent shape.
592
+ Priority for rollup: usage_metadata (if any) > response_metadata.token_usage > llm_output.{token_usage|usage}
593
+ We still include all raw sources alongside the normalized rollup.
594
+ """
595
+ out = {}
596
+ sources_token_usage = [] # raw token_usage dicts from response_metadata
597
+ sources_usage_meta = [] # raw usage_metadata dicts
598
+ roll = {
599
+ "input_tokens": 0,
600
+ "output_tokens": 0,
601
+ "total_tokens": 0,
602
+ "input_cost": 0.0,
603
+ "output_cost": 0.0,
604
+ "total_cost": 0.0,
605
+ "prompt_tokens": 0,
606
+ "completion_tokens": 0,
607
+ "reasoning_tokens": 0,
608
+ "cached_tokens": 0,
609
+ }
610
+
611
+ try:
612
+ # 1) llm_output
613
+ llm_output = getattr(response, "llm_output", None)
614
+ if isinstance(llm_output, dict):
615
+ out["llm_output"] = llm_output
616
+ tu = llm_output.get("token_usage") or llm_output.get("usage")
617
+ coerced_tu = _coerce_usage(tu)
618
+ if coerced_tu:
619
+ out["llm_output_token_usage"] = coerced_tu # clean copy
620
+
621
+ # 2) generations -> response_metadata / usage_metadata
622
+ gens = getattr(response, "generations", None)
623
+ resp_meta_list, usage_meta_list = [], []
624
+ if gens:
625
+ for gen_list in gens:
626
+ for gen in (
627
+ gen_list
628
+ if isinstance(gen_list, (list, tuple))
629
+ else [gen_list]
630
+ ):
631
+ msg = getattr(gen, "message", None)
632
+ if msg is None:
633
+ continue
634
+ rm = getattr(msg, "response_metadata", None)
635
+ if isinstance(rm, dict):
636
+ resp_meta_list.append(rm)
637
+ tu = rm.get("token_usage") or rm.get("usage")
638
+ coerced = _coerce_usage(tu)
639
+ if coerced:
640
+ sources_token_usage.append(coerced)
641
+ um = getattr(msg, "usage_metadata", None)
642
+ if isinstance(um, dict):
643
+ usage_meta_list.append(um)
644
+ sources_usage_meta.append(dict(um))
645
+
646
+ if resp_meta_list:
647
+ out["response_metadata"] = resp_meta_list
648
+ if usage_meta_list:
649
+ out["usage_metadata"] = usage_meta_list
650
+
651
+ # 3) Build the normalized rollup with priority
652
+ if sources_usage_meta:
653
+ for d in sources_usage_meta:
654
+ _acc_from(d, roll)
655
+ out["usage_source"] = "usage_metadata"
656
+ elif sources_token_usage:
657
+ for d in sources_token_usage:
658
+ _acc_from(d, roll)
659
+ out["usage_source"] = "response_metadata.token_usage"
660
+ else:
661
+ # fall back to llm_output if we coerced anything
662
+ coerced = out.get("llm_output_token_usage") or {}
663
+ if coerced:
664
+ _acc_from(coerced, roll)
665
+ out["usage_source"] = "llm_output.token_usage"
666
+
667
+ def _extract_extras(d: dict) -> dict:
668
+ if not isinstance(d, dict):
669
+ return {"reasoning_tokens": 0, "cached_tokens": 0}
670
+ # reasoning
671
+ rt = d.get("reasoning_tokens") or (
672
+ d.get("completion_tokens_details") or {}
673
+ ).get("reasoning_tokens")
674
+ # cached
675
+ cached = (
676
+ d.get("cached_tokens")
677
+ or d.get("cached_input_tokens")
678
+ or (d.get("prompt_tokens_details") or {}).get(
679
+ "cached_tokens"
680
+ )
681
+ or d.get("prompt_cache_hits")
682
+ )
683
+
684
+ def _to_int(x):
685
+ try:
686
+ return int(float(x))
687
+ except Exception:
688
+ return 0
689
+
690
+ return {
691
+ "reasoning_tokens": _to_int(rt),
692
+ "cached_tokens": _to_int(cached),
693
+ }
694
+
695
+ # Enrich from non-selected sources only (avoid double-counting the same info)
696
+ src = out.get("usage_source")
697
+ extras_candidates = []
698
+ if src != "llm_output.token_usage":
699
+ extras_candidates.append(
700
+ _extract_extras(out.get("llm_output_token_usage") or {})
701
+ )
702
+ if src != "response_metadata.token_usage":
703
+ for d in sources_token_usage:
704
+ extras_candidates.append(_extract_extras(d or {}))
705
+
706
+ if extras_candidates:
707
+ # choose the strongest signal present rather than summing duplicates
708
+ roll["reasoning_tokens"] += max(
709
+ e["reasoning_tokens"] for e in extras_candidates
710
+ )
711
+ roll["cached_tokens"] += max(
712
+ e["cached_tokens"] for e in extras_candidates
713
+ )
714
+
715
+ # Final consistency guards
716
+ if roll["prompt_tokens"] == 0 and roll["input_tokens"] > 0:
717
+ roll["prompt_tokens"] = roll["input_tokens"]
718
+ if roll["completion_tokens"] == 0 and roll["output_tokens"] > 0:
719
+ roll["completion_tokens"] = roll["output_tokens"]
720
+ # Ensure total is at least input+output (some providers omit total)
721
+ roll["total_tokens"] = max(
722
+ roll["total_tokens"],
723
+ roll["input_tokens"] + roll["output_tokens"],
724
+ roll["prompt_tokens"] + roll["completion_tokens"],
725
+ )
726
+
727
+ if any(v for v in roll.values()):
728
+ out["usage_rollup"] = roll
729
+
730
+ except Exception as e:
731
+ out["parse_error"] = repr(e)
732
+
733
+ return out
734
+
735
+ def on_llm_end(self, response, *, run_id, **kwargs):
736
+ name, t0, tags, metadata = self._starts.pop(
737
+ run_id, ("llm:unknown", time.perf_counter(), [], {})
738
+ )
739
+ ms = (time.perf_counter() - t0) * 1000.0
740
+ self.agg.add(name, ms, True)
741
+ metrics = self._extract_metrics(response)
742
+ self.samples.append({
743
+ "name": name,
744
+ "ms": ms,
745
+ "ok": True,
746
+ "tags": tags,
747
+ "metadata": metadata,
748
+ "metrics": metrics, # <- ALL captured metrics live here
749
+ })
750
+
751
+ def on_llm_error(self, error, *, run_id, **kwargs):
752
+ name, t0, tags, metadata = self._starts.pop(
753
+ run_id, ("llm:unknown", time.perf_counter(), [], {})
754
+ )
755
+ ms = (time.perf_counter() - t0) * 1000.0
756
+ self.agg.add(name, ms, False)
757
+ self.samples.append({
758
+ "name": name,
759
+ "ms": ms,
760
+ "ok": False,
761
+ "tags": tags,
762
+ "metadata": metadata,
763
+ "metrics": {"error": repr(error)},
764
+ })
765
+
766
+
767
+ def _coerce_usage(obj) -> dict:
768
+ """
769
+ Best-effort normalize provider token-usage objects into a dict.
770
+ Handles dicts, pydantic-ish objects with .dict()/.model_dump(), plain objects
771
+ with attributes, and string reprs like 'Usage(prompt_tokens=..., ...)'.
772
+ Returns a (possibly empty) dict.
773
+ """
774
+ if obj is None:
775
+ return {}
776
+
777
+ # Already a dict
778
+ if isinstance(obj, dict):
779
+ return dict(obj)
780
+
781
+ # Objects that can dump themselves
782
+ for meth in ("dict", "model_dump", "to_dict", "_asdict"):
783
+ if hasattr(obj, meth):
784
+ try:
785
+ return dict(getattr(obj, meth)())
786
+ except Exception:
787
+ pass
788
+
789
+ # Objects with attributes
790
+ attrs = (
791
+ "prompt_tokens",
792
+ "completion_tokens",
793
+ "total_tokens",
794
+ "input_tokens",
795
+ "output_tokens",
796
+ )
797
+ if any(hasattr(obj, a) for a in attrs):
798
+ d = {}
799
+ for a in attrs:
800
+ v = getattr(obj, a, None)
801
+ if v is not None:
802
+ try:
803
+ d[a] = int(v)
804
+ except Exception:
805
+ pass
806
+
807
+ # Common nested details
808
+ try:
809
+ ctd = getattr(obj, "completion_tokens_details", None)
810
+ if ctd is not None:
811
+ dd = {}
812
+ for k in (
813
+ "reasoning_tokens",
814
+ "accepted_prediction_tokens",
815
+ "rejected_prediction_tokens",
816
+ "audio_tokens",
817
+ "text_tokens",
818
+ ):
819
+ val = getattr(ctd, k, None)
820
+ if isinstance(val, (int, float)):
821
+ dd[k] = int(val)
822
+ if dd:
823
+ d["completion_tokens_details"] = dd
824
+ ptd = getattr(obj, "prompt_tokens_details", None)
825
+ if ptd is not None:
826
+ dd = {}
827
+ for k in (
828
+ "cached_tokens",
829
+ "audio_tokens",
830
+ "image_tokens",
831
+ "text_tokens",
832
+ ):
833
+ val = getattr(ptd, k, None)
834
+ if isinstance(val, (int, float)):
835
+ dd[k] = int(val)
836
+ if dd:
837
+ d["prompt_tokens_details"] = dd
838
+ except Exception:
839
+ pass
840
+
841
+ return d
842
+
843
+ # String repr like "Usage(completion_tokens=340, prompt_tokens=328, total_tokens=668, ...)"
844
+ if isinstance(obj, str):
845
+ pairs = {k: int(v) for k, v in re.findall(r"(\w+)=([0-9]+)", obj)}
846
+ # pull some nested detail hints if present
847
+ for probe in ("reasoning_tokens", "cached_tokens"):
848
+ if probe not in pairs:
849
+ m = re.search(rf"{probe}=([0-9]+)", obj)
850
+ if m:
851
+ pairs[probe] = int(m.group(1))
852
+ return pairs
853
+
854
+ return {}
855
+
856
+
857
+ # ---------------------------------
858
+ # Decorator
859
+ # ---------------------------------
860
+
861
+
862
+ # Keep the decorator, but move it out of base.py to avoid bloat.
863
+ def timed_tool(tool_name: str, sink: _Agg | None = None):
864
+ """
865
+ Simple timing decorator for tools; complements PerToolTimer callbacks.
866
+ If you're already using the callback, this adds a local measurement too.
867
+ """
868
+ sink = sink or _Agg()
869
+
870
+ def deco(fn: Callable):
871
+ @wraps(fn)
872
+ def wrapper(*args, **kwargs):
873
+ t0 = time.perf_counter()
874
+ ok = True
875
+ try:
876
+ return fn(*args, **kwargs)
877
+ except Exception:
878
+ ok = False
879
+ raise
880
+ finally:
881
+ sink.add(tool_name, (time.perf_counter() - t0) * 1000.0, ok)
882
+
883
+ return wrapper
884
+
885
+ return deco
886
+
887
+
888
+ # ---------------------------------
889
+ # Rendering helpers
890
+ # ---------------------------------
891
+
892
+
893
+ def render_table(
894
+ title: str, rows: Iterable[Tuple[str, int, float, float, float]]
895
+ ) -> str:
896
+ # rows: (name, count, total_s, avg_ms, max_ms)
897
+ out = []
898
+ out.append(f"\n{title}")
899
+ out.append(
900
+ "┏{0}┳{1}┳{2}┳{3}┳{4}┓".format(
901
+ "━" * 30, "━" * 7, "━" * 11, "━" * 10, "━" * 10
902
+ )
903
+ )
904
+ out.append(
905
+ "┃ {0:<28} ┃ {1:>5} ┃ {2:>9} ┃ {3:>8} ┃ {4:>8} ┃".format(
906
+ "Name", "Count", "Total(s)", "Avg(ms)", "Max(ms)"
907
+ )
908
+ )
909
+ out.append(
910
+ "┡{0}╇{1}╇{2}╇{3}╇{4}┩".format(
911
+ "━" * 30, "━" * 7, "━" * 11, "━" * 10, "━" * 10
912
+ )
913
+ )
914
+ for name, cnt, tot_s, avg_ms, max_ms in rows:
915
+ out.append(
916
+ "│ {0:<28} │ {1:>5} │ {2:>9.2f} │ {3:>8.0f} │ {4:>8.0f} │".format(
917
+ name[:28], cnt, tot_s, avg_ms, max_ms
918
+ )
919
+ )
920
+ out.append(
921
+ "└{0}┴{1}┴{2}┴{3}┴{4}┘".format(
922
+ "─" * 30, "─" * 7, "─" * 11, "─" * 10, "─" * 10
923
+ )
924
+ )
925
+ return "\n".join(out)
926
+
927
+
928
+ def _parse_iso(ts: str | None):
929
+ if not ts:
930
+ return None
931
+ # handle both "...Z" and "+00:00"
932
+ try:
933
+ return datetime.datetime.fromisoformat(ts.replace("Z", "+00:00"))
934
+ except Exception:
935
+ return None
936
+
937
+
938
+ def _mk_table(
939
+ title: str, rows: list[tuple[str, int, float, float, float]]
940
+ ) -> Table:
941
+ t = Table(
942
+ title=title,
943
+ title_style="bold white",
944
+ box=HEAVY,
945
+ show_lines=False,
946
+ expand=False, # <- important: don’t stretch columns differently per table
947
+ pad_edge=False,
948
+ padding=COL_PAD,
949
+ header_style="bold",
950
+ )
951
+
952
+ # lock all column widths so every table renders identically
953
+ t.add_column(
954
+ "Name",
955
+ style="cyan",
956
+ no_wrap=True,
957
+ overflow="ellipsis",
958
+ width=NAME_W,
959
+ min_width=NAME_W,
960
+ max_width=NAME_W,
961
+ )
962
+ t.add_column(
963
+ "Count",
964
+ justify="right",
965
+ width=COUNT_W,
966
+ min_width=COUNT_W,
967
+ max_width=COUNT_W,
968
+ )
969
+ t.add_column(
970
+ "Total(s)",
971
+ justify="right",
972
+ width=TOTAL_W,
973
+ min_width=TOTAL_W,
974
+ max_width=TOTAL_W,
975
+ )
976
+ t.add_column(
977
+ "Avg(ms)",
978
+ justify="right",
979
+ width=AVG_W,
980
+ min_width=AVG_W,
981
+ max_width=AVG_W,
982
+ )
983
+ t.add_column(
984
+ "Max(ms)",
985
+ justify="right",
986
+ width=MAX_W,
987
+ min_width=MAX_W,
988
+ max_width=MAX_W,
989
+ )
990
+
991
+ if not rows:
992
+ t.add_row("—", "0", f"{0.00:,.2f}", f"{0:,.0f}", f"{0:,.0f}")
993
+ return t
994
+
995
+ for name, count, total_s, avg_ms, max_ms in rows:
996
+ # keep your color hint for graph rows
997
+ name_cell = (
998
+ f"[bright_magenta]{name}[/]"
999
+ if str(name).startswith("graph:")
1000
+ else name
1001
+ )
1002
+ t.add_row(
1003
+ name_cell,
1004
+ f"{count:,}", # right-justified by column, with thousands separator
1005
+ f"{total_s:,.2f}",
1006
+ f"{avg_ms:,.0f}",
1007
+ f"{max_ms:,.0f}",
1008
+ )
1009
+ return t
1010
+
1011
+
1012
+ def _truncate_pad(s: str, width: int) -> str:
1013
+ s = str(s)
1014
+ if len(s) <= width:
1015
+ return s.ljust(width)
1016
+ if width <= 3:
1017
+ return s[:width]
1018
+ return s[: width - 3] + "..."
1019
+
1020
+
1021
+ def _plain_table(rows):
1022
+ header = (
1023
+ f"{'Name':<{NAME_W}} | "
1024
+ f"{'Count':>{COUNT_W}} | "
1025
+ f"{'Total(s)':>{TOTAL_W}} | "
1026
+ f"{'Avg(ms)':>{AVG_W}} | "
1027
+ f"{'Max(ms)':>{MAX_W}}"
1028
+ )
1029
+ lines = [header]
1030
+
1031
+ if not rows:
1032
+ lines.append(
1033
+ f"{'—':<{NAME_W}} | "
1034
+ f"{0:>{COUNT_W}d} | "
1035
+ f"{0.00:>{TOTAL_W},.2f} | "
1036
+ f"{0:>{AVG_W},.0f} | "
1037
+ f"{0:>{MAX_W},.0f}"
1038
+ )
1039
+ return "\n".join(lines)
1040
+
1041
+ for n, c, ts, am, mm in rows:
1042
+ name = _truncate_pad(n, NAME_W)
1043
+ lines.append(
1044
+ f"{name} | "
1045
+ f"{c:>{COUNT_W}d} | "
1046
+ f"{ts:>{TOTAL_W},.2f} | "
1047
+ f"{am:>{AVG_W},.0f} | "
1048
+ f"{mm:>{MAX_W},.0f}"
1049
+ )
1050
+
1051
+ return "\n".join(lines)
1052
+
1053
+
1054
+ # ---------------------------------
1055
+ # Facade to use
1056
+ # ---------------------------------
1057
+ @dataclass
1058
+ class Telemetry:
1059
+ enable: bool = True
1060
+ debug_raw: bool = False # toggle raw dump
1061
+ output_dir: str = "metrics" # where to save JSON
1062
+ save_json_default: bool = True # opt-in autosave
1063
+
1064
+ tool: PerToolTimer = field(default_factory=PerToolTimer)
1065
+ runnable: PerRunnableTimer = field(default_factory=PerRunnableTimer)
1066
+ llm: PerLLMTimer = field(default_factory=PerLLMTimer)
1067
+
1068
+ # Run-scoped context we’ll embed in the JSON filename/body
1069
+ context: Dict[str, Any] = field(default_factory=dict)
1070
+
1071
+ # ---------- JSON/export helpers ----------
1072
+ def begin_run(self, *, agent: str, thread_id: str) -> None:
1073
+ """Call at the start of BaseAgent.invoke()."""
1074
+ self.context.clear()
1075
+ self.context.update({
1076
+ "agent": agent,
1077
+ "thread_id": thread_id,
1078
+ "run_id": uuid.uuid4().hex,
1079
+ "started_at": datetime.datetime.now(
1080
+ datetime.timezone.utc
1081
+ ).isoformat(),
1082
+ })
1083
+
1084
+ @property
1085
+ def callbacks(self) -> List[BaseCallbackHandler]:
1086
+ return [] if not self.enable else [self.tool, self.runnable, self.llm]
1087
+
1088
+ def _snapshot(self) -> dict:
1089
+ """Collect everything we might want to inspect."""
1090
+
1091
+ def _as_dict(obj):
1092
+ try:
1093
+ return dict(vars(obj))
1094
+ except Exception:
1095
+ return repr(obj)
1096
+
1097
+ # Keys like run_id can be UUIDs; stringify to be safe
1098
+ def _stringify_keys(d):
1099
+ try:
1100
+ return {str(k): v for k, v in d.items()}
1101
+ except Exception:
1102
+ return repr(d)
1103
+
1104
+ return {
1105
+ "runnable": {
1106
+ "_starts": _stringify_keys(
1107
+ getattr(self.runnable, "_starts", {})
1108
+ ),
1109
+ "agg": _as_dict(getattr(self.runnable, "agg", {})),
1110
+ "buckets": list(
1111
+ getattr(self.runnable.agg, "buckets", lambda: [])()
1112
+ ),
1113
+ },
1114
+ "tool": {
1115
+ "_starts": _stringify_keys(getattr(self.tool, "_starts", {})),
1116
+ "agg": _as_dict(getattr(self.tool, "agg", {})),
1117
+ "buckets": list(
1118
+ getattr(self.tool.agg, "buckets", lambda: [])()
1119
+ ),
1120
+ },
1121
+ "llm": {
1122
+ "_starts": _stringify_keys(getattr(self.llm, "_starts", {})),
1123
+ "agg": _as_dict(getattr(self.llm, "agg", {})),
1124
+ "buckets": list(getattr(self.llm.agg, "buckets", lambda: [])()),
1125
+ },
1126
+ }
1127
+
1128
+ def _records_struct(self) -> dict:
1129
+ def _normalize(rec_list):
1130
+ # aggregator stores tuples like (name, ms, ok)
1131
+ out = []
1132
+ for r in rec_list:
1133
+ try:
1134
+ name, ms, ok = r
1135
+ except Exception:
1136
+ # fallback if shape changed
1137
+ name, ms, ok = (str(r), None, None)
1138
+ out.append({"name": name, "ms": ms, "ok": bool(ok)})
1139
+ return out
1140
+
1141
+ return {
1142
+ "runnable": _normalize(getattr(self.runnable.agg, "records", [])),
1143
+ "tool": _normalize(getattr(self.tool.agg, "records", [])),
1144
+ "llm": _normalize(getattr(self.llm.agg, "records", [])),
1145
+ }
1146
+
1147
+ def _tables_struct(self) -> dict:
1148
+ """Structured tables ready for JSON."""
1149
+
1150
+ def _rows(rows):
1151
+ # rows are (name, count, total_s, avg_ms, max_ms)
1152
+ return [
1153
+ {"name": n, "count": c, "total_s": ts, "avg_ms": a, "max_ms": m}
1154
+ for (n, c, ts, a, m) in rows
1155
+ ]
1156
+
1157
+ return {
1158
+ "runnable": _rows(self.runnable.agg.buckets()),
1159
+ "tool": _rows(self.tool.agg.buckets()),
1160
+ "llm": _rows(self.llm.agg.buckets()),
1161
+ }
1162
+
1163
+ def _totals(self, tables: dict) -> dict:
1164
+ tot = {k: sum(r["total_s"] for r in v) for k, v in tables.items()}
1165
+ unattributed = max(
1166
+ 0.0,
1167
+ tot.get("runnable", 0.0)
1168
+ - (tot.get("llm", 0.0) + tot.get("tool", 0.0)),
1169
+ )
1170
+ return {
1171
+ "graph_total_s": tot.get("runnable", 0.0),
1172
+ "llm_total_s": tot.get("llm", 0.0),
1173
+ "tool_total_s": tot.get("tool", 0.0),
1174
+ "unattributed_s": unattributed,
1175
+ }
1176
+
1177
+ def _ensure_dir(self, path: str) -> None:
1178
+ os.makedirs(path, exist_ok=True)
1179
+
1180
+ def _default_filepath(self) -> str:
1181
+ ts = datetime.datetime.now().strftime("%Y%m%d-%H%M%S-%f")
1182
+ agent = (self.context.get("agent") or "agent").replace(" ", "_")
1183
+ thread_id = self.context.get("thread_id") or "thread"
1184
+ run_id = (self.context.get("run_id") or "run")[:8]
1185
+ fname = f"{ts}_{agent}_{thread_id}_{run_id}.json"
1186
+ return os.path.join(self.output_dir, fname)
1187
+
1188
+ def _json_default(self, o):
1189
+ # dataclasses --> dict
1190
+ try:
1191
+ import dataclasses
1192
+
1193
+ if dataclasses.is_dataclass(o):
1194
+ return dataclasses.asdict(o)
1195
+ except Exception:
1196
+ pass
1197
+ # Everything else (locks, functions, callbacks, etc.) --> repr string
1198
+ return repr(o)
1199
+
1200
+ def _save_json(self, payload: dict, filepath: str | None = None) -> str:
1201
+ path = filepath or self._default_filepath()
1202
+ self._ensure_dir(os.path.dirname(path) or ".")
1203
+ with open(path, "w", encoding="utf-8") as f:
1204
+ json.dump(
1205
+ payload,
1206
+ f,
1207
+ ensure_ascii=False,
1208
+ indent=2,
1209
+ default=self._json_default,
1210
+ )
1211
+ return path
1212
+
1213
+ def to_json(
1214
+ self, *, include_raw_snapshot: bool, include_raw_records: bool
1215
+ ) -> dict:
1216
+ tables = self._tables_struct()
1217
+ out = {
1218
+ "context": {
1219
+ **self.context,
1220
+ "ended_at": datetime.datetime.now(
1221
+ datetime.timezone.utc
1222
+ ).isoformat(),
1223
+ },
1224
+ "tables": tables,
1225
+ "totals": self._totals(tables),
1226
+ "llm_events": list(getattr(self.llm, "samples", [])),
1227
+ }
1228
+ if include_raw_snapshot:
1229
+ out["raw_snapshot"] = self._snapshot()
1230
+ if include_raw_records:
1231
+ out["raw_records"] = self._records_struct()
1232
+ return out
1233
+
1234
+ def render(
1235
+ self,
1236
+ raw: bool | None = None,
1237
+ save_json: bool | None = None,
1238
+ filepath: str | None = None,
1239
+ save_raw_snapshot: bool | None = None,
1240
+ save_raw_records: bool | None = None,
1241
+ ):
1242
+ if not self.enable:
1243
+ return ""
1244
+
1245
+ # --- Gather tables ---
1246
+ r_rows = self.runnable.agg.buckets()
1247
+ t_rows = self.tool.agg.buckets()
1248
+ l_rows = self.llm.agg.buckets() if hasattr(self, "llm") else []
1249
+
1250
+ # --- Build priceable payload early (also gives us context) ---
1251
+ inc_snapshot = True if save_raw_snapshot is None else save_raw_snapshot
1252
+ inc_records = True if save_raw_records is None else save_raw_records
1253
+ payload = self.to_json(
1254
+ include_raw_snapshot=inc_snapshot, include_raw_records=inc_records
1255
+ )
1256
+ ctx = payload.get("context", {}) or {}
1257
+ agent_name = (
1258
+ ctx.get("agent")
1259
+ or getattr(self, "__class__", type("X", (object,), {})).__name__
1260
+ or "UnknownAgent"
1261
+ )
1262
+ thread_id = (
1263
+ ctx.get("thread_id") or getattr(self, "thread_id", None) or "—"
1264
+ )
1265
+ run_id = ctx.get("run_id", "—")
1266
+ started_at = ctx.get("started_at")
1267
+ ended_at = ctx.get("ended_at")
1268
+ start_dt = _parse_iso(started_at)
1269
+ end_dt = _parse_iso(ended_at)
1270
+ wall_secs = (
1271
+ (end_dt - start_dt).total_seconds()
1272
+ if (start_dt and end_dt)
1273
+ else None
1274
+ )
1275
+
1276
+ # Optional human alias set post-construction: executor.name = "exec-A"
1277
+ human_alias = getattr(self, "name", None) or getattr(
1278
+ self, "alias", None
1279
+ )
1280
+ base_label = human_alias or agent_name
1281
+
1282
+ # Lazily create a per-instance short id (stable for the object's lifetime)
1283
+ if not hasattr(self, "_short_id"):
1284
+ try:
1285
+ import uuid as _uuid
1286
+
1287
+ self._short_id = _uuid.uuid4().hex[:6]
1288
+ except Exception:
1289
+ self._short_id = format(id(self) & 0xFFFFFF, "06x")
1290
+ agent_label = f"{base_label} [{self._short_id}]"
1291
+
1292
+ # --- Totals (use wall clock for unattributed) ---
1293
+ def _total(rows):
1294
+ return sum((row[2] for row in rows), 0.0)
1295
+
1296
+ llm_total = _total(l_rows)
1297
+ tool_total = _total(t_rows)
1298
+ unattributed = (
1299
+ max(0.0, wall_secs - (llm_total + tool_total))
1300
+ if wall_secs is not None
1301
+ else None
1302
+ )
1303
+ graph_bucket_sum = _total(r_rows) # informative only (overlaps)
1304
+
1305
+ # --- Pricing (optional) ---
1306
+ pricing_text_lines = []
1307
+ pricing_mod = _get_pricing_module()
1308
+ if pricing_mod and (payload.get("llm_events") or []):
1309
+ registry_path = os.environ.get("URSA_PRICING_JSON")
1310
+ registry = pricing_mod.load_registry(path=registry_path)
1311
+ payload = pricing_mod.price_payload(
1312
+ payload, registry=registry, overwrite=False
1313
+ )
1314
+ costs = payload.get("costs") or {}
1315
+ total_usd = costs.get("total_usd", 0.0)
1316
+ by_model = costs.get("by_model_usd", {})
1317
+ src_counts = costs.get("event_sources", {})
1318
+ pricing_text_lines.append("[bold]Cost Summary (USD)[/]")
1319
+ pricing_text_lines.append(
1320
+ f" total: [bold green]${total_usd:,.6f}[/]"
1321
+ )
1322
+ for model, amt in (by_model or {}).items():
1323
+ pricing_text_lines.append(f" {model}: ${amt:,.6f}")
1324
+ if src_counts:
1325
+ pricing_text_lines.append(
1326
+ f" (events: provider={src_counts.get('provider', 0)}, "
1327
+ f"computed={src_counts.get('computed', 0)}, "
1328
+ f"no_usage={src_counts.get('no_usage', 0)}, "
1329
+ f"no_pricing={src_counts.get('no_pricing', 0)})"
1330
+ )
1331
+
1332
+ # --- Save JSON (if requested) ---
1333
+ do_save = self.save_json_default if save_json is None else save_json
1334
+ saved_path = None
1335
+ if do_save:
1336
+ saved_path = self._save_json(payload, filepath=filepath)
1337
+
1338
+ # --- Build Rich renderables ---
1339
+ console = get_console()
1340
+
1341
+ # --- Build header & attribution lines (markup-aware) ---
1342
+ header_lines = []
1343
+ header_lines.append(
1344
+ f"[bold magenta]{agent_label}[/] [dim]•[/] thread [bold]{thread_id}[/] [dim]•[/] run [bold]{run_id}[/]"
1345
+ )
1346
+ if start_dt and end_dt:
1347
+ header_lines.append(
1348
+ f"[dim]{started_at} → {ended_at}[/dim] [bold]wall[/]: {wall_secs:,.2f}s"
1349
+ )
1350
+
1351
+ attrib_lines = []
1352
+ attrib_lines.append("[bold]Attribution[/]")
1353
+ if wall_secs is not None:
1354
+ attrib_lines.append(
1355
+ f" Total run (wall): [bold]{wall_secs:,.2f}s[/]"
1356
+ )
1357
+ attrib_lines.append(f" LLM total: {llm_total:,.2f}s")
1358
+ attrib_lines.append(f" Tool total: {tool_total:,.2f}s")
1359
+ if unattributed is not None:
1360
+ attrib_lines.append(f" Unattributed: {unattributed:,.2f}s")
1361
+ attrib_lines.append(
1362
+ f"[dim] Sum of runnable buckets (non-additive): {graph_bucket_sum:,.2f}s[/]"
1363
+ )
1364
+ if saved_path:
1365
+ attrib_lines.append(f"[dim]Saved metrics JSON to:[/] {saved_path}")
1366
+
1367
+ header_str = "\n".join(
1368
+ header_lines
1369
+ ) # these strings contain [bold], [dim], etc.
1370
+ attrib_str = "\n".join(attrib_lines)
1371
+ pricing_str = (
1372
+ "\n".join(pricing_text_lines) if pricing_text_lines else None
1373
+ )
1374
+
1375
+ tbl_nodes = _mk_table("Per-Node / Runnable Timing", r_rows)
1376
+ tbl_tools = _mk_table("Per-Tool Timing", t_rows)
1377
+ tbl_llms = _mk_table("Per-LLM Timing", l_rows)
1378
+
1379
+ renderables = [
1380
+ Text.from_markup(header_str), # <- parse markup
1381
+ Rule(),
1382
+ tbl_nodes,
1383
+ tbl_tools,
1384
+ tbl_llms,
1385
+ Rule(),
1386
+ Text.from_markup(attrib_str), # <- parse markup
1387
+ ]
1388
+ if pricing_str:
1389
+ renderables += [
1390
+ Rule(),
1391
+ Text.from_markup(pricing_str),
1392
+ ] # <- parse markup
1393
+
1394
+ panel = Panel.fit(
1395
+ Group(*renderables), # <- pass a single renderable
1396
+ title=f"[bold white]Metrics[/] • [cyan]{agent_label}[/]",
1397
+ border_style="bright_magenta",
1398
+ padding=(1, 2),
1399
+ box=HEAVY, # <- beefy border with corners
1400
+ )
1401
+ console.print(panel)
1402
+
1403
+ _session_ingest(payload)
1404
+
1405
+ # parts = []
1406
+ # parts.append(f"{agent_label} • thread {thread_id} • run {run_id}")
1407
+ # if start_dt and end_dt:
1408
+ # parts.append(f"{started_at} → {ended_at} wall: {wall_secs:.2f}s")
1409
+ # parts.append("\nPer-Node / Runnable Timing\n" + _plain_table(r_rows))
1410
+ # parts.append("\nPer-Tool Timing\n" + _plain_table(t_rows))
1411
+ # parts.append("\nPer-LLM Timing\n" + _plain_table(l_rows))
1412
+ # parts.append(
1413
+ # "\nAttribution\n"
1414
+ # + (f" Total run (wall): {wall_secs:.2f}s\n" if wall_secs is not None else "")
1415
+ # + f" LLM total: {llm_total:.2f}s\n"
1416
+ # + f" Tool total: {tool_total:.2f}s\n"
1417
+ # + (f" Unattributed: {unattributed:.2f}s\n" if unattributed is not None else "")
1418
+ # + f" Graph bucket sum (overlaps): {graph_bucket_sum:.2f}s"
1419
+ # )
1420
+ # if pricing_text_lines:
1421
+ # sanitized = []
1422
+ # for line in pricing_text_lines:
1423
+ # sanitized.append(
1424
+ # line.replace("[bold]", "")
1425
+ # .replace("[/bold]", "")
1426
+ # .replace("[dim]", "")
1427
+ # .replace("[/dim]", "")
1428
+ # .replace("[bold green]", "")
1429
+ # .replace("[/bold green]", "")
1430
+ # )
1431
+ # parts.append("\n" + "\n".join(sanitized))
1432
+ # if saved_path:
1433
+ # parts.append(f"\nSaved metrics JSON to: {saved_path}")
1434
+
1435
+ # include_raw = self.debug_raw if raw is None else raw
1436
+ # if include_raw:
1437
+ # import pprint as _pp
1438
+ # parts.append("\nRaw Debug Snapshot\n" + _pp.pformat(self._snapshot(), sort_dicts=True, width=120))
1439
+
1440
+ # return ""
1441
+ # # return "\n".join(parts)