dataforge-07 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataforge/__init__.py +204 -0
- dataforge/__main__.py +5 -0
- dataforge/agent/__init__.py +16 -0
- dataforge/agent/providers.py +259 -0
- dataforge/agent/scratchpad.py +183 -0
- dataforge/agent/tool_actions.py +343 -0
- dataforge/bench/__init__.py +31 -0
- dataforge/bench/core.py +426 -0
- dataforge/bench/groq_client.py +386 -0
- dataforge/bench/methods.py +443 -0
- dataforge/bench/report.py +309 -0
- dataforge/bench/runner.py +247 -0
- dataforge/causal/__init__.py +21 -0
- dataforge/causal/dag.py +174 -0
- dataforge/causal/pc.py +232 -0
- dataforge/causal/root_cause.py +193 -0
- dataforge/cli/__init__.py +50 -0
- dataforge/cli/audit.py +70 -0
- dataforge/cli/bench.py +154 -0
- dataforge/cli/common.py +267 -0
- dataforge/cli/constraints.py +407 -0
- dataforge/cli/profile.py +147 -0
- dataforge/cli/release.py +166 -0
- dataforge/cli/repair.py +407 -0
- dataforge/cli/revert.py +139 -0
- dataforge/cli/watch.py +144 -0
- dataforge/datasets/__init__.py +25 -0
- dataforge/datasets/embedded/hospital/clean.csv +11 -0
- dataforge/datasets/embedded/hospital/dirty.csv +11 -0
- dataforge/datasets/real_world.py +290 -0
- dataforge/datasets/registry.py +103 -0
- dataforge/detectors/__init__.py +80 -0
- dataforge/detectors/base.py +145 -0
- dataforge/detectors/decimal_shift.py +166 -0
- dataforge/detectors/fd_violation.py +157 -0
- dataforge/detectors/type_mismatch.py +173 -0
- dataforge/engine/__init__.py +39 -0
- dataforge/engine/repair.py +905 -0
- dataforge/env/__init__.py +22 -0
- dataforge/env/environment.py +883 -0
- dataforge/env/observation.py +61 -0
- dataforge/env/openenv_core.py +161 -0
- dataforge/env/reward.py +128 -0
- dataforge/env/server.py +176 -0
- dataforge/evaluation_contract.py +76 -0
- dataforge/fixtures/hospital_10rows.csv +11 -0
- dataforge/fixtures/hospital_schema.yaml +17 -0
- dataforge/http/__init__.py +1 -0
- dataforge/http/problem.py +103 -0
- dataforge/integrations/__init__.py +1 -0
- dataforge/integrations/dbt.py +164 -0
- dataforge/observability.py +76 -0
- dataforge/py.typed +1 -0
- dataforge/release/__init__.py +1 -0
- dataforge/release/doctor.py +367 -0
- dataforge/release/full_vision.py +702 -0
- dataforge/release/gate.py +861 -0
- dataforge/release/playground_check.py +411 -0
- dataforge/repair_contract.py +468 -0
- dataforge/repairers/__init__.py +88 -0
- dataforge/repairers/base.py +77 -0
- dataforge/repairers/decimal_shift.py +43 -0
- dataforge/repairers/fd_violation.py +225 -0
- dataforge/repairers/type_mismatch.py +73 -0
- dataforge/safety/__init__.py +5 -0
- dataforge/safety/adversarial/attack_01_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_02_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_03_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_04_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_05_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_06_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_07_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_08_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_09_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_10_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_11_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_12_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_13_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_14_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_15_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_16_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_17_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_18_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_19_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_20_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_21_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_22_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_23_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_24_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_25_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_26_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_27_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_28_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_29_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_30_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_31_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_32_row_delete.yaml +8 -0
- dataforge/safety/adversarial/attack_33_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_34_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_35_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_36_row_delete.yaml +11 -0
- dataforge/safety/adversarial/attack_37_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_38_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_39_row_delete.yaml +8 -0
- dataforge/safety/adversarial/attack_40_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_41_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_42_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_43_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_44_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_45_row_delete.yaml +8 -0
- dataforge/safety/adversarial/attack_46_row_delete.yaml +8 -0
- dataforge/safety/adversarial/attack_47_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_48_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_49_row_delete.yaml +8 -0
- dataforge/safety/adversarial/attack_50_row_delete.yaml +7 -0
- dataforge/safety/constitution.py +307 -0
- dataforge/safety/constitutions/default.yaml +40 -0
- dataforge/safety/filter.py +134 -0
- dataforge/schema_inference.py +620 -0
- dataforge/stores/__init__.py +46 -0
- dataforge/stores/base.py +73 -0
- dataforge/stores/cloud.py +78 -0
- dataforge/stores/csv.py +94 -0
- dataforge/stores/duckdb.py +313 -0
- dataforge/stores/patch_plan.py +178 -0
- dataforge/stores/registry.py +82 -0
- dataforge/stores/repair.py +121 -0
- dataforge/stores/revert.py +22 -0
- dataforge/stores/sql.py +27 -0
- dataforge/table.py +228 -0
- dataforge/transactions/__init__.py +34 -0
- dataforge/transactions/files.py +96 -0
- dataforge/transactions/log.py +613 -0
- dataforge/transactions/revert.py +102 -0
- dataforge/transactions/txn.py +104 -0
- dataforge/ui/__init__.py +1 -0
- dataforge/ui/profile_view.py +136 -0
- dataforge/ui/repair_diff.py +91 -0
- dataforge/verifier/__init__.py +55 -0
- dataforge/verifier/constraint_ir.py +155 -0
- dataforge/verifier/explain.py +47 -0
- dataforge/verifier/gate.py +5 -0
- dataforge/verifier/schema.py +111 -0
- dataforge/verifier/smt.py +433 -0
- dataforge_07-0.1.0.dist-info/METADATA +436 -0
- dataforge_07-0.1.0.dist-info/RECORD +150 -0
- dataforge_07-0.1.0.dist-info/WHEEL +5 -0
- dataforge_07-0.1.0.dist-info/entry_points.txt +3 -0
- dataforge_07-0.1.0.dist-info/licenses/LICENSE +176 -0
- dataforge_07-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,343 @@
|
|
|
1
|
+
"""Typed tool-use action models for the DataForge RL environment.
|
|
2
|
+
|
|
3
|
+
This module defines a discriminated union of 8 action types that an RL agent
|
|
4
|
+
can submit to the DataForge environment. Each action is a standalone Pydantic
|
|
5
|
+
model with its own validation rules, preventing cross-model field pollution.
|
|
6
|
+
|
|
7
|
+
The ``parse_action`` function is the single entry point for HTTP handlers
|
|
8
|
+
and tests to validate raw action dicts into typed models.
|
|
9
|
+
|
|
10
|
+
Action Types:
|
|
11
|
+
INSPECT_ROWS — View a slice of the dataset.
|
|
12
|
+
SQL_QUERY — Execute read-only SQL against the episode DataFrame.
|
|
13
|
+
STAT_TEST — Run a statistical test on a column.
|
|
14
|
+
PATTERN_MATCH — Evaluate a regex pattern against column values.
|
|
15
|
+
HYPOTHESIS — Record a causal-root claim for credit.
|
|
16
|
+
ROOT_CAUSE — Analyze selected detected errors for minimal roots.
|
|
17
|
+
DIAGNOSE — Flag a suspected issue at (row, column).
|
|
18
|
+
FIX — Propose a corrected value for a diagnosed issue.
|
|
19
|
+
|
|
20
|
+
Example::
|
|
21
|
+
|
|
22
|
+
>>> from dataforge.agent.tool_actions import parse_action
|
|
23
|
+
>>> action = parse_action({"action_type": "INSPECT_ROWS", "row_indices": [0, 1]})
|
|
24
|
+
>>> action.action_type
|
|
25
|
+
'INSPECT_ROWS'
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
from __future__ import annotations
|
|
29
|
+
|
|
30
|
+
from typing import Annotated, Any, Literal
|
|
31
|
+
|
|
32
|
+
from pydantic import BaseModel, Field, field_validator
|
|
33
|
+
|
|
34
|
+
__all__ = [
|
|
35
|
+
"Action",
|
|
36
|
+
"Diagnose",
|
|
37
|
+
"Fix",
|
|
38
|
+
"Hypothesis",
|
|
39
|
+
"InspectRows",
|
|
40
|
+
"PatternMatch",
|
|
41
|
+
"RootCause",
|
|
42
|
+
"SqlQuery",
|
|
43
|
+
"StatTest",
|
|
44
|
+
"parse_action",
|
|
45
|
+
]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class InspectRows(BaseModel):
|
|
49
|
+
"""View a slice of dataset rows.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
action_type: Must be ``"INSPECT_ROWS"``.
|
|
53
|
+
row_indices: Zero-indexed row indices to retrieve. At least 1 required.
|
|
54
|
+
column_names: Optional column filter. If omitted, all columns returned.
|
|
55
|
+
|
|
56
|
+
Example::
|
|
57
|
+
|
|
58
|
+
>>> InspectRows(action_type="INSPECT_ROWS", row_indices=[0, 1, 2])
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
action_type: Literal["INSPECT_ROWS"]
|
|
62
|
+
row_indices: list[int] = Field(min_length=1, description="Row indices to inspect (0-indexed).")
|
|
63
|
+
column_names: list[str] | None = Field(default=None, description="Optional column filter.")
|
|
64
|
+
|
|
65
|
+
@field_validator("row_indices")
|
|
66
|
+
@classmethod
|
|
67
|
+
def _validate_row_indices(cls, v: list[int]) -> list[int]:
|
|
68
|
+
"""Validate that all row indices are non-negative."""
|
|
69
|
+
if any(i < 0 for i in v):
|
|
70
|
+
raise ValueError("All row indices must be >= 0")
|
|
71
|
+
return v
|
|
72
|
+
|
|
73
|
+
model_config = {"frozen": True}
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class SqlQuery(BaseModel):
|
|
77
|
+
"""Execute read-only SQL against the episode DataFrame via DuckDB.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
action_type: Must be ``"SQL_QUERY"``.
|
|
81
|
+
query: SQL query string. Must be read-only (SELECT only).
|
|
82
|
+
|
|
83
|
+
Example::
|
|
84
|
+
|
|
85
|
+
>>> SqlQuery(action_type="SQL_QUERY", query="SELECT * FROM data LIMIT 5")
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
action_type: Literal["SQL_QUERY"]
|
|
89
|
+
query: str = Field(min_length=1, description="Read-only SQL query.")
|
|
90
|
+
|
|
91
|
+
model_config = {"frozen": True}
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class StatTest(BaseModel):
|
|
95
|
+
"""Run a statistical test on a dataset column.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
action_type: Must be ``"STAT_TEST"``.
|
|
99
|
+
test_type: One of ``"zscore"``, ``"iqr"``, ``"ks"``.
|
|
100
|
+
column: Column name to test.
|
|
101
|
+
threshold: Optional threshold override. Defaults vary by test type.
|
|
102
|
+
|
|
103
|
+
Example::
|
|
104
|
+
|
|
105
|
+
>>> StatTest(action_type="STAT_TEST", test_type="zscore", column="rating")
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
action_type: Literal["STAT_TEST"]
|
|
109
|
+
test_type: Literal["zscore", "iqr", "ks"] = Field(description="Statistical test to run.")
|
|
110
|
+
column: str = Field(min_length=1, description="Column name to test.")
|
|
111
|
+
threshold: float | None = Field(default=None, description="Optional threshold override.")
|
|
112
|
+
|
|
113
|
+
model_config = {"frozen": True}
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class PatternMatch(BaseModel):
|
|
117
|
+
"""Evaluate a regex pattern against column values.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
action_type: Must be ``"PATTERN_MATCH"``.
|
|
121
|
+
pattern: Regular expression string.
|
|
122
|
+
column: Column name to evaluate.
|
|
123
|
+
expect_match: If True, report rows that match. If False, report non-matches.
|
|
124
|
+
|
|
125
|
+
Example::
|
|
126
|
+
|
|
127
|
+
>>> PatternMatch(
|
|
128
|
+
... action_type="PATTERN_MATCH",
|
|
129
|
+
... pattern=r"^\\d{5}$",
|
|
130
|
+
... column="zip_code",
|
|
131
|
+
... )
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
action_type: Literal["PATTERN_MATCH"]
|
|
135
|
+
pattern: str = Field(min_length=1, description="Regex pattern.")
|
|
136
|
+
column: str = Field(min_length=1, description="Column name to evaluate.")
|
|
137
|
+
expect_match: bool = Field(
|
|
138
|
+
default=True,
|
|
139
|
+
description="True to report matches, False to report non-matches.",
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
model_config = {"frozen": True}
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class Hypothesis(BaseModel):
|
|
146
|
+
"""Record a causal-root claim for root-cause credit.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
action_type: Must be ``"HYPOTHESIS"``.
|
|
150
|
+
claim: Textual description of the hypothesized root cause.
|
|
151
|
+
affected_rows: Row indices believed to be affected.
|
|
152
|
+
affected_columns: Column names believed to be affected.
|
|
153
|
+
root_cause_type: Detector-vocabulary root cause type
|
|
154
|
+
(e.g., ``"decimal_shift"``, ``"type_mismatch"``).
|
|
155
|
+
|
|
156
|
+
Example::
|
|
157
|
+
|
|
158
|
+
>>> Hypothesis(
|
|
159
|
+
... action_type="HYPOTHESIS",
|
|
160
|
+
... claim="Column 'rating' has a decimal shift at row 5",
|
|
161
|
+
... affected_rows=[5],
|
|
162
|
+
... affected_columns=["rating"],
|
|
163
|
+
... root_cause_type="decimal_shift",
|
|
164
|
+
... )
|
|
165
|
+
"""
|
|
166
|
+
|
|
167
|
+
action_type: Literal["HYPOTHESIS"]
|
|
168
|
+
claim: str = Field(min_length=1, description="Root-cause claim.")
|
|
169
|
+
affected_rows: list[int] = Field(min_length=1, description="Affected row indices.")
|
|
170
|
+
affected_columns: list[str] = Field(min_length=1, description="Affected column names.")
|
|
171
|
+
root_cause_type: str = Field(min_length=1, description="Detector-vocabulary root cause type.")
|
|
172
|
+
|
|
173
|
+
@field_validator("affected_rows")
|
|
174
|
+
@classmethod
|
|
175
|
+
def _validate_affected_rows(cls, v: list[int]) -> list[int]:
|
|
176
|
+
"""Validate that all affected row indices are non-negative."""
|
|
177
|
+
if any(i < 0 for i in v):
|
|
178
|
+
raise ValueError("All affected row indices must be >= 0")
|
|
179
|
+
return v
|
|
180
|
+
|
|
181
|
+
model_config = {"frozen": True}
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
class RootCause(BaseModel):
|
|
185
|
+
"""Analyze selected detected errors for minimal causal roots.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
action_type: Must be ``"ROOT_CAUSE"``.
|
|
189
|
+
error_indices: Zero-based indices into the episode's detected issue list.
|
|
190
|
+
|
|
191
|
+
Example::
|
|
192
|
+
|
|
193
|
+
>>> RootCause(action_type="ROOT_CAUSE", error_indices=[0, 1])
|
|
194
|
+
"""
|
|
195
|
+
|
|
196
|
+
action_type: Literal["ROOT_CAUSE"]
|
|
197
|
+
error_indices: list[int] = Field(min_length=1, description="Detected issue indices.")
|
|
198
|
+
|
|
199
|
+
@field_validator("error_indices")
|
|
200
|
+
@classmethod
|
|
201
|
+
def _validate_error_indices(cls, v: list[int]) -> list[int]:
|
|
202
|
+
"""Validate that all error indices are non-negative."""
|
|
203
|
+
if any(i < 0 for i in v):
|
|
204
|
+
raise ValueError("All error indices must be >= 0")
|
|
205
|
+
return v
|
|
206
|
+
|
|
207
|
+
model_config = {"frozen": True}
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class Diagnose(BaseModel):
|
|
211
|
+
"""Flag a suspected data-quality issue at a specific (row, column).
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
action_type: Must be ``"DIAGNOSE"``.
|
|
215
|
+
row: Zero-indexed row number.
|
|
216
|
+
column: Column name.
|
|
217
|
+
issue_type: Issue type from detector vocabulary.
|
|
218
|
+
|
|
219
|
+
Example::
|
|
220
|
+
|
|
221
|
+
>>> Diagnose(
|
|
222
|
+
... action_type="DIAGNOSE",
|
|
223
|
+
... row=5, column="rating",
|
|
224
|
+
... issue_type="decimal_shift",
|
|
225
|
+
... )
|
|
226
|
+
"""
|
|
227
|
+
|
|
228
|
+
action_type: Literal["DIAGNOSE"]
|
|
229
|
+
row: int = Field(ge=0, description="Zero-indexed row number.")
|
|
230
|
+
column: str = Field(min_length=1, description="Column name.")
|
|
231
|
+
issue_type: str = Field(min_length=1, description="Issue type classification.")
|
|
232
|
+
|
|
233
|
+
model_config = {"frozen": True}
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
class Fix(BaseModel):
|
|
237
|
+
"""Propose a corrected value for a diagnosed issue.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
action_type: Must be ``"FIX"``.
|
|
241
|
+
row: Zero-indexed row number.
|
|
242
|
+
column: Column name.
|
|
243
|
+
new_value: The corrected cell value as a string.
|
|
244
|
+
justification: Explanation of why this fix is correct.
|
|
245
|
+
fix_type: How to fix the issue. Defaults to ``"correct_value"``.
|
|
246
|
+
|
|
247
|
+
Example::
|
|
248
|
+
|
|
249
|
+
>>> Fix(
|
|
250
|
+
... action_type="FIX",
|
|
251
|
+
... row=5, column="rating",
|
|
252
|
+
... new_value="4.5",
|
|
253
|
+
... justification="Decimal shift: 45.0 should be 4.5",
|
|
254
|
+
... )
|
|
255
|
+
"""
|
|
256
|
+
|
|
257
|
+
action_type: Literal["FIX"]
|
|
258
|
+
row: int = Field(ge=0, description="Zero-indexed row number.")
|
|
259
|
+
column: str = Field(min_length=1, description="Column name.")
|
|
260
|
+
new_value: str = Field(description="Corrected cell value.")
|
|
261
|
+
justification: str = Field(min_length=1, description="Fix justification.")
|
|
262
|
+
fix_type: Literal["correct_value", "delete_row", "impute", "standardize"] = Field(
|
|
263
|
+
default="correct_value", description="Fix operation type."
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
model_config = {"frozen": True}
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
# ═══════════════════════════════════════════════════════════════════════════
|
|
270
|
+
# Discriminated union and parser
|
|
271
|
+
# ═══════════════════════════════════════════════════════════════════════════
|
|
272
|
+
|
|
273
|
+
Action = Annotated[
|
|
274
|
+
InspectRows | SqlQuery | StatTest | PatternMatch | Hypothesis | RootCause | Diagnose | Fix,
|
|
275
|
+
Field(discriminator="action_type"),
|
|
276
|
+
]
|
|
277
|
+
"""Discriminated union of all valid DataForge environment actions."""
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def parse_action(raw: dict[str, Any]) -> Action:
|
|
281
|
+
"""Parse and validate a raw action dict into the appropriate typed model.
|
|
282
|
+
|
|
283
|
+
This is the single entry point for HTTP handlers and tests to validate
|
|
284
|
+
actions. The ``action_type`` field is used as the discriminator.
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
raw: Dictionary with an ``action_type`` key and action-specific fields.
|
|
288
|
+
|
|
289
|
+
Returns:
|
|
290
|
+
A validated action model instance.
|
|
291
|
+
|
|
292
|
+
Raises:
|
|
293
|
+
pydantic.ValidationError: If the action is malformed or invalid.
|
|
294
|
+
KeyError: If ``action_type`` is missing.
|
|
295
|
+
ValueError: If ``action_type`` is not recognized.
|
|
296
|
+
|
|
297
|
+
Example::
|
|
298
|
+
|
|
299
|
+
>>> action = parse_action({"action_type": "INSPECT_ROWS", "row_indices": [0]})
|
|
300
|
+
>>> isinstance(action, InspectRows)
|
|
301
|
+
True
|
|
302
|
+
"""
|
|
303
|
+
from pydantic import TypeAdapter
|
|
304
|
+
|
|
305
|
+
adapter: TypeAdapter[Action] = TypeAdapter(Action)
|
|
306
|
+
return adapter.validate_python(_normalize_action(raw))
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def _normalize_action(raw: dict[str, Any]) -> dict[str, Any]:
|
|
310
|
+
"""Return a canonical action dictionary from supported external aliases."""
|
|
311
|
+
normalized = dict(raw)
|
|
312
|
+
action_type = normalized.get("action_type")
|
|
313
|
+
if action_type == "SQL_QUERY" and "sql" in normalized and "query" not in normalized:
|
|
314
|
+
normalized["query"] = normalized["sql"]
|
|
315
|
+
if action_type == "STAT_TEST" and "test" in normalized and "test_type" not in normalized:
|
|
316
|
+
normalized["test_type"] = normalized["test"]
|
|
317
|
+
if action_type == "PATTERN_MATCH":
|
|
318
|
+
if "regex" in normalized and "pattern" not in normalized:
|
|
319
|
+
normalized["pattern"] = normalized["regex"]
|
|
320
|
+
if "expect" in normalized and "expect_match" not in normalized:
|
|
321
|
+
normalized["expect_match"] = normalized["expect"] == "match"
|
|
322
|
+
if action_type == "HYPOTHESIS":
|
|
323
|
+
root_column = normalized.get("root_column")
|
|
324
|
+
downstream = normalized.get("downstream")
|
|
325
|
+
if root_column is not None and "affected_columns" not in normalized:
|
|
326
|
+
downstream_columns = downstream if isinstance(downstream, list) else []
|
|
327
|
+
normalized["affected_columns"] = [root_column, *downstream_columns]
|
|
328
|
+
if "affected_rows" not in normalized:
|
|
329
|
+
normalized["affected_rows"] = [0]
|
|
330
|
+
if root_column is not None and "root_cause_type" not in normalized:
|
|
331
|
+
normalized["root_cause_type"] = root_column
|
|
332
|
+
if (
|
|
333
|
+
action_type == "ROOT_CAUSE"
|
|
334
|
+
and "indices" in normalized
|
|
335
|
+
and "error_indices" not in normalized
|
|
336
|
+
):
|
|
337
|
+
normalized["error_indices"] = normalized["indices"]
|
|
338
|
+
if action_type == "FIX":
|
|
339
|
+
if "proposed_value" in normalized and "new_value" not in normalized:
|
|
340
|
+
normalized["new_value"] = normalized["proposed_value"]
|
|
341
|
+
if "justification" not in normalized:
|
|
342
|
+
normalized["justification"] = "Agent proposed value via FIX."
|
|
343
|
+
return normalized
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""Shared benchmark helpers for real-world DataForge evaluation."""
|
|
2
|
+
|
|
3
|
+
from dataforge.bench.core import (
|
|
4
|
+
AggregateBenchmarkResult,
|
|
5
|
+
BenchmarkRepair,
|
|
6
|
+
BenchmarkRunOutput,
|
|
7
|
+
SeedBenchmarkResult,
|
|
8
|
+
chunk_row_indices,
|
|
9
|
+
estimate_llm_calls,
|
|
10
|
+
normalize_repairs,
|
|
11
|
+
quota_units,
|
|
12
|
+
score_repairs,
|
|
13
|
+
validate_estimated_calls,
|
|
14
|
+
)
|
|
15
|
+
from dataforge.bench.report import write_benchmark_outputs
|
|
16
|
+
from dataforge.bench.runner import run_agent_comparison
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"AggregateBenchmarkResult",
|
|
20
|
+
"BenchmarkRepair",
|
|
21
|
+
"BenchmarkRunOutput",
|
|
22
|
+
"SeedBenchmarkResult",
|
|
23
|
+
"chunk_row_indices",
|
|
24
|
+
"estimate_llm_calls",
|
|
25
|
+
"normalize_repairs",
|
|
26
|
+
"quota_units",
|
|
27
|
+
"run_agent_comparison",
|
|
28
|
+
"score_repairs",
|
|
29
|
+
"validate_estimated_calls",
|
|
30
|
+
"write_benchmark_outputs",
|
|
31
|
+
]
|