dataforge-07 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataforge/__init__.py +204 -0
- dataforge/__main__.py +5 -0
- dataforge/agent/__init__.py +16 -0
- dataforge/agent/providers.py +259 -0
- dataforge/agent/scratchpad.py +183 -0
- dataforge/agent/tool_actions.py +343 -0
- dataforge/bench/__init__.py +31 -0
- dataforge/bench/core.py +426 -0
- dataforge/bench/groq_client.py +386 -0
- dataforge/bench/methods.py +443 -0
- dataforge/bench/report.py +309 -0
- dataforge/bench/runner.py +247 -0
- dataforge/causal/__init__.py +21 -0
- dataforge/causal/dag.py +174 -0
- dataforge/causal/pc.py +232 -0
- dataforge/causal/root_cause.py +193 -0
- dataforge/cli/__init__.py +50 -0
- dataforge/cli/audit.py +70 -0
- dataforge/cli/bench.py +154 -0
- dataforge/cli/common.py +267 -0
- dataforge/cli/constraints.py +407 -0
- dataforge/cli/profile.py +147 -0
- dataforge/cli/release.py +166 -0
- dataforge/cli/repair.py +407 -0
- dataforge/cli/revert.py +139 -0
- dataforge/cli/watch.py +144 -0
- dataforge/datasets/__init__.py +25 -0
- dataforge/datasets/embedded/hospital/clean.csv +11 -0
- dataforge/datasets/embedded/hospital/dirty.csv +11 -0
- dataforge/datasets/real_world.py +290 -0
- dataforge/datasets/registry.py +103 -0
- dataforge/detectors/__init__.py +80 -0
- dataforge/detectors/base.py +145 -0
- dataforge/detectors/decimal_shift.py +166 -0
- dataforge/detectors/fd_violation.py +157 -0
- dataforge/detectors/type_mismatch.py +173 -0
- dataforge/engine/__init__.py +39 -0
- dataforge/engine/repair.py +905 -0
- dataforge/env/__init__.py +22 -0
- dataforge/env/environment.py +883 -0
- dataforge/env/observation.py +61 -0
- dataforge/env/openenv_core.py +161 -0
- dataforge/env/reward.py +128 -0
- dataforge/env/server.py +176 -0
- dataforge/evaluation_contract.py +76 -0
- dataforge/fixtures/hospital_10rows.csv +11 -0
- dataforge/fixtures/hospital_schema.yaml +17 -0
- dataforge/http/__init__.py +1 -0
- dataforge/http/problem.py +103 -0
- dataforge/integrations/__init__.py +1 -0
- dataforge/integrations/dbt.py +164 -0
- dataforge/observability.py +76 -0
- dataforge/py.typed +1 -0
- dataforge/release/__init__.py +1 -0
- dataforge/release/doctor.py +367 -0
- dataforge/release/full_vision.py +702 -0
- dataforge/release/gate.py +861 -0
- dataforge/release/playground_check.py +411 -0
- dataforge/repair_contract.py +468 -0
- dataforge/repairers/__init__.py +88 -0
- dataforge/repairers/base.py +77 -0
- dataforge/repairers/decimal_shift.py +43 -0
- dataforge/repairers/fd_violation.py +225 -0
- dataforge/repairers/type_mismatch.py +73 -0
- dataforge/safety/__init__.py +5 -0
- dataforge/safety/adversarial/attack_01_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_02_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_03_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_04_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_05_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_06_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_07_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_08_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_09_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_10_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_11_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_12_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_13_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_14_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_15_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_16_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_17_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_18_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_19_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_20_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_21_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_22_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_23_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_24_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_25_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_26_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_27_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_28_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_29_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_30_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_31_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_32_row_delete.yaml +8 -0
- dataforge/safety/adversarial/attack_33_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_34_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_35_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_36_row_delete.yaml +11 -0
- dataforge/safety/adversarial/attack_37_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_38_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_39_row_delete.yaml +8 -0
- dataforge/safety/adversarial/attack_40_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_41_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_42_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_43_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_44_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_45_row_delete.yaml +8 -0
- dataforge/safety/adversarial/attack_46_row_delete.yaml +8 -0
- dataforge/safety/adversarial/attack_47_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_48_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_49_row_delete.yaml +8 -0
- dataforge/safety/adversarial/attack_50_row_delete.yaml +7 -0
- dataforge/safety/constitution.py +307 -0
- dataforge/safety/constitutions/default.yaml +40 -0
- dataforge/safety/filter.py +134 -0
- dataforge/schema_inference.py +620 -0
- dataforge/stores/__init__.py +46 -0
- dataforge/stores/base.py +73 -0
- dataforge/stores/cloud.py +78 -0
- dataforge/stores/csv.py +94 -0
- dataforge/stores/duckdb.py +313 -0
- dataforge/stores/patch_plan.py +178 -0
- dataforge/stores/registry.py +82 -0
- dataforge/stores/repair.py +121 -0
- dataforge/stores/revert.py +22 -0
- dataforge/stores/sql.py +27 -0
- dataforge/table.py +228 -0
- dataforge/transactions/__init__.py +34 -0
- dataforge/transactions/files.py +96 -0
- dataforge/transactions/log.py +613 -0
- dataforge/transactions/revert.py +102 -0
- dataforge/transactions/txn.py +104 -0
- dataforge/ui/__init__.py +1 -0
- dataforge/ui/profile_view.py +136 -0
- dataforge/ui/repair_diff.py +91 -0
- dataforge/verifier/__init__.py +55 -0
- dataforge/verifier/constraint_ir.py +155 -0
- dataforge/verifier/explain.py +47 -0
- dataforge/verifier/gate.py +5 -0
- dataforge/verifier/schema.py +111 -0
- dataforge/verifier/smt.py +433 -0
- dataforge_07-0.1.0.dist-info/METADATA +436 -0
- dataforge_07-0.1.0.dist-info/RECORD +150 -0
- dataforge_07-0.1.0.dist-info/WHEEL +5 -0
- dataforge_07-0.1.0.dist-info/entry_points.txt +3 -0
- dataforge_07-0.1.0.dist-info/licenses/LICENSE +176 -0
- dataforge_07-0.1.0.dist-info/top_level.txt +1 -0
dataforge/bench/core.py
ADDED
|
@@ -0,0 +1,426 @@
|
|
|
1
|
+
"""Shared benchmark types, metrics, and quota helpers."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import importlib.metadata as package_metadata
|
|
6
|
+
import platform
|
|
7
|
+
import subprocess
|
|
8
|
+
import sys
|
|
9
|
+
from collections import OrderedDict
|
|
10
|
+
from datetime import UTC, datetime
|
|
11
|
+
from math import ceil
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from statistics import mean, stdev
|
|
14
|
+
from typing import Literal
|
|
15
|
+
|
|
16
|
+
from pydantic import BaseModel, Field
|
|
17
|
+
|
|
18
|
+
from dataforge.datasets.real_world import GroundTruthCell, RealWorldDataset
|
|
19
|
+
from dataforge.datasets.registry import DATASET_REGISTRY
|
|
20
|
+
|
|
21
|
+
BenchmarkStatus = Literal["ok", "skipped"]
|
|
22
|
+
BENCHMARK_SCHEMA_VERSION = "dataforge_benchmark_run_v2"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class BenchmarkRepair(BaseModel):
|
|
26
|
+
"""One benchmark repair prediction."""
|
|
27
|
+
|
|
28
|
+
row: int = Field(ge=0)
|
|
29
|
+
column: str = Field(min_length=1)
|
|
30
|
+
new_value: str
|
|
31
|
+
reason: str = Field(min_length=1)
|
|
32
|
+
|
|
33
|
+
model_config = {"frozen": True}
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class RepairScore(BaseModel):
|
|
37
|
+
"""Exact-match cell repair metrics for one episode."""
|
|
38
|
+
|
|
39
|
+
tp: int = Field(ge=0)
|
|
40
|
+
fp: int = Field(ge=0)
|
|
41
|
+
fn: int = Field(ge=0)
|
|
42
|
+
precision: float = Field(ge=0.0, le=1.0)
|
|
43
|
+
recall: float = Field(ge=0.0, le=1.0)
|
|
44
|
+
f1: float = Field(ge=0.0, le=1.0)
|
|
45
|
+
|
|
46
|
+
model_config = {"frozen": True}
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class SeedBenchmarkResult(BaseModel):
|
|
50
|
+
"""Benchmark result for one dataset/method/seed run."""
|
|
51
|
+
|
|
52
|
+
method: str = Field(min_length=1)
|
|
53
|
+
dataset: str = Field(min_length=1)
|
|
54
|
+
seed: int = Field(ge=0)
|
|
55
|
+
status: BenchmarkStatus
|
|
56
|
+
skip_reason: str | None = None
|
|
57
|
+
precision: float | None = None
|
|
58
|
+
recall: float | None = None
|
|
59
|
+
f1: float | None = None
|
|
60
|
+
tp: int | None = None
|
|
61
|
+
fp: int | None = None
|
|
62
|
+
fn: int | None = None
|
|
63
|
+
avg_steps: float | None = None
|
|
64
|
+
llm_calls: int = Field(ge=0, default=0)
|
|
65
|
+
prompt_tokens: int = Field(ge=0, default=0)
|
|
66
|
+
completion_tokens: int = Field(ge=0, default=0)
|
|
67
|
+
quota_units: float = Field(ge=0.0, default=0.0)
|
|
68
|
+
gpu_hours: float = Field(ge=0.0, default=0.0)
|
|
69
|
+
runtime_s: float = Field(ge=0.0, default=0.0)
|
|
70
|
+
provider: str | None = None
|
|
71
|
+
model: str | None = None
|
|
72
|
+
warnings: list[str] = Field(default_factory=list)
|
|
73
|
+
reproduction_command: str = Field(min_length=1)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class AggregateBenchmarkResult(BaseModel):
|
|
77
|
+
"""Aggregated benchmark result across seeds for one method/dataset pair."""
|
|
78
|
+
|
|
79
|
+
method: str = Field(min_length=1)
|
|
80
|
+
dataset: str = Field(min_length=1)
|
|
81
|
+
status: BenchmarkStatus
|
|
82
|
+
skip_reason: str | None = None
|
|
83
|
+
seeds_requested: int = Field(ge=0)
|
|
84
|
+
seeds_completed: int = Field(ge=0)
|
|
85
|
+
precision_mean: float | None = None
|
|
86
|
+
precision_std: float | None = None
|
|
87
|
+
recall_mean: float | None = None
|
|
88
|
+
recall_std: float | None = None
|
|
89
|
+
f1_mean: float | None = None
|
|
90
|
+
f1_std: float | None = None
|
|
91
|
+
avg_steps_mean: float | None = None
|
|
92
|
+
avg_steps_std: float | None = None
|
|
93
|
+
quota_units_mean: float | None = None
|
|
94
|
+
quota_units_std: float | None = None
|
|
95
|
+
gpu_hours_mean: float | None = None
|
|
96
|
+
gpu_hours_std: float | None = None
|
|
97
|
+
runtime_s_mean: float | None = None
|
|
98
|
+
runtime_s_std: float | None = None
|
|
99
|
+
provider: str | None = None
|
|
100
|
+
model: str | None = None
|
|
101
|
+
reproduction_command: str = Field(min_length=1)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class BenchmarkRunOutput(BaseModel):
|
|
105
|
+
"""Serializable benchmark run output."""
|
|
106
|
+
|
|
107
|
+
metadata: dict[str, object]
|
|
108
|
+
records: list[SeedBenchmarkResult]
|
|
109
|
+
aggregates: list[AggregateBenchmarkResult]
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class BenchmarkDatasetEvidence(BaseModel):
|
|
113
|
+
"""Pinned source and loaded artifact evidence for one benchmark dataset."""
|
|
114
|
+
|
|
115
|
+
name: str = Field(min_length=1)
|
|
116
|
+
source_urls: tuple[str, str]
|
|
117
|
+
source_revision: str = Field(min_length=7)
|
|
118
|
+
dirty_sha256: str = Field(min_length=64, max_length=64)
|
|
119
|
+
clean_sha256: str = Field(min_length=64, max_length=64)
|
|
120
|
+
n_rows: int = Field(ge=0)
|
|
121
|
+
n_columns: int = Field(ge=1)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class BenchmarkEvidenceMetadata(BaseModel):
|
|
125
|
+
"""Typed provenance block written into benchmark JSON artifacts."""
|
|
126
|
+
|
|
127
|
+
schema_version: str = BENCHMARK_SCHEMA_VERSION
|
|
128
|
+
methods: list[str]
|
|
129
|
+
datasets: list[str]
|
|
130
|
+
seeds: int = Field(ge=1)
|
|
131
|
+
seed_list: list[int]
|
|
132
|
+
git_commit: str | None
|
|
133
|
+
git_dirty: bool | None
|
|
134
|
+
generated_at_utc: str
|
|
135
|
+
python_version: str
|
|
136
|
+
platform: str
|
|
137
|
+
dependency_versions: dict[str, str]
|
|
138
|
+
generator_command: str
|
|
139
|
+
reproduction_command: str
|
|
140
|
+
dataset_evidence: list[BenchmarkDatasetEvidence]
|
|
141
|
+
artifact_sha256s: dict[str, str]
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def build_seed_list(*, seeds: int, seed_list: list[int] | None = None) -> list[int]:
|
|
145
|
+
"""Resolve either a seed count or explicit seed list into concrete seeds."""
|
|
146
|
+
if seed_list is not None:
|
|
147
|
+
if not seed_list:
|
|
148
|
+
raise ValueError("Benchmark seed list must contain at least one seed.")
|
|
149
|
+
if any(seed < 0 for seed in seed_list):
|
|
150
|
+
raise ValueError("Benchmark seeds must be >= 0.")
|
|
151
|
+
if len(set(seed_list)) != len(seed_list):
|
|
152
|
+
raise ValueError("Benchmark seed list must not contain duplicates.")
|
|
153
|
+
return list(seed_list)
|
|
154
|
+
if seeds <= 0:
|
|
155
|
+
raise ValueError("Benchmark seeds must be >= 1.")
|
|
156
|
+
return list(range(seeds))
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _package_version(name: str) -> str:
|
|
160
|
+
"""Return an installed package version or a stable missing marker."""
|
|
161
|
+
try:
|
|
162
|
+
return package_metadata.version(name)
|
|
163
|
+
except package_metadata.PackageNotFoundError:
|
|
164
|
+
return "not-installed"
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def benchmark_dependency_versions() -> dict[str, str]:
|
|
168
|
+
"""Return versions of dependencies that influence benchmark behavior."""
|
|
169
|
+
return {
|
|
170
|
+
"dataforge": _package_version("dataforge"),
|
|
171
|
+
"httpx": _package_version("httpx"),
|
|
172
|
+
"pandas": _package_version("pandas"),
|
|
173
|
+
"pydantic": _package_version("pydantic"),
|
|
174
|
+
"python-dotenv": _package_version("python-dotenv"),
|
|
175
|
+
"typer": _package_version("typer"),
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def _project_root() -> Path:
|
|
180
|
+
"""Return the source checkout root when running from this repository."""
|
|
181
|
+
return Path(__file__).resolve().parents[2]
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def _git_command(args: list[str]) -> str | None:
|
|
185
|
+
"""Run a read-only git command and return stdout when available."""
|
|
186
|
+
try:
|
|
187
|
+
result = subprocess.run(
|
|
188
|
+
["git", *args],
|
|
189
|
+
cwd=_project_root(),
|
|
190
|
+
check=False,
|
|
191
|
+
capture_output=True,
|
|
192
|
+
text=True,
|
|
193
|
+
encoding="utf-8",
|
|
194
|
+
errors="replace",
|
|
195
|
+
timeout=10,
|
|
196
|
+
)
|
|
197
|
+
except (FileNotFoundError, subprocess.TimeoutExpired):
|
|
198
|
+
return None
|
|
199
|
+
if result.returncode != 0:
|
|
200
|
+
return None
|
|
201
|
+
return result.stdout.strip()
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def current_git_commit() -> str | None:
|
|
205
|
+
"""Return the current source commit, if this checkout is under git."""
|
|
206
|
+
return _git_command(["rev-parse", "HEAD"])
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def git_worktree_dirty() -> bool | None:
|
|
210
|
+
"""Return whether the checkout has tracked or untracked changes."""
|
|
211
|
+
status = _git_command(["status", "--porcelain"])
|
|
212
|
+
if status is None:
|
|
213
|
+
return None
|
|
214
|
+
return bool(status)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def dataset_evidence_from_loaded(dataset: RealWorldDataset) -> BenchmarkDatasetEvidence:
|
|
218
|
+
"""Build source and loaded-byte evidence for one dataset."""
|
|
219
|
+
return BenchmarkDatasetEvidence(
|
|
220
|
+
name=dataset.metadata.name,
|
|
221
|
+
source_urls=dataset.metadata.source_urls,
|
|
222
|
+
source_revision=dataset.metadata.source_revision,
|
|
223
|
+
dirty_sha256=dataset.dirty_sha256,
|
|
224
|
+
clean_sha256=dataset.clean_sha256,
|
|
225
|
+
n_rows=len(dataset.clean_df.index),
|
|
226
|
+
n_columns=len(dataset.clean_df.columns),
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def build_benchmark_metadata(
|
|
231
|
+
*,
|
|
232
|
+
methods: list[str],
|
|
233
|
+
datasets: list[str],
|
|
234
|
+
seed_list: list[int],
|
|
235
|
+
reproduction_command: str,
|
|
236
|
+
dataset_evidence: list[BenchmarkDatasetEvidence],
|
|
237
|
+
) -> BenchmarkEvidenceMetadata:
|
|
238
|
+
"""Build the typed provenance metadata stored in benchmark JSON."""
|
|
239
|
+
artifact_sha256s: dict[str, str] = {}
|
|
240
|
+
for evidence in dataset_evidence:
|
|
241
|
+
artifact_sha256s[f"dataset:{evidence.name}:dirty.csv"] = evidence.dirty_sha256
|
|
242
|
+
artifact_sha256s[f"dataset:{evidence.name}:clean.csv"] = evidence.clean_sha256
|
|
243
|
+
|
|
244
|
+
return BenchmarkEvidenceMetadata(
|
|
245
|
+
methods=methods,
|
|
246
|
+
datasets=datasets,
|
|
247
|
+
seeds=len(seed_list),
|
|
248
|
+
seed_list=seed_list,
|
|
249
|
+
git_commit=current_git_commit(),
|
|
250
|
+
git_dirty=git_worktree_dirty(),
|
|
251
|
+
generated_at_utc=datetime.now(UTC).replace(microsecond=0).isoformat(),
|
|
252
|
+
python_version=sys.version.split()[0],
|
|
253
|
+
platform=platform.platform(),
|
|
254
|
+
dependency_versions=benchmark_dependency_versions(),
|
|
255
|
+
generator_command=reproduction_command,
|
|
256
|
+
reproduction_command=reproduction_command,
|
|
257
|
+
dataset_evidence=dataset_evidence,
|
|
258
|
+
artifact_sha256s=artifact_sha256s,
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def chunk_row_indices(n_rows: int) -> tuple[tuple[int, ...], ...]:
|
|
263
|
+
"""Split rows into contiguous chunks with a target of twenty chunks."""
|
|
264
|
+
if n_rows <= 0:
|
|
265
|
+
return ()
|
|
266
|
+
chunk_size = ceil(n_rows / 20)
|
|
267
|
+
chunks: list[tuple[int, ...]] = []
|
|
268
|
+
for start in range(0, n_rows, chunk_size):
|
|
269
|
+
stop = min(start + chunk_size, n_rows)
|
|
270
|
+
chunks.append(tuple(range(start, stop)))
|
|
271
|
+
return tuple(chunks)
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def normalize_repairs(repairs: list[BenchmarkRepair]) -> list[BenchmarkRepair]:
|
|
275
|
+
"""Collapse repairs to one final prediction per cell using last-write-wins."""
|
|
276
|
+
by_cell: OrderedDict[tuple[int, str], BenchmarkRepair] = OrderedDict()
|
|
277
|
+
for repair in repairs:
|
|
278
|
+
key = (repair.row, repair.column)
|
|
279
|
+
if key in by_cell:
|
|
280
|
+
del by_cell[key]
|
|
281
|
+
by_cell[key] = repair
|
|
282
|
+
return list(by_cell.values())
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def score_repairs(
|
|
286
|
+
ground_truth: tuple[GroundTruthCell, ...] | list[GroundTruthCell],
|
|
287
|
+
repairs: list[BenchmarkRepair],
|
|
288
|
+
) -> RepairScore:
|
|
289
|
+
"""Score repaired cells against exact dirty-to-clean ground truth."""
|
|
290
|
+
normalized = normalize_repairs(repairs)
|
|
291
|
+
ground_truth_map = {(cell.row, cell.column): cell.clean_value for cell in ground_truth}
|
|
292
|
+
|
|
293
|
+
matched: set[tuple[int, str]] = set()
|
|
294
|
+
tp = 0
|
|
295
|
+
fp = 0
|
|
296
|
+
for repair in normalized:
|
|
297
|
+
key = (repair.row, repair.column)
|
|
298
|
+
clean_value = ground_truth_map.get(key)
|
|
299
|
+
if clean_value is not None and repair.new_value == clean_value:
|
|
300
|
+
tp += 1
|
|
301
|
+
matched.add(key)
|
|
302
|
+
else:
|
|
303
|
+
fp += 1
|
|
304
|
+
|
|
305
|
+
fn = len(ground_truth_map) - len(matched)
|
|
306
|
+
precision = tp / (tp + fp) if (tp + fp) else 0.0
|
|
307
|
+
recall = tp / (tp + fn) if (tp + fn) else 0.0
|
|
308
|
+
f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
|
|
309
|
+
return RepairScore(
|
|
310
|
+
tp=tp,
|
|
311
|
+
fp=fp,
|
|
312
|
+
fn=fn,
|
|
313
|
+
precision=round(precision, 4),
|
|
314
|
+
recall=round(recall, 4),
|
|
315
|
+
f1=round(f1, 4),
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def quota_units(*, llm_calls: int, prompt_tokens: int, completion_tokens: int) -> float:
|
|
320
|
+
"""Compute free-tier quota units consumed by one episode."""
|
|
321
|
+
request_fraction = llm_calls / 1000 if llm_calls else 0.0
|
|
322
|
+
token_fraction = (
|
|
323
|
+
(prompt_tokens + completion_tokens) / 100000
|
|
324
|
+
if (prompt_tokens or completion_tokens)
|
|
325
|
+
else 0.0
|
|
326
|
+
)
|
|
327
|
+
return round(max(request_fraction, token_fraction), 4)
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def estimate_llm_calls(*, methods: list[str], datasets: list[str], seeds: int) -> int:
|
|
331
|
+
"""Estimate total LLM calls for the selected run configuration."""
|
|
332
|
+
estimated = 0
|
|
333
|
+
for dataset_name in datasets:
|
|
334
|
+
chunks = len(chunk_row_indices(DATASET_REGISTRY[dataset_name].n_rows))
|
|
335
|
+
for method in methods:
|
|
336
|
+
if method == "llm_zeroshot":
|
|
337
|
+
estimated += chunks * seeds
|
|
338
|
+
elif method == "llm_react":
|
|
339
|
+
estimated += chunks * 2 * seeds
|
|
340
|
+
return estimated
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def validate_estimated_calls(*, estimated_calls: int, really_run_big_bench: bool) -> None:
|
|
344
|
+
"""Enforce the free-tier call budget."""
|
|
345
|
+
if estimated_calls > 500 and not really_run_big_bench:
|
|
346
|
+
raise ValueError(
|
|
347
|
+
"Estimated benchmark size exceeds 500 free-tier LLM calls. "
|
|
348
|
+
"Pass --really-run-big-bench to continue."
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def aggregate_seed_results(
|
|
353
|
+
records: list[SeedBenchmarkResult],
|
|
354
|
+
*,
|
|
355
|
+
seeds_requested: int,
|
|
356
|
+
) -> list[AggregateBenchmarkResult]:
|
|
357
|
+
"""Aggregate seed-level results into method/dataset summaries."""
|
|
358
|
+
grouped: OrderedDict[tuple[str, str], list[SeedBenchmarkResult]] = OrderedDict()
|
|
359
|
+
for record in records:
|
|
360
|
+
grouped.setdefault((record.method, record.dataset), []).append(record)
|
|
361
|
+
|
|
362
|
+
def _mean_std(values: list[float]) -> tuple[float, float]:
|
|
363
|
+
if len(values) == 1:
|
|
364
|
+
return round(values[0], 4), 0.0
|
|
365
|
+
return round(mean(values), 4), round(stdev(values), 4)
|
|
366
|
+
|
|
367
|
+
aggregates: list[AggregateBenchmarkResult] = []
|
|
368
|
+
for (method, dataset), rows in grouped.items():
|
|
369
|
+
ok_rows = [row for row in rows if row.status == "ok"]
|
|
370
|
+
if not ok_rows:
|
|
371
|
+
aggregates.append(
|
|
372
|
+
AggregateBenchmarkResult(
|
|
373
|
+
method=method,
|
|
374
|
+
dataset=dataset,
|
|
375
|
+
status="skipped",
|
|
376
|
+
skip_reason=rows[0].skip_reason,
|
|
377
|
+
seeds_requested=seeds_requested,
|
|
378
|
+
seeds_completed=0,
|
|
379
|
+
provider=rows[0].provider,
|
|
380
|
+
model=rows[0].model,
|
|
381
|
+
reproduction_command=rows[0].reproduction_command,
|
|
382
|
+
)
|
|
383
|
+
)
|
|
384
|
+
continue
|
|
385
|
+
|
|
386
|
+
precision_mean, precision_std = _mean_std([row.precision or 0.0 for row in ok_rows])
|
|
387
|
+
recall_mean, recall_std = _mean_std([row.recall or 0.0 for row in ok_rows])
|
|
388
|
+
f1_mean, f1_std = _mean_std([row.f1 or 0.0 for row in ok_rows])
|
|
389
|
+
avg_steps_mean, avg_steps_std = _mean_std([row.avg_steps or 0.0 for row in ok_rows])
|
|
390
|
+
quota_mean, quota_std = _mean_std([row.quota_units for row in ok_rows])
|
|
391
|
+
gpu_hours_mean, gpu_hours_std = _mean_std([row.gpu_hours for row in ok_rows])
|
|
392
|
+
runtime_mean, runtime_std = _mean_std([row.runtime_s for row in ok_rows])
|
|
393
|
+
aggregates.append(
|
|
394
|
+
AggregateBenchmarkResult(
|
|
395
|
+
method=method,
|
|
396
|
+
dataset=dataset,
|
|
397
|
+
status="ok",
|
|
398
|
+
skip_reason=None,
|
|
399
|
+
seeds_requested=seeds_requested,
|
|
400
|
+
seeds_completed=len(ok_rows),
|
|
401
|
+
precision_mean=precision_mean,
|
|
402
|
+
precision_std=precision_std,
|
|
403
|
+
recall_mean=recall_mean,
|
|
404
|
+
recall_std=recall_std,
|
|
405
|
+
f1_mean=f1_mean,
|
|
406
|
+
f1_std=f1_std,
|
|
407
|
+
avg_steps_mean=avg_steps_mean,
|
|
408
|
+
avg_steps_std=avg_steps_std,
|
|
409
|
+
quota_units_mean=quota_mean,
|
|
410
|
+
quota_units_std=quota_std,
|
|
411
|
+
gpu_hours_mean=gpu_hours_mean,
|
|
412
|
+
gpu_hours_std=gpu_hours_std,
|
|
413
|
+
runtime_s_mean=runtime_mean,
|
|
414
|
+
runtime_s_std=runtime_std,
|
|
415
|
+
provider=ok_rows[0].provider,
|
|
416
|
+
model=ok_rows[0].model,
|
|
417
|
+
reproduction_command=ok_rows[0].reproduction_command,
|
|
418
|
+
)
|
|
419
|
+
)
|
|
420
|
+
return aggregates
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
def write_run_output(output: BenchmarkRunOutput, path: Path) -> None:
|
|
424
|
+
"""Serialize benchmark run output to JSON."""
|
|
425
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
426
|
+
path.write_text(output.model_dump_json(indent=2), encoding="utf-8")
|