dslighting 1.3.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dsat/__init__.py +3 -0
- dsat/benchmark/__init__.py +1 -0
- dsat/benchmark/benchmark.py +168 -0
- dsat/benchmark/datasci.py +291 -0
- dsat/benchmark/mle.py +777 -0
- dsat/benchmark/sciencebench.py +304 -0
- dsat/common/__init__.py +0 -0
- dsat/common/constants.py +11 -0
- dsat/common/exceptions.py +48 -0
- dsat/common/typing.py +19 -0
- dsat/config.py +79 -0
- dsat/models/__init__.py +3 -0
- dsat/models/candidates.py +16 -0
- dsat/models/formats.py +52 -0
- dsat/models/task.py +64 -0
- dsat/operators/__init__.py +0 -0
- dsat/operators/aflow_ops.py +90 -0
- dsat/operators/autokaggle_ops.py +170 -0
- dsat/operators/automind_ops.py +38 -0
- dsat/operators/base.py +22 -0
- dsat/operators/code.py +45 -0
- dsat/operators/dsagent_ops.py +123 -0
- dsat/operators/llm_basic.py +84 -0
- dsat/prompts/__init__.py +0 -0
- dsat/prompts/aflow_prompt.py +76 -0
- dsat/prompts/aide_prompt.py +52 -0
- dsat/prompts/autokaggle_prompt.py +290 -0
- dsat/prompts/automind_prompt.py +29 -0
- dsat/prompts/common.py +51 -0
- dsat/prompts/data_interpreter_prompt.py +82 -0
- dsat/prompts/dsagent_prompt.py +88 -0
- dsat/runner.py +554 -0
- dsat/services/__init__.py +0 -0
- dsat/services/data_analyzer.py +387 -0
- dsat/services/llm.py +486 -0
- dsat/services/llm_single.py +421 -0
- dsat/services/sandbox.py +386 -0
- dsat/services/states/__init__.py +0 -0
- dsat/services/states/autokaggle_state.py +43 -0
- dsat/services/states/base.py +14 -0
- dsat/services/states/dsa_log.py +13 -0
- dsat/services/states/experience.py +237 -0
- dsat/services/states/journal.py +153 -0
- dsat/services/states/operator_library.py +290 -0
- dsat/services/vdb.py +76 -0
- dsat/services/workspace.py +178 -0
- dsat/tasks/__init__.py +3 -0
- dsat/tasks/handlers.py +376 -0
- dsat/templates/open_ended/grade_template.py +107 -0
- dsat/tools/__init__.py +4 -0
- dsat/utils/__init__.py +0 -0
- dsat/utils/context.py +172 -0
- dsat/utils/dynamic_import.py +71 -0
- dsat/utils/parsing.py +33 -0
- dsat/workflows/__init__.py +12 -0
- dsat/workflows/base.py +53 -0
- dsat/workflows/factory.py +439 -0
- dsat/workflows/manual/__init__.py +0 -0
- dsat/workflows/manual/autokaggle_workflow.py +148 -0
- dsat/workflows/manual/data_interpreter_workflow.py +153 -0
- dsat/workflows/manual/deepanalyze_workflow.py +484 -0
- dsat/workflows/manual/dsagent_workflow.py +76 -0
- dsat/workflows/search/__init__.py +0 -0
- dsat/workflows/search/aflow_workflow.py +344 -0
- dsat/workflows/search/aide_workflow.py +283 -0
- dsat/workflows/search/automind_workflow.py +237 -0
- dsat/workflows/templates/__init__.py +0 -0
- dsat/workflows/templates/basic_kaggle_loop.py +71 -0
- dslighting/__init__.py +170 -0
- dslighting/core/__init__.py +13 -0
- dslighting/core/agent.py +646 -0
- dslighting/core/config_builder.py +318 -0
- dslighting/core/data_loader.py +422 -0
- dslighting/core/task_detector.py +422 -0
- dslighting/utils/__init__.py +19 -0
- dslighting/utils/defaults.py +151 -0
- dslighting-1.3.9.dist-info/METADATA +554 -0
- dslighting-1.3.9.dist-info/RECORD +80 -0
- dslighting-1.3.9.dist-info/WHEEL +5 -0
- dslighting-1.3.9.dist-info/top_level.txt +2 -0
dsat/benchmark/mle.py
ADDED
|
@@ -0,0 +1,777 @@
|
|
|
1
|
+
# dsat/benchmark/mle.py
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import time
|
|
5
|
+
import uuid
|
|
6
|
+
import yaml
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any, Callable, List, Tuple, Optional, Dict
|
|
10
|
+
import pandas as pd
|
|
11
|
+
import logging
|
|
12
|
+
|
|
13
|
+
from dsat.benchmark.benchmark import BaseBenchmark
|
|
14
|
+
|
|
15
|
+
# --- mlebench related imports ---
|
|
16
|
+
from mlebench.data import is_dataset_prepared
|
|
17
|
+
from mlebench.grade import aggregate_reports, grade_csv
|
|
18
|
+
from mlebench.grade_helpers import CompetitionReport
|
|
19
|
+
from mlebench.registry import Competition, Registry
|
|
20
|
+
from mlebench.registry import registry as DEFAULT_MLE_REGISTRY
|
|
21
|
+
|
|
22
|
+
# --- DSAT core model imports ---
|
|
23
|
+
from dsat.models.task import TaskDefinition
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class MLEBenchmark(BaseBenchmark):
|
|
29
|
+
"""
|
|
30
|
+
Benchmark class to integrate mle_bench competitions into the DSAT framework.
|
|
31
|
+
"""
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
name: str,
|
|
35
|
+
file_path: Optional[str],
|
|
36
|
+
log_path: str,
|
|
37
|
+
data_dir: Optional[str] = None,
|
|
38
|
+
competitions: Optional[List[str]] = None,
|
|
39
|
+
data_source: str = "prepared" # Default to "prepared" for competition simulation
|
|
40
|
+
):
|
|
41
|
+
# Set up data_dir and registry before calling parent constructor
|
|
42
|
+
self.data_dir = Path(data_dir) if data_dir else DEFAULT_MLE_REGISTRY.get_data_dir()
|
|
43
|
+
self.registry: Registry = DEFAULT_MLE_REGISTRY.set_data_dir(self.data_dir)
|
|
44
|
+
self.data_source = data_source # Save data source preference
|
|
45
|
+
|
|
46
|
+
# Load configuration
|
|
47
|
+
self.config = self._load_config()
|
|
48
|
+
|
|
49
|
+
if competitions:
|
|
50
|
+
# Intelligent merge: Filter config entries by CLI args, preserving metadata (like mode).
|
|
51
|
+
cli_ids = set(competitions)
|
|
52
|
+
merged_list = []
|
|
53
|
+
|
|
54
|
+
# 1. Index existing config entries
|
|
55
|
+
config_map = {}
|
|
56
|
+
for entry in self.config.get("competitions", []):
|
|
57
|
+
if isinstance(entry, str):
|
|
58
|
+
config_map[entry] = {"id": entry, "mode": "standard_ml"}
|
|
59
|
+
elif isinstance(entry, dict) and "id" in entry:
|
|
60
|
+
config_map[entry["id"]] = entry
|
|
61
|
+
|
|
62
|
+
# 2. Build new list based on CLI args
|
|
63
|
+
for cid in competitions:
|
|
64
|
+
if cid in config_map:
|
|
65
|
+
merged_list.append(config_map[cid])
|
|
66
|
+
else:
|
|
67
|
+
# New competition not in config, use default string format
|
|
68
|
+
merged_list.append(cid)
|
|
69
|
+
|
|
70
|
+
self.config["competitions"] = merged_list
|
|
71
|
+
|
|
72
|
+
# file_path is accepted for compatibility but will be ignored.
|
|
73
|
+
super().__init__(name, file_path, log_path)
|
|
74
|
+
|
|
75
|
+
Path(self.log_path).mkdir(parents=True, exist_ok=True)
|
|
76
|
+
|
|
77
|
+
# RE-INITIALIZE problems by calling the correct loader after registry is set up.
|
|
78
|
+
self.problems = self._load_problems()
|
|
79
|
+
logger.info(f"MLEBenchmark initialized with data_dir: {self.data_dir}")
|
|
80
|
+
|
|
81
|
+
def _load_config(self) -> Dict[str, Any]:
|
|
82
|
+
"""Load configuration from config.yaml file."""
|
|
83
|
+
try:
|
|
84
|
+
# Get the path to the config.yaml file relative to this module
|
|
85
|
+
framework_dir = Path(__file__).parent.parent.parent
|
|
86
|
+
config_path = framework_dir / "config.yaml"
|
|
87
|
+
|
|
88
|
+
if not config_path.exists():
|
|
89
|
+
logger.warning(f"Config file not found at {config_path}, using default configuration")
|
|
90
|
+
return {"competitions": []}
|
|
91
|
+
|
|
92
|
+
with open(config_path, 'r', encoding='utf-8') as f:
|
|
93
|
+
config = yaml.safe_load(f)
|
|
94
|
+
return config or {"competitions": []}
|
|
95
|
+
except Exception as e:
|
|
96
|
+
logger.error(f"Error loading config file: {e}, using default configuration")
|
|
97
|
+
return {"competitions": []}
|
|
98
|
+
|
|
99
|
+
def _load_problems(self) -> List[Dict[str, Any]]:
|
|
100
|
+
"""
|
|
101
|
+
Dynamically load competition problems from the mlebench registry
|
|
102
|
+
instead of a static file. This is the correct integration pattern.
|
|
103
|
+
"""
|
|
104
|
+
logger.info(f"Discovering prepared competitions in {self.data_dir}...")
|
|
105
|
+
|
|
106
|
+
if not self.data_dir.exists():
|
|
107
|
+
raise FileNotFoundError(
|
|
108
|
+
f"MLEBench data directory not found: {self.data_dir}. "
|
|
109
|
+
"Please provide a valid path via --mle-data-dir."
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
problems = []
|
|
113
|
+
# The competition IDs are loaded from configuration (already merged with arguments in __init__)
|
|
114
|
+
competition_entries = self.config.get("competitions", [])
|
|
115
|
+
|
|
116
|
+
if not competition_entries:
|
|
117
|
+
logger.warning("No competitions configured in config.yaml")
|
|
118
|
+
return problems
|
|
119
|
+
|
|
120
|
+
# Iterate over competitions from configuration
|
|
121
|
+
for entry in competition_entries:
|
|
122
|
+
# Handle both string (legacy) and dict (new) config formats
|
|
123
|
+
if isinstance(entry, str):
|
|
124
|
+
comp_id = entry
|
|
125
|
+
mode = "standard_ml"
|
|
126
|
+
elif isinstance(entry, dict):
|
|
127
|
+
comp_id = entry.get("id")
|
|
128
|
+
mode = entry.get("mode", "standard_ml")
|
|
129
|
+
else:
|
|
130
|
+
logger.warning(f"Invalid competition entry type: {type(entry)}")
|
|
131
|
+
continue
|
|
132
|
+
|
|
133
|
+
if not comp_id:
|
|
134
|
+
continue
|
|
135
|
+
|
|
136
|
+
try:
|
|
137
|
+
competition = self.registry.get_competition(comp_id)
|
|
138
|
+
|
|
139
|
+
# For open_ended tasks, skip the prepared check and add directly
|
|
140
|
+
if mode == "open_ended":
|
|
141
|
+
problems.append({
|
|
142
|
+
"competition_id": comp_id,
|
|
143
|
+
"mode": mode
|
|
144
|
+
})
|
|
145
|
+
logger.info(f"Found open-ended competition: {comp_id} (mode={mode})")
|
|
146
|
+
elif is_dataset_prepared(competition, grading_only=False):
|
|
147
|
+
problems.append({
|
|
148
|
+
"competition_id": comp_id,
|
|
149
|
+
"mode": mode
|
|
150
|
+
})
|
|
151
|
+
logger.debug(f"Found prepared competition: {comp_id} (mode={mode})")
|
|
152
|
+
else:
|
|
153
|
+
# Standard ML tasks must be prepared
|
|
154
|
+
logger.warning(
|
|
155
|
+
f"Skipping standard ML competition '{comp_id}' as its dataset is not fully prepared."
|
|
156
|
+
)
|
|
157
|
+
except Exception as e:
|
|
158
|
+
logger.warning(f"Error loading competition '{comp_id}': {e}")
|
|
159
|
+
|
|
160
|
+
if not problems:
|
|
161
|
+
logger.error(
|
|
162
|
+
f"No prepared competitions found in {self.data_dir}. "
|
|
163
|
+
"Please run `mlebench prep <competition_id>` or `mlebench prep --all`."
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
return problems
|
|
167
|
+
|
|
168
|
+
def set_mode(self, mode: str):
|
|
169
|
+
"""Sets the benchmark mode to 'validation' or 'test'."""
|
|
170
|
+
logger.info(f"Setting MLEBenchmark mode to '{mode}'")
|
|
171
|
+
self.registry.set_mode(mode)
|
|
172
|
+
|
|
173
|
+
async def grade(self, submission_path: Path) -> float:
|
|
174
|
+
"""Grades a submission and returns a numerical fitness score for AFlow."""
|
|
175
|
+
if not self.problems:
|
|
176
|
+
return 0.0
|
|
177
|
+
|
|
178
|
+
# Assume we are grading for the first (and likely only) competition
|
|
179
|
+
problem = self.problems[0]
|
|
180
|
+
competition_id = problem["competition_id"]
|
|
181
|
+
mode = problem.get("mode", "standard_ml")
|
|
182
|
+
|
|
183
|
+
if mode == "open_ended":
|
|
184
|
+
# For open-ended tasks, use LLM judge for scoring
|
|
185
|
+
return await self._grade_open_ended(submission_path, competition_id, mode)
|
|
186
|
+
|
|
187
|
+
try:
|
|
188
|
+
competition = self.registry.get_competition(competition_id)
|
|
189
|
+
if not submission_path.exists():
|
|
190
|
+
logger.warning(f"Grading failed: submission file not found at {submission_path}")
|
|
191
|
+
return 0.0
|
|
192
|
+
|
|
193
|
+
report = grade_csv(submission_path, competition)
|
|
194
|
+
score = report.score if report.score is not None else 0.0
|
|
195
|
+
|
|
196
|
+
# Normalize score to be a fitness value (higher is better)
|
|
197
|
+
if report.is_lower_better:
|
|
198
|
+
return 1.0 / (1.0 + score) if score > 0 else 1.0
|
|
199
|
+
return score
|
|
200
|
+
|
|
201
|
+
except Exception as e:
|
|
202
|
+
logger.error(f"Error during grading for {competition_id}: {e}")
|
|
203
|
+
return 0.0
|
|
204
|
+
|
|
205
|
+
async def _grade_open_ended(self, artifacts_path: Path, competition_id: str, mode: str) -> float:
|
|
206
|
+
"""
|
|
207
|
+
Grade open-ended tasks using LLM judge.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
artifacts_path: Path to the artifacts directory
|
|
211
|
+
competition_id: Competition identifier
|
|
212
|
+
mode: Task mode (should be "open_ended")
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
Float score between 0.0 and 1.0
|
|
216
|
+
"""
|
|
217
|
+
try:
|
|
218
|
+
# Check if artifacts directory exists
|
|
219
|
+
if not artifacts_path.exists():
|
|
220
|
+
logger.warning(f"Artifacts directory not found: {artifacts_path}")
|
|
221
|
+
return 0.0
|
|
222
|
+
|
|
223
|
+
# Load task description and rubric
|
|
224
|
+
competition = self.registry.get_competition(competition_id)
|
|
225
|
+
competition_dir = competition.raw_dir.parent
|
|
226
|
+
|
|
227
|
+
description_path = competition_dir / "description.md"
|
|
228
|
+
rubric_path = competition_dir / "rubric.md"
|
|
229
|
+
|
|
230
|
+
description = ""
|
|
231
|
+
rubric = ""
|
|
232
|
+
|
|
233
|
+
if description_path.exists():
|
|
234
|
+
description = description_path.read_text(encoding='utf-8')
|
|
235
|
+
else:
|
|
236
|
+
description = competition.description
|
|
237
|
+
|
|
238
|
+
if rubric_path.exists():
|
|
239
|
+
rubric = rubric_path.read_text(encoding='utf-8')
|
|
240
|
+
|
|
241
|
+
# Analyze artifacts directory contents
|
|
242
|
+
artifact_files = list(artifacts_path.iterdir()) if artifacts_path.is_dir() else []
|
|
243
|
+
|
|
244
|
+
# Build file listing
|
|
245
|
+
file_summary = []
|
|
246
|
+
for file in sorted(artifact_files):
|
|
247
|
+
if file.is_file():
|
|
248
|
+
# Read file content based on type
|
|
249
|
+
try:
|
|
250
|
+
if file.suffix in ['.py', '.md', '.txt', '.csv']:
|
|
251
|
+
content = file.read_text(encoding='utf-8', errors='ignore')
|
|
252
|
+
# Truncate large files
|
|
253
|
+
if len(content) > 2000:
|
|
254
|
+
content = content[:2000] + "\n... (truncated)"
|
|
255
|
+
file_summary.append(f"### {file.name}\n```\n{content}\n```")
|
|
256
|
+
elif file.suffix in ['.png', '.jpg', '.jpeg']:
|
|
257
|
+
file_summary.append(f"### {file.name}\n[Image file - {file.stat().st_size} bytes]")
|
|
258
|
+
else:
|
|
259
|
+
file_summary.append(f"### {file.name}\n[Binary file - {file.stat().st_size} bytes]")
|
|
260
|
+
except Exception as e:
|
|
261
|
+
file_summary.append(f"### {file.name}\n[Error reading file: {e}]")
|
|
262
|
+
|
|
263
|
+
# If no files found, return minimal score
|
|
264
|
+
if not file_summary:
|
|
265
|
+
logger.warning(f"No artifacts found in {artifacts_path}")
|
|
266
|
+
return 0.1 # Small score for at least creating the directory
|
|
267
|
+
|
|
268
|
+
# Construct LLM judge prompt
|
|
269
|
+
# Define separator outside f-string (backslashes not allowed in f-string expressions)
|
|
270
|
+
file_sep = "\n\n"
|
|
271
|
+
judge_prompt = f"""You are an expert judge for open-ended data science tasks. Your role is to evaluate submissions based on given criteria.
|
|
272
|
+
|
|
273
|
+
## Task Description
|
|
274
|
+
|
|
275
|
+
{description}
|
|
276
|
+
|
|
277
|
+
## Evaluation Criteria
|
|
278
|
+
|
|
279
|
+
{rubric if rubric else "Evaluate based on completeness, correctness, and quality of the solution."}
|
|
280
|
+
|
|
281
|
+
## Submitted Artifacts
|
|
282
|
+
|
|
283
|
+
The following files were submitted:
|
|
284
|
+
|
|
285
|
+
{file_sep.join(file_summary)}
|
|
286
|
+
|
|
287
|
+
## Scoring Instructions
|
|
288
|
+
|
|
289
|
+
Evaluate the submission on a scale from 0.0 to 1.0 based on:
|
|
290
|
+
|
|
291
|
+
1. **Completeness** (0-0.3): Did the submission address all aspects of the task?
|
|
292
|
+
2. **Correctness** (0-0.4): Are the methods and results correct?
|
|
293
|
+
3. **Quality** (0-0.3): Is the code well-structured, documented, and the analysis thorough?
|
|
294
|
+
|
|
295
|
+
Provide your response in the following JSON format:
|
|
296
|
+
|
|
297
|
+
```json
|
|
298
|
+
{{
|
|
299
|
+
"reasoning": "Brief explanation of the score",
|
|
300
|
+
"completeness": 0.0-0.3,
|
|
301
|
+
"correctness": 0.0-0.4,
|
|
302
|
+
"quality": 0.0-0.3,
|
|
303
|
+
"total_score": 0.0-1.0
|
|
304
|
+
}}
|
|
305
|
+
```
|
|
306
|
+
|
|
307
|
+
Respond ONLY with the JSON object, no additional text.
|
|
308
|
+
"""
|
|
309
|
+
|
|
310
|
+
# Call LLM for judgment
|
|
311
|
+
try:
|
|
312
|
+
from dsat.services.llm import LLMService
|
|
313
|
+
from dsat.config import LLMConfig
|
|
314
|
+
|
|
315
|
+
# Get LLM service - need to initialize it
|
|
316
|
+
# Try to get LLM model from config or environment
|
|
317
|
+
llm_model = os.environ.get('LLM_MODEL', 'glm-4.7')
|
|
318
|
+
# For DSATRunner context, api_key etc are usually available in env or config
|
|
319
|
+
# Here we try to instantiate a service quickly
|
|
320
|
+
|
|
321
|
+
# Assuming environment variables are set from .env
|
|
322
|
+
api_key = os.environ.get('API_KEY')
|
|
323
|
+
api_base = os.environ.get('API_BASE')
|
|
324
|
+
provider = os.environ.get('LLM_PROVIDER', 'openai')
|
|
325
|
+
|
|
326
|
+
llm_config = LLMConfig(
|
|
327
|
+
model=llm_model,
|
|
328
|
+
api_key=api_key,
|
|
329
|
+
api_base=api_base,
|
|
330
|
+
provider=provider
|
|
331
|
+
)
|
|
332
|
+
llm_service = LLMService(llm_config)
|
|
333
|
+
|
|
334
|
+
logger.info(f"Calling LLM judge for {competition_id} with model {llm_model}")
|
|
335
|
+
|
|
336
|
+
# Call LLM
|
|
337
|
+
# Note: LLMService.achat is async
|
|
338
|
+
messages = [{"role": "user", "content": judge_prompt}]
|
|
339
|
+
llm_response = await llm_service.achat(messages)
|
|
340
|
+
|
|
341
|
+
# Parse JSON response
|
|
342
|
+
import json
|
|
343
|
+
import re
|
|
344
|
+
|
|
345
|
+
# Extract JSON from response
|
|
346
|
+
json_match = re.search(r'\{[^{}]*\}', llm_response, re.DOTALL)
|
|
347
|
+
if json_match:
|
|
348
|
+
result = json.loads(json_match.group())
|
|
349
|
+
score = float(result.get('total_score', 0.5))
|
|
350
|
+
|
|
351
|
+
# Log detailed breakdown
|
|
352
|
+
completeness = result.get('completeness', 0.0)
|
|
353
|
+
correctness = result.get('correctness', 0.0)
|
|
354
|
+
quality = result.get('quality', 0.0)
|
|
355
|
+
reasoning = result.get('reasoning', '')
|
|
356
|
+
|
|
357
|
+
logger.info(f"LLM judge score for {competition_id}: {score:.2f}")
|
|
358
|
+
logger.info(f" Breakdown: completeness={completeness:.2f}, correctness={correctness:.2f}, quality={quality:.2f}")
|
|
359
|
+
logger.info(f" Reasoning: {reasoning[:200]}...")
|
|
360
|
+
|
|
361
|
+
return max(0.0, min(score, 1.0)) # Ensure score is in [0, 1]
|
|
362
|
+
else:
|
|
363
|
+
logger.warning(f"Failed to parse LLM judge response as JSON: {llm_response[:200]}")
|
|
364
|
+
# Fallback to heuristic if parsing fails
|
|
365
|
+
return self._heuristic_score(artifact_files)
|
|
366
|
+
|
|
367
|
+
except Exception as llm_error:
|
|
368
|
+
logger.error(f"LLM judge call failed: {llm_error}, using heuristic scoring")
|
|
369
|
+
return self._heuristic_score(artifact_files)
|
|
370
|
+
|
|
371
|
+
except Exception as e:
|
|
372
|
+
logger.error(f"Error during LLM judge grading: {e}", exc_info=True)
|
|
373
|
+
return 0.0
|
|
374
|
+
|
|
375
|
+
def _heuristic_score(self, artifact_files) -> float:
|
|
376
|
+
"""Fallback heuristic scoring when LLM judge is unavailable."""
|
|
377
|
+
has_code = any(f.suffix == '.py' for f in artifact_files if f.is_file())
|
|
378
|
+
has_plots = any(f.suffix in ['.png', '.jpg', '.jpeg'] for f in artifact_files if f.is_file())
|
|
379
|
+
has_report = any(f.suffix in ['.md', '.txt'] for f in artifact_files if f.is_file())
|
|
380
|
+
|
|
381
|
+
# Simple heuristic: 0.6 base + 0.2 for code + 0.1 for plots + 0.1 for report
|
|
382
|
+
score = 0.6
|
|
383
|
+
if has_code:
|
|
384
|
+
score += 0.2
|
|
385
|
+
if has_plots:
|
|
386
|
+
score += 0.1
|
|
387
|
+
if has_report:
|
|
388
|
+
score += 0.1
|
|
389
|
+
|
|
390
|
+
score = min(score, 1.0) # Cap at 1.0
|
|
391
|
+
logger.info(f"Heuristic score: {score:.2f} (code={has_code}, plots={has_plots}, report={has_report})")
|
|
392
|
+
return score
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
def get_result_columns(self) -> List[str]:
|
|
396
|
+
return [
|
|
397
|
+
"competition_id", "submission_path", "answers_path", "score", "cost", "running_time",
|
|
398
|
+
"input_tokens", "output_tokens", "total_tokens",
|
|
399
|
+
"gold_medal", "silver_medal", "bronze_medal", "above_median",
|
|
400
|
+
"submission_exists", "valid_submission", "error_message",
|
|
401
|
+
]
|
|
402
|
+
|
|
403
|
+
def _create_error_report(self, competition_id: str, submission_path: Path, error_msg: str) -> CompetitionReport:
|
|
404
|
+
"""Creates a dummy report if grading or eval_fn execution fails."""
|
|
405
|
+
# Create a minimal dict that CompetitionReport.from_dict can parse
|
|
406
|
+
base_data = {
|
|
407
|
+
"competition_id": competition_id,
|
|
408
|
+
"score": None,
|
|
409
|
+
"gold_threshold": float('nan'),
|
|
410
|
+
"silver_threshold": float('nan'),
|
|
411
|
+
"bronze_threshold": float('nan'),
|
|
412
|
+
"median_threshold": float('nan'),
|
|
413
|
+
"any_medal": False,
|
|
414
|
+
"gold_medal": False,
|
|
415
|
+
"silver_medal": False,
|
|
416
|
+
"bronze_medal": False,
|
|
417
|
+
"above_median": False,
|
|
418
|
+
"submission_exists": submission_path.exists(),
|
|
419
|
+
"valid_submission": False,
|
|
420
|
+
"is_lower_better": False, # Default
|
|
421
|
+
"created_at": datetime.now().isoformat(),
|
|
422
|
+
"submission_path": str(submission_path),
|
|
423
|
+
}
|
|
424
|
+
# Use from_dict for safety
|
|
425
|
+
report = CompetitionReport.from_dict(base_data)
|
|
426
|
+
# Log the actual error separately
|
|
427
|
+
logger.error(f"Error for {competition_id}: {error_msg}")
|
|
428
|
+
return report
|
|
429
|
+
|
|
430
|
+
async def evaluate_problem(self, problem: dict, eval_fn: Callable) -> Tuple[Tuple, CompetitionReport, Optional[str]]:
|
|
431
|
+
"""
|
|
432
|
+
Evaluates a single MLEBench competition.
|
|
433
|
+
"""
|
|
434
|
+
competition_id = problem.get("competition_id")
|
|
435
|
+
mode = problem.get("mode", "standard_ml") # Default to standard_ml
|
|
436
|
+
|
|
437
|
+
if not competition_id:
|
|
438
|
+
raise ValueError("Problem data must contain 'competition_id'")
|
|
439
|
+
|
|
440
|
+
# Define unique output path
|
|
441
|
+
# timestamp = datetime.now().strftime('%Y%m%dT%H%M%S')
|
|
442
|
+
unique_id = uuid.uuid4().hex[:6]
|
|
443
|
+
output_filename = f"submission_{competition_id}_{unique_id}.csv"
|
|
444
|
+
output_submission_path = (Path(self.log_path) / output_filename).absolute()
|
|
445
|
+
|
|
446
|
+
# Start timing
|
|
447
|
+
start_time = time.perf_counter()
|
|
448
|
+
|
|
449
|
+
cost = 0.0
|
|
450
|
+
running_time = 0.0
|
|
451
|
+
input_tokens = 0
|
|
452
|
+
output_tokens = 0
|
|
453
|
+
total_tokens = 0
|
|
454
|
+
report: Optional[CompetitionReport] = None
|
|
455
|
+
error_message: Optional[str] = None
|
|
456
|
+
competition: Optional[Competition] = None
|
|
457
|
+
|
|
458
|
+
try:
|
|
459
|
+
competition = self.registry.get_competition(competition_id)
|
|
460
|
+
# Skip prepared check for open-ended tasks
|
|
461
|
+
if mode != "open_ended" and not is_dataset_prepared(competition, grading_only=False):
|
|
462
|
+
raise ValueError(f"Dataset for '{competition_id}' not prepared in '{self.data_dir}'.")
|
|
463
|
+
|
|
464
|
+
# Determine source data directory (Host Path)
|
|
465
|
+
# This path will be symlinked to './data' (or similar) in the workspace by the Runner/WorkspaceService.
|
|
466
|
+
if self.data_source == "prepared":
|
|
467
|
+
if competition.public_dir.exists():
|
|
468
|
+
source_data_dir = competition.public_dir
|
|
469
|
+
logger.info(f"Using prepared data source: {source_data_dir}")
|
|
470
|
+
else:
|
|
471
|
+
# Strict enforcement: Do not fallback to raw if prepared was requested.
|
|
472
|
+
# This prevents accidental leakage of raw data (which might contain answers) in competition mode.
|
|
473
|
+
raise FileNotFoundError(
|
|
474
|
+
f"Prepared public data directory not found at {competition.public_dir}. "
|
|
475
|
+
f"Please run `mlebench prep {competition_id}` first, or check your data directory structure."
|
|
476
|
+
)
|
|
477
|
+
else:
|
|
478
|
+
# data_source == "raw"
|
|
479
|
+
source_data_dir = competition.raw_dir
|
|
480
|
+
logger.info(f"Using raw data source: {source_data_dir}")
|
|
481
|
+
|
|
482
|
+
# 1. Create standardized TaskDefinition
|
|
483
|
+
# Determine task_type based on mode
|
|
484
|
+
task_type = "open_ended" if mode == "open_ended" else "kaggle"
|
|
485
|
+
|
|
486
|
+
# Prepare payload based on task type
|
|
487
|
+
if mode == "open_ended":
|
|
488
|
+
# For open-ended tasks, provide file paths to description, rubric, and raw data
|
|
489
|
+
competition_dir = competition.raw_dir.parent # Go up from /raw to competition root
|
|
490
|
+
|
|
491
|
+
description_path = competition_dir / "description.md"
|
|
492
|
+
rubric_path = competition_dir / "rubric.md"
|
|
493
|
+
|
|
494
|
+
payload = {
|
|
495
|
+
"description": competition.description, # Fallback brief description
|
|
496
|
+
"description_file": str(description_path) if description_path.exists() else "",
|
|
497
|
+
"rubric": "", # Will be loaded from file
|
|
498
|
+
"rubric_file": str(rubric_path) if rubric_path.exists() else "",
|
|
499
|
+
"public_data_dir": str(source_data_dir.absolute()),
|
|
500
|
+
"raw_data_dir": "./data",
|
|
501
|
+
"output_submission_path": str(output_submission_path.absolute())
|
|
502
|
+
}
|
|
503
|
+
logger.info(f"Open-ended task payload: data_source={source_data_dir}, mapped_to=./data")
|
|
504
|
+
else:
|
|
505
|
+
# For standard ML tasks
|
|
506
|
+
payload = {
|
|
507
|
+
"description": competition.description,
|
|
508
|
+
"public_data_dir": str(source_data_dir.absolute()), # Use the determined source
|
|
509
|
+
"output_submission_path": str(output_submission_path.absolute())
|
|
510
|
+
}
|
|
511
|
+
|
|
512
|
+
task = TaskDefinition(
|
|
513
|
+
task_id=competition_id,
|
|
514
|
+
task_type=task_type,
|
|
515
|
+
mode=mode,
|
|
516
|
+
payload=payload
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
# 2. Call the generic evaluation function
|
|
520
|
+
result, cost, usage_summary = await eval_fn(task)
|
|
521
|
+
|
|
522
|
+
# 3. Extract token information from usage_summary
|
|
523
|
+
input_tokens = usage_summary.get("prompt_tokens", 0)
|
|
524
|
+
output_tokens = usage_summary.get("completion_tokens", 0)
|
|
525
|
+
total_tokens = usage_summary.get("total_tokens", 0)
|
|
526
|
+
|
|
527
|
+
# 3. Process result and perform grading
|
|
528
|
+
if isinstance(result, Path):
|
|
529
|
+
logger.debug(f"Grading submission {result} for {competition_id} (mode={mode})")
|
|
530
|
+
|
|
531
|
+
if mode == "open_ended":
|
|
532
|
+
# Skip standard grading for open-ended tasks
|
|
533
|
+
logger.info(f"Skipping CSV grading for open-ended task {competition_id}")
|
|
534
|
+
# Create a dummy 'success' report
|
|
535
|
+
base_data = {
|
|
536
|
+
"competition_id": competition_id,
|
|
537
|
+
"score": 1.0, # Placeholder score
|
|
538
|
+
"gold_threshold": 0.0,
|
|
539
|
+
"silver_threshold": 0.0,
|
|
540
|
+
"bronze_threshold": 0.0,
|
|
541
|
+
"median_threshold": 0.0,
|
|
542
|
+
"any_medal": True,
|
|
543
|
+
"gold_medal": True, # Mark as success for visualization
|
|
544
|
+
"silver_medal": False,
|
|
545
|
+
"bronze_medal": False,
|
|
546
|
+
"above_median": True,
|
|
547
|
+
"submission_exists": result.exists(),
|
|
548
|
+
"valid_submission": True, # It ran successfully
|
|
549
|
+
"is_lower_better": False,
|
|
550
|
+
"created_at": datetime.now().isoformat(),
|
|
551
|
+
"submission_path": str(result),
|
|
552
|
+
}
|
|
553
|
+
report = CompetitionReport.from_dict(base_data)
|
|
554
|
+
else:
|
|
555
|
+
report = grade_csv(result, competition)
|
|
556
|
+
|
|
557
|
+
elif isinstance(result, str) and result.startswith("[ERROR]"):
|
|
558
|
+
error_message = f"DSAT workflow failed: {result}"
|
|
559
|
+
logger.error(error_message)
|
|
560
|
+
report = self._create_error_report(competition_id, output_submission_path, error_message)
|
|
561
|
+
else:
|
|
562
|
+
error_message = f"Unexpected result type from eval_fn: {type(result).__name__}"
|
|
563
|
+
logger.error(error_message)
|
|
564
|
+
report = self._create_error_report(competition_id, output_submission_path, error_message)
|
|
565
|
+
|
|
566
|
+
except Exception as e:
|
|
567
|
+
error_message = f"Error during MLEBenchmark evaluation of {competition_id}: {e}"
|
|
568
|
+
logger.error(error_message, exc_info=True)
|
|
569
|
+
report = self._create_error_report(competition_id, output_submission_path, error_message)
|
|
570
|
+
|
|
571
|
+
if report is None:
|
|
572
|
+
final_error = error_message or "Unknown error: report is None"
|
|
573
|
+
report = self._create_error_report(competition_id, output_submission_path, final_error)
|
|
574
|
+
error_message = final_error
|
|
575
|
+
|
|
576
|
+
if not report.valid_submission:
|
|
577
|
+
answers_path_str = str(getattr(competition, 'answers', 'N/A')) if competition else 'N/A'
|
|
578
|
+
self.log_mismatch(
|
|
579
|
+
problem=competition_id,
|
|
580
|
+
expected_output=answers_path_str,
|
|
581
|
+
prediction=f"File: {output_submission_path}, Exists: {report.submission_exists}, Valid: {report.valid_submission}",
|
|
582
|
+
extracted_output=report.score,
|
|
583
|
+
extract_answer_code=error_message or "Grading function failed or file invalid/missing"
|
|
584
|
+
)
|
|
585
|
+
if not error_message:
|
|
586
|
+
error_message = "Submission invalid or missing."
|
|
587
|
+
|
|
588
|
+
# Calculate running time
|
|
589
|
+
running_time = round(time.perf_counter() - start_time, 4)
|
|
590
|
+
|
|
591
|
+
answers_path_str = str(getattr(competition, 'answers', 'N/A')) if competition else 'N/A'
|
|
592
|
+
csv_tuple = (
|
|
593
|
+
report.competition_id, str(report.submission_path), answers_path_str,
|
|
594
|
+
report.score, cost, running_time,
|
|
595
|
+
input_tokens, output_tokens, total_tokens,
|
|
596
|
+
report.gold_medal, report.silver_medal, report.bronze_medal,
|
|
597
|
+
report.above_median, report.submission_exists, report.valid_submission, error_message,
|
|
598
|
+
)
|
|
599
|
+
return csv_tuple, report, error_message
|
|
600
|
+
|
|
601
|
+
async def _grade_open_ended(self, artifacts_path: Path, competition_id: str, mode: str) -> float:
|
|
602
|
+
"""
|
|
603
|
+
Grade open-ended tasks using LLM judge.
|
|
604
|
+
|
|
605
|
+
Args:
|
|
606
|
+
artifacts_path: Path to the artifacts directory
|
|
607
|
+
competition_id: Competition identifier
|
|
608
|
+
mode: Task mode (should be "open_ended")
|
|
609
|
+
|
|
610
|
+
Returns:
|
|
611
|
+
Float score between 0.0 and 1.0
|
|
612
|
+
"""
|
|
613
|
+
try:
|
|
614
|
+
# Check if artifacts directory exists
|
|
615
|
+
if not artifacts_path.exists():
|
|
616
|
+
logger.warning(f"Artifacts directory not found: {artifacts_path}")
|
|
617
|
+
return 0.0
|
|
618
|
+
|
|
619
|
+
# Load task description and rubric
|
|
620
|
+
competition = self.registry.get_competition(competition_id)
|
|
621
|
+
competition_dir = competition.raw_dir.parent
|
|
622
|
+
|
|
623
|
+
description_path = competition_dir / "description.md"
|
|
624
|
+
rubric_path = competition_dir / "rubric.md"
|
|
625
|
+
|
|
626
|
+
description = ""
|
|
627
|
+
rubric = ""
|
|
628
|
+
|
|
629
|
+
if description_path.exists():
|
|
630
|
+
description = description_path.read_text(encoding='utf-8')
|
|
631
|
+
else:
|
|
632
|
+
description = competition.description
|
|
633
|
+
|
|
634
|
+
if rubric_path.exists():
|
|
635
|
+
rubric = rubric_path.read_text(encoding='utf-8')
|
|
636
|
+
|
|
637
|
+
# Analyze artifacts directory contents
|
|
638
|
+
artifact_files = list(artifacts_path.iterdir()) if artifacts_path.is_dir() else []
|
|
639
|
+
|
|
640
|
+
# Build file listing
|
|
641
|
+
file_summary = []
|
|
642
|
+
for file in sorted(artifact_files):
|
|
643
|
+
if file.is_file():
|
|
644
|
+
# Read file content based on type
|
|
645
|
+
try:
|
|
646
|
+
if file.suffix in ['.py', '.md', '.txt', '.csv']:
|
|
647
|
+
content = file.read_text(encoding='utf-8', errors='ignore')
|
|
648
|
+
# Truncate large files
|
|
649
|
+
if len(content) > 2000:
|
|
650
|
+
content = content[:2000] + "\n... (truncated)"
|
|
651
|
+
file_summary.append(f"### {file.name}\n```\n{content}\n```")
|
|
652
|
+
elif file.suffix in ['.png', '.jpg', '.jpeg']:
|
|
653
|
+
file_summary.append(f"### {file.name}\n[Image file - {file.stat().st_size} bytes]")
|
|
654
|
+
else:
|
|
655
|
+
file_summary.append(f"### {file.name}\n[Binary file - {file.stat().st_size} bytes]")
|
|
656
|
+
except Exception as e:
|
|
657
|
+
file_summary.append(f"### {file.name}\n[Error reading file: {e}]")
|
|
658
|
+
|
|
659
|
+
# If no files found, return minimal score
|
|
660
|
+
if not file_summary:
|
|
661
|
+
logger.warning(f"No artifacts found in {artifacts_path}")
|
|
662
|
+
return 0.1 # Small score for at least creating the directory
|
|
663
|
+
|
|
664
|
+
# Construct LLM judge prompt
|
|
665
|
+
# Define separator outside f-string (backslashes not allowed in f-string expressions)
|
|
666
|
+
file_sep = "\n\n"
|
|
667
|
+
judge_prompt = f"""You are an expert judge for open-ended data science tasks. Your role is to evaluate submissions based on given criteria.
|
|
668
|
+
|
|
669
|
+
## Task Description
|
|
670
|
+
|
|
671
|
+
{description}
|
|
672
|
+
|
|
673
|
+
## Evaluation Criteria
|
|
674
|
+
|
|
675
|
+
{rubric if rubric else "Evaluate based on completeness, correctness, and quality of the solution."}
|
|
676
|
+
|
|
677
|
+
## Submitted Artifacts
|
|
678
|
+
|
|
679
|
+
The following files were submitted:
|
|
680
|
+
|
|
681
|
+
{file_sep.join(file_summary)}
|
|
682
|
+
|
|
683
|
+
## Scoring Instructions
|
|
684
|
+
|
|
685
|
+
Evaluate the submission on a scale from 0.0 to 1.0 based on:
|
|
686
|
+
|
|
687
|
+
1. **Completeness** (0-0.3): Did the submission address all aspects of the task?
|
|
688
|
+
2. **Correctness** (0-0.4): Are the methods and results correct?
|
|
689
|
+
3. **Quality** (0-0.3): Is the code well-structured, documented, and the analysis thorough?
|
|
690
|
+
|
|
691
|
+
Provide your response in the following JSON format:
|
|
692
|
+
|
|
693
|
+
```json
|
|
694
|
+
{{
|
|
695
|
+
"reasoning": "Brief explanation of the score",
|
|
696
|
+
"completeness": 0.0-0.3,
|
|
697
|
+
"correctness": 0.0-0.4,
|
|
698
|
+
"quality": 0.0-0.3,
|
|
699
|
+
"total_score": 0.0-1.0
|
|
700
|
+
}}
|
|
701
|
+
```
|
|
702
|
+
|
|
703
|
+
Respond ONLY with the JSON object, no additional text.
|
|
704
|
+
"""
|
|
705
|
+
|
|
706
|
+
# Call LLM for judgment
|
|
707
|
+
try:
|
|
708
|
+
from dsat.llm import LLMService
|
|
709
|
+
from dsat.config import DSATConfig
|
|
710
|
+
|
|
711
|
+
# Get LLM service - need to initialize it
|
|
712
|
+
# Try to get LLM model from config or environment
|
|
713
|
+
llm_model = os.environ.get('LLM_MODEL', 'glm-4.7')
|
|
714
|
+
llm_provider = os.environ.get('LLM_PROVIDER', 'openai')
|
|
715
|
+
|
|
716
|
+
llm_service = LLMService(
|
|
717
|
+
model=llm_model,
|
|
718
|
+
provider=llm_provider
|
|
719
|
+
)
|
|
720
|
+
|
|
721
|
+
logger.info(f"Calling LLM judge for {competition_id} with model {llm_model}")
|
|
722
|
+
|
|
723
|
+
# Call LLM
|
|
724
|
+
llm_response = await llm_service.call(judge_prompt)
|
|
725
|
+
|
|
726
|
+
# Parse JSON response
|
|
727
|
+
import json
|
|
728
|
+
import re
|
|
729
|
+
|
|
730
|
+
# Extract JSON from response
|
|
731
|
+
json_match = re.search(r'\{[^{}]*\}', llm_response, re.DOTALL)
|
|
732
|
+
if json_match:
|
|
733
|
+
result = json.loads(json_match.group())
|
|
734
|
+
score = float(result.get('total_score', 0.5))
|
|
735
|
+
|
|
736
|
+
# Log detailed breakdown
|
|
737
|
+
completeness = result.get('completeness', 0.0)
|
|
738
|
+
correctness = result.get('correctness', 0.0)
|
|
739
|
+
quality = result.get('quality', 0.0)
|
|
740
|
+
reasoning = result.get('reasoning', '')
|
|
741
|
+
|
|
742
|
+
logger.info(f"LLM judge score for {competition_id}: {score:.2f}")
|
|
743
|
+
logger.info(f" Breakdown: completeness={completeness:.2f}, correctness={correctness:.2f}, quality={quality:.2f}")
|
|
744
|
+
logger.info(f" Reasoning: {reasoning[:200]}...")
|
|
745
|
+
|
|
746
|
+
return max(0.0, min(score, 1.0)) # Ensure score is in [0, 1]
|
|
747
|
+
else:
|
|
748
|
+
logger.warning(f"Failed to parse LLM judge response as JSON: {llm_response[:200]}")
|
|
749
|
+
# Fallback to heuristic if parsing fails
|
|
750
|
+
return self._heuristic_score(artifact_files)
|
|
751
|
+
|
|
752
|
+
except Exception as llm_error:
|
|
753
|
+
logger.error(f"LLM judge call failed: {llm_error}, using heuristic scoring")
|
|
754
|
+
return self._heuristic_score(artifact_files)
|
|
755
|
+
|
|
756
|
+
except Exception as e:
|
|
757
|
+
logger.error(f"Error during LLM judge grading: {e}", exc_info=True)
|
|
758
|
+
return 0.0
|
|
759
|
+
|
|
760
|
+
def _heuristic_score(self, artifact_files) -> float:
|
|
761
|
+
"""Fallback heuristic scoring when LLM judge is unavailable."""
|
|
762
|
+
has_code = any(f.suffix == '.py' for f in artifact_files if f.is_file())
|
|
763
|
+
has_plots = any(f.suffix in ['.png', '.jpg', '.jpeg'] for f in artifact_files if f.is_file())
|
|
764
|
+
has_report = any(f.suffix in ['.md', '.txt'] for f in artifact_files if f.is_file())
|
|
765
|
+
|
|
766
|
+
# Simple heuristic: 0.6 base + 0.2 for code + 0.1 for plots + 0.1 for report
|
|
767
|
+
score = 0.6
|
|
768
|
+
if has_code:
|
|
769
|
+
score += 0.2
|
|
770
|
+
if has_plots:
|
|
771
|
+
score += 0.1
|
|
772
|
+
if has_report:
|
|
773
|
+
score += 0.1
|
|
774
|
+
|
|
775
|
+
score = min(score, 1.0) # Cap at 1.0
|
|
776
|
+
logger.info(f"Heuristic score: {score:.2f} (code={has_code}, plots={has_plots}, report={has_report})")
|
|
777
|
+
return score
|