dataforge-07 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataforge/__init__.py +204 -0
- dataforge/__main__.py +5 -0
- dataforge/agent/__init__.py +16 -0
- dataforge/agent/providers.py +259 -0
- dataforge/agent/scratchpad.py +183 -0
- dataforge/agent/tool_actions.py +343 -0
- dataforge/bench/__init__.py +31 -0
- dataforge/bench/core.py +426 -0
- dataforge/bench/groq_client.py +386 -0
- dataforge/bench/methods.py +443 -0
- dataforge/bench/report.py +309 -0
- dataforge/bench/runner.py +247 -0
- dataforge/causal/__init__.py +21 -0
- dataforge/causal/dag.py +174 -0
- dataforge/causal/pc.py +232 -0
- dataforge/causal/root_cause.py +193 -0
- dataforge/cli/__init__.py +50 -0
- dataforge/cli/audit.py +70 -0
- dataforge/cli/bench.py +154 -0
- dataforge/cli/common.py +267 -0
- dataforge/cli/constraints.py +407 -0
- dataforge/cli/profile.py +147 -0
- dataforge/cli/release.py +166 -0
- dataforge/cli/repair.py +407 -0
- dataforge/cli/revert.py +139 -0
- dataforge/cli/watch.py +144 -0
- dataforge/datasets/__init__.py +25 -0
- dataforge/datasets/embedded/hospital/clean.csv +11 -0
- dataforge/datasets/embedded/hospital/dirty.csv +11 -0
- dataforge/datasets/real_world.py +290 -0
- dataforge/datasets/registry.py +103 -0
- dataforge/detectors/__init__.py +80 -0
- dataforge/detectors/base.py +145 -0
- dataforge/detectors/decimal_shift.py +166 -0
- dataforge/detectors/fd_violation.py +157 -0
- dataforge/detectors/type_mismatch.py +173 -0
- dataforge/engine/__init__.py +39 -0
- dataforge/engine/repair.py +905 -0
- dataforge/env/__init__.py +22 -0
- dataforge/env/environment.py +883 -0
- dataforge/env/observation.py +61 -0
- dataforge/env/openenv_core.py +161 -0
- dataforge/env/reward.py +128 -0
- dataforge/env/server.py +176 -0
- dataforge/evaluation_contract.py +76 -0
- dataforge/fixtures/hospital_10rows.csv +11 -0
- dataforge/fixtures/hospital_schema.yaml +17 -0
- dataforge/http/__init__.py +1 -0
- dataforge/http/problem.py +103 -0
- dataforge/integrations/__init__.py +1 -0
- dataforge/integrations/dbt.py +164 -0
- dataforge/observability.py +76 -0
- dataforge/py.typed +1 -0
- dataforge/release/__init__.py +1 -0
- dataforge/release/doctor.py +367 -0
- dataforge/release/full_vision.py +702 -0
- dataforge/release/gate.py +861 -0
- dataforge/release/playground_check.py +411 -0
- dataforge/repair_contract.py +468 -0
- dataforge/repairers/__init__.py +88 -0
- dataforge/repairers/base.py +77 -0
- dataforge/repairers/decimal_shift.py +43 -0
- dataforge/repairers/fd_violation.py +225 -0
- dataforge/repairers/type_mismatch.py +73 -0
- dataforge/safety/__init__.py +5 -0
- dataforge/safety/adversarial/attack_01_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_02_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_03_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_04_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_05_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_06_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_07_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_08_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_09_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_10_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_11_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_12_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_13_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_14_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_15_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_16_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_17_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_18_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_19_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_20_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_21_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_22_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_23_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_24_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_25_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_26_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_27_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_28_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_29_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_30_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_31_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_32_row_delete.yaml +8 -0
- dataforge/safety/adversarial/attack_33_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_34_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_35_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_36_row_delete.yaml +11 -0
- dataforge/safety/adversarial/attack_37_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_38_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_39_row_delete.yaml +8 -0
- dataforge/safety/adversarial/attack_40_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_41_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_42_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_43_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_44_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_45_row_delete.yaml +8 -0
- dataforge/safety/adversarial/attack_46_row_delete.yaml +8 -0
- dataforge/safety/adversarial/attack_47_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_48_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_49_row_delete.yaml +8 -0
- dataforge/safety/adversarial/attack_50_row_delete.yaml +7 -0
- dataforge/safety/constitution.py +307 -0
- dataforge/safety/constitutions/default.yaml +40 -0
- dataforge/safety/filter.py +134 -0
- dataforge/schema_inference.py +620 -0
- dataforge/stores/__init__.py +46 -0
- dataforge/stores/base.py +73 -0
- dataforge/stores/cloud.py +78 -0
- dataforge/stores/csv.py +94 -0
- dataforge/stores/duckdb.py +313 -0
- dataforge/stores/patch_plan.py +178 -0
- dataforge/stores/registry.py +82 -0
- dataforge/stores/repair.py +121 -0
- dataforge/stores/revert.py +22 -0
- dataforge/stores/sql.py +27 -0
- dataforge/table.py +228 -0
- dataforge/transactions/__init__.py +34 -0
- dataforge/transactions/files.py +96 -0
- dataforge/transactions/log.py +613 -0
- dataforge/transactions/revert.py +102 -0
- dataforge/transactions/txn.py +104 -0
- dataforge/ui/__init__.py +1 -0
- dataforge/ui/profile_view.py +136 -0
- dataforge/ui/repair_diff.py +91 -0
- dataforge/verifier/__init__.py +55 -0
- dataforge/verifier/constraint_ir.py +155 -0
- dataforge/verifier/explain.py +47 -0
- dataforge/verifier/gate.py +5 -0
- dataforge/verifier/schema.py +111 -0
- dataforge/verifier/smt.py +433 -0
- dataforge_07-0.1.0.dist-info/METADATA +436 -0
- dataforge_07-0.1.0.dist-info/RECORD +150 -0
- dataforge_07-0.1.0.dist-info/WHEEL +5 -0
- dataforge_07-0.1.0.dist-info/entry_points.txt +3 -0
- dataforge_07-0.1.0.dist-info/licenses/LICENSE +176 -0
- dataforge_07-0.1.0.dist-info/top_level.txt +1 -0
dataforge/__init__.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
"""DataForge public package.
|
|
2
|
+
|
|
3
|
+
The root package is the stable facade for integration surfaces. Symbols are
|
|
4
|
+
resolved lazily so importing :mod:`dataforge` does not eagerly import pandas,
|
|
5
|
+
FastAPI-facing helpers, or the SMT stack.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from importlib import import_module
|
|
11
|
+
from typing import TYPE_CHECKING, Any
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from dataforge.cli.common import load_schema, schema_from_mapping
|
|
15
|
+
from dataforge.detectors import Issue, Schema, Severity, run_all_detectors
|
|
16
|
+
from dataforge.engine.repair import (
|
|
17
|
+
CandidateFix,
|
|
18
|
+
CandidateRepair,
|
|
19
|
+
ProofObligation,
|
|
20
|
+
RepairFailure,
|
|
21
|
+
RepairPipelineRequest,
|
|
22
|
+
RepairPipelineResult,
|
|
23
|
+
RepairReceipt,
|
|
24
|
+
RootCause,
|
|
25
|
+
VerifiedFix,
|
|
26
|
+
run_repair_pipeline,
|
|
27
|
+
)
|
|
28
|
+
from dataforge.integrations.dbt import schema_from_dbt_artifacts, schema_from_dbt_manifest
|
|
29
|
+
from dataforge.repair_contract import CONTRACT_VERSION
|
|
30
|
+
from dataforge.repairers import ProposedFix
|
|
31
|
+
from dataforge.safety import SafetyContext, SafetyFilter, SafetyResult, SafetyVerdict
|
|
32
|
+
from dataforge.schema_inference import (
|
|
33
|
+
ConstraintCandidate,
|
|
34
|
+
ConstraintReviewArtifact,
|
|
35
|
+
ConstraintReviewError,
|
|
36
|
+
ReviewedConstraintCandidate,
|
|
37
|
+
SchemaInferenceResult,
|
|
38
|
+
build_constraint_review_artifact,
|
|
39
|
+
dump_constraint_review_artifact,
|
|
40
|
+
infer_schema,
|
|
41
|
+
load_constraint_review_artifact,
|
|
42
|
+
)
|
|
43
|
+
from dataforge.stores import (
|
|
44
|
+
DuckDBStore,
|
|
45
|
+
PatchPlan,
|
|
46
|
+
TableStoreError,
|
|
47
|
+
TableStoreRepairResult,
|
|
48
|
+
is_table_store_uri,
|
|
49
|
+
run_table_store_repair,
|
|
50
|
+
store_from_uri,
|
|
51
|
+
)
|
|
52
|
+
from dataforge.table import read_csv
|
|
53
|
+
from dataforge.transactions.log import (
|
|
54
|
+
TransactionAuditReport,
|
|
55
|
+
TransactionAuditVerdict,
|
|
56
|
+
TransactionLogError,
|
|
57
|
+
verify_transaction_log,
|
|
58
|
+
)
|
|
59
|
+
from dataforge.transactions.revert import TransactionRevertError, revert_transaction
|
|
60
|
+
from dataforge.transactions.txn import CellFix, RepairTransaction
|
|
61
|
+
from dataforge.verifier import (
|
|
62
|
+
ConstraintIR,
|
|
63
|
+
SMTVerifier,
|
|
64
|
+
VerificationResult,
|
|
65
|
+
VerificationVerdict,
|
|
66
|
+
constraint_ir_from_schema,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
__all__ = [
|
|
70
|
+
"CONTRACT_VERSION",
|
|
71
|
+
"CandidateFix",
|
|
72
|
+
"CandidateRepair",
|
|
73
|
+
"CellFix",
|
|
74
|
+
"ConstraintCandidate",
|
|
75
|
+
"ConstraintReviewArtifact",
|
|
76
|
+
"ConstraintReviewError",
|
|
77
|
+
"ConstraintIR",
|
|
78
|
+
"DuckDBStore",
|
|
79
|
+
"Issue",
|
|
80
|
+
"PatchPlan",
|
|
81
|
+
"ProposedFix",
|
|
82
|
+
"ProofObligation",
|
|
83
|
+
"RepairFailure",
|
|
84
|
+
"RepairPipelineRequest",
|
|
85
|
+
"RepairPipelineResult",
|
|
86
|
+
"RepairReceipt",
|
|
87
|
+
"RepairTransaction",
|
|
88
|
+
"RootCause",
|
|
89
|
+
"ReviewedConstraintCandidate",
|
|
90
|
+
"SMTVerifier",
|
|
91
|
+
"SafetyContext",
|
|
92
|
+
"SafetyFilter",
|
|
93
|
+
"SafetyResult",
|
|
94
|
+
"SafetyVerdict",
|
|
95
|
+
"Schema",
|
|
96
|
+
"SchemaInferenceResult",
|
|
97
|
+
"Severity",
|
|
98
|
+
"TransactionAuditReport",
|
|
99
|
+
"TransactionAuditVerdict",
|
|
100
|
+
"TransactionLogError",
|
|
101
|
+
"TransactionRevertError",
|
|
102
|
+
"TableStoreError",
|
|
103
|
+
"TableStoreRepairResult",
|
|
104
|
+
"VerificationResult",
|
|
105
|
+
"VerificationVerdict",
|
|
106
|
+
"VerifiedFix",
|
|
107
|
+
"__version__",
|
|
108
|
+
"load_schema",
|
|
109
|
+
"build_constraint_review_artifact",
|
|
110
|
+
"constraint_ir_from_schema",
|
|
111
|
+
"dump_constraint_review_artifact",
|
|
112
|
+
"load_constraint_review_artifact",
|
|
113
|
+
"read_csv",
|
|
114
|
+
"revert_transaction",
|
|
115
|
+
"run_all_detectors",
|
|
116
|
+
"run_repair_pipeline",
|
|
117
|
+
"schema_from_mapping",
|
|
118
|
+
"schema_from_dbt_artifacts",
|
|
119
|
+
"schema_from_dbt_manifest",
|
|
120
|
+
"infer_schema",
|
|
121
|
+
"is_table_store_uri",
|
|
122
|
+
"run_table_store_repair",
|
|
123
|
+
"store_from_uri",
|
|
124
|
+
"verify_transaction_log",
|
|
125
|
+
]
|
|
126
|
+
|
|
127
|
+
__version__ = "0.1.0"
|
|
128
|
+
|
|
129
|
+
_PUBLIC_EXPORTS: dict[str, tuple[str, str]] = {
|
|
130
|
+
"CONTRACT_VERSION": ("dataforge.repair_contract", "CONTRACT_VERSION"),
|
|
131
|
+
"CandidateFix": ("dataforge.engine.repair", "CandidateFix"),
|
|
132
|
+
"CandidateRepair": ("dataforge.engine.repair", "CandidateRepair"),
|
|
133
|
+
"CellFix": ("dataforge.transactions.txn", "CellFix"),
|
|
134
|
+
"ConstraintCandidate": ("dataforge.schema_inference", "ConstraintCandidate"),
|
|
135
|
+
"ConstraintReviewArtifact": ("dataforge.schema_inference", "ConstraintReviewArtifact"),
|
|
136
|
+
"ConstraintReviewError": ("dataforge.schema_inference", "ConstraintReviewError"),
|
|
137
|
+
"ConstraintIR": ("dataforge.verifier", "ConstraintIR"),
|
|
138
|
+
"DuckDBStore": ("dataforge.stores", "DuckDBStore"),
|
|
139
|
+
"Issue": ("dataforge.detectors", "Issue"),
|
|
140
|
+
"ProposedFix": ("dataforge.repairers", "ProposedFix"),
|
|
141
|
+
"ProofObligation": ("dataforge.engine.repair", "ProofObligation"),
|
|
142
|
+
"PatchPlan": ("dataforge.stores", "PatchPlan"),
|
|
143
|
+
"RepairFailure": ("dataforge.engine.repair", "RepairFailure"),
|
|
144
|
+
"RepairPipelineRequest": ("dataforge.engine.repair", "RepairPipelineRequest"),
|
|
145
|
+
"RepairPipelineResult": ("dataforge.engine.repair", "RepairPipelineResult"),
|
|
146
|
+
"RepairReceipt": ("dataforge.engine.repair", "RepairReceipt"),
|
|
147
|
+
"RepairTransaction": ("dataforge.transactions.txn", "RepairTransaction"),
|
|
148
|
+
"RootCause": ("dataforge.engine.repair", "RootCause"),
|
|
149
|
+
"ReviewedConstraintCandidate": ("dataforge.schema_inference", "ReviewedConstraintCandidate"),
|
|
150
|
+
"SMTVerifier": ("dataforge.verifier", "SMTVerifier"),
|
|
151
|
+
"SafetyContext": ("dataforge.safety", "SafetyContext"),
|
|
152
|
+
"SafetyFilter": ("dataforge.safety", "SafetyFilter"),
|
|
153
|
+
"SafetyResult": ("dataforge.safety", "SafetyResult"),
|
|
154
|
+
"SafetyVerdict": ("dataforge.safety", "SafetyVerdict"),
|
|
155
|
+
"Schema": ("dataforge.detectors", "Schema"),
|
|
156
|
+
"SchemaInferenceResult": ("dataforge.schema_inference", "SchemaInferenceResult"),
|
|
157
|
+
"Severity": ("dataforge.detectors", "Severity"),
|
|
158
|
+
"TransactionAuditReport": ("dataforge.transactions.log", "TransactionAuditReport"),
|
|
159
|
+
"TransactionAuditVerdict": ("dataforge.transactions.log", "TransactionAuditVerdict"),
|
|
160
|
+
"TransactionLogError": ("dataforge.transactions.log", "TransactionLogError"),
|
|
161
|
+
"TransactionRevertError": ("dataforge.transactions.revert", "TransactionRevertError"),
|
|
162
|
+
"TableStoreError": ("dataforge.stores", "TableStoreError"),
|
|
163
|
+
"TableStoreRepairResult": ("dataforge.stores", "TableStoreRepairResult"),
|
|
164
|
+
"VerificationResult": ("dataforge.verifier", "VerificationResult"),
|
|
165
|
+
"VerificationVerdict": ("dataforge.verifier", "VerificationVerdict"),
|
|
166
|
+
"VerifiedFix": ("dataforge.engine.repair", "VerifiedFix"),
|
|
167
|
+
"load_schema": ("dataforge.cli.common", "load_schema"),
|
|
168
|
+
"build_constraint_review_artifact": (
|
|
169
|
+
"dataforge.schema_inference",
|
|
170
|
+
"build_constraint_review_artifact",
|
|
171
|
+
),
|
|
172
|
+
"constraint_ir_from_schema": ("dataforge.verifier", "constraint_ir_from_schema"),
|
|
173
|
+
"dump_constraint_review_artifact": (
|
|
174
|
+
"dataforge.schema_inference",
|
|
175
|
+
"dump_constraint_review_artifact",
|
|
176
|
+
),
|
|
177
|
+
"load_constraint_review_artifact": (
|
|
178
|
+
"dataforge.schema_inference",
|
|
179
|
+
"load_constraint_review_artifact",
|
|
180
|
+
),
|
|
181
|
+
"read_csv": ("dataforge.table", "read_csv"),
|
|
182
|
+
"revert_transaction": ("dataforge.transactions.revert", "revert_transaction"),
|
|
183
|
+
"run_all_detectors": ("dataforge.detectors", "run_all_detectors"),
|
|
184
|
+
"run_repair_pipeline": ("dataforge.engine.repair", "run_repair_pipeline"),
|
|
185
|
+
"schema_from_mapping": ("dataforge.cli.common", "schema_from_mapping"),
|
|
186
|
+
"schema_from_dbt_artifacts": ("dataforge.integrations.dbt", "schema_from_dbt_artifacts"),
|
|
187
|
+
"schema_from_dbt_manifest": ("dataforge.integrations.dbt", "schema_from_dbt_manifest"),
|
|
188
|
+
"infer_schema": ("dataforge.schema_inference", "infer_schema"),
|
|
189
|
+
"is_table_store_uri": ("dataforge.stores", "is_table_store_uri"),
|
|
190
|
+
"run_table_store_repair": ("dataforge.stores", "run_table_store_repair"),
|
|
191
|
+
"store_from_uri": ("dataforge.stores", "store_from_uri"),
|
|
192
|
+
"verify_transaction_log": ("dataforge.transactions.log", "verify_transaction_log"),
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def __getattr__(name: str) -> Any:
|
|
197
|
+
"""Resolve public facade exports on first use."""
|
|
198
|
+
try:
|
|
199
|
+
module_name, attribute_name = _PUBLIC_EXPORTS[name]
|
|
200
|
+
except KeyError as exc:
|
|
201
|
+
raise AttributeError(name) from exc
|
|
202
|
+
value = getattr(import_module(module_name), attribute_name)
|
|
203
|
+
globals()[name] = value
|
|
204
|
+
return value
|
dataforge/__main__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""DataForge agent package — typed tool-use actions and scratchpad.
|
|
2
|
+
|
|
3
|
+
Public API:
|
|
4
|
+
parse_action — Parse raw dict into typed Action model.
|
|
5
|
+
Action — Discriminated union of all action types.
|
|
6
|
+
Scratchpad — In-episode hypothesis tracker.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from dataforge.agent.scratchpad import Scratchpad
|
|
10
|
+
from dataforge.agent.tool_actions import Action, parse_action
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"Action",
|
|
14
|
+
"Scratchpad",
|
|
15
|
+
"parse_action",
|
|
16
|
+
]
|
|
@@ -0,0 +1,259 @@
|
|
|
1
|
+
"""Multi-provider LLM client for DataForge.
|
|
2
|
+
|
|
3
|
+
Reads ``DATAFORGE_LLM_PROVIDER`` from the environment and dispatches to the
|
|
4
|
+
matching provider. Week 1 implements **groq** and **gemini** only; other
|
|
5
|
+
providers raise ``NotImplementedError``.
|
|
6
|
+
|
|
7
|
+
No LLM calls are made by detectors — this module is for the agent loop
|
|
8
|
+
(Week 2+) and is stubbed here to establish the interface.
|
|
9
|
+
|
|
10
|
+
The interface is:
|
|
11
|
+
``async def complete(messages, model, temperature) -> str``
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
from typing import Literal, TypedDict
|
|
18
|
+
|
|
19
|
+
import httpx
|
|
20
|
+
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
|
21
|
+
|
|
22
|
+
# ── Message type ──────────────────────────────────────────────────────────
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class Message(TypedDict):
|
|
26
|
+
"""A single chat message.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
role: The speaker role — ``"system"``, ``"user"``, or ``"assistant"``.
|
|
30
|
+
content: The text content of the message.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
role: Literal["system", "user", "assistant"]
|
|
34
|
+
content: str
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# ── Exceptions ────────────────────────────────────────────────────────────
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class ProviderError(Exception):
|
|
41
|
+
"""Raised when an LLM provider call fails after retries.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
provider: The provider name that failed.
|
|
45
|
+
message: Description of the failure.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(self, provider: str, message: str) -> None:
|
|
49
|
+
self.provider = provider
|
|
50
|
+
super().__init__(f"[{provider}] {message}")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
# ── Provider dispatch ─────────────────────────────────────────────────────
|
|
54
|
+
|
|
55
|
+
_SUPPORTED_PROVIDERS = frozenset({"groq", "gemini", "cerebras", "openrouter", "hf", "cloudflare"})
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def get_provider_name() -> str:
|
|
59
|
+
"""Read the active provider from the environment.
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
The lowercased provider name from ``DATAFORGE_LLM_PROVIDER``.
|
|
63
|
+
When no explicit provider is configured, prefer a provider whose
|
|
64
|
+
credential is present in the environment.
|
|
65
|
+
|
|
66
|
+
Example:
|
|
67
|
+
>>> import os
|
|
68
|
+
>>> os.environ["DATAFORGE_LLM_PROVIDER"] = "gemini"
|
|
69
|
+
>>> get_provider_name()
|
|
70
|
+
'gemini'
|
|
71
|
+
"""
|
|
72
|
+
configured = os.environ.get("DATAFORGE_LLM_PROVIDER")
|
|
73
|
+
if configured:
|
|
74
|
+
return configured.lower()
|
|
75
|
+
if os.environ.get("GROQ_API_KEY"):
|
|
76
|
+
return "groq"
|
|
77
|
+
if os.environ.get("GEMINI_API_KEY"):
|
|
78
|
+
return "gemini"
|
|
79
|
+
return "groq"
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
async def complete(
|
|
83
|
+
messages: list[Message],
|
|
84
|
+
*,
|
|
85
|
+
model: str | None = None,
|
|
86
|
+
temperature: float = 0.0,
|
|
87
|
+
) -> str:
|
|
88
|
+
"""Send a chat completion request to the active LLM provider.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
messages: List of chat messages forming the conversation.
|
|
92
|
+
model: Optional model override. If None, uses the provider default.
|
|
93
|
+
temperature: Sampling temperature (0.0 = deterministic).
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
The assistant's response text.
|
|
97
|
+
|
|
98
|
+
Raises:
|
|
99
|
+
NotImplementedError: If the provider is not yet implemented.
|
|
100
|
+
ProviderError: If the API call fails after retries.
|
|
101
|
+
|
|
102
|
+
Example:
|
|
103
|
+
>>> import asyncio
|
|
104
|
+
>>> msgs = [{"role": "user", "content": "What is 2+2?"}]
|
|
105
|
+
>>> # result = asyncio.run(complete(msgs)) # requires API key
|
|
106
|
+
"""
|
|
107
|
+
provider = get_provider_name()
|
|
108
|
+
|
|
109
|
+
if provider == "groq":
|
|
110
|
+
return await _complete_groq(messages, model=model, temperature=temperature)
|
|
111
|
+
if provider == "gemini":
|
|
112
|
+
return await _complete_gemini(messages, model=model, temperature=temperature)
|
|
113
|
+
|
|
114
|
+
if provider in _SUPPORTED_PROVIDERS:
|
|
115
|
+
raise NotImplementedError(
|
|
116
|
+
f"Provider '{provider}' is planned but not yet implemented. "
|
|
117
|
+
f"Use 'groq' or 'gemini' for Week 1."
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
raise NotImplementedError(
|
|
121
|
+
f"Unknown provider '{provider}'. Supported: {sorted(_SUPPORTED_PROVIDERS)}"
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
# ── Groq provider ────────────────────────────────────────────────────────
|
|
126
|
+
|
|
127
|
+
_GROQ_URL = "https://api.groq.com/openai/v1/chat/completions"
|
|
128
|
+
_GROQ_DEFAULT_MODEL = "llama-3.1-70b-versatile"
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
@retry(
|
|
132
|
+
retry=retry_if_exception_type(httpx.HTTPStatusError),
|
|
133
|
+
wait=wait_exponential(multiplier=1, min=1, max=30),
|
|
134
|
+
stop=stop_after_attempt(3),
|
|
135
|
+
reraise=True,
|
|
136
|
+
)
|
|
137
|
+
async def _complete_groq(
|
|
138
|
+
messages: list[Message],
|
|
139
|
+
*,
|
|
140
|
+
model: str | None = None,
|
|
141
|
+
temperature: float = 0.0,
|
|
142
|
+
) -> str:
|
|
143
|
+
"""Call Groq's OpenAI-compatible chat completions API.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
messages: Chat messages.
|
|
147
|
+
model: Model name (defaults to llama-3.1-70b-versatile).
|
|
148
|
+
temperature: Sampling temperature.
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
The assistant's response text.
|
|
152
|
+
|
|
153
|
+
Raises:
|
|
154
|
+
ProviderError: If the response is malformed.
|
|
155
|
+
"""
|
|
156
|
+
api_key = os.environ.get("GROQ_API_KEY", "")
|
|
157
|
+
if not api_key:
|
|
158
|
+
raise ProviderError("groq", "GROQ_API_KEY environment variable not set")
|
|
159
|
+
|
|
160
|
+
payload = {
|
|
161
|
+
"model": model or _GROQ_DEFAULT_MODEL,
|
|
162
|
+
"messages": [dict(m) for m in messages],
|
|
163
|
+
"temperature": temperature,
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
async with httpx.AsyncClient(timeout=60.0) as client:
|
|
167
|
+
response = await client.post(
|
|
168
|
+
_GROQ_URL,
|
|
169
|
+
json=payload,
|
|
170
|
+
headers={
|
|
171
|
+
"Authorization": f"Bearer {api_key}",
|
|
172
|
+
"Content-Type": "application/json",
|
|
173
|
+
},
|
|
174
|
+
)
|
|
175
|
+
response.raise_for_status()
|
|
176
|
+
|
|
177
|
+
data = response.json()
|
|
178
|
+
try:
|
|
179
|
+
return str(data["choices"][0]["message"]["content"])
|
|
180
|
+
except (KeyError, IndexError) as exc:
|
|
181
|
+
raise ProviderError("groq", f"Unexpected response format: {data}") from exc
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
# ── Gemini provider ──────────────────────────────────────────────────────
|
|
185
|
+
|
|
186
|
+
_GEMINI_URL_TEMPLATE = (
|
|
187
|
+
"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent"
|
|
188
|
+
)
|
|
189
|
+
_GEMINI_DEFAULT_MODEL = "gemini-2.0-flash"
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
@retry(
|
|
193
|
+
retry=retry_if_exception_type(httpx.HTTPStatusError),
|
|
194
|
+
wait=wait_exponential(multiplier=1, min=1, max=30),
|
|
195
|
+
stop=stop_after_attempt(3),
|
|
196
|
+
reraise=True,
|
|
197
|
+
)
|
|
198
|
+
async def _complete_gemini(
|
|
199
|
+
messages: list[Message],
|
|
200
|
+
*,
|
|
201
|
+
model: str | None = None,
|
|
202
|
+
temperature: float = 0.0,
|
|
203
|
+
) -> str:
|
|
204
|
+
"""Call Google's Gemini generativeLanguage API.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
messages: Chat messages (converted to Gemini's content format).
|
|
208
|
+
model: Model name (defaults to gemini-2.0-flash).
|
|
209
|
+
temperature: Sampling temperature.
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
The assistant's response text.
|
|
213
|
+
|
|
214
|
+
Raises:
|
|
215
|
+
ProviderError: If the response is malformed.
|
|
216
|
+
"""
|
|
217
|
+
api_key = os.environ.get("GEMINI_API_KEY", "")
|
|
218
|
+
if not api_key:
|
|
219
|
+
raise ProviderError("gemini", "GEMINI_API_KEY environment variable not set")
|
|
220
|
+
|
|
221
|
+
model_name = model or _GEMINI_DEFAULT_MODEL
|
|
222
|
+
url = _GEMINI_URL_TEMPLATE.format(model=model_name)
|
|
223
|
+
|
|
224
|
+
# Convert OpenAI-style messages to Gemini format.
|
|
225
|
+
contents: list[dict[str, object]] = []
|
|
226
|
+
system_instruction: str | None = None
|
|
227
|
+
for msg in messages:
|
|
228
|
+
if msg["role"] == "system":
|
|
229
|
+
system_instruction = msg["content"]
|
|
230
|
+
else:
|
|
231
|
+
role = "user" if msg["role"] == "user" else "model"
|
|
232
|
+
contents.append(
|
|
233
|
+
{
|
|
234
|
+
"role": role,
|
|
235
|
+
"parts": [{"text": msg["content"]}],
|
|
236
|
+
}
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
payload: dict[str, object] = {
|
|
240
|
+
"contents": contents,
|
|
241
|
+
"generationConfig": {"temperature": temperature},
|
|
242
|
+
}
|
|
243
|
+
if system_instruction:
|
|
244
|
+
payload["systemInstruction"] = {"parts": [{"text": system_instruction}]}
|
|
245
|
+
|
|
246
|
+
async with httpx.AsyncClient(timeout=60.0) as client:
|
|
247
|
+
response = await client.post(
|
|
248
|
+
url,
|
|
249
|
+
json=payload,
|
|
250
|
+
params={"key": api_key},
|
|
251
|
+
headers={"Content-Type": "application/json"},
|
|
252
|
+
)
|
|
253
|
+
response.raise_for_status()
|
|
254
|
+
|
|
255
|
+
data = response.json()
|
|
256
|
+
try:
|
|
257
|
+
return str(data["candidates"][0]["content"]["parts"][0]["text"])
|
|
258
|
+
except (KeyError, IndexError) as exc:
|
|
259
|
+
raise ProviderError("gemini", f"Unexpected response format: {data}") from exc
|
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
"""In-episode hypothesis and issue tracker for the DataForge RL agent.
|
|
2
|
+
|
|
3
|
+
The scratchpad is a mutable, episode-scoped data structure that the agent
|
|
4
|
+
uses to record hypotheses, confirmed issues, and dead ends. The environment
|
|
5
|
+
exposes a compact summary of the scratchpad in each observation, enabling
|
|
6
|
+
the agent to reason about its investigation history without direct access
|
|
7
|
+
to the underlying data structure.
|
|
8
|
+
|
|
9
|
+
Example::
|
|
10
|
+
|
|
11
|
+
>>> from dataforge.agent.scratchpad import Scratchpad
|
|
12
|
+
>>> pad = Scratchpad()
|
|
13
|
+
>>> pad.add_hypothesis("Rating column has decimal shift", [5], ["rating"], "decimal_shift")
|
|
14
|
+
>>> pad.confirm_issue(5, "rating", "decimal_shift")
|
|
15
|
+
>>> pad.summary()
|
|
16
|
+
'Hypotheses: 1 (0 pending). Confirmed: 1. Dead ends: 0.'
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
from __future__ import annotations
|
|
20
|
+
|
|
21
|
+
from dataclasses import dataclass, field
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"ConfirmedIssue",
|
|
25
|
+
"DeadEnd",
|
|
26
|
+
"HypothesisRecord",
|
|
27
|
+
"Scratchpad",
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass(frozen=True)
|
|
32
|
+
class HypothesisRecord:
|
|
33
|
+
"""A recorded hypothesis about a data-quality root cause.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
claim: Textual description of the hypothesis.
|
|
37
|
+
affected_rows: Row indices the hypothesis covers.
|
|
38
|
+
affected_columns: Column names the hypothesis covers.
|
|
39
|
+
root_cause_type: Detector-vocabulary root cause type.
|
|
40
|
+
confirmed: Whether the hypothesis was confirmed by ground truth.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
claim: str
|
|
44
|
+
affected_rows: tuple[int, ...]
|
|
45
|
+
affected_columns: tuple[str, ...]
|
|
46
|
+
root_cause_type: str
|
|
47
|
+
confirmed: bool = False
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass(frozen=True)
|
|
51
|
+
class ConfirmedIssue:
|
|
52
|
+
"""A confirmed data-quality issue at a specific location.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
row: Zero-indexed row number.
|
|
56
|
+
column: Column name.
|
|
57
|
+
issue_type: Issue type classification.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
row: int
|
|
61
|
+
column: str
|
|
62
|
+
issue_type: str
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@dataclass(frozen=True)
|
|
66
|
+
class DeadEnd:
|
|
67
|
+
"""A recorded dead end — an investigation path that yielded nothing.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
description: What was tried and why it failed.
|
|
71
|
+
step_number: Step at which the dead end was recorded.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
description: str
|
|
75
|
+
step_number: int
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@dataclass
|
|
79
|
+
class Scratchpad:
|
|
80
|
+
"""Mutable in-episode tracker for hypotheses, confirmed issues, and dead ends.
|
|
81
|
+
|
|
82
|
+
Reset at the start of each episode. The ``summary()`` method produces a
|
|
83
|
+
compact string for inclusion in agent observations.
|
|
84
|
+
|
|
85
|
+
Example::
|
|
86
|
+
|
|
87
|
+
>>> pad = Scratchpad()
|
|
88
|
+
>>> pad.add_hypothesis("Decimal shift in rating", [5], ["rating"], "decimal_shift")
|
|
89
|
+
>>> len(pad.hypotheses)
|
|
90
|
+
1
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
hypotheses: list[HypothesisRecord] = field(default_factory=list)
|
|
94
|
+
confirmed_issues: list[ConfirmedIssue] = field(default_factory=list)
|
|
95
|
+
dead_ends: list[DeadEnd] = field(default_factory=list)
|
|
96
|
+
|
|
97
|
+
def add_hypothesis(
|
|
98
|
+
self,
|
|
99
|
+
claim: str,
|
|
100
|
+
affected_rows: list[int],
|
|
101
|
+
affected_columns: list[str],
|
|
102
|
+
root_cause_type: str,
|
|
103
|
+
) -> HypothesisRecord:
|
|
104
|
+
"""Record a new hypothesis.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
claim: Textual description of the hypothesis.
|
|
108
|
+
affected_rows: Row indices the hypothesis covers.
|
|
109
|
+
affected_columns: Column names the hypothesis covers.
|
|
110
|
+
root_cause_type: Detector-vocabulary root cause type.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
The recorded hypothesis.
|
|
114
|
+
"""
|
|
115
|
+
record = HypothesisRecord(
|
|
116
|
+
claim=claim,
|
|
117
|
+
affected_rows=tuple(affected_rows),
|
|
118
|
+
affected_columns=tuple(affected_columns),
|
|
119
|
+
root_cause_type=root_cause_type,
|
|
120
|
+
)
|
|
121
|
+
self.hypotheses.append(record)
|
|
122
|
+
return record
|
|
123
|
+
|
|
124
|
+
def confirm_hypothesis(self, index: int) -> None:
|
|
125
|
+
"""Mark a hypothesis as confirmed.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
index: Index into the ``hypotheses`` list.
|
|
129
|
+
|
|
130
|
+
Raises:
|
|
131
|
+
IndexError: If the index is out of range.
|
|
132
|
+
"""
|
|
133
|
+
old = self.hypotheses[index]
|
|
134
|
+
self.hypotheses[index] = HypothesisRecord(
|
|
135
|
+
claim=old.claim,
|
|
136
|
+
affected_rows=old.affected_rows,
|
|
137
|
+
affected_columns=old.affected_columns,
|
|
138
|
+
root_cause_type=old.root_cause_type,
|
|
139
|
+
confirmed=True,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
def confirm_issue(self, row: int, column: str, issue_type: str) -> None:
|
|
143
|
+
"""Record a confirmed issue.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
row: Zero-indexed row number.
|
|
147
|
+
column: Column name.
|
|
148
|
+
issue_type: Issue type classification.
|
|
149
|
+
"""
|
|
150
|
+
self.confirmed_issues.append(ConfirmedIssue(row=row, column=column, issue_type=issue_type))
|
|
151
|
+
|
|
152
|
+
def add_dead_end(self, description: str, step_number: int) -> None:
|
|
153
|
+
"""Record a dead end.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
description: What was tried and why it failed.
|
|
157
|
+
step_number: Step at which the dead end was recorded.
|
|
158
|
+
"""
|
|
159
|
+
self.dead_ends.append(DeadEnd(description=description, step_number=step_number))
|
|
160
|
+
|
|
161
|
+
def reset(self) -> None:
|
|
162
|
+
"""Clear all tracked state for a new episode."""
|
|
163
|
+
self.hypotheses.clear()
|
|
164
|
+
self.confirmed_issues.clear()
|
|
165
|
+
self.dead_ends.clear()
|
|
166
|
+
|
|
167
|
+
def summary(self) -> str:
|
|
168
|
+
"""Produce a compact summary string for observation embedding.
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
A one-line summary of scratchpad state.
|
|
172
|
+
|
|
173
|
+
Example::
|
|
174
|
+
|
|
175
|
+
>>> Scratchpad().summary()
|
|
176
|
+
'Hypotheses: 0 (0 pending). Confirmed: 0. Dead ends: 0.'
|
|
177
|
+
"""
|
|
178
|
+
pending = sum(1 for h in self.hypotheses if not h.confirmed)
|
|
179
|
+
return (
|
|
180
|
+
f"Hypotheses: {len(self.hypotheses)} ({pending} pending). "
|
|
181
|
+
f"Confirmed: {len(self.confirmed_issues)}. "
|
|
182
|
+
f"Dead ends: {len(self.dead_ends)}."
|
|
183
|
+
)
|