osmosis-ai 0.1.8__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of osmosis-ai might be problematic. Click here for more details.

Files changed (36) hide show
  1. osmosis_ai/__init__.py +19 -132
  2. osmosis_ai/cli.py +50 -0
  3. osmosis_ai/cli_commands.py +181 -0
  4. osmosis_ai/cli_services/__init__.py +60 -0
  5. osmosis_ai/cli_services/config.py +410 -0
  6. osmosis_ai/cli_services/dataset.py +175 -0
  7. osmosis_ai/cli_services/engine.py +421 -0
  8. osmosis_ai/cli_services/errors.py +7 -0
  9. osmosis_ai/cli_services/reporting.py +307 -0
  10. osmosis_ai/cli_services/session.py +174 -0
  11. osmosis_ai/cli_services/shared.py +209 -0
  12. osmosis_ai/consts.py +2 -16
  13. osmosis_ai/providers/__init__.py +36 -0
  14. osmosis_ai/providers/anthropic_provider.py +85 -0
  15. osmosis_ai/providers/base.py +60 -0
  16. osmosis_ai/providers/gemini_provider.py +314 -0
  17. osmosis_ai/providers/openai_family.py +607 -0
  18. osmosis_ai/providers/shared.py +92 -0
  19. osmosis_ai/rubric_eval.py +356 -0
  20. osmosis_ai/rubric_types.py +49 -0
  21. osmosis_ai/utils.py +284 -89
  22. osmosis_ai-0.2.4.dist-info/METADATA +314 -0
  23. osmosis_ai-0.2.4.dist-info/RECORD +27 -0
  24. osmosis_ai-0.2.4.dist-info/entry_points.txt +4 -0
  25. {osmosis_ai-0.1.8.dist-info → osmosis_ai-0.2.4.dist-info}/licenses/LICENSE +1 -1
  26. osmosis_ai/adapters/__init__.py +0 -9
  27. osmosis_ai/adapters/anthropic.py +0 -502
  28. osmosis_ai/adapters/langchain.py +0 -674
  29. osmosis_ai/adapters/langchain_anthropic.py +0 -338
  30. osmosis_ai/adapters/langchain_openai.py +0 -596
  31. osmosis_ai/adapters/openai.py +0 -900
  32. osmosis_ai/logger.py +0 -77
  33. osmosis_ai-0.1.8.dist-info/METADATA +0 -281
  34. osmosis_ai-0.1.8.dist-info/RECORD +0 -15
  35. {osmosis_ai-0.1.8.dist-info → osmosis_ai-0.2.4.dist-info}/WHEEL +0 -0
  36. {osmosis_ai-0.1.8.dist-info → osmosis_ai-0.2.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,307 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from datetime import datetime, timezone
5
+ from pathlib import Path
6
+ from typing import Any, Callable, Optional
7
+
8
+ from .engine import EvaluationRecordResult, EvaluationReport, EvaluationRun
9
+ from .errors import CLIError
10
+ from .shared import calculate_stat_deltas
11
+
12
+
13
+ class TextReportFormatter:
14
+ """Builds human-readable text lines for an evaluation report."""
15
+
16
+ def build(
17
+ self,
18
+ report: EvaluationReport,
19
+ baseline: Optional["BaselineStatistics"] = None,
20
+ ) -> list[str]:
21
+ lines: list[str] = []
22
+ provider = str(report.rubric_config.model_info.get("provider", "")).strip() or "<unknown>"
23
+ model_name = str(report.rubric_config.model_info.get("model", "")).strip() or "<unspecified>"
24
+
25
+ lines.append(
26
+ f"Rubric '{report.rubric_config.rubric_id}' "
27
+ f"({report.rubric_config.source_label}) -> provider '{provider}' model '{model_name}'"
28
+ )
29
+ lines.append(f"Loaded {len(report.record_results)} matching record(s) from {report.data_path}")
30
+ lines.append(f"Running {report.number} evaluation(s) per record")
31
+ lines.append("")
32
+
33
+ for record_result in report.record_results:
34
+ lines.extend(self._format_record(record_result))
35
+ if baseline is not None:
36
+ lines.extend(self._format_baseline(report, baseline))
37
+ return lines
38
+
39
+ def _format_record(self, record_result: EvaluationRecordResult) -> list[str]:
40
+ lines: list[str] = [f"[{record_result.conversation_label}]"]
41
+ total_runs = len(record_result.runs)
42
+ for index, run in enumerate(record_result.runs):
43
+ lines.extend(self._format_run(run))
44
+ if index < total_runs - 1:
45
+ lines.append("")
46
+
47
+ summary_lines = self._format_summary(record_result.statistics, len(record_result.runs))
48
+ if summary_lines:
49
+ lines.extend(summary_lines)
50
+ lines.append("")
51
+ return lines
52
+
53
+ def _format_run(self, run: EvaluationRun) -> list[str]:
54
+ lines: list[str] = [f" Run {run.run_index:02d} [{run.status.upper()}]"]
55
+ if run.status == "success":
56
+ score_text = "n/a" if run.score is None else f"{run.score:.4f}"
57
+ lines.append(self._format_detail_line("score", score_text))
58
+ if run.preview:
59
+ lines.append(self._format_detail_line("preview", run.preview))
60
+ explanation = run.explanation or "(no explanation provided)"
61
+ lines.append(self._format_detail_line("explanation", explanation))
62
+ else:
63
+ error_text = run.error or "(no error message provided)"
64
+ lines.append(self._format_detail_line("error", error_text))
65
+ if run.preview:
66
+ lines.append(self._format_detail_line("preview", run.preview))
67
+ if run.explanation:
68
+ lines.append(self._format_detail_line("explanation", run.explanation))
69
+ lines.append(self._format_detail_line("duration", f"{run.duration_seconds:.2f}s"))
70
+ return lines
71
+
72
+ def _format_summary(self, statistics: dict[str, float], total_runs: int) -> list[str]:
73
+ if not statistics:
74
+ return []
75
+ success_count = int(round(statistics.get("success_count", total_runs)))
76
+ failure_count = int(round(statistics.get("failure_count", total_runs - success_count)))
77
+ if total_runs <= 1 and failure_count == 0:
78
+ return []
79
+
80
+ lines = [" Summary:"]
81
+ lines.append(f" total: {int(round(statistics.get('total_runs', total_runs)))}")
82
+ lines.append(f" successes: {success_count}")
83
+ lines.append(f" failures: {failure_count}")
84
+ if success_count > 0:
85
+ lines.append(f" average: {statistics.get('average', 0.0):.4f}")
86
+ lines.append(f" variance: {statistics.get('variance', 0.0):.6f}")
87
+ lines.append(f" stdev: {statistics.get('stdev', 0.0):.4f}")
88
+ lines.append(f" min/max: {statistics.get('min', 0.0):.4f} / {statistics.get('max', 0.0):.4f}")
89
+ else:
90
+ lines.append(" average: n/a")
91
+ lines.append(" variance: n/a")
92
+ lines.append(" stdev: n/a")
93
+ lines.append(" min/max: n/a")
94
+ return lines
95
+
96
+ def _format_baseline(
97
+ self,
98
+ report: EvaluationReport,
99
+ baseline: "BaselineStatistics",
100
+ ) -> list[str]:
101
+ lines = [f"Baseline comparison (source: {baseline.source_path}):"]
102
+ deltas = baseline.delta(report.overall_statistics)
103
+ keys = ["average", "variance", "stdev", "min", "max", "success_count", "failure_count", "total_runs"]
104
+
105
+ for key in keys:
106
+ if key not in baseline.statistics or key not in report.overall_statistics:
107
+ continue
108
+ baseline_value = float(baseline.statistics[key])
109
+ current_value = float(report.overall_statistics[key])
110
+ delta_value = float(deltas.get(key, current_value - baseline_value))
111
+
112
+ if key in {"success_count", "failure_count", "total_runs"}:
113
+ baseline_str = f"{int(round(baseline_value))}"
114
+ current_str = f"{int(round(current_value))}"
115
+ delta_str = f"{delta_value:+.0f}"
116
+ else:
117
+ precision = 6 if key == "variance" else 4
118
+ baseline_str = format(baseline_value, f".{precision}f")
119
+ current_str = format(current_value, f".{precision}f")
120
+ delta_str = format(delta_value, f"+.{precision}f")
121
+
122
+ lines.append(
123
+ f" {key:12s} baseline={baseline_str} current={current_str} delta={delta_str}"
124
+ )
125
+ return lines
126
+
127
+ @staticmethod
128
+ def _format_detail_line(label: str, value: str, *, indent: int = 4) -> str:
129
+ indent_str = " " * indent
130
+ value_str = str(value)
131
+ continuation_indent = indent_str + " "
132
+ value_str = value_str.replace("\n", f"\n{continuation_indent}")
133
+ return f"{indent_str}{label}: {value_str}"
134
+
135
+
136
+ class ConsoleReportRenderer:
137
+ """Pretty prints evaluation reports to stdout (or any printer function)."""
138
+
139
+ def __init__(
140
+ self,
141
+ printer: Callable[[str], None] = print,
142
+ formatter: Optional[TextReportFormatter] = None,
143
+ ):
144
+ self._printer = printer
145
+ self._formatter = formatter or TextReportFormatter()
146
+
147
+ def render(
148
+ self,
149
+ report: EvaluationReport,
150
+ baseline: Optional["BaselineStatistics"] = None,
151
+ ) -> None:
152
+ for line in self._formatter.build(report, baseline):
153
+ self._printer(line)
154
+
155
+
156
+ class BaselineStatistics:
157
+ def __init__(self, source_path: Path, statistics: dict[str, float]):
158
+ self.source_path = source_path
159
+ self.statistics = statistics
160
+
161
+ def delta(self, current: dict[str, float]) -> dict[str, float]:
162
+ return calculate_stat_deltas(self.statistics, current)
163
+
164
+
165
+ class BaselineComparator:
166
+ """Loads baseline JSON payloads and extracts statistics."""
167
+
168
+ def load(self, path: Path) -> BaselineStatistics:
169
+ if not path.exists():
170
+ raise CLIError(f"Baseline path '{path}' does not exist.")
171
+ if path.is_dir():
172
+ raise CLIError(f"Baseline path '{path}' is a directory, expected JSON file.")
173
+
174
+ try:
175
+ payload = json.loads(path.read_text(encoding="utf-8"))
176
+ except json.JSONDecodeError as exc:
177
+ raise CLIError(f"Failed to parse baseline JSON: {exc}") from exc
178
+
179
+ if not isinstance(payload, dict):
180
+ raise CLIError("Baseline JSON must contain an object.")
181
+
182
+ source = None
183
+ if isinstance(payload.get("overall_statistics"), dict):
184
+ source = payload["overall_statistics"]
185
+ elif all(key in payload for key in ("average", "variance", "stdev")):
186
+ source = payload
187
+ if source is None:
188
+ raise CLIError(
189
+ "Baseline JSON must include an 'overall_statistics' object or top-level statistics."
190
+ )
191
+
192
+ statistics: dict[str, float] = {}
193
+ for key, value in source.items():
194
+ try:
195
+ statistics[key] = float(value)
196
+ except (TypeError, ValueError):
197
+ continue
198
+ if not statistics:
199
+ raise CLIError("Baseline statistics could not be parsed into numeric values.")
200
+
201
+ return BaselineStatistics(source_path=path, statistics=statistics)
202
+
203
+
204
+ class JsonReportFormatter:
205
+ """Builds JSON-serialisable payloads for evaluation reports."""
206
+
207
+ def build(
208
+ self,
209
+ report: EvaluationReport,
210
+ *,
211
+ output_identifier: Optional[str],
212
+ baseline: Optional[BaselineStatistics],
213
+ ) -> dict[str, Any]:
214
+ provider = str(report.rubric_config.model_info.get("provider", "")).strip() or "<unknown>"
215
+ model_name = str(report.rubric_config.model_info.get("model", "")).strip() or "<unspecified>"
216
+
217
+ generated: dict[str, Any] = {
218
+ "generated_at": datetime.now(timezone.utc).isoformat(),
219
+ "rubric_id": report.rubric_config.rubric_id,
220
+ "rubric_source": report.rubric_config.source_label,
221
+ "provider": provider,
222
+ "model": model_name,
223
+ "number": report.number,
224
+ "config_path": str(report.config_path),
225
+ "data_path": str(report.data_path),
226
+ "overall_statistics": _normalise_statistics(report.overall_statistics),
227
+ "records": [],
228
+ }
229
+ if output_identifier is not None:
230
+ generated["output_identifier"] = output_identifier
231
+
232
+ for record_result in report.record_results:
233
+ conversation_label = record_result.conversation_label
234
+ record_identifier = record_result.record.record_identifier(conversation_label)
235
+
236
+ record_payload: dict[str, Any] = {
237
+ "id": record_identifier,
238
+ "record_index": record_result.record_index,
239
+ "conversation_id": conversation_label,
240
+ "input_record": record_result.record.payload,
241
+ "statistics": _normalise_statistics(record_result.statistics),
242
+ "runs": [],
243
+ }
244
+
245
+ for run in record_result.runs:
246
+ record_payload["runs"].append(
247
+ {
248
+ "run_index": run.run_index,
249
+ "status": run.status,
250
+ "started_at": run.started_at.isoformat(),
251
+ "completed_at": run.completed_at.isoformat(),
252
+ "duration_seconds": run.duration_seconds,
253
+ "score": run.score,
254
+ "explanation": run.explanation,
255
+ "preview": run.preview,
256
+ "error": run.error,
257
+ "raw": run.raw,
258
+ }
259
+ )
260
+
261
+ generated["records"].append(record_payload)
262
+
263
+ if baseline is not None:
264
+ generated["baseline_comparison"] = {
265
+ "source_path": str(baseline.source_path),
266
+ "baseline_statistics": _normalise_statistics(baseline.statistics),
267
+ "delta_statistics": _normalise_statistics(baseline.delta(report.overall_statistics)),
268
+ }
269
+
270
+ return generated
271
+
272
+
273
+ class JsonReportWriter:
274
+ """Serialises an evaluation report to disk."""
275
+
276
+ def __init__(self, formatter: Optional[JsonReportFormatter] = None):
277
+ self._formatter = formatter or JsonReportFormatter()
278
+
279
+ def write(
280
+ self,
281
+ report: EvaluationReport,
282
+ *,
283
+ output_path: Path,
284
+ output_identifier: Optional[str],
285
+ baseline: Optional[BaselineStatistics],
286
+ ) -> Path:
287
+ parent_dir = output_path.parent
288
+ if parent_dir and not parent_dir.exists():
289
+ parent_dir.mkdir(parents=True, exist_ok=True)
290
+
291
+ payload = self._formatter.build(
292
+ report,
293
+ output_identifier=output_identifier,
294
+ baseline=baseline,
295
+ )
296
+ with output_path.open("w", encoding="utf-8") as fh:
297
+ json.dump(payload, fh, indent=2, ensure_ascii=False)
298
+ return output_path
299
+
300
+ def _normalise_statistics(stats: dict[str, float]) -> dict[str, Any]:
301
+ normalised: dict[str, Any] = {}
302
+ for key, value in stats.items():
303
+ if key in {"success_count", "failure_count", "total_runs"}:
304
+ normalised[key] = int(round(value))
305
+ else:
306
+ normalised[key] = float(value)
307
+ return normalised
@@ -0,0 +1,174 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Callable, Optional, Sequence
7
+
8
+ from ..rubric_eval import ensure_api_key_available
9
+ from ..rubric_types import MissingAPIKeyError
10
+ from .config import RubricConfig, RubricSuite, discover_rubric_config_path, load_rubric_suite
11
+ from .dataset import DatasetLoader, DatasetRecord
12
+ from .engine import RubricEvaluationEngine, EvaluationReport
13
+ from .errors import CLIError
14
+ from .reporting import BaselineComparator, BaselineStatistics, JsonReportWriter
15
+
16
+
17
+ _CACHE_ROOT = Path("~/.cache/osmosis/eval_result").expanduser()
18
+
19
+
20
+ def _sanitise_rubric_folder(rubric_id: str) -> str:
21
+ """Produce a filesystem-safe folder name for the rubric id."""
22
+ clean = "".join(ch if ch.isalnum() or ch in {"-", "_"} else "_" for ch in rubric_id.strip())
23
+ clean = clean.strip("_") or "rubric"
24
+ return clean.lower()
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class EvaluationSessionRequest:
29
+ rubric_id: str
30
+ data_path: Path
31
+ number: int = 1
32
+ config_path: Optional[Path] = None
33
+ output_path: Optional[Path] = None
34
+ output_identifier: Optional[str] = None
35
+ baseline_path: Optional[Path] = None
36
+
37
+
38
+ @dataclass
39
+ class EvaluationSessionResult:
40
+ request: EvaluationSessionRequest
41
+ config_path: Path
42
+ data_path: Path
43
+ rubric_config: RubricConfig
44
+ records: Sequence[DatasetRecord]
45
+ report: EvaluationReport
46
+ baseline: Optional[BaselineStatistics]
47
+ written_path: Optional[Path]
48
+ output_identifier: Optional[str]
49
+
50
+
51
+ class EvaluationSession:
52
+ """Coordinates rubric evaluation end-to-end for reusable orchestration."""
53
+
54
+ def __init__(
55
+ self,
56
+ *,
57
+ config_locator: Callable[[Optional[str], Path], Path] = discover_rubric_config_path,
58
+ suite_loader: Callable[[Path], RubricSuite] = load_rubric_suite,
59
+ dataset_loader: Optional[DatasetLoader] = None,
60
+ engine: Optional[RubricEvaluationEngine] = None,
61
+ baseline_comparator: Optional[BaselineComparator] = None,
62
+ report_writer: Optional[JsonReportWriter] = None,
63
+ identifier_factory: Optional[Callable[[], str]] = None,
64
+ ):
65
+ self._config_locator = config_locator
66
+ self._suite_loader = suite_loader
67
+ self._dataset_loader = dataset_loader or DatasetLoader()
68
+ self._engine = engine or RubricEvaluationEngine()
69
+ self._baseline_comparator = baseline_comparator or BaselineComparator()
70
+ self._report_writer = report_writer or JsonReportWriter()
71
+ self._identifier_factory = identifier_factory or self._default_identifier
72
+
73
+ def execute(self, request: EvaluationSessionRequest) -> EvaluationSessionResult:
74
+ rubric_id = request.rubric_id.strip()
75
+ if not rubric_id:
76
+ raise CLIError("Rubric identifier cannot be empty.")
77
+
78
+ number_value = request.number if request.number is not None else 1
79
+ number = int(number_value)
80
+ if number < 1:
81
+ raise CLIError("Number of runs must be a positive integer.")
82
+
83
+ data_path = request.data_path.expanduser()
84
+ if not data_path.exists():
85
+ raise CLIError(f"Data path '{data_path}' does not exist.")
86
+ if data_path.is_dir():
87
+ raise CLIError(f"Expected a JSONL file but received directory '{data_path}'.")
88
+
89
+ config_override = str(request.config_path.expanduser()) if request.config_path else None
90
+ config_path = self._config_locator(config_override, data_path)
91
+ suite = self._suite_loader(config_path)
92
+ rubric_config = suite.get(rubric_id)
93
+
94
+ try:
95
+ ensure_api_key_available(rubric_config.model_info)
96
+ except (MissingAPIKeyError, TypeError) as exc:
97
+ raise CLIError(str(exc)) from exc
98
+
99
+ all_records = self._dataset_loader.load(data_path)
100
+ matching_records = [
101
+ record for record in all_records if record.rubric_id.lower() == rubric_id.lower()
102
+ ]
103
+ if not matching_records:
104
+ raise CLIError(f"No records in '{data_path}' reference rubric '{rubric_id}'.")
105
+
106
+ baseline_stats = self._load_baseline(request.baseline_path)
107
+
108
+ resolved_output_path, resolved_identifier = self._resolve_output_path(
109
+ request.output_path,
110
+ request.output_identifier,
111
+ rubric_id=rubric_id,
112
+ )
113
+
114
+ report = self._engine.execute(
115
+ rubric_config=rubric_config,
116
+ config_path=config_path,
117
+ data_path=data_path,
118
+ records=matching_records,
119
+ number=number,
120
+ )
121
+
122
+ written_path = None
123
+ if resolved_output_path is not None:
124
+ written_path = self._report_writer.write(
125
+ report,
126
+ output_path=resolved_output_path,
127
+ output_identifier=resolved_identifier,
128
+ baseline=baseline_stats,
129
+ )
130
+
131
+ return EvaluationSessionResult(
132
+ request=request,
133
+ config_path=config_path,
134
+ data_path=data_path,
135
+ rubric_config=rubric_config,
136
+ records=matching_records,
137
+ report=report,
138
+ baseline=baseline_stats,
139
+ written_path=written_path,
140
+ output_identifier=resolved_identifier,
141
+ )
142
+
143
+ def _load_baseline(self, baseline_path: Optional[Path]) -> Optional[BaselineStatistics]:
144
+ if baseline_path is None:
145
+ return None
146
+ resolved = baseline_path.expanduser()
147
+ return self._baseline_comparator.load(resolved)
148
+
149
+ def _resolve_output_path(
150
+ self,
151
+ output_candidate: Optional[Path],
152
+ output_identifier: Optional[str],
153
+ *,
154
+ rubric_id: str,
155
+ ) -> tuple[Optional[Path], Optional[str]]:
156
+ if output_candidate is None:
157
+ identifier = output_identifier or self._identifier_factory()
158
+ target_dir = _CACHE_ROOT / _sanitise_rubric_folder(rubric_id)
159
+ target_dir.mkdir(parents=True, exist_ok=True)
160
+ return target_dir / f"rubric_eval_result_{identifier}.json", identifier
161
+
162
+ candidate = output_candidate.expanduser()
163
+ if candidate.suffix:
164
+ if candidate.exists() and candidate.is_dir():
165
+ raise CLIError(f"Output path '{candidate}' is a directory.")
166
+ return candidate, output_identifier
167
+
168
+ candidate.mkdir(parents=True, exist_ok=True)
169
+ identifier = output_identifier or self._identifier_factory()
170
+ return candidate / f"rubric_eval_result_{identifier}.json", identifier
171
+
172
+ @staticmethod
173
+ def _default_identifier() -> str:
174
+ return str(int(time.time()))
@@ -0,0 +1,209 @@
1
+ from __future__ import annotations
2
+
3
+ from statistics import mean, pvariance, pstdev
4
+ from typing import Any, Collection, Optional, Set
5
+
6
+ from .errors import CLIError
7
+
8
+
9
+ def coerce_optional_float(value: Any, field_name: str, source_label: str) -> Optional[float]:
10
+ if value is None:
11
+ return None
12
+ if isinstance(value, (int, float)) and not isinstance(value, bool):
13
+ return float(value)
14
+ raise CLIError(
15
+ f"Expected '{field_name}' in {source_label} to be numeric, got {type(value).__name__}."
16
+ )
17
+
18
+
19
+ def collapse_preview_text(value: Any, *, max_length: int = 140) -> Optional[str]:
20
+ if not isinstance(value, str):
21
+ return None
22
+ collapsed = " ".join(value.strip().split())
23
+ if not collapsed:
24
+ return None
25
+ if len(collapsed) > max_length:
26
+ collapsed = collapsed[: max_length - 3].rstrip() + "..."
27
+ return collapsed
28
+
29
+
30
+ def calculate_statistics(scores: list[float]) -> dict[str, float]:
31
+ if not scores:
32
+ return {
33
+ "average": 0.0,
34
+ "variance": 0.0,
35
+ "stdev": 0.0,
36
+ "min": 0.0,
37
+ "max": 0.0,
38
+ }
39
+ average = mean(scores)
40
+ variance = pvariance(scores)
41
+ std_dev = pstdev(scores)
42
+ return {
43
+ "average": average,
44
+ "variance": variance,
45
+ "stdev": std_dev,
46
+ "min": min(scores),
47
+ "max": max(scores),
48
+ }
49
+
50
+
51
+ def calculate_stat_deltas(baseline: dict[str, float], current: dict[str, float]) -> dict[str, float]:
52
+ delta: dict[str, float] = {}
53
+ for key, current_value in current.items():
54
+ if key not in baseline:
55
+ continue
56
+ try:
57
+ baseline_value = float(baseline[key])
58
+ current_numeric = float(current_value)
59
+ except (TypeError, ValueError):
60
+ continue
61
+ delta[key] = current_numeric - baseline_value
62
+ return delta
63
+
64
+
65
+ def gather_text_fragments(
66
+ node: Any,
67
+ fragments: list[str],
68
+ *,
69
+ allow_free_strings: bool = False,
70
+ seen: Optional[Set[int]] = None,
71
+ string_key_allowlist: Optional[Collection[str]] = None,
72
+ ) -> None:
73
+ """Collect textual snippets from nested message-like structures.
74
+
75
+ The traversal favours common chat-completions shapes (e.g. ``{"type": "text"}``
76
+ blocks) and avoids indiscriminately pulling in metadata values such as IDs.
77
+ ``allow_free_strings`` controls whether bare strings encountered at the current
78
+ level should be considered textual content (useful for raw message content but
79
+ typically disabled for metadata fields).
80
+ """
81
+
82
+ if seen is None:
83
+ seen = set()
84
+
85
+ if isinstance(node, str):
86
+ if allow_free_strings:
87
+ stripped = node.strip()
88
+ if stripped:
89
+ fragments.append(stripped)
90
+ return
91
+
92
+ if isinstance(node, list):
93
+ for item in node:
94
+ gather_text_fragments(
95
+ item,
96
+ fragments,
97
+ allow_free_strings=allow_free_strings,
98
+ seen=seen,
99
+ string_key_allowlist=string_key_allowlist,
100
+ )
101
+ return
102
+
103
+ if not isinstance(node, dict):
104
+ return
105
+
106
+ node_id = id(node)
107
+ if node_id in seen:
108
+ return
109
+ seen.add(node_id)
110
+
111
+ allowlist = {"text", "value", "message"}
112
+ if string_key_allowlist is not None:
113
+ allowlist = {key.lower() for key in string_key_allowlist}
114
+ else:
115
+ allowlist = {key.lower() for key in allowlist}
116
+
117
+ prioritized_keys = ("text", "value")
118
+ handled_keys: Set[str] = {
119
+ "text",
120
+ "value",
121
+ "content",
122
+ "message",
123
+ "parts",
124
+ "input_text",
125
+ "output_text",
126
+ "type",
127
+ "role",
128
+ "name",
129
+ "id",
130
+ "index",
131
+ "finish_reason",
132
+ "reason",
133
+ "tool_call_id",
134
+ "metadata",
135
+ }
136
+
137
+ for key in prioritized_keys:
138
+ if key not in node:
139
+ continue
140
+ before_count = len(fragments)
141
+ gather_text_fragments(
142
+ node[key],
143
+ fragments,
144
+ allow_free_strings=True,
145
+ seen=seen,
146
+ string_key_allowlist=string_key_allowlist,
147
+ )
148
+ if len(fragments) > before_count:
149
+ break
150
+
151
+ if node.get("type") == "tool_result" and "content" in node:
152
+ gather_text_fragments(
153
+ node["content"],
154
+ fragments,
155
+ allow_free_strings=True,
156
+ seen=seen,
157
+ string_key_allowlist=string_key_allowlist,
158
+ )
159
+ elif "content" in node:
160
+ gather_text_fragments(
161
+ node["content"],
162
+ fragments,
163
+ allow_free_strings=True,
164
+ seen=seen,
165
+ string_key_allowlist=string_key_allowlist,
166
+ )
167
+
168
+ for key in ("message", "parts", "input_text", "output_text"):
169
+ if key in node:
170
+ gather_text_fragments(
171
+ node[key],
172
+ fragments,
173
+ allow_free_strings=True,
174
+ seen=seen,
175
+ string_key_allowlist=string_key_allowlist,
176
+ )
177
+
178
+ for key, value in node.items():
179
+ if key in handled_keys:
180
+ continue
181
+ if isinstance(value, (list, dict)):
182
+ gather_text_fragments(
183
+ value,
184
+ fragments,
185
+ allow_free_strings=False,
186
+ seen=seen,
187
+ string_key_allowlist=string_key_allowlist,
188
+ )
189
+ elif isinstance(value, str) and key.lower() in allowlist:
190
+ stripped = value.strip()
191
+ if stripped:
192
+ fragments.append(stripped)
193
+
194
+
195
+ def collect_text_fragments(
196
+ node: Any,
197
+ *,
198
+ allow_free_strings: bool = False,
199
+ string_key_allowlist: Optional[Collection[str]] = None,
200
+ ) -> list[str]:
201
+ fragments: list[str] = []
202
+ gather_text_fragments(
203
+ node,
204
+ fragments,
205
+ allow_free_strings=allow_free_strings,
206
+ seen=set(),
207
+ string_key_allowlist=string_key_allowlist,
208
+ )
209
+ return fragments