dataforge-07 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. dataforge/__init__.py +204 -0
  2. dataforge/__main__.py +5 -0
  3. dataforge/agent/__init__.py +16 -0
  4. dataforge/agent/providers.py +259 -0
  5. dataforge/agent/scratchpad.py +183 -0
  6. dataforge/agent/tool_actions.py +343 -0
  7. dataforge/bench/__init__.py +31 -0
  8. dataforge/bench/core.py +426 -0
  9. dataforge/bench/groq_client.py +386 -0
  10. dataforge/bench/methods.py +443 -0
  11. dataforge/bench/report.py +309 -0
  12. dataforge/bench/runner.py +247 -0
  13. dataforge/causal/__init__.py +21 -0
  14. dataforge/causal/dag.py +174 -0
  15. dataforge/causal/pc.py +232 -0
  16. dataforge/causal/root_cause.py +193 -0
  17. dataforge/cli/__init__.py +50 -0
  18. dataforge/cli/audit.py +70 -0
  19. dataforge/cli/bench.py +154 -0
  20. dataforge/cli/common.py +267 -0
  21. dataforge/cli/constraints.py +407 -0
  22. dataforge/cli/profile.py +147 -0
  23. dataforge/cli/release.py +166 -0
  24. dataforge/cli/repair.py +407 -0
  25. dataforge/cli/revert.py +139 -0
  26. dataforge/cli/watch.py +144 -0
  27. dataforge/datasets/__init__.py +25 -0
  28. dataforge/datasets/embedded/hospital/clean.csv +11 -0
  29. dataforge/datasets/embedded/hospital/dirty.csv +11 -0
  30. dataforge/datasets/real_world.py +290 -0
  31. dataforge/datasets/registry.py +103 -0
  32. dataforge/detectors/__init__.py +80 -0
  33. dataforge/detectors/base.py +145 -0
  34. dataforge/detectors/decimal_shift.py +166 -0
  35. dataforge/detectors/fd_violation.py +157 -0
  36. dataforge/detectors/type_mismatch.py +173 -0
  37. dataforge/engine/__init__.py +39 -0
  38. dataforge/engine/repair.py +905 -0
  39. dataforge/env/__init__.py +22 -0
  40. dataforge/env/environment.py +883 -0
  41. dataforge/env/observation.py +61 -0
  42. dataforge/env/openenv_core.py +161 -0
  43. dataforge/env/reward.py +128 -0
  44. dataforge/env/server.py +176 -0
  45. dataforge/evaluation_contract.py +76 -0
  46. dataforge/fixtures/hospital_10rows.csv +11 -0
  47. dataforge/fixtures/hospital_schema.yaml +17 -0
  48. dataforge/http/__init__.py +1 -0
  49. dataforge/http/problem.py +103 -0
  50. dataforge/integrations/__init__.py +1 -0
  51. dataforge/integrations/dbt.py +164 -0
  52. dataforge/observability.py +76 -0
  53. dataforge/py.typed +1 -0
  54. dataforge/release/__init__.py +1 -0
  55. dataforge/release/doctor.py +367 -0
  56. dataforge/release/full_vision.py +702 -0
  57. dataforge/release/gate.py +861 -0
  58. dataforge/release/playground_check.py +411 -0
  59. dataforge/repair_contract.py +468 -0
  60. dataforge/repairers/__init__.py +88 -0
  61. dataforge/repairers/base.py +77 -0
  62. dataforge/repairers/decimal_shift.py +43 -0
  63. dataforge/repairers/fd_violation.py +225 -0
  64. dataforge/repairers/type_mismatch.py +73 -0
  65. dataforge/safety/__init__.py +5 -0
  66. dataforge/safety/adversarial/attack_01_phone_pii.yaml +8 -0
  67. dataforge/safety/adversarial/attack_02_phone_pii.yaml +8 -0
  68. dataforge/safety/adversarial/attack_03_phone_pii.yaml +8 -0
  69. dataforge/safety/adversarial/attack_04_phone_pii.yaml +8 -0
  70. dataforge/safety/adversarial/attack_05_phone_pii.yaml +8 -0
  71. dataforge/safety/adversarial/attack_06_phone_pii.yaml +8 -0
  72. dataforge/safety/adversarial/attack_07_phone_pii.yaml +8 -0
  73. dataforge/safety/adversarial/attack_08_phone_pii.yaml +8 -0
  74. dataforge/safety/adversarial/attack_09_phone_pii.yaml +8 -0
  75. dataforge/safety/adversarial/attack_10_phone_pii.yaml +8 -0
  76. dataforge/safety/adversarial/attack_11_ssn_pii.yaml +8 -0
  77. dataforge/safety/adversarial/attack_12_ssn_pii.yaml +8 -0
  78. dataforge/safety/adversarial/attack_13_ssn_pii.yaml +8 -0
  79. dataforge/safety/adversarial/attack_14_ssn_pii.yaml +8 -0
  80. dataforge/safety/adversarial/attack_15_ssn_pii.yaml +8 -0
  81. dataforge/safety/adversarial/attack_16_ssn_pii.yaml +8 -0
  82. dataforge/safety/adversarial/attack_17_ssn_pii.yaml +8 -0
  83. dataforge/safety/adversarial/attack_18_ssn_pii.yaml +8 -0
  84. dataforge/safety/adversarial/attack_19_ssn_pii.yaml +8 -0
  85. dataforge/safety/adversarial/attack_20_ssn_pii.yaml +8 -0
  86. dataforge/safety/adversarial/attack_21_email_pii.yaml +8 -0
  87. dataforge/safety/adversarial/attack_22_email_pii.yaml +8 -0
  88. dataforge/safety/adversarial/attack_23_email_pii.yaml +8 -0
  89. dataforge/safety/adversarial/attack_24_email_pii.yaml +8 -0
  90. dataforge/safety/adversarial/attack_25_email_pii.yaml +8 -0
  91. dataforge/safety/adversarial/attack_26_email_pii.yaml +8 -0
  92. dataforge/safety/adversarial/attack_27_email_pii.yaml +8 -0
  93. dataforge/safety/adversarial/attack_28_email_pii.yaml +8 -0
  94. dataforge/safety/adversarial/attack_29_email_pii.yaml +8 -0
  95. dataforge/safety/adversarial/attack_30_email_pii.yaml +8 -0
  96. dataforge/safety/adversarial/attack_31_row_delete.yaml +7 -0
  97. dataforge/safety/adversarial/attack_32_row_delete.yaml +8 -0
  98. dataforge/safety/adversarial/attack_33_row_delete.yaml +7 -0
  99. dataforge/safety/adversarial/attack_34_row_delete.yaml +7 -0
  100. dataforge/safety/adversarial/attack_35_row_delete.yaml +7 -0
  101. dataforge/safety/adversarial/attack_36_row_delete.yaml +11 -0
  102. dataforge/safety/adversarial/attack_37_row_delete.yaml +7 -0
  103. dataforge/safety/adversarial/attack_38_row_delete.yaml +7 -0
  104. dataforge/safety/adversarial/attack_39_row_delete.yaml +8 -0
  105. dataforge/safety/adversarial/attack_40_row_delete.yaml +7 -0
  106. dataforge/safety/adversarial/attack_41_row_delete.yaml +7 -0
  107. dataforge/safety/adversarial/attack_42_row_delete.yaml +7 -0
  108. dataforge/safety/adversarial/attack_43_row_delete.yaml +7 -0
  109. dataforge/safety/adversarial/attack_44_row_delete.yaml +7 -0
  110. dataforge/safety/adversarial/attack_45_row_delete.yaml +8 -0
  111. dataforge/safety/adversarial/attack_46_row_delete.yaml +8 -0
  112. dataforge/safety/adversarial/attack_47_row_delete.yaml +7 -0
  113. dataforge/safety/adversarial/attack_48_row_delete.yaml +7 -0
  114. dataforge/safety/adversarial/attack_49_row_delete.yaml +8 -0
  115. dataforge/safety/adversarial/attack_50_row_delete.yaml +7 -0
  116. dataforge/safety/constitution.py +307 -0
  117. dataforge/safety/constitutions/default.yaml +40 -0
  118. dataforge/safety/filter.py +134 -0
  119. dataforge/schema_inference.py +620 -0
  120. dataforge/stores/__init__.py +46 -0
  121. dataforge/stores/base.py +73 -0
  122. dataforge/stores/cloud.py +78 -0
  123. dataforge/stores/csv.py +94 -0
  124. dataforge/stores/duckdb.py +313 -0
  125. dataforge/stores/patch_plan.py +178 -0
  126. dataforge/stores/registry.py +82 -0
  127. dataforge/stores/repair.py +121 -0
  128. dataforge/stores/revert.py +22 -0
  129. dataforge/stores/sql.py +27 -0
  130. dataforge/table.py +228 -0
  131. dataforge/transactions/__init__.py +34 -0
  132. dataforge/transactions/files.py +96 -0
  133. dataforge/transactions/log.py +613 -0
  134. dataforge/transactions/revert.py +102 -0
  135. dataforge/transactions/txn.py +104 -0
  136. dataforge/ui/__init__.py +1 -0
  137. dataforge/ui/profile_view.py +136 -0
  138. dataforge/ui/repair_diff.py +91 -0
  139. dataforge/verifier/__init__.py +55 -0
  140. dataforge/verifier/constraint_ir.py +155 -0
  141. dataforge/verifier/explain.py +47 -0
  142. dataforge/verifier/gate.py +5 -0
  143. dataforge/verifier/schema.py +111 -0
  144. dataforge/verifier/smt.py +433 -0
  145. dataforge_07-0.1.0.dist-info/METADATA +436 -0
  146. dataforge_07-0.1.0.dist-info/RECORD +150 -0
  147. dataforge_07-0.1.0.dist-info/WHEEL +5 -0
  148. dataforge_07-0.1.0.dist-info/entry_points.txt +3 -0
  149. dataforge_07-0.1.0.dist-info/licenses/LICENSE +176 -0
  150. dataforge_07-0.1.0.dist-info/top_level.txt +1 -0
dataforge/__init__.py ADDED
@@ -0,0 +1,204 @@
1
+ """DataForge public package.
2
+
3
+ The root package is the stable facade for integration surfaces. Symbols are
4
+ resolved lazily so importing :mod:`dataforge` does not eagerly import pandas,
5
+ FastAPI-facing helpers, or the SMT stack.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from importlib import import_module
11
+ from typing import TYPE_CHECKING, Any
12
+
13
+ if TYPE_CHECKING:
14
+ from dataforge.cli.common import load_schema, schema_from_mapping
15
+ from dataforge.detectors import Issue, Schema, Severity, run_all_detectors
16
+ from dataforge.engine.repair import (
17
+ CandidateFix,
18
+ CandidateRepair,
19
+ ProofObligation,
20
+ RepairFailure,
21
+ RepairPipelineRequest,
22
+ RepairPipelineResult,
23
+ RepairReceipt,
24
+ RootCause,
25
+ VerifiedFix,
26
+ run_repair_pipeline,
27
+ )
28
+ from dataforge.integrations.dbt import schema_from_dbt_artifacts, schema_from_dbt_manifest
29
+ from dataforge.repair_contract import CONTRACT_VERSION
30
+ from dataforge.repairers import ProposedFix
31
+ from dataforge.safety import SafetyContext, SafetyFilter, SafetyResult, SafetyVerdict
32
+ from dataforge.schema_inference import (
33
+ ConstraintCandidate,
34
+ ConstraintReviewArtifact,
35
+ ConstraintReviewError,
36
+ ReviewedConstraintCandidate,
37
+ SchemaInferenceResult,
38
+ build_constraint_review_artifact,
39
+ dump_constraint_review_artifact,
40
+ infer_schema,
41
+ load_constraint_review_artifact,
42
+ )
43
+ from dataforge.stores import (
44
+ DuckDBStore,
45
+ PatchPlan,
46
+ TableStoreError,
47
+ TableStoreRepairResult,
48
+ is_table_store_uri,
49
+ run_table_store_repair,
50
+ store_from_uri,
51
+ )
52
+ from dataforge.table import read_csv
53
+ from dataforge.transactions.log import (
54
+ TransactionAuditReport,
55
+ TransactionAuditVerdict,
56
+ TransactionLogError,
57
+ verify_transaction_log,
58
+ )
59
+ from dataforge.transactions.revert import TransactionRevertError, revert_transaction
60
+ from dataforge.transactions.txn import CellFix, RepairTransaction
61
+ from dataforge.verifier import (
62
+ ConstraintIR,
63
+ SMTVerifier,
64
+ VerificationResult,
65
+ VerificationVerdict,
66
+ constraint_ir_from_schema,
67
+ )
68
+
69
+ __all__ = [
70
+ "CONTRACT_VERSION",
71
+ "CandidateFix",
72
+ "CandidateRepair",
73
+ "CellFix",
74
+ "ConstraintCandidate",
75
+ "ConstraintReviewArtifact",
76
+ "ConstraintReviewError",
77
+ "ConstraintIR",
78
+ "DuckDBStore",
79
+ "Issue",
80
+ "PatchPlan",
81
+ "ProposedFix",
82
+ "ProofObligation",
83
+ "RepairFailure",
84
+ "RepairPipelineRequest",
85
+ "RepairPipelineResult",
86
+ "RepairReceipt",
87
+ "RepairTransaction",
88
+ "RootCause",
89
+ "ReviewedConstraintCandidate",
90
+ "SMTVerifier",
91
+ "SafetyContext",
92
+ "SafetyFilter",
93
+ "SafetyResult",
94
+ "SafetyVerdict",
95
+ "Schema",
96
+ "SchemaInferenceResult",
97
+ "Severity",
98
+ "TransactionAuditReport",
99
+ "TransactionAuditVerdict",
100
+ "TransactionLogError",
101
+ "TransactionRevertError",
102
+ "TableStoreError",
103
+ "TableStoreRepairResult",
104
+ "VerificationResult",
105
+ "VerificationVerdict",
106
+ "VerifiedFix",
107
+ "__version__",
108
+ "load_schema",
109
+ "build_constraint_review_artifact",
110
+ "constraint_ir_from_schema",
111
+ "dump_constraint_review_artifact",
112
+ "load_constraint_review_artifact",
113
+ "read_csv",
114
+ "revert_transaction",
115
+ "run_all_detectors",
116
+ "run_repair_pipeline",
117
+ "schema_from_mapping",
118
+ "schema_from_dbt_artifacts",
119
+ "schema_from_dbt_manifest",
120
+ "infer_schema",
121
+ "is_table_store_uri",
122
+ "run_table_store_repair",
123
+ "store_from_uri",
124
+ "verify_transaction_log",
125
+ ]
126
+
127
+ __version__ = "0.1.0"
128
+
129
+ _PUBLIC_EXPORTS: dict[str, tuple[str, str]] = {
130
+ "CONTRACT_VERSION": ("dataforge.repair_contract", "CONTRACT_VERSION"),
131
+ "CandidateFix": ("dataforge.engine.repair", "CandidateFix"),
132
+ "CandidateRepair": ("dataforge.engine.repair", "CandidateRepair"),
133
+ "CellFix": ("dataforge.transactions.txn", "CellFix"),
134
+ "ConstraintCandidate": ("dataforge.schema_inference", "ConstraintCandidate"),
135
+ "ConstraintReviewArtifact": ("dataforge.schema_inference", "ConstraintReviewArtifact"),
136
+ "ConstraintReviewError": ("dataforge.schema_inference", "ConstraintReviewError"),
137
+ "ConstraintIR": ("dataforge.verifier", "ConstraintIR"),
138
+ "DuckDBStore": ("dataforge.stores", "DuckDBStore"),
139
+ "Issue": ("dataforge.detectors", "Issue"),
140
+ "ProposedFix": ("dataforge.repairers", "ProposedFix"),
141
+ "ProofObligation": ("dataforge.engine.repair", "ProofObligation"),
142
+ "PatchPlan": ("dataforge.stores", "PatchPlan"),
143
+ "RepairFailure": ("dataforge.engine.repair", "RepairFailure"),
144
+ "RepairPipelineRequest": ("dataforge.engine.repair", "RepairPipelineRequest"),
145
+ "RepairPipelineResult": ("dataforge.engine.repair", "RepairPipelineResult"),
146
+ "RepairReceipt": ("dataforge.engine.repair", "RepairReceipt"),
147
+ "RepairTransaction": ("dataforge.transactions.txn", "RepairTransaction"),
148
+ "RootCause": ("dataforge.engine.repair", "RootCause"),
149
+ "ReviewedConstraintCandidate": ("dataforge.schema_inference", "ReviewedConstraintCandidate"),
150
+ "SMTVerifier": ("dataforge.verifier", "SMTVerifier"),
151
+ "SafetyContext": ("dataforge.safety", "SafetyContext"),
152
+ "SafetyFilter": ("dataforge.safety", "SafetyFilter"),
153
+ "SafetyResult": ("dataforge.safety", "SafetyResult"),
154
+ "SafetyVerdict": ("dataforge.safety", "SafetyVerdict"),
155
+ "Schema": ("dataforge.detectors", "Schema"),
156
+ "SchemaInferenceResult": ("dataforge.schema_inference", "SchemaInferenceResult"),
157
+ "Severity": ("dataforge.detectors", "Severity"),
158
+ "TransactionAuditReport": ("dataforge.transactions.log", "TransactionAuditReport"),
159
+ "TransactionAuditVerdict": ("dataforge.transactions.log", "TransactionAuditVerdict"),
160
+ "TransactionLogError": ("dataforge.transactions.log", "TransactionLogError"),
161
+ "TransactionRevertError": ("dataforge.transactions.revert", "TransactionRevertError"),
162
+ "TableStoreError": ("dataforge.stores", "TableStoreError"),
163
+ "TableStoreRepairResult": ("dataforge.stores", "TableStoreRepairResult"),
164
+ "VerificationResult": ("dataforge.verifier", "VerificationResult"),
165
+ "VerificationVerdict": ("dataforge.verifier", "VerificationVerdict"),
166
+ "VerifiedFix": ("dataforge.engine.repair", "VerifiedFix"),
167
+ "load_schema": ("dataforge.cli.common", "load_schema"),
168
+ "build_constraint_review_artifact": (
169
+ "dataforge.schema_inference",
170
+ "build_constraint_review_artifact",
171
+ ),
172
+ "constraint_ir_from_schema": ("dataforge.verifier", "constraint_ir_from_schema"),
173
+ "dump_constraint_review_artifact": (
174
+ "dataforge.schema_inference",
175
+ "dump_constraint_review_artifact",
176
+ ),
177
+ "load_constraint_review_artifact": (
178
+ "dataforge.schema_inference",
179
+ "load_constraint_review_artifact",
180
+ ),
181
+ "read_csv": ("dataforge.table", "read_csv"),
182
+ "revert_transaction": ("dataforge.transactions.revert", "revert_transaction"),
183
+ "run_all_detectors": ("dataforge.detectors", "run_all_detectors"),
184
+ "run_repair_pipeline": ("dataforge.engine.repair", "run_repair_pipeline"),
185
+ "schema_from_mapping": ("dataforge.cli.common", "schema_from_mapping"),
186
+ "schema_from_dbt_artifacts": ("dataforge.integrations.dbt", "schema_from_dbt_artifacts"),
187
+ "schema_from_dbt_manifest": ("dataforge.integrations.dbt", "schema_from_dbt_manifest"),
188
+ "infer_schema": ("dataforge.schema_inference", "infer_schema"),
189
+ "is_table_store_uri": ("dataforge.stores", "is_table_store_uri"),
190
+ "run_table_store_repair": ("dataforge.stores", "run_table_store_repair"),
191
+ "store_from_uri": ("dataforge.stores", "store_from_uri"),
192
+ "verify_transaction_log": ("dataforge.transactions.log", "verify_transaction_log"),
193
+ }
194
+
195
+
196
+ def __getattr__(name: str) -> Any:
197
+ """Resolve public facade exports on first use."""
198
+ try:
199
+ module_name, attribute_name = _PUBLIC_EXPORTS[name]
200
+ except KeyError as exc:
201
+ raise AttributeError(name) from exc
202
+ value = getattr(import_module(module_name), attribute_name)
203
+ globals()[name] = value
204
+ return value
dataforge/__main__.py ADDED
@@ -0,0 +1,5 @@
1
+ """Enable ``python -m dataforge`` invocation."""
2
+
3
+ from dataforge.cli import app
4
+
5
+ app()
@@ -0,0 +1,16 @@
1
+ """DataForge agent package — typed tool-use actions and scratchpad.
2
+
3
+ Public API:
4
+ parse_action — Parse raw dict into typed Action model.
5
+ Action — Discriminated union of all action types.
6
+ Scratchpad — In-episode hypothesis tracker.
7
+ """
8
+
9
+ from dataforge.agent.scratchpad import Scratchpad
10
+ from dataforge.agent.tool_actions import Action, parse_action
11
+
12
+ __all__ = [
13
+ "Action",
14
+ "Scratchpad",
15
+ "parse_action",
16
+ ]
@@ -0,0 +1,259 @@
1
+ """Multi-provider LLM client for DataForge.
2
+
3
+ Reads ``DATAFORGE_LLM_PROVIDER`` from the environment and dispatches to the
4
+ matching provider. Week 1 implements **groq** and **gemini** only; other
5
+ providers raise ``NotImplementedError``.
6
+
7
+ No LLM calls are made by detectors — this module is for the agent loop
8
+ (Week 2+) and is stubbed here to establish the interface.
9
+
10
+ The interface is:
11
+ ``async def complete(messages, model, temperature) -> str``
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import os
17
+ from typing import Literal, TypedDict
18
+
19
+ import httpx
20
+ from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
21
+
22
+ # ── Message type ──────────────────────────────────────────────────────────
23
+
24
+
25
+ class Message(TypedDict):
26
+ """A single chat message.
27
+
28
+ Args:
29
+ role: The speaker role — ``"system"``, ``"user"``, or ``"assistant"``.
30
+ content: The text content of the message.
31
+ """
32
+
33
+ role: Literal["system", "user", "assistant"]
34
+ content: str
35
+
36
+
37
+ # ── Exceptions ────────────────────────────────────────────────────────────
38
+
39
+
40
+ class ProviderError(Exception):
41
+ """Raised when an LLM provider call fails after retries.
42
+
43
+ Args:
44
+ provider: The provider name that failed.
45
+ message: Description of the failure.
46
+ """
47
+
48
+ def __init__(self, provider: str, message: str) -> None:
49
+ self.provider = provider
50
+ super().__init__(f"[{provider}] {message}")
51
+
52
+
53
+ # ── Provider dispatch ─────────────────────────────────────────────────────
54
+
55
+ _SUPPORTED_PROVIDERS = frozenset({"groq", "gemini", "cerebras", "openrouter", "hf", "cloudflare"})
56
+
57
+
58
+ def get_provider_name() -> str:
59
+ """Read the active provider from the environment.
60
+
61
+ Returns:
62
+ The lowercased provider name from ``DATAFORGE_LLM_PROVIDER``.
63
+ When no explicit provider is configured, prefer a provider whose
64
+ credential is present in the environment.
65
+
66
+ Example:
67
+ >>> import os
68
+ >>> os.environ["DATAFORGE_LLM_PROVIDER"] = "gemini"
69
+ >>> get_provider_name()
70
+ 'gemini'
71
+ """
72
+ configured = os.environ.get("DATAFORGE_LLM_PROVIDER")
73
+ if configured:
74
+ return configured.lower()
75
+ if os.environ.get("GROQ_API_KEY"):
76
+ return "groq"
77
+ if os.environ.get("GEMINI_API_KEY"):
78
+ return "gemini"
79
+ return "groq"
80
+
81
+
82
+ async def complete(
83
+ messages: list[Message],
84
+ *,
85
+ model: str | None = None,
86
+ temperature: float = 0.0,
87
+ ) -> str:
88
+ """Send a chat completion request to the active LLM provider.
89
+
90
+ Args:
91
+ messages: List of chat messages forming the conversation.
92
+ model: Optional model override. If None, uses the provider default.
93
+ temperature: Sampling temperature (0.0 = deterministic).
94
+
95
+ Returns:
96
+ The assistant's response text.
97
+
98
+ Raises:
99
+ NotImplementedError: If the provider is not yet implemented.
100
+ ProviderError: If the API call fails after retries.
101
+
102
+ Example:
103
+ >>> import asyncio
104
+ >>> msgs = [{"role": "user", "content": "What is 2+2?"}]
105
+ >>> # result = asyncio.run(complete(msgs)) # requires API key
106
+ """
107
+ provider = get_provider_name()
108
+
109
+ if provider == "groq":
110
+ return await _complete_groq(messages, model=model, temperature=temperature)
111
+ if provider == "gemini":
112
+ return await _complete_gemini(messages, model=model, temperature=temperature)
113
+
114
+ if provider in _SUPPORTED_PROVIDERS:
115
+ raise NotImplementedError(
116
+ f"Provider '{provider}' is planned but not yet implemented. "
117
+ f"Use 'groq' or 'gemini' for Week 1."
118
+ )
119
+
120
+ raise NotImplementedError(
121
+ f"Unknown provider '{provider}'. Supported: {sorted(_SUPPORTED_PROVIDERS)}"
122
+ )
123
+
124
+
125
+ # ── Groq provider ────────────────────────────────────────────────────────
126
+
127
+ _GROQ_URL = "https://api.groq.com/openai/v1/chat/completions"
128
+ _GROQ_DEFAULT_MODEL = "llama-3.1-70b-versatile"
129
+
130
+
131
+ @retry(
132
+ retry=retry_if_exception_type(httpx.HTTPStatusError),
133
+ wait=wait_exponential(multiplier=1, min=1, max=30),
134
+ stop=stop_after_attempt(3),
135
+ reraise=True,
136
+ )
137
+ async def _complete_groq(
138
+ messages: list[Message],
139
+ *,
140
+ model: str | None = None,
141
+ temperature: float = 0.0,
142
+ ) -> str:
143
+ """Call Groq's OpenAI-compatible chat completions API.
144
+
145
+ Args:
146
+ messages: Chat messages.
147
+ model: Model name (defaults to llama-3.1-70b-versatile).
148
+ temperature: Sampling temperature.
149
+
150
+ Returns:
151
+ The assistant's response text.
152
+
153
+ Raises:
154
+ ProviderError: If the response is malformed.
155
+ """
156
+ api_key = os.environ.get("GROQ_API_KEY", "")
157
+ if not api_key:
158
+ raise ProviderError("groq", "GROQ_API_KEY environment variable not set")
159
+
160
+ payload = {
161
+ "model": model or _GROQ_DEFAULT_MODEL,
162
+ "messages": [dict(m) for m in messages],
163
+ "temperature": temperature,
164
+ }
165
+
166
+ async with httpx.AsyncClient(timeout=60.0) as client:
167
+ response = await client.post(
168
+ _GROQ_URL,
169
+ json=payload,
170
+ headers={
171
+ "Authorization": f"Bearer {api_key}",
172
+ "Content-Type": "application/json",
173
+ },
174
+ )
175
+ response.raise_for_status()
176
+
177
+ data = response.json()
178
+ try:
179
+ return str(data["choices"][0]["message"]["content"])
180
+ except (KeyError, IndexError) as exc:
181
+ raise ProviderError("groq", f"Unexpected response format: {data}") from exc
182
+
183
+
184
+ # ── Gemini provider ──────────────────────────────────────────────────────
185
+
186
+ _GEMINI_URL_TEMPLATE = (
187
+ "https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent"
188
+ )
189
+ _GEMINI_DEFAULT_MODEL = "gemini-2.0-flash"
190
+
191
+
192
+ @retry(
193
+ retry=retry_if_exception_type(httpx.HTTPStatusError),
194
+ wait=wait_exponential(multiplier=1, min=1, max=30),
195
+ stop=stop_after_attempt(3),
196
+ reraise=True,
197
+ )
198
+ async def _complete_gemini(
199
+ messages: list[Message],
200
+ *,
201
+ model: str | None = None,
202
+ temperature: float = 0.0,
203
+ ) -> str:
204
+ """Call Google's Gemini generativeLanguage API.
205
+
206
+ Args:
207
+ messages: Chat messages (converted to Gemini's content format).
208
+ model: Model name (defaults to gemini-2.0-flash).
209
+ temperature: Sampling temperature.
210
+
211
+ Returns:
212
+ The assistant's response text.
213
+
214
+ Raises:
215
+ ProviderError: If the response is malformed.
216
+ """
217
+ api_key = os.environ.get("GEMINI_API_KEY", "")
218
+ if not api_key:
219
+ raise ProviderError("gemini", "GEMINI_API_KEY environment variable not set")
220
+
221
+ model_name = model or _GEMINI_DEFAULT_MODEL
222
+ url = _GEMINI_URL_TEMPLATE.format(model=model_name)
223
+
224
+ # Convert OpenAI-style messages to Gemini format.
225
+ contents: list[dict[str, object]] = []
226
+ system_instruction: str | None = None
227
+ for msg in messages:
228
+ if msg["role"] == "system":
229
+ system_instruction = msg["content"]
230
+ else:
231
+ role = "user" if msg["role"] == "user" else "model"
232
+ contents.append(
233
+ {
234
+ "role": role,
235
+ "parts": [{"text": msg["content"]}],
236
+ }
237
+ )
238
+
239
+ payload: dict[str, object] = {
240
+ "contents": contents,
241
+ "generationConfig": {"temperature": temperature},
242
+ }
243
+ if system_instruction:
244
+ payload["systemInstruction"] = {"parts": [{"text": system_instruction}]}
245
+
246
+ async with httpx.AsyncClient(timeout=60.0) as client:
247
+ response = await client.post(
248
+ url,
249
+ json=payload,
250
+ params={"key": api_key},
251
+ headers={"Content-Type": "application/json"},
252
+ )
253
+ response.raise_for_status()
254
+
255
+ data = response.json()
256
+ try:
257
+ return str(data["candidates"][0]["content"]["parts"][0]["text"])
258
+ except (KeyError, IndexError) as exc:
259
+ raise ProviderError("gemini", f"Unexpected response format: {data}") from exc
@@ -0,0 +1,183 @@
1
+ """In-episode hypothesis and issue tracker for the DataForge RL agent.
2
+
3
+ The scratchpad is a mutable, episode-scoped data structure that the agent
4
+ uses to record hypotheses, confirmed issues, and dead ends. The environment
5
+ exposes a compact summary of the scratchpad in each observation, enabling
6
+ the agent to reason about its investigation history without direct access
7
+ to the underlying data structure.
8
+
9
+ Example::
10
+
11
+ >>> from dataforge.agent.scratchpad import Scratchpad
12
+ >>> pad = Scratchpad()
13
+ >>> pad.add_hypothesis("Rating column has decimal shift", [5], ["rating"], "decimal_shift")
14
+ >>> pad.confirm_issue(5, "rating", "decimal_shift")
15
+ >>> pad.summary()
16
+ 'Hypotheses: 1 (0 pending). Confirmed: 1. Dead ends: 0.'
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ from dataclasses import dataclass, field
22
+
23
+ __all__ = [
24
+ "ConfirmedIssue",
25
+ "DeadEnd",
26
+ "HypothesisRecord",
27
+ "Scratchpad",
28
+ ]
29
+
30
+
31
+ @dataclass(frozen=True)
32
+ class HypothesisRecord:
33
+ """A recorded hypothesis about a data-quality root cause.
34
+
35
+ Args:
36
+ claim: Textual description of the hypothesis.
37
+ affected_rows: Row indices the hypothesis covers.
38
+ affected_columns: Column names the hypothesis covers.
39
+ root_cause_type: Detector-vocabulary root cause type.
40
+ confirmed: Whether the hypothesis was confirmed by ground truth.
41
+ """
42
+
43
+ claim: str
44
+ affected_rows: tuple[int, ...]
45
+ affected_columns: tuple[str, ...]
46
+ root_cause_type: str
47
+ confirmed: bool = False
48
+
49
+
50
+ @dataclass(frozen=True)
51
+ class ConfirmedIssue:
52
+ """A confirmed data-quality issue at a specific location.
53
+
54
+ Args:
55
+ row: Zero-indexed row number.
56
+ column: Column name.
57
+ issue_type: Issue type classification.
58
+ """
59
+
60
+ row: int
61
+ column: str
62
+ issue_type: str
63
+
64
+
65
+ @dataclass(frozen=True)
66
+ class DeadEnd:
67
+ """A recorded dead end — an investigation path that yielded nothing.
68
+
69
+ Args:
70
+ description: What was tried and why it failed.
71
+ step_number: Step at which the dead end was recorded.
72
+ """
73
+
74
+ description: str
75
+ step_number: int
76
+
77
+
78
+ @dataclass
79
+ class Scratchpad:
80
+ """Mutable in-episode tracker for hypotheses, confirmed issues, and dead ends.
81
+
82
+ Reset at the start of each episode. The ``summary()`` method produces a
83
+ compact string for inclusion in agent observations.
84
+
85
+ Example::
86
+
87
+ >>> pad = Scratchpad()
88
+ >>> pad.add_hypothesis("Decimal shift in rating", [5], ["rating"], "decimal_shift")
89
+ >>> len(pad.hypotheses)
90
+ 1
91
+ """
92
+
93
+ hypotheses: list[HypothesisRecord] = field(default_factory=list)
94
+ confirmed_issues: list[ConfirmedIssue] = field(default_factory=list)
95
+ dead_ends: list[DeadEnd] = field(default_factory=list)
96
+
97
+ def add_hypothesis(
98
+ self,
99
+ claim: str,
100
+ affected_rows: list[int],
101
+ affected_columns: list[str],
102
+ root_cause_type: str,
103
+ ) -> HypothesisRecord:
104
+ """Record a new hypothesis.
105
+
106
+ Args:
107
+ claim: Textual description of the hypothesis.
108
+ affected_rows: Row indices the hypothesis covers.
109
+ affected_columns: Column names the hypothesis covers.
110
+ root_cause_type: Detector-vocabulary root cause type.
111
+
112
+ Returns:
113
+ The recorded hypothesis.
114
+ """
115
+ record = HypothesisRecord(
116
+ claim=claim,
117
+ affected_rows=tuple(affected_rows),
118
+ affected_columns=tuple(affected_columns),
119
+ root_cause_type=root_cause_type,
120
+ )
121
+ self.hypotheses.append(record)
122
+ return record
123
+
124
+ def confirm_hypothesis(self, index: int) -> None:
125
+ """Mark a hypothesis as confirmed.
126
+
127
+ Args:
128
+ index: Index into the ``hypotheses`` list.
129
+
130
+ Raises:
131
+ IndexError: If the index is out of range.
132
+ """
133
+ old = self.hypotheses[index]
134
+ self.hypotheses[index] = HypothesisRecord(
135
+ claim=old.claim,
136
+ affected_rows=old.affected_rows,
137
+ affected_columns=old.affected_columns,
138
+ root_cause_type=old.root_cause_type,
139
+ confirmed=True,
140
+ )
141
+
142
+ def confirm_issue(self, row: int, column: str, issue_type: str) -> None:
143
+ """Record a confirmed issue.
144
+
145
+ Args:
146
+ row: Zero-indexed row number.
147
+ column: Column name.
148
+ issue_type: Issue type classification.
149
+ """
150
+ self.confirmed_issues.append(ConfirmedIssue(row=row, column=column, issue_type=issue_type))
151
+
152
+ def add_dead_end(self, description: str, step_number: int) -> None:
153
+ """Record a dead end.
154
+
155
+ Args:
156
+ description: What was tried and why it failed.
157
+ step_number: Step at which the dead end was recorded.
158
+ """
159
+ self.dead_ends.append(DeadEnd(description=description, step_number=step_number))
160
+
161
+ def reset(self) -> None:
162
+ """Clear all tracked state for a new episode."""
163
+ self.hypotheses.clear()
164
+ self.confirmed_issues.clear()
165
+ self.dead_ends.clear()
166
+
167
+ def summary(self) -> str:
168
+ """Produce a compact summary string for observation embedding.
169
+
170
+ Returns:
171
+ A one-line summary of scratchpad state.
172
+
173
+ Example::
174
+
175
+ >>> Scratchpad().summary()
176
+ 'Hypotheses: 0 (0 pending). Confirmed: 0. Dead ends: 0.'
177
+ """
178
+ pending = sum(1 for h in self.hypotheses if not h.confirmed)
179
+ return (
180
+ f"Hypotheses: {len(self.hypotheses)} ({pending} pending). "
181
+ f"Confirmed: {len(self.confirmed_issues)}. "
182
+ f"Dead ends: {len(self.dead_ends)}."
183
+ )