dslighting 1.3.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dsat/__init__.py +3 -0
- dsat/benchmark/__init__.py +1 -0
- dsat/benchmark/benchmark.py +168 -0
- dsat/benchmark/datasci.py +291 -0
- dsat/benchmark/mle.py +777 -0
- dsat/benchmark/sciencebench.py +304 -0
- dsat/common/__init__.py +0 -0
- dsat/common/constants.py +11 -0
- dsat/common/exceptions.py +48 -0
- dsat/common/typing.py +19 -0
- dsat/config.py +79 -0
- dsat/models/__init__.py +3 -0
- dsat/models/candidates.py +16 -0
- dsat/models/formats.py +52 -0
- dsat/models/task.py +64 -0
- dsat/operators/__init__.py +0 -0
- dsat/operators/aflow_ops.py +90 -0
- dsat/operators/autokaggle_ops.py +170 -0
- dsat/operators/automind_ops.py +38 -0
- dsat/operators/base.py +22 -0
- dsat/operators/code.py +45 -0
- dsat/operators/dsagent_ops.py +123 -0
- dsat/operators/llm_basic.py +84 -0
- dsat/prompts/__init__.py +0 -0
- dsat/prompts/aflow_prompt.py +76 -0
- dsat/prompts/aide_prompt.py +52 -0
- dsat/prompts/autokaggle_prompt.py +290 -0
- dsat/prompts/automind_prompt.py +29 -0
- dsat/prompts/common.py +51 -0
- dsat/prompts/data_interpreter_prompt.py +82 -0
- dsat/prompts/dsagent_prompt.py +88 -0
- dsat/runner.py +554 -0
- dsat/services/__init__.py +0 -0
- dsat/services/data_analyzer.py +387 -0
- dsat/services/llm.py +486 -0
- dsat/services/llm_single.py +421 -0
- dsat/services/sandbox.py +386 -0
- dsat/services/states/__init__.py +0 -0
- dsat/services/states/autokaggle_state.py +43 -0
- dsat/services/states/base.py +14 -0
- dsat/services/states/dsa_log.py +13 -0
- dsat/services/states/experience.py +237 -0
- dsat/services/states/journal.py +153 -0
- dsat/services/states/operator_library.py +290 -0
- dsat/services/vdb.py +76 -0
- dsat/services/workspace.py +178 -0
- dsat/tasks/__init__.py +3 -0
- dsat/tasks/handlers.py +376 -0
- dsat/templates/open_ended/grade_template.py +107 -0
- dsat/tools/__init__.py +4 -0
- dsat/utils/__init__.py +0 -0
- dsat/utils/context.py +172 -0
- dsat/utils/dynamic_import.py +71 -0
- dsat/utils/parsing.py +33 -0
- dsat/workflows/__init__.py +12 -0
- dsat/workflows/base.py +53 -0
- dsat/workflows/factory.py +439 -0
- dsat/workflows/manual/__init__.py +0 -0
- dsat/workflows/manual/autokaggle_workflow.py +148 -0
- dsat/workflows/manual/data_interpreter_workflow.py +153 -0
- dsat/workflows/manual/deepanalyze_workflow.py +484 -0
- dsat/workflows/manual/dsagent_workflow.py +76 -0
- dsat/workflows/search/__init__.py +0 -0
- dsat/workflows/search/aflow_workflow.py +344 -0
- dsat/workflows/search/aide_workflow.py +283 -0
- dsat/workflows/search/automind_workflow.py +237 -0
- dsat/workflows/templates/__init__.py +0 -0
- dsat/workflows/templates/basic_kaggle_loop.py +71 -0
- dslighting/__init__.py +170 -0
- dslighting/core/__init__.py +13 -0
- dslighting/core/agent.py +646 -0
- dslighting/core/config_builder.py +318 -0
- dslighting/core/data_loader.py +422 -0
- dslighting/core/task_detector.py +422 -0
- dslighting/utils/__init__.py +19 -0
- dslighting/utils/defaults.py +151 -0
- dslighting-1.3.9.dist-info/METADATA +554 -0
- dslighting-1.3.9.dist-info/RECORD +80 -0
- dslighting-1.3.9.dist-info/WHEEL +5 -0
- dslighting-1.3.9.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
# dsat/benchmark/sciencebench.py
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import time
|
|
5
|
+
import uuid
|
|
6
|
+
import yaml
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
10
|
+
|
|
11
|
+
from dsat.benchmark.benchmark import BaseBenchmark
|
|
12
|
+
from dsat.models.task import TaskDefinition
|
|
13
|
+
|
|
14
|
+
from sciencebench.data import is_dataset_prepared
|
|
15
|
+
from sciencebench.grade import grade_submission
|
|
16
|
+
from sciencebench.grade_helpers import CompetitionReport
|
|
17
|
+
from sciencebench.registry import Competition, Registry
|
|
18
|
+
from sciencebench.registry import registry as DEFAULT_SCIENCEBENCH_REGISTRY
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ScienceBenchBenchmark(BaseBenchmark):
|
|
24
|
+
"""
|
|
25
|
+
Benchmark class to integrate ScienceBench competitions into the DSAT framework.
|
|
26
|
+
|
|
27
|
+
Expected `data_dir` layout (same as MLEBench):
|
|
28
|
+
<data_dir>/<competition_id>/prepared/public/...
|
|
29
|
+
<data_dir>/<competition_id>/prepared/private/...
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
name: str,
|
|
35
|
+
file_path: Optional[str],
|
|
36
|
+
log_path: str,
|
|
37
|
+
data_dir: Optional[str] = None,
|
|
38
|
+
competitions: Optional[List[str]] = None,
|
|
39
|
+
):
|
|
40
|
+
self.data_dir = Path(data_dir) if data_dir else DEFAULT_SCIENCEBENCH_REGISTRY.get_data_dir()
|
|
41
|
+
self.registry: Registry = DEFAULT_SCIENCEBENCH_REGISTRY.set_data_dir(self.data_dir)
|
|
42
|
+
|
|
43
|
+
self.config = self._load_config()
|
|
44
|
+
if competitions:
|
|
45
|
+
self.config["competitions"] = list(competitions)
|
|
46
|
+
|
|
47
|
+
super().__init__(name, file_path, log_path)
|
|
48
|
+
Path(self.log_path).mkdir(parents=True, exist_ok=True)
|
|
49
|
+
|
|
50
|
+
self.problems = self._load_problems()
|
|
51
|
+
logger.info("ScienceBenchBenchmark initialized with data_dir: %s", self.data_dir)
|
|
52
|
+
|
|
53
|
+
def _load_config(self) -> Dict[str, Any]:
|
|
54
|
+
try:
|
|
55
|
+
framework_dir = Path(__file__).parent.parent.parent
|
|
56
|
+
config_path = framework_dir / "config.yaml"
|
|
57
|
+
if not config_path.exists():
|
|
58
|
+
logger.warning("Config file not found at %s, using default configuration", config_path)
|
|
59
|
+
return {"competitions": []}
|
|
60
|
+
|
|
61
|
+
with open(config_path, "r", encoding="utf-8") as f:
|
|
62
|
+
config = yaml.safe_load(f) or {}
|
|
63
|
+
|
|
64
|
+
competitions = config.get("sciencebench_competitions", [])
|
|
65
|
+
if competitions:
|
|
66
|
+
return {"competitions": list(competitions)}
|
|
67
|
+
return {"competitions": []}
|
|
68
|
+
except Exception as exc:
|
|
69
|
+
logger.error("Error loading config file for ScienceBench: %s", exc)
|
|
70
|
+
return {"competitions": []}
|
|
71
|
+
|
|
72
|
+
def _load_problems(self) -> List[Dict[str, Any]]:
|
|
73
|
+
logger.info("Discovering prepared ScienceBench competitions in %s...", self.data_dir)
|
|
74
|
+
|
|
75
|
+
if not self.data_dir.exists():
|
|
76
|
+
raise FileNotFoundError(
|
|
77
|
+
f"ScienceBench data directory not found: {self.data_dir}. "
|
|
78
|
+
"Please provide a valid path via --data-dir."
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
competition_ids = self.config.get("competitions", [])
|
|
82
|
+
if not competition_ids:
|
|
83
|
+
logger.warning("No sciencebench competitions provided; pass --task-id <competition_id>.")
|
|
84
|
+
return []
|
|
85
|
+
|
|
86
|
+
problems: List[Dict[str, Any]] = []
|
|
87
|
+
for competition_id in competition_ids:
|
|
88
|
+
try:
|
|
89
|
+
competition = self.registry.get_competition(competition_id)
|
|
90
|
+
if is_dataset_prepared(competition, grading_only=False):
|
|
91
|
+
problems.append({"competition_id": competition_id})
|
|
92
|
+
else:
|
|
93
|
+
logger.warning(
|
|
94
|
+
"Skipping competition '%s' as its dataset is not fully prepared under '%s'.",
|
|
95
|
+
competition_id,
|
|
96
|
+
self.data_dir,
|
|
97
|
+
)
|
|
98
|
+
except Exception as exc:
|
|
99
|
+
logger.warning("Error loading competition '%s': %s", competition_id, exc)
|
|
100
|
+
|
|
101
|
+
if not problems:
|
|
102
|
+
logger.error(
|
|
103
|
+
"No prepared ScienceBench competitions found in %s. "
|
|
104
|
+
"Please prepare the dataset or point --data-dir to a prepared competitions folder.",
|
|
105
|
+
self.data_dir,
|
|
106
|
+
)
|
|
107
|
+
return problems
|
|
108
|
+
|
|
109
|
+
def set_mode(self, mode: str):
|
|
110
|
+
logger.info("Setting ScienceBenchBenchmark mode to '%s'", mode)
|
|
111
|
+
self.registry.set_mode(mode)
|
|
112
|
+
|
|
113
|
+
async def grade(self, submission_path: Path) -> float:
|
|
114
|
+
"""Grades a submission and returns a numerical fitness score for AFlow."""
|
|
115
|
+
if not self.problems:
|
|
116
|
+
return 0.0
|
|
117
|
+
|
|
118
|
+
competition_id = self.problems[0].get("competition_id")
|
|
119
|
+
if not competition_id:
|
|
120
|
+
return 0.0
|
|
121
|
+
|
|
122
|
+
try:
|
|
123
|
+
competition = self.registry.get_competition(competition_id)
|
|
124
|
+
if not submission_path.exists():
|
|
125
|
+
logger.warning("Grading failed: submission file not found at %s", submission_path)
|
|
126
|
+
return 0.0
|
|
127
|
+
|
|
128
|
+
report = grade_submission(submission_path, competition)
|
|
129
|
+
score = report.score if report.score is not None else 0.0
|
|
130
|
+
|
|
131
|
+
is_lower_better = False
|
|
132
|
+
try:
|
|
133
|
+
if competition.leaderboard.exists():
|
|
134
|
+
import pandas as pd
|
|
135
|
+
|
|
136
|
+
leaderboard = pd.read_csv(competition.leaderboard)
|
|
137
|
+
if "score" in leaderboard.columns and len(leaderboard.index) > 1:
|
|
138
|
+
is_lower_better = competition.grader.is_lower_better(leaderboard)
|
|
139
|
+
except Exception as exc:
|
|
140
|
+
logger.warning(
|
|
141
|
+
"Could not determine score direction for %s: %s",
|
|
142
|
+
competition_id,
|
|
143
|
+
exc,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
if is_lower_better:
|
|
147
|
+
return 1.0 / (1.0 + score) if score > 0 else 1.0
|
|
148
|
+
return float(score)
|
|
149
|
+
|
|
150
|
+
except Exception as exc:
|
|
151
|
+
logger.error("Error during grading for %s: %s", competition_id, exc)
|
|
152
|
+
return 0.0
|
|
153
|
+
|
|
154
|
+
def get_result_columns(self) -> List[str]:
|
|
155
|
+
return [
|
|
156
|
+
"competition_id",
|
|
157
|
+
"submission_path",
|
|
158
|
+
"answers_path",
|
|
159
|
+
"score",
|
|
160
|
+
"cost",
|
|
161
|
+
"running_time",
|
|
162
|
+
"input_tokens",
|
|
163
|
+
"output_tokens",
|
|
164
|
+
"total_tokens",
|
|
165
|
+
"gold_medal",
|
|
166
|
+
"silver_medal",
|
|
167
|
+
"bronze_medal",
|
|
168
|
+
"above_median",
|
|
169
|
+
"submission_exists",
|
|
170
|
+
"valid_submission",
|
|
171
|
+
"error_message",
|
|
172
|
+
]
|
|
173
|
+
|
|
174
|
+
def _create_error_report(self, competition_id: str, submission_path: Path, error_msg: str) -> CompetitionReport:
|
|
175
|
+
base_data = {
|
|
176
|
+
"competition_id": competition_id,
|
|
177
|
+
"score": None,
|
|
178
|
+
"gold_threshold": float("nan"),
|
|
179
|
+
"silver_threshold": float("nan"),
|
|
180
|
+
"bronze_threshold": float("nan"),
|
|
181
|
+
"median_threshold": float("nan"),
|
|
182
|
+
"any_medal": False,
|
|
183
|
+
"gold_medal": False,
|
|
184
|
+
"silver_medal": False,
|
|
185
|
+
"bronze_medal": False,
|
|
186
|
+
"above_median": False,
|
|
187
|
+
"submission_exists": submission_path.exists(),
|
|
188
|
+
"valid_submission": False,
|
|
189
|
+
"is_lower_better": False,
|
|
190
|
+
"created_at": datetime.now().isoformat(),
|
|
191
|
+
"submission_path": str(submission_path),
|
|
192
|
+
}
|
|
193
|
+
report = CompetitionReport.from_dict(base_data)
|
|
194
|
+
logger.error("Error for %s: %s", competition_id, error_msg)
|
|
195
|
+
return report
|
|
196
|
+
|
|
197
|
+
async def evaluate_problem(
|
|
198
|
+
self, problem: dict, eval_fn: Callable
|
|
199
|
+
) -> Tuple[Tuple, CompetitionReport, Optional[str]]:
|
|
200
|
+
competition_id = problem.get("competition_id")
|
|
201
|
+
if not competition_id:
|
|
202
|
+
raise ValueError("Problem data must contain 'competition_id'")
|
|
203
|
+
|
|
204
|
+
unique_id = uuid.uuid4().hex[:6]
|
|
205
|
+
competition: Optional[Competition] = None
|
|
206
|
+
|
|
207
|
+
# Default output: a unique CSV name under log_path/.
|
|
208
|
+
# If the competition declares a specific submission format in config (e.g. `.npy`),
|
|
209
|
+
# keep the same naming scheme but switch the file extension accordingly (resolved after
|
|
210
|
+
# loading the competition config).
|
|
211
|
+
output_filename = f"submission_{competition_id}_{unique_id}.csv"
|
|
212
|
+
output_submission_path = (Path(self.log_path) / output_filename).absolute()
|
|
213
|
+
|
|
214
|
+
start_time = time.perf_counter()
|
|
215
|
+
|
|
216
|
+
cost = 0.0
|
|
217
|
+
input_tokens = 0
|
|
218
|
+
output_tokens = 0
|
|
219
|
+
total_tokens = 0
|
|
220
|
+
report: Optional[CompetitionReport] = None
|
|
221
|
+
error_message: Optional[str] = None
|
|
222
|
+
|
|
223
|
+
try:
|
|
224
|
+
competition = self.registry.get_competition(competition_id)
|
|
225
|
+
|
|
226
|
+
submission_filename = getattr(competition, "submission_filename", None)
|
|
227
|
+
submission_suffix = Path(submission_filename).suffix if submission_filename else ".csv"
|
|
228
|
+
if submission_suffix and not submission_suffix.startswith("."):
|
|
229
|
+
submission_suffix = f".{submission_suffix}"
|
|
230
|
+
output_filename = f"submission_{competition_id}_{unique_id}{submission_suffix}"
|
|
231
|
+
output_submission_path = (Path(self.log_path) / output_filename).absolute()
|
|
232
|
+
|
|
233
|
+
if not is_dataset_prepared(competition, grading_only=False):
|
|
234
|
+
raise ValueError(f"Dataset for '{competition_id}' not prepared in '{self.data_dir}'.")
|
|
235
|
+
|
|
236
|
+
task = TaskDefinition(
|
|
237
|
+
task_id=competition_id,
|
|
238
|
+
task_type="kaggle",
|
|
239
|
+
payload={
|
|
240
|
+
"description": competition.description,
|
|
241
|
+
"public_data_dir": str(competition.public_dir.absolute()),
|
|
242
|
+
"output_submission_path": str(output_submission_path.absolute()),
|
|
243
|
+
},
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
result, cost, usage_summary = await eval_fn(task)
|
|
247
|
+
|
|
248
|
+
input_tokens = usage_summary.get("prompt_tokens", 0)
|
|
249
|
+
output_tokens = usage_summary.get("completion_tokens", 0)
|
|
250
|
+
total_tokens = usage_summary.get("total_tokens", 0)
|
|
251
|
+
|
|
252
|
+
if isinstance(result, Path):
|
|
253
|
+
report = grade_submission(result, competition)
|
|
254
|
+
elif isinstance(result, str) and result.startswith("[ERROR]"):
|
|
255
|
+
error_message = f"DSAT workflow failed: {result}"
|
|
256
|
+
report = self._create_error_report(competition_id, output_submission_path, error_message)
|
|
257
|
+
else:
|
|
258
|
+
error_message = f"Unexpected result type from eval_fn: {type(result).__name__}"
|
|
259
|
+
report = self._create_error_report(competition_id, output_submission_path, error_message)
|
|
260
|
+
|
|
261
|
+
except Exception as exc:
|
|
262
|
+
error_message = f"Error during ScienceBench evaluation of {competition_id}: {exc}"
|
|
263
|
+
logger.error(error_message, exc_info=True)
|
|
264
|
+
report = self._create_error_report(competition_id, output_submission_path, error_message)
|
|
265
|
+
|
|
266
|
+
if report is None:
|
|
267
|
+
final_error = error_message or "Unknown error: report is None"
|
|
268
|
+
report = self._create_error_report(competition_id, output_submission_path, final_error)
|
|
269
|
+
error_message = final_error
|
|
270
|
+
|
|
271
|
+
if not report.valid_submission:
|
|
272
|
+
answers_path_str = str(getattr(competition, "answers", "N/A")) if competition else "N/A"
|
|
273
|
+
self.log_mismatch(
|
|
274
|
+
problem=competition_id,
|
|
275
|
+
expected_output=answers_path_str,
|
|
276
|
+
prediction=f"File: {output_submission_path}, Exists: {report.submission_exists}, Valid: {report.valid_submission}",
|
|
277
|
+
extracted_output=report.score,
|
|
278
|
+
extract_answer_code=error_message or "Grading function failed or file invalid/missing",
|
|
279
|
+
)
|
|
280
|
+
if not error_message:
|
|
281
|
+
error_message = "Submission invalid or missing."
|
|
282
|
+
|
|
283
|
+
running_time = round(time.perf_counter() - start_time, 4)
|
|
284
|
+
answers_path_str = str(getattr(competition, "answers", "N/A")) if competition else "N/A"
|
|
285
|
+
|
|
286
|
+
csv_tuple = (
|
|
287
|
+
report.competition_id,
|
|
288
|
+
str(report.submission_path),
|
|
289
|
+
answers_path_str,
|
|
290
|
+
report.score,
|
|
291
|
+
cost,
|
|
292
|
+
running_time,
|
|
293
|
+
input_tokens,
|
|
294
|
+
output_tokens,
|
|
295
|
+
total_tokens,
|
|
296
|
+
bool(getattr(report, "gold_medal", False)),
|
|
297
|
+
bool(getattr(report, "silver_medal", False)),
|
|
298
|
+
bool(getattr(report, "bronze_medal", False)),
|
|
299
|
+
bool(getattr(report, "above_median", False)),
|
|
300
|
+
report.submission_exists,
|
|
301
|
+
report.valid_submission,
|
|
302
|
+
error_message,
|
|
303
|
+
)
|
|
304
|
+
return csv_tuple, report, error_message
|
dsat/common/__init__.py
ADDED
|
File without changes
|
dsat/common/constants.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
# Default timeout values (in seconds)
|
|
2
|
+
DEFAULT_SANDBOX_TIMEOUT = 24 * 3600
|
|
3
|
+
DEFAULT_LLM_TIMEOUT = 300
|
|
4
|
+
|
|
5
|
+
# Default configuration values
|
|
6
|
+
DEFAULT_LLM_MODEL = "gpt-4o-mini"
|
|
7
|
+
DEFAULT_TEMPERATURE = 0.7
|
|
8
|
+
|
|
9
|
+
# Path constants
|
|
10
|
+
DEFAULT_WORKSPACE_DIR = "runs"
|
|
11
|
+
DEFAULT_BENCHMARK_DIR = "benchmarks"
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
class DSATError(Exception):
|
|
2
|
+
"""Base exception for all DSAT-related errors."""
|
|
3
|
+
def __init__(self, message=None, *args, **kwargs):
|
|
4
|
+
if message is None:
|
|
5
|
+
message = "An error occurred in the DSAT framework."
|
|
6
|
+
super().__init__(message, *args)
|
|
7
|
+
self.message = message
|
|
8
|
+
self.details = kwargs.get("details", None)
|
|
9
|
+
|
|
10
|
+
def __str__(self):
|
|
11
|
+
if self.details:
|
|
12
|
+
return f"{self.message} Details: {self.details}"
|
|
13
|
+
return self.message
|
|
14
|
+
|
|
15
|
+
class WorkspaceError(DSATError):
|
|
16
|
+
"""Exception raised for workspace-related errors."""
|
|
17
|
+
def __init__(self, message=None, *args, **kwargs):
|
|
18
|
+
if message is None:
|
|
19
|
+
message = "A workspace-related error occurred."
|
|
20
|
+
super().__init__(message, *args, **kwargs)
|
|
21
|
+
|
|
22
|
+
class SandboxError(DSATError):
|
|
23
|
+
"""Exception raised for sandbox execution errors."""
|
|
24
|
+
def __init__(self, message=None, *args, **kwargs):
|
|
25
|
+
if message is None:
|
|
26
|
+
message = "A sandbox execution error occurred."
|
|
27
|
+
super().__init__(message, *args, **kwargs)
|
|
28
|
+
|
|
29
|
+
class LLMError(DSATError):
|
|
30
|
+
"""Exception raised for LLM service errors."""
|
|
31
|
+
def __init__(self, message=None, *args, **kwargs):
|
|
32
|
+
if message is None:
|
|
33
|
+
message = "An LLM service error occurred."
|
|
34
|
+
super().__init__(message, *args, **kwargs)
|
|
35
|
+
|
|
36
|
+
class DynamicImportError(DSATError):
|
|
37
|
+
"""Exception raised when dynamically importing code fails (e.g., syntax errors in generated workflows)."""
|
|
38
|
+
def __init__(self, message=None, *args, **kwargs):
|
|
39
|
+
if message is None:
|
|
40
|
+
message = "A dynamic code import error occurred."
|
|
41
|
+
super().__init__(message, *args, **kwargs)
|
|
42
|
+
|
|
43
|
+
class WorkflowError(DSATError):
|
|
44
|
+
"""Exception raised for workflow execution errors."""
|
|
45
|
+
def __init__(self, message=None, *args, **kwargs):
|
|
46
|
+
if message is None:
|
|
47
|
+
message = "A workflow execution error occurred."
|
|
48
|
+
super().__init__(message, *args, **kwargs)
|
dsat/common/typing.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# dsat/common/typing.py
|
|
2
|
+
from pydantic import BaseModel, Field
|
|
3
|
+
from typing import List, Optional, Dict, Any
|
|
4
|
+
|
|
5
|
+
class ExecutionResult(BaseModel):
|
|
6
|
+
"""
|
|
7
|
+
Standardized result of executing a code snippet in any sandbox mode.
|
|
8
|
+
"""
|
|
9
|
+
success: bool = Field(description="True if the execution completed without errors, False otherwise.")
|
|
10
|
+
stdout: str = Field(default="", description="The captured standard output stream.")
|
|
11
|
+
stderr: str = Field(default="", description="The captured standard error stream.")
|
|
12
|
+
exc_type: Optional[str] = Field(default=None, description="The type of exception if one was raised (e.g., 'TimeoutError', 'ValueError').")
|
|
13
|
+
# For notebook mode, this can hold base64 encoded images or other rich outputs.
|
|
14
|
+
artifacts: List[str] = Field(default_factory=list, description="A list of generated artifacts, like image filenames.")
|
|
15
|
+
metadata: Dict[str, Any] = Field(default_factory=dict, description="Arbitrary execution metadata (timestamps, paths, etc.).")
|
|
16
|
+
|
|
17
|
+
class Config:
|
|
18
|
+
"""Pydantic configuration."""
|
|
19
|
+
extra = 'forbid'
|
dsat/config.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
# dsat/config.py
|
|
2
|
+
|
|
3
|
+
from typing import Dict, Any, Optional, List
|
|
4
|
+
from pydantic import BaseModel, Field
|
|
5
|
+
|
|
6
|
+
class LLMConfig(BaseModel):
|
|
7
|
+
"""LLM service settings."""
|
|
8
|
+
model: str = "gpt-4o-mini"
|
|
9
|
+
temperature: float = 0.7
|
|
10
|
+
api_key: Optional[str] = Field(None, description="API key, defaults to API_KEY env var if not set.")
|
|
11
|
+
api_base: Optional[str] = "https://api.openai.com/v1"
|
|
12
|
+
provider: Optional[str] = Field(None, description="Optional LiteLLM provider alias, e.g. 'siliconflow'.")
|
|
13
|
+
max_retries: int = 3
|
|
14
|
+
|
|
15
|
+
class SandboxConfig(BaseModel):
|
|
16
|
+
"""Code execution sandbox settings."""
|
|
17
|
+
timeout: int = 6 * 3600
|
|
18
|
+
|
|
19
|
+
class TaskConfig(BaseModel):
|
|
20
|
+
"""Defines the problem to be solved."""
|
|
21
|
+
goal: str = "Solve the given data science task."
|
|
22
|
+
eval_metric: Optional[str] = None
|
|
23
|
+
data_dir: Optional[str] = None
|
|
24
|
+
|
|
25
|
+
class RunConfig(BaseModel):
|
|
26
|
+
"""Settings for a specific execution run."""
|
|
27
|
+
name: str = "dsat_run"
|
|
28
|
+
total_steps: int = 4
|
|
29
|
+
keep_all_workspaces: bool = Field(False, description="If True, do not delete any workspace after execution.")
|
|
30
|
+
keep_workspace_on_failure: bool = Field(True, description="If True, keep the workspace only if the task execution fails.")
|
|
31
|
+
parameters: Dict[str, Any] = Field(default_factory=dict, description="Arbitrary runtime parameters saved for telemetry.")
|
|
32
|
+
|
|
33
|
+
class AgentSearchConfig(BaseModel):
|
|
34
|
+
"""Parameters for Paradigm 2 (AIDE/AutoMind) search."""
|
|
35
|
+
num_drafts: int = 5
|
|
36
|
+
debug_prob: float = 0.8
|
|
37
|
+
max_iterations: int = 5
|
|
38
|
+
max_debug_depth: int = 10
|
|
39
|
+
|
|
40
|
+
class AutoKaggleConfig(BaseModel):
|
|
41
|
+
"""Parameters for the AutoKaggle SOP workflow."""
|
|
42
|
+
max_attempts_per_phase: int = 10
|
|
43
|
+
success_threshold: float = 3.0
|
|
44
|
+
|
|
45
|
+
class AgentConfig(BaseModel):
|
|
46
|
+
"""Configuration for a specific agent's behavior."""
|
|
47
|
+
search: AgentSearchConfig = Field(default_factory=AgentSearchConfig)
|
|
48
|
+
max_retries: int = 10
|
|
49
|
+
autokaggle: AutoKaggleConfig = Field(default_factory=AutoKaggleConfig)
|
|
50
|
+
|
|
51
|
+
class OptimizerConfig(BaseModel):
|
|
52
|
+
"""Parameters for Paradigm 3 (AFlow) meta-optimization."""
|
|
53
|
+
max_rounds: int = 10
|
|
54
|
+
validation_runs_per_candidate: int = 1
|
|
55
|
+
top_k_selection: int = 2
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class WorkflowConfig(BaseModel):
|
|
59
|
+
"""Specifies which workflow to run and its parameters."""
|
|
60
|
+
name: str
|
|
61
|
+
params: Dict[str, Any] = Field(default_factory=dict)
|
|
62
|
+
# This field is populated at runtime by main.py, not from the YAML file.
|
|
63
|
+
class_ref: Optional[Any] = Field(None, exclude=True)
|
|
64
|
+
|
|
65
|
+
class DSATConfig(BaseModel):
|
|
66
|
+
"""The root configuration model for the entire DSAT application."""
|
|
67
|
+
run: RunConfig = Field(default_factory=RunConfig)
|
|
68
|
+
task: TaskConfig = Field(default_factory=TaskConfig)
|
|
69
|
+
llm: LLMConfig = Field(default_factory=LLMConfig)
|
|
70
|
+
sandbox: SandboxConfig = Field(default_factory=SandboxConfig)
|
|
71
|
+
|
|
72
|
+
# Paradigm-specific configurations
|
|
73
|
+
workflow: Optional[WorkflowConfig] = None
|
|
74
|
+
agent: AgentConfig = Field(default_factory=AgentConfig)
|
|
75
|
+
optimizer: Optional[OptimizerConfig] = None
|
|
76
|
+
|
|
77
|
+
class Config:
|
|
78
|
+
"""Pydantic configuration."""
|
|
79
|
+
extra = 'forbid'
|
dsat/models/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from pydantic import BaseModel, Field
|
|
2
|
+
from typing import Optional, List
|
|
3
|
+
|
|
4
|
+
class WorkflowCandidate(BaseModel):
|
|
5
|
+
"""
|
|
6
|
+
Represents a proposed orchestration workflow being evaluated by a
|
|
7
|
+
Paradigm 3 (AFlow-style) meta-optimization driver.
|
|
8
|
+
"""
|
|
9
|
+
workflow_code: str = Field(description="The code for the proposed orchestration workflow.")
|
|
10
|
+
fitness: Optional[float] = Field(default=None, description="The fitness score of this workflow.")
|
|
11
|
+
lineage: List[str] = Field(default_factory=list, description="IDs of parent workflow candidates.")
|
|
12
|
+
round_num: Optional[int] = Field(default=None, description="The optimization round this candidate was generated in.")
|
|
13
|
+
|
|
14
|
+
class Config:
|
|
15
|
+
"""Pydantic configuration."""
|
|
16
|
+
extra = 'forbid'
|
dsat/models/formats.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from typing import List, Optional, Dict, Any
|
|
2
|
+
from pydantic import BaseModel, Field
|
|
3
|
+
|
|
4
|
+
# --- Pydantic Models for Operator Outputs ---
|
|
5
|
+
# These models are now used with LLMService.call_with_json
|
|
6
|
+
|
|
7
|
+
class ReviewResult(BaseModel):
|
|
8
|
+
"""Structured result from the ReviewOperator."""
|
|
9
|
+
is_buggy: bool = Field(..., description="True if the execution failed or has a clear bug, otherwise False.")
|
|
10
|
+
summary: str = Field(..., description="If buggy, a proposal to fix the bug. If successful, a summary of empirical findings.")
|
|
11
|
+
metric_value: Optional[float] = Field(..., description="A quantitative measure of success based on the task requirements. Null if the task does not define a quantitative metric or if it cannot be determined.")
|
|
12
|
+
lower_is_better: bool = Field(default=True, description="True if a lower metric is better (e.g., RMSE), False if higher is better (e.g., Accuracy).")
|
|
13
|
+
|
|
14
|
+
class Task(BaseModel):
|
|
15
|
+
"""A single task within a larger plan."""
|
|
16
|
+
task_id: str = Field(..., description="Unique identifier for a task, e.g., '1', '2.1'.")
|
|
17
|
+
instruction: str = Field(..., description="Clear, concise instruction for what to do in this task.")
|
|
18
|
+
dependent_task_ids: List[str] = Field(default_factory=list, description="List of task_ids this task depends on.")
|
|
19
|
+
|
|
20
|
+
class Plan(BaseModel):
|
|
21
|
+
"""A structured plan consisting of multiple tasks."""
|
|
22
|
+
tasks: List[Task] = Field(..., description="A list of tasks to achieve the overall goal.")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ComplexityScore(BaseModel):
|
|
26
|
+
"""Structured result from the ComplexityScorerOperator."""
|
|
27
|
+
complexity: int = Field(..., description="An integer score from 1 to 5 representing the plan's complexity.", ge=1, le=5)
|
|
28
|
+
justification: str = Field(..., description="A brief justification for the assigned score.")
|
|
29
|
+
|
|
30
|
+
class DecomposedPlan(BaseModel):
|
|
31
|
+
"""A structured plan decomposed into sequential tasks, used for stepwise execution."""
|
|
32
|
+
tasks: List[Task] = Field(..., description="A list of sequential tasks to achieve the overall goal.")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class FileArtifact(BaseModel):
|
|
37
|
+
filename: str = Field(description="The name of the file, e.g., 'input_data.dat', 'image_001.png', or 'results.json'.")
|
|
38
|
+
description: str = Field(description="A brief description of the file's purpose and content.")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class TaskContract(BaseModel):
|
|
42
|
+
task_goal: str = Field(description="A clear, one-sentence summary of the main objective.")
|
|
43
|
+
task_type: str = Field(description="A high-level categorization derived from the task goal (e.g., 'Classification', 'Generation', 'Simulation', 'Data Processing').")
|
|
44
|
+
input_files: List[FileArtifact] = Field(description="A list of all input files required to complete the task.")
|
|
45
|
+
output_files: List[FileArtifact] = Field(description="A list of all final output files that must be generated.")
|
|
46
|
+
evaluation_metric: str = Field(description="The primary metric for evaluating the success of the output, e.g., 'Accuracy Score', 'ROUGE Score', 'Visual Appeal'.")
|
|
47
|
+
|
|
48
|
+
class StepPlan(BaseModel):
|
|
49
|
+
"""A detailed plan for a single phase and the artifacts it's expected to create."""
|
|
50
|
+
plan: str = Field(description="A detailed, step-by-step natural language plan for the developer.")
|
|
51
|
+
input_artifacts: List[str] = Field(default_factory=list, description="A list of artifact filenames from the global state that this plan will use as input (e.g., ['train_preprocessed.csv', 'scaler.joblib']).")
|
|
52
|
+
output_files: List[str] = Field(description="A list of filenames that the plan is expected to generate (e.g., ['model.pkl', 'submission.csv']).")
|
dsat/models/task.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from pydantic import BaseModel, Field
|
|
2
|
+
from typing import Literal, Dict, Any
|
|
3
|
+
|
|
4
|
+
# Define a controlled vocabulary for task types.
|
|
5
|
+
# The framework will choose appropriate TaskHandler based on this type.
|
|
6
|
+
TaskType = Literal["kaggle", "qa", "code", "datasci", "open_ended"]
|
|
7
|
+
|
|
8
|
+
# Define execution modes.
|
|
9
|
+
# 'standard_ml': Strict evaluation against ground truth (e.g., RMSE, Accuracy).
|
|
10
|
+
# 'open_ended': Exploratory tasks evaluated by LLM judge or artifact existence.
|
|
11
|
+
TaskMode = Literal["standard_ml", "open_ended"]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TaskDefinition(BaseModel):
|
|
15
|
+
"""
|
|
16
|
+
Standardized, serializable representation of any task in the DSAT framework.
|
|
17
|
+
|
|
18
|
+
This is the "logical contract" for the framework to interact with external benchmarks.
|
|
19
|
+
It encapsulates all information about a task so that DSATRunner and TaskHandler
|
|
20
|
+
can understand and process it.
|
|
21
|
+
"""
|
|
22
|
+
task_id: str = Field(
|
|
23
|
+
description="Unique identifier for the task instance, e.g., 'house-prices-advanced-regression-techniques', 'gsm8k_train_001'."
|
|
24
|
+
)
|
|
25
|
+
task_type: TaskType = Field(
|
|
26
|
+
description="General category of the task, used to select the correct TaskHandler for processing."
|
|
27
|
+
)
|
|
28
|
+
mode: TaskMode = Field(
|
|
29
|
+
default="standard_ml",
|
|
30
|
+
description="The execution mode for the task. Defaults to 'standard_ml' for backward compatibility."
|
|
31
|
+
)
|
|
32
|
+
payload: Dict[str, Any] = Field(
|
|
33
|
+
description="A dictionary containing task-specific data."
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
class Config:
|
|
37
|
+
"""
|
|
38
|
+
Pydantic configuration with documentation examples.
|
|
39
|
+
|
|
40
|
+
Payload examples:
|
|
41
|
+
- task_type='kaggle':
|
|
42
|
+
{
|
|
43
|
+
"description": "Predict sales prices for houses in Ames, Iowa.",
|
|
44
|
+
"public_data_dir": "/path/to/benchmark/data",
|
|
45
|
+
"output_submission_path": "/path/to/run/artifacts/submission.csv"
|
|
46
|
+
}
|
|
47
|
+
- task_type='qa':
|
|
48
|
+
{
|
|
49
|
+
"question": "What is the result of 9*8-2?"
|
|
50
|
+
}
|
|
51
|
+
- task_type='code':
|
|
52
|
+
{
|
|
53
|
+
"prompt": "Write a Python function that computes the nth Fibonacci number.",
|
|
54
|
+
"entry_point": "fibonacci",
|
|
55
|
+
"test_cases": "[...]"
|
|
56
|
+
}
|
|
57
|
+
- task_type='open_ended':
|
|
58
|
+
{
|
|
59
|
+
"description": "Design a mathematical model to optimize traffic flow in a city.",
|
|
60
|
+
"rubric": "Evaluation criteria: model accuracy (40%), creativity (30%), clarity (30%)",
|
|
61
|
+
"output_submission_path": "/path/to/run/artifacts"
|
|
62
|
+
}
|
|
63
|
+
"""
|
|
64
|
+
frozen = True # Task definitions should be immutable after creation.
|
|
File without changes
|