dataforge-07 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. dataforge/__init__.py +204 -0
  2. dataforge/__main__.py +5 -0
  3. dataforge/agent/__init__.py +16 -0
  4. dataforge/agent/providers.py +259 -0
  5. dataforge/agent/scratchpad.py +183 -0
  6. dataforge/agent/tool_actions.py +343 -0
  7. dataforge/bench/__init__.py +31 -0
  8. dataforge/bench/core.py +426 -0
  9. dataforge/bench/groq_client.py +386 -0
  10. dataforge/bench/methods.py +443 -0
  11. dataforge/bench/report.py +309 -0
  12. dataforge/bench/runner.py +247 -0
  13. dataforge/causal/__init__.py +21 -0
  14. dataforge/causal/dag.py +174 -0
  15. dataforge/causal/pc.py +232 -0
  16. dataforge/causal/root_cause.py +193 -0
  17. dataforge/cli/__init__.py +50 -0
  18. dataforge/cli/audit.py +70 -0
  19. dataforge/cli/bench.py +154 -0
  20. dataforge/cli/common.py +267 -0
  21. dataforge/cli/constraints.py +407 -0
  22. dataforge/cli/profile.py +147 -0
  23. dataforge/cli/release.py +166 -0
  24. dataforge/cli/repair.py +407 -0
  25. dataforge/cli/revert.py +139 -0
  26. dataforge/cli/watch.py +144 -0
  27. dataforge/datasets/__init__.py +25 -0
  28. dataforge/datasets/embedded/hospital/clean.csv +11 -0
  29. dataforge/datasets/embedded/hospital/dirty.csv +11 -0
  30. dataforge/datasets/real_world.py +290 -0
  31. dataforge/datasets/registry.py +103 -0
  32. dataforge/detectors/__init__.py +80 -0
  33. dataforge/detectors/base.py +145 -0
  34. dataforge/detectors/decimal_shift.py +166 -0
  35. dataforge/detectors/fd_violation.py +157 -0
  36. dataforge/detectors/type_mismatch.py +173 -0
  37. dataforge/engine/__init__.py +39 -0
  38. dataforge/engine/repair.py +905 -0
  39. dataforge/env/__init__.py +22 -0
  40. dataforge/env/environment.py +883 -0
  41. dataforge/env/observation.py +61 -0
  42. dataforge/env/openenv_core.py +161 -0
  43. dataforge/env/reward.py +128 -0
  44. dataforge/env/server.py +176 -0
  45. dataforge/evaluation_contract.py +76 -0
  46. dataforge/fixtures/hospital_10rows.csv +11 -0
  47. dataforge/fixtures/hospital_schema.yaml +17 -0
  48. dataforge/http/__init__.py +1 -0
  49. dataforge/http/problem.py +103 -0
  50. dataforge/integrations/__init__.py +1 -0
  51. dataforge/integrations/dbt.py +164 -0
  52. dataforge/observability.py +76 -0
  53. dataforge/py.typed +1 -0
  54. dataforge/release/__init__.py +1 -0
  55. dataforge/release/doctor.py +367 -0
  56. dataforge/release/full_vision.py +702 -0
  57. dataforge/release/gate.py +861 -0
  58. dataforge/release/playground_check.py +411 -0
  59. dataforge/repair_contract.py +468 -0
  60. dataforge/repairers/__init__.py +88 -0
  61. dataforge/repairers/base.py +77 -0
  62. dataforge/repairers/decimal_shift.py +43 -0
  63. dataforge/repairers/fd_violation.py +225 -0
  64. dataforge/repairers/type_mismatch.py +73 -0
  65. dataforge/safety/__init__.py +5 -0
  66. dataforge/safety/adversarial/attack_01_phone_pii.yaml +8 -0
  67. dataforge/safety/adversarial/attack_02_phone_pii.yaml +8 -0
  68. dataforge/safety/adversarial/attack_03_phone_pii.yaml +8 -0
  69. dataforge/safety/adversarial/attack_04_phone_pii.yaml +8 -0
  70. dataforge/safety/adversarial/attack_05_phone_pii.yaml +8 -0
  71. dataforge/safety/adversarial/attack_06_phone_pii.yaml +8 -0
  72. dataforge/safety/adversarial/attack_07_phone_pii.yaml +8 -0
  73. dataforge/safety/adversarial/attack_08_phone_pii.yaml +8 -0
  74. dataforge/safety/adversarial/attack_09_phone_pii.yaml +8 -0
  75. dataforge/safety/adversarial/attack_10_phone_pii.yaml +8 -0
  76. dataforge/safety/adversarial/attack_11_ssn_pii.yaml +8 -0
  77. dataforge/safety/adversarial/attack_12_ssn_pii.yaml +8 -0
  78. dataforge/safety/adversarial/attack_13_ssn_pii.yaml +8 -0
  79. dataforge/safety/adversarial/attack_14_ssn_pii.yaml +8 -0
  80. dataforge/safety/adversarial/attack_15_ssn_pii.yaml +8 -0
  81. dataforge/safety/adversarial/attack_16_ssn_pii.yaml +8 -0
  82. dataforge/safety/adversarial/attack_17_ssn_pii.yaml +8 -0
  83. dataforge/safety/adversarial/attack_18_ssn_pii.yaml +8 -0
  84. dataforge/safety/adversarial/attack_19_ssn_pii.yaml +8 -0
  85. dataforge/safety/adversarial/attack_20_ssn_pii.yaml +8 -0
  86. dataforge/safety/adversarial/attack_21_email_pii.yaml +8 -0
  87. dataforge/safety/adversarial/attack_22_email_pii.yaml +8 -0
  88. dataforge/safety/adversarial/attack_23_email_pii.yaml +8 -0
  89. dataforge/safety/adversarial/attack_24_email_pii.yaml +8 -0
  90. dataforge/safety/adversarial/attack_25_email_pii.yaml +8 -0
  91. dataforge/safety/adversarial/attack_26_email_pii.yaml +8 -0
  92. dataforge/safety/adversarial/attack_27_email_pii.yaml +8 -0
  93. dataforge/safety/adversarial/attack_28_email_pii.yaml +8 -0
  94. dataforge/safety/adversarial/attack_29_email_pii.yaml +8 -0
  95. dataforge/safety/adversarial/attack_30_email_pii.yaml +8 -0
  96. dataforge/safety/adversarial/attack_31_row_delete.yaml +7 -0
  97. dataforge/safety/adversarial/attack_32_row_delete.yaml +8 -0
  98. dataforge/safety/adversarial/attack_33_row_delete.yaml +7 -0
  99. dataforge/safety/adversarial/attack_34_row_delete.yaml +7 -0
  100. dataforge/safety/adversarial/attack_35_row_delete.yaml +7 -0
  101. dataforge/safety/adversarial/attack_36_row_delete.yaml +11 -0
  102. dataforge/safety/adversarial/attack_37_row_delete.yaml +7 -0
  103. dataforge/safety/adversarial/attack_38_row_delete.yaml +7 -0
  104. dataforge/safety/adversarial/attack_39_row_delete.yaml +8 -0
  105. dataforge/safety/adversarial/attack_40_row_delete.yaml +7 -0
  106. dataforge/safety/adversarial/attack_41_row_delete.yaml +7 -0
  107. dataforge/safety/adversarial/attack_42_row_delete.yaml +7 -0
  108. dataforge/safety/adversarial/attack_43_row_delete.yaml +7 -0
  109. dataforge/safety/adversarial/attack_44_row_delete.yaml +7 -0
  110. dataforge/safety/adversarial/attack_45_row_delete.yaml +8 -0
  111. dataforge/safety/adversarial/attack_46_row_delete.yaml +8 -0
  112. dataforge/safety/adversarial/attack_47_row_delete.yaml +7 -0
  113. dataforge/safety/adversarial/attack_48_row_delete.yaml +7 -0
  114. dataforge/safety/adversarial/attack_49_row_delete.yaml +8 -0
  115. dataforge/safety/adversarial/attack_50_row_delete.yaml +7 -0
  116. dataforge/safety/constitution.py +307 -0
  117. dataforge/safety/constitutions/default.yaml +40 -0
  118. dataforge/safety/filter.py +134 -0
  119. dataforge/schema_inference.py +620 -0
  120. dataforge/stores/__init__.py +46 -0
  121. dataforge/stores/base.py +73 -0
  122. dataforge/stores/cloud.py +78 -0
  123. dataforge/stores/csv.py +94 -0
  124. dataforge/stores/duckdb.py +313 -0
  125. dataforge/stores/patch_plan.py +178 -0
  126. dataforge/stores/registry.py +82 -0
  127. dataforge/stores/repair.py +121 -0
  128. dataforge/stores/revert.py +22 -0
  129. dataforge/stores/sql.py +27 -0
  130. dataforge/table.py +228 -0
  131. dataforge/transactions/__init__.py +34 -0
  132. dataforge/transactions/files.py +96 -0
  133. dataforge/transactions/log.py +613 -0
  134. dataforge/transactions/revert.py +102 -0
  135. dataforge/transactions/txn.py +104 -0
  136. dataforge/ui/__init__.py +1 -0
  137. dataforge/ui/profile_view.py +136 -0
  138. dataforge/ui/repair_diff.py +91 -0
  139. dataforge/verifier/__init__.py +55 -0
  140. dataforge/verifier/constraint_ir.py +155 -0
  141. dataforge/verifier/explain.py +47 -0
  142. dataforge/verifier/gate.py +5 -0
  143. dataforge/verifier/schema.py +111 -0
  144. dataforge/verifier/smt.py +433 -0
  145. dataforge_07-0.1.0.dist-info/METADATA +436 -0
  146. dataforge_07-0.1.0.dist-info/RECORD +150 -0
  147. dataforge_07-0.1.0.dist-info/WHEEL +5 -0
  148. dataforge_07-0.1.0.dist-info/entry_points.txt +3 -0
  149. dataforge_07-0.1.0.dist-info/licenses/LICENSE +176 -0
  150. dataforge_07-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,309 @@
1
+ """Benchmark report rendering and README marker updates."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from collections import defaultdict
7
+ from pathlib import Path
8
+ from typing import cast
9
+
10
+ from dataforge.bench.core import AggregateBenchmarkResult, BenchmarkRunOutput
11
+
12
+
13
+ def _format_metric(mean_value: float | None, std_value: float | None) -> str:
14
+ """Format a mean/std metric cell for markdown tables."""
15
+ if mean_value is None:
16
+ return "Skipped"
17
+ if std_value is None:
18
+ return f"{mean_value:.4f}"
19
+ return f"{mean_value:.4f} +/- {std_value:.4f}"
20
+
21
+
22
+ def _render_table(headers: list[str], rows: list[list[str]]) -> str:
23
+ """Render a simple markdown table."""
24
+ lines = [
25
+ "| " + " | ".join(headers) + " |",
26
+ "| " + " | ".join("---" for _ in headers) + " |",
27
+ ]
28
+ for row in rows:
29
+ lines.append("| " + " | ".join(row) + " |")
30
+ return "\n".join(lines)
31
+
32
+
33
+ def load_agent_output(path: Path) -> BenchmarkRunOutput:
34
+ """Load agent comparison JSON output."""
35
+ return BenchmarkRunOutput.model_validate(json.loads(path.read_text(encoding="utf-8")))
36
+
37
+
38
+ def load_sota_output(path: Path) -> dict[str, object]:
39
+ """Load citation-only SOTA comparison JSON output."""
40
+ raw = json.loads(path.read_text(encoding="utf-8"))
41
+ if not isinstance(raw, dict):
42
+ raise ValueError("SOTA comparison JSON must be a top-level object.")
43
+ return cast(dict[str, object], raw)
44
+
45
+
46
+ def replace_benchmark_block(readme_text: str, block_text: str) -> str:
47
+ """Replace a benchmark marker block idempotently."""
48
+ start_marker = "<!-- BENCH:START -->"
49
+ end_marker = "<!-- BENCH:END -->"
50
+ if start_marker not in readme_text or end_marker not in readme_text:
51
+ raise ValueError("Benchmark markers are missing.")
52
+ start = readme_text.index(start_marker) + len(start_marker)
53
+ end = readme_text.index(end_marker)
54
+ return readme_text[:start] + "\n" + block_text.strip() + "\n" + readme_text[end:]
55
+
56
+
57
+ def _aggregate_across_datasets(aggregates: list[AggregateBenchmarkResult]) -> list[list[str]]:
58
+ """Build a simple cross-dataset local summary table."""
59
+ grouped: dict[str, list[AggregateBenchmarkResult]] = defaultdict(list)
60
+ skipped: dict[str, str | None] = {}
61
+ for aggregate in aggregates:
62
+ if aggregate.status == "ok":
63
+ grouped[aggregate.method].append(aggregate)
64
+ else:
65
+ skipped.setdefault(aggregate.method, aggregate.skip_reason)
66
+
67
+ rows: list[list[str]] = []
68
+ methods = sorted(set(grouped) | set(skipped))
69
+ for method in methods:
70
+ ok_rows = grouped.get(method, [])
71
+ if not ok_rows:
72
+ rows.append([method, "Skipped", "Skipped", "Skipped", "Skipped", "Skipped", "Skipped"])
73
+ continue
74
+ p_mean = sum(row.precision_mean or 0.0 for row in ok_rows) / len(ok_rows)
75
+ r_mean = sum(row.recall_mean or 0.0 for row in ok_rows) / len(ok_rows)
76
+ f_mean = sum(row.f1_mean or 0.0 for row in ok_rows) / len(ok_rows)
77
+ step_mean = sum(row.avg_steps_mean or 0.0 for row in ok_rows) / len(ok_rows)
78
+ quota_mean = sum(row.quota_units_mean or 0.0 for row in ok_rows) / len(ok_rows)
79
+ gpu_hours_mean = sum(row.gpu_hours_mean or 0.0 for row in ok_rows) / len(ok_rows)
80
+ rows.append(
81
+ [
82
+ method,
83
+ f"{p_mean:.4f}",
84
+ f"{r_mean:.4f}",
85
+ f"{f_mean:.4f}",
86
+ f"{step_mean:.2f}",
87
+ f"{quota_mean:.4f}",
88
+ f"{gpu_hours_mean:.4f}",
89
+ ]
90
+ )
91
+ return rows
92
+
93
+
94
+ def _collect_skip_reasons(aggregates: list[AggregateBenchmarkResult]) -> list[str]:
95
+ """Collect distinct aggregate skip reasons in stable order."""
96
+ reasons: list[str] = []
97
+ for aggregate in aggregates:
98
+ reason = aggregate.skip_reason
99
+ if aggregate.status == "ok" or reason is None or reason in reasons:
100
+ continue
101
+ reasons.append(reason)
102
+ return reasons
103
+
104
+
105
+ def _metadata_list(metadata: dict[str, object], key: str) -> list[str]:
106
+ value = metadata.get(key, [])
107
+ if not isinstance(value, list):
108
+ return []
109
+ return [str(item) for item in value]
110
+
111
+
112
+ def _dataset_revision_summary(agent_output: BenchmarkRunOutput) -> str:
113
+ raw_evidence = agent_output.metadata.get("dataset_evidence", [])
114
+ if not isinstance(raw_evidence, list):
115
+ return ""
116
+ revisions: list[str] = []
117
+ dataset_names: list[str] = []
118
+ for item in raw_evidence:
119
+ if not isinstance(item, dict):
120
+ continue
121
+ name = str(item.get("name", "")).strip()
122
+ revision = str(item.get("source_revision", "")).strip()
123
+ if name:
124
+ dataset_names.append(name)
125
+ if revision and revision not in revisions:
126
+ revisions.append(revision)
127
+ if not revisions:
128
+ return ""
129
+ return (
130
+ "\n\nDataset bytes are pinned to BigDaMa/raha revision "
131
+ f"`{', '.join(revisions)}` for {', '.join(dataset_names)}; "
132
+ "dirty/clean SHA-256s are recorded in the JSON metadata."
133
+ )
134
+
135
+
136
+ def build_readme_benchmark_block(agent_output: BenchmarkRunOutput, report_path: Path) -> str:
137
+ """Build the generated README benchmark summary block."""
138
+ rows = _aggregate_across_datasets(agent_output.aggregates)
139
+ table = _render_table(
140
+ ["Method", "Precision", "Recall", "F1", "Avg Steps", "Quota Units", "GPU Hours"],
141
+ rows,
142
+ )
143
+ skip_reasons = _collect_skip_reasons(agent_output.aggregates)
144
+ skip_note = ""
145
+ if skip_reasons:
146
+ skip_note = "\n\nSkipped methods in this run: " + "; ".join(skip_reasons)
147
+ schema_version = str(agent_output.metadata.get("schema_version", "legacy"))
148
+ seed_values = _metadata_list(agent_output.metadata, "seed_list")
149
+ if not seed_values:
150
+ seed_values = [str(agent_output.metadata.get("seeds", ""))]
151
+ git_commit = str(agent_output.metadata.get("git_commit", "unknown"))
152
+ git_dirty = str(agent_output.metadata.get("git_dirty", "unknown")).lower()
153
+ dataset_summary = _dataset_revision_summary(agent_output)
154
+ return (
155
+ "Generated from `eval/results/agent_comparison.json` "
156
+ f"(schema `{schema_version}`, seeds `{', '.join(seed_values)}`, "
157
+ f"git `{git_commit[:12]}`, dirty `{git_dirty}`).\n\n"
158
+ f"{table}\n\n"
159
+ f"See `{report_path.name}` for per-dataset tables, error bars, and citation-only SOTA rows."
160
+ f"{dataset_summary}"
161
+ f"{skip_note}"
162
+ )
163
+
164
+
165
+ def render_benchmark_report(
166
+ agent_output: BenchmarkRunOutput,
167
+ sota_output: dict[str, object],
168
+ ) -> str:
169
+ """Render the full markdown benchmark report."""
170
+ per_dataset_sections: list[str] = []
171
+ by_dataset: dict[str, list[AggregateBenchmarkResult]] = defaultdict(list)
172
+ for aggregate in agent_output.aggregates:
173
+ by_dataset[aggregate.dataset].append(aggregate)
174
+
175
+ for dataset, rows in by_dataset.items():
176
+ table_rows = [
177
+ [
178
+ row.method,
179
+ _format_metric(row.precision_mean, row.precision_std),
180
+ _format_metric(row.recall_mean, row.recall_std),
181
+ _format_metric(row.f1_mean, row.f1_std),
182
+ _format_metric(row.avg_steps_mean, row.avg_steps_std),
183
+ _format_metric(row.quota_units_mean, row.quota_units_std),
184
+ _format_metric(row.gpu_hours_mean, row.gpu_hours_std),
185
+ ]
186
+ for row in rows
187
+ ]
188
+ per_dataset_sections.append(
189
+ f"### {dataset.title()}\n\n"
190
+ + _render_table(
191
+ [
192
+ "Method",
193
+ "Precision",
194
+ "Recall",
195
+ "F1",
196
+ "Avg Steps",
197
+ "Quota Units",
198
+ "GPU Hours",
199
+ ],
200
+ table_rows,
201
+ )
202
+ )
203
+
204
+ local_summary = _render_table(
205
+ ["Method", "Precision", "Recall", "F1", "Avg Steps", "Quota Units", "GPU Hours"],
206
+ _aggregate_across_datasets(agent_output.aggregates),
207
+ )
208
+
209
+ raw_rows = sota_output.get("rows", [])
210
+ if not isinstance(raw_rows, list):
211
+ raw_rows = []
212
+ sota_rows = [
213
+ [
214
+ str(row["method"]),
215
+ str(row["dataset"]),
216
+ f"{float(row['precision']):.3f}",
217
+ f"{float(row['recall']):.3f}",
218
+ f"{float(row['f1']):.3f}",
219
+ str(row.get("note", "Citation-only literature result.")),
220
+ ]
221
+ for row in raw_rows
222
+ if isinstance(row, dict)
223
+ ]
224
+ source = sota_output.get("source", {})
225
+ source_title = (
226
+ source.get("title", "Unknown source") if isinstance(source, dict) else "Unknown source"
227
+ )
228
+ source_url = source.get("url", "") if isinstance(source, dict) else ""
229
+ source_table = source.get("table", "") if isinstance(source, dict) else ""
230
+ source_hash = source.get("source_sha256", "") if isinstance(source, dict) else ""
231
+ source_retrieved = source.get("retrieved_at_utc", "") if isinstance(source, dict) else ""
232
+ skip_reasons = _collect_skip_reasons(agent_output.aggregates)
233
+ skip_note = ""
234
+ if skip_reasons:
235
+ skip_note = "\nSkipped methods in this reproduced run: " + "; ".join(skip_reasons) + "\n"
236
+
237
+ method_values = agent_output.metadata.get("methods", [])
238
+ dataset_values = agent_output.metadata.get("datasets", [])
239
+ methods = [str(method) for method in method_values] if isinstance(method_values, list) else []
240
+ datasets = (
241
+ [str(dataset) for dataset in dataset_values] if isinstance(dataset_values, list) else []
242
+ )
243
+ seed_list = _metadata_list(agent_output.metadata, "seed_list")
244
+ seeds = str(agent_output.metadata.get("seeds", ""))
245
+ reproduction_command = str(agent_output.metadata.get("reproduction_command", ""))
246
+ schema_version = str(agent_output.metadata.get("schema_version", "legacy"))
247
+ git_commit = str(agent_output.metadata.get("git_commit", "unknown"))
248
+ git_dirty = str(agent_output.metadata.get("git_dirty", "unknown")).lower()
249
+ dataset_summary = _dataset_revision_summary(agent_output).strip()
250
+
251
+ return (
252
+ "# Benchmark Report\n\n"
253
+ "## Reproduction\n\n"
254
+ f"`{reproduction_command}`\n\n"
255
+ "## Configuration\n\n"
256
+ f"- Methods: {', '.join(methods)}\n"
257
+ f"- Datasets: {', '.join(datasets)}\n"
258
+ f"- Seeds: {seeds}\n"
259
+ f"- Exact seed list: {', '.join(seed_list) if seed_list else seeds}\n"
260
+ f"- Evidence schema: `{schema_version}`\n"
261
+ f"- Git commit: `{git_commit}`; dirty worktree: `{git_dirty}`\n"
262
+ "- Free-tier quota units: `max(llm_calls / 1000, (prompt_tokens + completion_tokens) / 100000)`\n"
263
+ "- GRPO compute cost is reported as free-tier GPU-hours, not dollars.\n"
264
+ + (f"- {dataset_summary}\n" if dataset_summary else "")
265
+ + f"{skip_note}\n"
266
+ + "## Cross-Dataset Local Results\n\n"
267
+ + f"{local_summary}\n\n"
268
+ + "## Per-Dataset Local Results\n\n"
269
+ + "\n\n".join(per_dataset_sections)
270
+ + "\n\n## Citation-Only SOTA Reference\n\n"
271
+ + f"Source: [{source_title}]({source_url}); {source_table}; "
272
+ + f"source SHA-256 `{source_hash}`; retrieved `{source_retrieved}`.\n\n"
273
+ + "HoloClean rows are transcribed from BClean Table 4; see "
274
+ + "[HoloClean 2017](https://www.vldb.org/pvldb/vol10/p1190-rekatsinas.pdf) "
275
+ + "for the original system description.\n\n"
276
+ + _render_table(
277
+ ["Method", "Dataset", "Precision", "Recall", "F1", "Note"],
278
+ sota_rows,
279
+ )
280
+ + "\n\n## Methodology\n\n"
281
+ + "Local rows are reproduced from generated JSON. Citation-only SOTA rows are copied "
282
+ + "from literature and are not rerun in this repository. LLM quota units are free-tier "
283
+ + "fractions; GRPO compute cost is GPU-hours, not dollars.\n"
284
+ )
285
+
286
+
287
+ def write_benchmark_outputs(
288
+ *,
289
+ agent_json_path: Path,
290
+ sota_json_path: Path,
291
+ report_path: Path,
292
+ readme_path: Path,
293
+ homepage_path: Path | None = None,
294
+ ) -> None:
295
+ """Generate the benchmark report and patch generated public evidence blocks."""
296
+ agent_output = load_agent_output(agent_json_path)
297
+ sota_output = load_sota_output(sota_json_path)
298
+ report_text = render_benchmark_report(agent_output, sota_output)
299
+ report_path.write_text(report_text, encoding="utf-8")
300
+
301
+ readme_text = readme_path.read_text(encoding="utf-8")
302
+ benchmark_block = build_readme_benchmark_block(agent_output, report_path)
303
+ updated_readme = replace_benchmark_block(readme_text, benchmark_block)
304
+ readme_path.write_text(updated_readme, encoding="utf-8")
305
+
306
+ if homepage_path is not None:
307
+ homepage_text = homepage_path.read_text(encoding="utf-8")
308
+ updated_homepage = replace_benchmark_block(homepage_text, benchmark_block)
309
+ homepage_path.write_text(updated_homepage, encoding="utf-8")
@@ -0,0 +1,247 @@
1
+ """Top-level benchmark orchestration for agent comparison runs."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ from dotenv import load_dotenv
10
+
11
+ from dataforge.bench.core import (
12
+ AggregateBenchmarkResult,
13
+ BenchmarkRunOutput,
14
+ SeedBenchmarkResult,
15
+ aggregate_seed_results,
16
+ build_benchmark_metadata,
17
+ build_seed_list,
18
+ dataset_evidence_from_loaded,
19
+ estimate_llm_calls,
20
+ validate_estimated_calls,
21
+ write_run_output,
22
+ )
23
+ from dataforge.bench.groq_client import GroqBenchClient
24
+ from dataforge.bench.methods import (
25
+ run_heuristic_episode,
26
+ run_llm_react_episode,
27
+ run_llm_zeroshot_episode,
28
+ run_random_episode,
29
+ )
30
+ from dataforge.datasets.real_world import load_real_world_dataset
31
+ from dataforge.datasets.registry import DATASET_REGISTRY
32
+
33
+ _SUPPORTED_METHODS = frozenset({"random", "heuristic", "llm_zeroshot", "llm_react"})
34
+
35
+
36
+ def _validate_inputs(methods: list[str], datasets: list[str]) -> None:
37
+ """Validate user-selected methods and datasets."""
38
+ unknown_methods = sorted(set(methods) - _SUPPORTED_METHODS)
39
+ unknown_datasets = sorted(set(datasets) - set(DATASET_REGISTRY))
40
+ if unknown_methods:
41
+ raise ValueError(f"Unknown benchmark methods: {unknown_methods}")
42
+ if unknown_datasets:
43
+ raise ValueError(f"Unknown benchmark datasets: {unknown_datasets}")
44
+
45
+
46
+ def _reproduction_command(
47
+ methods: list[str],
48
+ datasets: list[str],
49
+ *,
50
+ seed_count: int,
51
+ seed_list: list[int] | None,
52
+ ) -> str:
53
+ """Build the canonical command for reproducing a benchmark run."""
54
+ command = (
55
+ "dataforge bench "
56
+ f"--methods {','.join(methods)} "
57
+ f"--datasets {','.join(datasets)} "
58
+ f"--seeds {seed_count}"
59
+ )
60
+ if seed_list is not None:
61
+ command += f" --seed-list {','.join(str(seed) for seed in seed_list)}"
62
+ return command
63
+
64
+
65
+ def _llm_skip_reason() -> str | None:
66
+ """Return a skip reason when LLM methods cannot run."""
67
+ provider = os.environ.get("DATAFORGE_LLM_PROVIDER", "").strip().lower()
68
+ if provider != "groq":
69
+ return "DATAFORGE_LLM_PROVIDER must be set to groq."
70
+ if not os.environ.get("GROQ_API_KEY"):
71
+ return "GROQ_API_KEY is not set."
72
+ return None
73
+
74
+
75
+ def _skipped_result(
76
+ *,
77
+ method: str,
78
+ dataset: str,
79
+ seed: int,
80
+ reason: str,
81
+ reproduction_command: str,
82
+ ) -> SeedBenchmarkResult:
83
+ """Build a skipped seed result with a clear reason."""
84
+ return SeedBenchmarkResult(
85
+ method=method,
86
+ dataset=dataset,
87
+ seed=seed,
88
+ status="skipped",
89
+ skip_reason=reason,
90
+ llm_calls=0,
91
+ prompt_tokens=0,
92
+ completion_tokens=0,
93
+ quota_units=0.0,
94
+ runtime_s=0.0,
95
+ provider=None,
96
+ model=None,
97
+ warnings=["provider_unset"],
98
+ reproduction_command=reproduction_command,
99
+ )
100
+
101
+
102
+ def run_agent_comparison(
103
+ *,
104
+ methods: list[str],
105
+ datasets: list[str],
106
+ seeds: int,
107
+ output_json: Path,
108
+ really_run_big_bench: bool,
109
+ cache_root: Path | None = None,
110
+ reproduction_command: str | None = None,
111
+ seed_list: list[int] | None = None,
112
+ verify_dataset_hashes: bool = True,
113
+ ) -> BenchmarkRunOutput:
114
+ """Run the selected benchmark methods across real-world datasets."""
115
+ load_dotenv()
116
+ _validate_inputs(methods, datasets)
117
+ resolved_seed_list = build_seed_list(seeds=seeds, seed_list=seed_list)
118
+
119
+ estimated_calls = estimate_llm_calls(
120
+ methods=methods,
121
+ datasets=datasets,
122
+ seeds=len(resolved_seed_list),
123
+ )
124
+ # Validate call budget before any client instantiation or dataset loads that could
125
+ # trigger network access in tests with environment variables set.
126
+ validate_estimated_calls(
127
+ estimated_calls=estimated_calls,
128
+ really_run_big_bench=really_run_big_bench,
129
+ )
130
+
131
+ reproduction_command = reproduction_command or _reproduction_command(
132
+ methods,
133
+ datasets,
134
+ seed_count=len(resolved_seed_list),
135
+ seed_list=seed_list,
136
+ )
137
+ records: list[SeedBenchmarkResult] = []
138
+ loaded_datasets = {
139
+ dataset_name: load_real_world_dataset(
140
+ dataset_name,
141
+ cache_root=cache_root,
142
+ verify_hashes=verify_dataset_hashes,
143
+ )
144
+ for dataset_name in datasets
145
+ }
146
+
147
+ llm_methods_requested = any(method.startswith("llm_") for method in methods)
148
+ skip_reason = _llm_skip_reason() if llm_methods_requested else None
149
+ client = None
150
+ if llm_methods_requested and skip_reason is None:
151
+ # Allow env-driven tuning for tiny CI checks.
152
+ model = os.environ.get("DATAFORGE_GROQ_MODEL", "llama-3.3-70b-versatile")
153
+ try:
154
+ min_interval_s = float(os.environ.get("DATAFORGE_GROQ_MIN_INTERVAL_S", "1.0"))
155
+ except ValueError:
156
+ min_interval_s = 1.0
157
+ try:
158
+ timeout_s = float(os.environ.get("DATAFORGE_GROQ_TIMEOUT_S", "30"))
159
+ except ValueError:
160
+ timeout_s = 30.0
161
+ try:
162
+ max_tokens = int(os.environ.get("DATAFORGE_GROQ_MAX_TOKENS", "256"))
163
+ except ValueError:
164
+ max_tokens = 256
165
+ try:
166
+ max_retries = int(os.environ.get("DATAFORGE_GROQ_MAX_RETRIES", "3"))
167
+ except ValueError:
168
+ max_retries = 3
169
+ client = GroqBenchClient(
170
+ api_key=os.environ["GROQ_API_KEY"],
171
+ model=model,
172
+ min_interval_s=min_interval_s,
173
+ max_tokens=max_tokens,
174
+ max_retries=max_retries,
175
+ timeout_s=timeout_s,
176
+ )
177
+
178
+ for dataset_name in datasets:
179
+ dataset = loaded_datasets[dataset_name]
180
+ for method in methods:
181
+ for seed in resolved_seed_list:
182
+ if os.environ.get("DATAFORGE_BENCH_VERBOSE"):
183
+ print(
184
+ f"[dataforge bench] start method={method} dataset={dataset_name} seed={seed}",
185
+ file=sys.stderr,
186
+ flush=True,
187
+ )
188
+ if method == "random":
189
+ result = run_random_episode(dataset, seed=seed)
190
+ elif method == "heuristic":
191
+ result = run_heuristic_episode(dataset, seed=seed)
192
+ elif method == "llm_zeroshot":
193
+ if client is None or skip_reason is not None:
194
+ result = _skipped_result(
195
+ method=method,
196
+ dataset=dataset_name,
197
+ seed=seed,
198
+ reason=skip_reason or "LLM client unavailable.",
199
+ reproduction_command=reproduction_command,
200
+ )
201
+ else:
202
+ result = run_llm_zeroshot_episode(dataset, seed=seed, client=client)
203
+ else:
204
+ if client is None or skip_reason is not None:
205
+ result = _skipped_result(
206
+ method=method,
207
+ dataset=dataset_name,
208
+ seed=seed,
209
+ reason=skip_reason or "LLM client unavailable.",
210
+ reproduction_command=reproduction_command,
211
+ )
212
+ else:
213
+ result = run_llm_react_episode(dataset, seed=seed, client=client)
214
+ if result.reproduction_command != reproduction_command:
215
+ result = result.model_copy(
216
+ update={"reproduction_command": reproduction_command}
217
+ )
218
+ if method == "heuristic":
219
+ result = result.model_copy(update={"seed": seed})
220
+ records.append(result)
221
+ if os.environ.get("DATAFORGE_BENCH_VERBOSE"):
222
+ print(
223
+ f"[dataforge bench] done method={method} dataset={dataset_name} seed={seed} status={result.status}",
224
+ file=sys.stderr,
225
+ flush=True,
226
+ )
227
+
228
+ aggregates: list[AggregateBenchmarkResult] = aggregate_seed_results(
229
+ records, seeds_requested=len(resolved_seed_list)
230
+ )
231
+ dataset_evidence = [
232
+ dataset_evidence_from_loaded(loaded_datasets[dataset_name]) for dataset_name in datasets
233
+ ]
234
+ metadata = build_benchmark_metadata(
235
+ methods=methods,
236
+ datasets=datasets,
237
+ seed_list=resolved_seed_list,
238
+ reproduction_command=reproduction_command,
239
+ dataset_evidence=dataset_evidence,
240
+ )
241
+ output = BenchmarkRunOutput(
242
+ metadata=metadata.model_dump(mode="json"),
243
+ records=records,
244
+ aggregates=aggregates,
245
+ )
246
+ write_run_output(output, output_json)
247
+ return output
@@ -0,0 +1,21 @@
1
+ """Causal analysis primitives for DataForge root-cause diagnosis."""
2
+
3
+ from dataforge.causal.dag import CausalDAG, CausalEdge
4
+ from dataforge.causal.pc import CausalDiscoveryResult, discover_causal_dag
5
+ from dataforge.causal.root_cause import (
6
+ CausalRootCauseAnalyzer,
7
+ ErrorEvidence,
8
+ RootCauseResult,
9
+ minimal_root_set,
10
+ )
11
+
12
+ __all__ = [
13
+ "CausalDAG",
14
+ "CausalDiscoveryResult",
15
+ "CausalEdge",
16
+ "CausalRootCauseAnalyzer",
17
+ "ErrorEvidence",
18
+ "RootCauseResult",
19
+ "discover_causal_dag",
20
+ "minimal_root_set",
21
+ ]