dslighting 1.3.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dsat/__init__.py +3 -0
- dsat/benchmark/__init__.py +1 -0
- dsat/benchmark/benchmark.py +168 -0
- dsat/benchmark/datasci.py +291 -0
- dsat/benchmark/mle.py +777 -0
- dsat/benchmark/sciencebench.py +304 -0
- dsat/common/__init__.py +0 -0
- dsat/common/constants.py +11 -0
- dsat/common/exceptions.py +48 -0
- dsat/common/typing.py +19 -0
- dsat/config.py +79 -0
- dsat/models/__init__.py +3 -0
- dsat/models/candidates.py +16 -0
- dsat/models/formats.py +52 -0
- dsat/models/task.py +64 -0
- dsat/operators/__init__.py +0 -0
- dsat/operators/aflow_ops.py +90 -0
- dsat/operators/autokaggle_ops.py +170 -0
- dsat/operators/automind_ops.py +38 -0
- dsat/operators/base.py +22 -0
- dsat/operators/code.py +45 -0
- dsat/operators/dsagent_ops.py +123 -0
- dsat/operators/llm_basic.py +84 -0
- dsat/prompts/__init__.py +0 -0
- dsat/prompts/aflow_prompt.py +76 -0
- dsat/prompts/aide_prompt.py +52 -0
- dsat/prompts/autokaggle_prompt.py +290 -0
- dsat/prompts/automind_prompt.py +29 -0
- dsat/prompts/common.py +51 -0
- dsat/prompts/data_interpreter_prompt.py +82 -0
- dsat/prompts/dsagent_prompt.py +88 -0
- dsat/runner.py +554 -0
- dsat/services/__init__.py +0 -0
- dsat/services/data_analyzer.py +387 -0
- dsat/services/llm.py +486 -0
- dsat/services/llm_single.py +421 -0
- dsat/services/sandbox.py +386 -0
- dsat/services/states/__init__.py +0 -0
- dsat/services/states/autokaggle_state.py +43 -0
- dsat/services/states/base.py +14 -0
- dsat/services/states/dsa_log.py +13 -0
- dsat/services/states/experience.py +237 -0
- dsat/services/states/journal.py +153 -0
- dsat/services/states/operator_library.py +290 -0
- dsat/services/vdb.py +76 -0
- dsat/services/workspace.py +178 -0
- dsat/tasks/__init__.py +3 -0
- dsat/tasks/handlers.py +376 -0
- dsat/templates/open_ended/grade_template.py +107 -0
- dsat/tools/__init__.py +4 -0
- dsat/utils/__init__.py +0 -0
- dsat/utils/context.py +172 -0
- dsat/utils/dynamic_import.py +71 -0
- dsat/utils/parsing.py +33 -0
- dsat/workflows/__init__.py +12 -0
- dsat/workflows/base.py +53 -0
- dsat/workflows/factory.py +439 -0
- dsat/workflows/manual/__init__.py +0 -0
- dsat/workflows/manual/autokaggle_workflow.py +148 -0
- dsat/workflows/manual/data_interpreter_workflow.py +153 -0
- dsat/workflows/manual/deepanalyze_workflow.py +484 -0
- dsat/workflows/manual/dsagent_workflow.py +76 -0
- dsat/workflows/search/__init__.py +0 -0
- dsat/workflows/search/aflow_workflow.py +344 -0
- dsat/workflows/search/aide_workflow.py +283 -0
- dsat/workflows/search/automind_workflow.py +237 -0
- dsat/workflows/templates/__init__.py +0 -0
- dsat/workflows/templates/basic_kaggle_loop.py +71 -0
- dslighting/__init__.py +170 -0
- dslighting/core/__init__.py +13 -0
- dslighting/core/agent.py +646 -0
- dslighting/core/config_builder.py +318 -0
- dslighting/core/data_loader.py +422 -0
- dslighting/core/task_detector.py +422 -0
- dslighting/utils/__init__.py +19 -0
- dslighting/utils/defaults.py +151 -0
- dslighting-1.3.9.dist-info/METADATA +554 -0
- dslighting-1.3.9.dist-info/RECORD +80 -0
- dslighting-1.3.9.dist-info/WHEEL +5 -0
- dslighting-1.3.9.dist-info/top_level.txt +2 -0
dsat/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# benchmarks/__init__.py
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Callable, List, Tuple, Optional, Dict
|
|
5
|
+
import pandas as pd
|
|
6
|
+
import logging
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
class BaseBenchmark:
|
|
11
|
+
"""
|
|
12
|
+
Abstract base class for all benchmark tests.
|
|
13
|
+
"""
|
|
14
|
+
def __init__(self, name: str, file_path: Optional[str], log_path: str, **kwargs):
|
|
15
|
+
self.name = name
|
|
16
|
+
self.file_path = file_path
|
|
17
|
+
self.log_path = log_path
|
|
18
|
+
self.problems = self._load_problems()
|
|
19
|
+
# self.results_path = Path(self.log_path) / f"{self.name}_results.csv"
|
|
20
|
+
from datetime import datetime
|
|
21
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
22
|
+
self.results_path = Path(self.log_path) / f"{self.name}_results_{timestamp}.csv"
|
|
23
|
+
self.mismatches_path = Path(self.log_path) / f"{self.name}_mismatches.log"
|
|
24
|
+
|
|
25
|
+
def _load_problems(self) -> List[Dict[str, Any]]:
|
|
26
|
+
"""Load problems from jsonl file."""
|
|
27
|
+
# MODIFICATION: Handle cases where file_path is not provided.
|
|
28
|
+
if not self.file_path:
|
|
29
|
+
logger.debug("No file_path provided. Subclass is expected to override _load_problems.")
|
|
30
|
+
return []
|
|
31
|
+
|
|
32
|
+
with open(self.file_path, "r", encoding="utf-8") as f:
|
|
33
|
+
return [json.loads(line) for line in f]
|
|
34
|
+
|
|
35
|
+
def get_result_columns(self) -> List[str]:
|
|
36
|
+
raise NotImplementedError
|
|
37
|
+
|
|
38
|
+
async def evaluate_problem(self, problem: Dict, eval_fn: Callable, **kwargs) -> Tuple:
|
|
39
|
+
raise NotImplementedError
|
|
40
|
+
|
|
41
|
+
def log_mismatch(self, **kwargs):
|
|
42
|
+
with open(self.mismatches_path, "a", encoding="utf-8") as f:
|
|
43
|
+
f.write(json.dumps(kwargs) + "\n")
|
|
44
|
+
|
|
45
|
+
# REFACTORED: The main evaluation loop now accepts and passes down `eval_fn`.
|
|
46
|
+
async def run_evaluation(self, eval_fn: Callable, **kwargs):
|
|
47
|
+
"""
|
|
48
|
+
Run the entire benchmark evaluation.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
eval_fn: The generic evaluation function provided by DSATRunner.get_eval_function().
|
|
52
|
+
"""
|
|
53
|
+
if not self.problems:
|
|
54
|
+
logger.error(f"Evaluation for '{self.name}' aborted: No problems were loaded.")
|
|
55
|
+
return
|
|
56
|
+
|
|
57
|
+
logger.info(f"Starting evaluation for benchmark '{self.name}' with {len(self.problems)} problems.")
|
|
58
|
+
|
|
59
|
+
results = []
|
|
60
|
+
tasks = [self.evaluate_problem(problem, eval_fn=eval_fn) for problem in self.problems]
|
|
61
|
+
|
|
62
|
+
# Use native asyncio for task completion
|
|
63
|
+
for future in asyncio.as_completed(tasks):
|
|
64
|
+
try:
|
|
65
|
+
# evaluate_problem returns (csv_tuple, report, error_message)
|
|
66
|
+
result_tuple, report, error_message = await future
|
|
67
|
+
results.append(result_tuple)
|
|
68
|
+
except Exception as e:
|
|
69
|
+
logger.error(f"An unexpected error occurred in evaluate_problem: {e}", exc_info=True)
|
|
70
|
+
|
|
71
|
+
# Save results to CSV
|
|
72
|
+
df = pd.DataFrame(results, columns=self.get_result_columns())
|
|
73
|
+
df.to_csv(self.results_path, index=False)
|
|
74
|
+
|
|
75
|
+
# Add metadata summary
|
|
76
|
+
self._append_metadata_to_csv(df, **kwargs)
|
|
77
|
+
|
|
78
|
+
logger.info(f"Evaluation complete. Results saved to {self.results_path}")
|
|
79
|
+
|
|
80
|
+
def _append_metadata_to_csv(self, df: pd.DataFrame, **kwargs):
|
|
81
|
+
"""Append metadata summary to the CSV file."""
|
|
82
|
+
try:
|
|
83
|
+
import numpy as np
|
|
84
|
+
|
|
85
|
+
# Calculate statistics
|
|
86
|
+
score_col = 'score' if 'score' in df.columns else None
|
|
87
|
+
cost_col = 'cost' if 'cost' in df.columns else None
|
|
88
|
+
running_time_col = 'running_time' if 'running_time' in df.columns else None
|
|
89
|
+
input_tokens_col = 'input_tokens' if 'input_tokens' in df.columns else None
|
|
90
|
+
output_tokens_col = 'output_tokens' if 'output_tokens' in df.columns else None
|
|
91
|
+
total_tokens_col = 'total_tokens' if 'total_tokens' in df.columns else None
|
|
92
|
+
|
|
93
|
+
stats = {}
|
|
94
|
+
if score_col:
|
|
95
|
+
valid_scores = df[score_col].dropna()
|
|
96
|
+
if len(valid_scores) > 0:
|
|
97
|
+
stats['avg_score'] = valid_scores.mean()
|
|
98
|
+
stats['median_score'] = valid_scores.median()
|
|
99
|
+
stats['std_score'] = valid_scores.std()
|
|
100
|
+
|
|
101
|
+
if cost_col:
|
|
102
|
+
valid_costs = df[cost_col].dropna()
|
|
103
|
+
if len(valid_costs) > 0:
|
|
104
|
+
stats['avg_cost'] = valid_costs.mean()
|
|
105
|
+
stats['total_cost'] = valid_costs.sum()
|
|
106
|
+
|
|
107
|
+
if running_time_col:
|
|
108
|
+
valid_times = df[running_time_col].dropna()
|
|
109
|
+
if len(valid_times) > 0:
|
|
110
|
+
stats['avg_running_time'] = valid_times.mean()
|
|
111
|
+
stats['total_running_time'] = valid_times.sum()
|
|
112
|
+
|
|
113
|
+
if input_tokens_col:
|
|
114
|
+
valid_input_tokens = df[input_tokens_col].dropna()
|
|
115
|
+
if len(valid_input_tokens) > 0:
|
|
116
|
+
stats['avg_input_tokens'] = valid_input_tokens.mean()
|
|
117
|
+
stats['total_input_tokens'] = valid_input_tokens.sum()
|
|
118
|
+
|
|
119
|
+
if output_tokens_col:
|
|
120
|
+
valid_output_tokens = df[output_tokens_col].dropna()
|
|
121
|
+
if len(valid_output_tokens) > 0:
|
|
122
|
+
stats['avg_output_tokens'] = valid_output_tokens.mean()
|
|
123
|
+
stats['total_output_tokens'] = valid_output_tokens.sum()
|
|
124
|
+
|
|
125
|
+
if total_tokens_col:
|
|
126
|
+
valid_total_tokens = df[total_tokens_col].dropna()
|
|
127
|
+
if len(valid_total_tokens) > 0:
|
|
128
|
+
stats['avg_total_tokens'] = valid_total_tokens.mean()
|
|
129
|
+
stats['total_total_tokens'] = valid_total_tokens.sum()
|
|
130
|
+
|
|
131
|
+
# Get model info from kwargs
|
|
132
|
+
model_info = kwargs.get('model_name', 'N/A')
|
|
133
|
+
|
|
134
|
+
# Create metadata rows
|
|
135
|
+
meta_rows = [
|
|
136
|
+
[''] * len(df.columns), # Empty separator row
|
|
137
|
+
['=== METADATA ==='] + [''] * (len(df.columns) - 1),
|
|
138
|
+
]
|
|
139
|
+
|
|
140
|
+
meta_data = {
|
|
141
|
+
'Model': model_info,
|
|
142
|
+
'Total Tasks': len(df),
|
|
143
|
+
'Average Score': f"{stats.get('avg_score', 0):.4f}" if 'avg_score' in stats else 'N/A',
|
|
144
|
+
'Median Score': f"{stats.get('median_score', 0):.4f}" if 'median_score' in stats else 'N/A',
|
|
145
|
+
'Std Score': f"{stats.get('std_score', 0):.4f}" if 'std_score' in stats else 'N/A',
|
|
146
|
+
'Average Cost': f"${stats.get('avg_cost', 0):.4f}" if 'avg_cost' in stats else 'N/A',
|
|
147
|
+
'Total Cost': f"${stats.get('total_cost', 0):.4f}" if 'total_cost' in stats else 'N/A',
|
|
148
|
+
'Average Running Time': f"{stats.get('avg_running_time', 0):.4f}s" if 'avg_running_time' in stats else 'N/A',
|
|
149
|
+
'Total Running Time': f"{stats.get('total_running_time', 0):.4f}s" if 'total_running_time' in stats else 'N/A',
|
|
150
|
+
'Average Input Tokens': f"{stats.get('avg_input_tokens', 0):.0f}" if 'avg_input_tokens' in stats else 'N/A',
|
|
151
|
+
'Total Input Tokens': f"{stats.get('total_input_tokens', 0):.0f}" if 'total_input_tokens' in stats else 'N/A',
|
|
152
|
+
'Average Output Tokens': f"{stats.get('avg_output_tokens', 0):.0f}" if 'avg_output_tokens' in stats else 'N/A',
|
|
153
|
+
'Total Output Tokens': f"{stats.get('total_output_tokens', 0):.0f}" if 'total_output_tokens' in stats else 'N/A',
|
|
154
|
+
'Average Total Tokens': f"{stats.get('avg_total_tokens', 0):.0f}" if 'avg_total_tokens' in stats else 'N/A',
|
|
155
|
+
'Total Total Tokens': f"{stats.get('total_total_tokens', 0):.0f}" if 'total_total_tokens' in stats else 'N/A',
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
for key, value in meta_data.items():
|
|
159
|
+
meta_rows.append([key, value] + [''] * (len(df.columns) - 2))
|
|
160
|
+
|
|
161
|
+
# Append to CSV
|
|
162
|
+
with open(self.results_path, 'a', encoding='utf-8') as f:
|
|
163
|
+
for row in meta_rows:
|
|
164
|
+
f.write(','.join(str(x) for x in row) + '\n')
|
|
165
|
+
|
|
166
|
+
logger.info(f"Metadata appended to {self.results_path}")
|
|
167
|
+
except Exception as e:
|
|
168
|
+
logger.warning(f"Failed to append metadata: {e}")
|
|
@@ -0,0 +1,291 @@
|
|
|
1
|
+
# dsat/benchmark/datasci.py
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import uuid
|
|
5
|
+
import yaml
|
|
6
|
+
import shutil
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any, Callable, List, Tuple, Optional, Dict
|
|
10
|
+
import pandas as pd
|
|
11
|
+
import logging
|
|
12
|
+
|
|
13
|
+
from dsat.benchmark.benchmark import BaseBenchmark
|
|
14
|
+
from dsat.models.task import TaskDefinition
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class DataSciBenchmark(BaseBenchmark):
|
|
20
|
+
"""
|
|
21
|
+
Benchmark class for DataSciBench tasks.
|
|
22
|
+
|
|
23
|
+
DataSciBench tasks involve multi-step data science workflows where:
|
|
24
|
+
- Input: prompt.json (task description) + optional input data files
|
|
25
|
+
- Output: Generated files that are compared against ground truth
|
|
26
|
+
- Evaluation: Uses metric.yaml to define evaluation functions
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
name: str,
|
|
32
|
+
file_path: Optional[str],
|
|
33
|
+
log_path: str,
|
|
34
|
+
datasci_root_dir: Optional[str] = None,
|
|
35
|
+
tasks: Optional[List[str]] = None,
|
|
36
|
+
**kwargs
|
|
37
|
+
):
|
|
38
|
+
# Set up root directory before calling parent constructor
|
|
39
|
+
if datasci_root_dir:
|
|
40
|
+
self.root_dir = Path(datasci_root_dir)
|
|
41
|
+
else:
|
|
42
|
+
# Default to DataSciBench_Selected in benchmarks directory
|
|
43
|
+
self.root_dir = Path(__file__).parent.parent.parent / "benchmarks" / "DataSciBench_Selected"
|
|
44
|
+
|
|
45
|
+
self.data_dir = self.root_dir / "data"
|
|
46
|
+
self.metric_dir = self.root_dir / "metric"
|
|
47
|
+
self.gt_data_dir = self.root_dir / "gt_data"
|
|
48
|
+
|
|
49
|
+
# Override tasks if provided via CLI
|
|
50
|
+
self.task_filter = tasks
|
|
51
|
+
|
|
52
|
+
# Call parent constructor (which calls _load_problems)
|
|
53
|
+
super().__init__(name, file_path, log_path, **kwargs)
|
|
54
|
+
|
|
55
|
+
Path(self.log_path).mkdir(parents=True, exist_ok=True)
|
|
56
|
+
|
|
57
|
+
# Re-initialize problems after setting up directories
|
|
58
|
+
self.problems = self._load_problems()
|
|
59
|
+
logger.info(f"DataSciBenchmark initialized with root_dir: {self.root_dir}")
|
|
60
|
+
logger.info(f"Loaded {len(self.problems)} tasks")
|
|
61
|
+
|
|
62
|
+
def _load_problems(self) -> List[Dict[str, Any]]:
|
|
63
|
+
"""Load DataSciBench tasks from the data directory."""
|
|
64
|
+
if not self.data_dir.exists():
|
|
65
|
+
logger.error(f"DataSciBench data directory not found: {self.data_dir}")
|
|
66
|
+
return []
|
|
67
|
+
|
|
68
|
+
problems = []
|
|
69
|
+
|
|
70
|
+
# Get all task directories
|
|
71
|
+
task_dirs = sorted([d for d in self.data_dir.iterdir() if d.is_dir()])
|
|
72
|
+
|
|
73
|
+
for task_dir in task_dirs:
|
|
74
|
+
task_id = task_dir.name
|
|
75
|
+
|
|
76
|
+
# Apply task filter if specified
|
|
77
|
+
if self.task_filter and task_id not in self.task_filter:
|
|
78
|
+
continue
|
|
79
|
+
|
|
80
|
+
prompt_file = task_dir / "prompt.json"
|
|
81
|
+
if not prompt_file.exists():
|
|
82
|
+
logger.warning(f"Skipping task '{task_id}': no prompt.json found")
|
|
83
|
+
continue
|
|
84
|
+
|
|
85
|
+
# Load prompt
|
|
86
|
+
try:
|
|
87
|
+
with open(prompt_file, 'r', encoding='utf-8') as f:
|
|
88
|
+
prompt_data = json.load(f)
|
|
89
|
+
except Exception as e:
|
|
90
|
+
logger.warning(f"Skipping task '{task_id}': failed to load prompt.json: {e}")
|
|
91
|
+
continue
|
|
92
|
+
|
|
93
|
+
# Check for metric file
|
|
94
|
+
metric_file = self.metric_dir / task_id / "metric.yaml"
|
|
95
|
+
if not metric_file.exists():
|
|
96
|
+
logger.warning(f"Skipping task '{task_id}': no metric.yaml found")
|
|
97
|
+
continue
|
|
98
|
+
|
|
99
|
+
# Get input files (exclude prompt.json)
|
|
100
|
+
input_files = [f.name for f in task_dir.iterdir()
|
|
101
|
+
if f.is_file() and f.name not in ('prompt.json', 'orig_prompt.json')]
|
|
102
|
+
|
|
103
|
+
problems.append({
|
|
104
|
+
"task_id": task_id,
|
|
105
|
+
"prompt": prompt_data.get("prompt", ""),
|
|
106
|
+
"data_source_type": prompt_data.get("data_source_type", "unknown"),
|
|
107
|
+
"input_files": input_files,
|
|
108
|
+
"task_dir": str(task_dir),
|
|
109
|
+
"metric_file": str(metric_file),
|
|
110
|
+
"gt_dir": str(self.gt_data_dir / task_id / "gt"),
|
|
111
|
+
})
|
|
112
|
+
logger.debug(f"Loaded task: {task_id}")
|
|
113
|
+
|
|
114
|
+
if not problems:
|
|
115
|
+
logger.error(f"No valid tasks found in {self.data_dir}")
|
|
116
|
+
|
|
117
|
+
return problems
|
|
118
|
+
|
|
119
|
+
def get_result_columns(self) -> List[str]:
|
|
120
|
+
"""Define columns for results CSV."""
|
|
121
|
+
return [
|
|
122
|
+
"task_id",
|
|
123
|
+
"output_dir",
|
|
124
|
+
"gt_dir",
|
|
125
|
+
"score",
|
|
126
|
+
"cost",
|
|
127
|
+
"files_generated",
|
|
128
|
+
"evaluation_passed",
|
|
129
|
+
"error_message",
|
|
130
|
+
]
|
|
131
|
+
|
|
132
|
+
def _evaluate_task_outputs(
|
|
133
|
+
self,
|
|
134
|
+
task_id: str,
|
|
135
|
+
output_dir: Path,
|
|
136
|
+
metric_file: Path,
|
|
137
|
+
gt_dir: Path
|
|
138
|
+
) -> Tuple[float, bool, Optional[str]]:
|
|
139
|
+
"""
|
|
140
|
+
Evaluate task outputs using metric.yaml evaluation functions.
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
Tuple of (score, passed, error_message)
|
|
144
|
+
"""
|
|
145
|
+
try:
|
|
146
|
+
# Load metric configuration
|
|
147
|
+
with open(metric_file, 'r', encoding='utf-8') as f:
|
|
148
|
+
metric_config = yaml.safe_load(f)
|
|
149
|
+
|
|
150
|
+
if not metric_config or 'TMC-list' not in metric_config:
|
|
151
|
+
return 0.0, False, "Invalid metric.yaml: missing TMC-list"
|
|
152
|
+
|
|
153
|
+
tmc_list = metric_config['TMC-list']
|
|
154
|
+
total_score = 0
|
|
155
|
+
max_score = len(tmc_list) * 2 # Each task can score 0, 1, or 2
|
|
156
|
+
|
|
157
|
+
for metric_item in tmc_list:
|
|
158
|
+
eval_code = metric_item.get('code', '')
|
|
159
|
+
gt_file = metric_item.get('ground_truth', '')
|
|
160
|
+
|
|
161
|
+
if not eval_code:
|
|
162
|
+
continue
|
|
163
|
+
|
|
164
|
+
# Prepare ground truth path
|
|
165
|
+
gt_path = gt_dir / gt_file if gt_file else gt_dir
|
|
166
|
+
|
|
167
|
+
try:
|
|
168
|
+
# Create evaluation context
|
|
169
|
+
import os
|
|
170
|
+
original_cwd = os.getcwd()
|
|
171
|
+
os.chdir(str(output_dir))
|
|
172
|
+
|
|
173
|
+
# Execute evaluation code
|
|
174
|
+
local_vars = {'ground_truth': str(gt_path)}
|
|
175
|
+
exec(eval_code, {'__builtins__': __builtins__, 'pd': pd}, local_vars)
|
|
176
|
+
|
|
177
|
+
# Find the evaluation function and call it
|
|
178
|
+
for name, obj in local_vars.items():
|
|
179
|
+
if callable(obj) and name != '__builtins__':
|
|
180
|
+
try:
|
|
181
|
+
result = obj(str(gt_path))
|
|
182
|
+
if result is True:
|
|
183
|
+
total_score += 2
|
|
184
|
+
elif result is not None:
|
|
185
|
+
total_score += 1
|
|
186
|
+
except Exception as func_error:
|
|
187
|
+
logger.debug(f"Evaluation function {name} failed: {func_error}")
|
|
188
|
+
break
|
|
189
|
+
|
|
190
|
+
os.chdir(original_cwd)
|
|
191
|
+
|
|
192
|
+
except Exception as eval_error:
|
|
193
|
+
logger.debug(f"Evaluation error for {task_id}: {eval_error}")
|
|
194
|
+
try:
|
|
195
|
+
os.chdir(original_cwd)
|
|
196
|
+
except:
|
|
197
|
+
pass
|
|
198
|
+
|
|
199
|
+
# Calculate normalized score (0-1)
|
|
200
|
+
score = total_score / max_score if max_score > 0 else 0.0
|
|
201
|
+
passed = score > 0
|
|
202
|
+
|
|
203
|
+
return score, passed, None
|
|
204
|
+
|
|
205
|
+
except Exception as e:
|
|
206
|
+
return 0.0, False, str(e)
|
|
207
|
+
|
|
208
|
+
async def evaluate_problem(
|
|
209
|
+
self,
|
|
210
|
+
problem: Dict[str, Any],
|
|
211
|
+
eval_fn: Callable
|
|
212
|
+
) -> Tuple[Tuple, Any, Optional[str]]:
|
|
213
|
+
"""
|
|
214
|
+
Evaluate a single DataSciBench task.
|
|
215
|
+
"""
|
|
216
|
+
task_id = problem.get("task_id")
|
|
217
|
+
if not task_id:
|
|
218
|
+
raise ValueError("Problem data must contain 'task_id'")
|
|
219
|
+
|
|
220
|
+
prompt = problem.get("prompt", "")
|
|
221
|
+
task_dir = Path(problem.get("task_dir", ""))
|
|
222
|
+
metric_file = Path(problem.get("metric_file", ""))
|
|
223
|
+
gt_dir = Path(problem.get("gt_dir", ""))
|
|
224
|
+
|
|
225
|
+
# Create unique output directory
|
|
226
|
+
unique_id = uuid.uuid4().hex[:6]
|
|
227
|
+
output_dir = Path(self.log_path) / f"output_{task_id}_{unique_id}"
|
|
228
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
229
|
+
|
|
230
|
+
cost = 0.0
|
|
231
|
+
error_message: Optional[str] = None
|
|
232
|
+
files_generated = 0
|
|
233
|
+
evaluation_passed = False
|
|
234
|
+
score = 0.0
|
|
235
|
+
|
|
236
|
+
try:
|
|
237
|
+
# Create TaskDefinition for the workflow
|
|
238
|
+
task = TaskDefinition(
|
|
239
|
+
task_id=task_id,
|
|
240
|
+
task_type="datasci",
|
|
241
|
+
payload={
|
|
242
|
+
"prompt": prompt,
|
|
243
|
+
"input_dir": str(task_dir),
|
|
244
|
+
"output_dir": str(output_dir),
|
|
245
|
+
"expected_outputs": [], # DataSciBench doesn't require specific output files
|
|
246
|
+
}
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
# Execute the workflow
|
|
250
|
+
result, cost = await eval_fn(task)
|
|
251
|
+
|
|
252
|
+
# Check for errors
|
|
253
|
+
if isinstance(result, str) and result.startswith("[ERROR]"):
|
|
254
|
+
error_message = result
|
|
255
|
+
logger.error(f"Task {task_id} failed: {error_message}")
|
|
256
|
+
else:
|
|
257
|
+
# Count generated files
|
|
258
|
+
files_generated = len(list(output_dir.glob("*")))
|
|
259
|
+
|
|
260
|
+
# Evaluate outputs against ground truth
|
|
261
|
+
if metric_file.exists() and gt_dir.exists():
|
|
262
|
+
score, evaluation_passed, eval_error = self._evaluate_task_outputs(
|
|
263
|
+
task_id, output_dir, metric_file, gt_dir
|
|
264
|
+
)
|
|
265
|
+
if eval_error:
|
|
266
|
+
error_message = f"Evaluation error: {eval_error}"
|
|
267
|
+
else:
|
|
268
|
+
# If no metric/gt, consider it passed if files were generated
|
|
269
|
+
evaluation_passed = files_generated > 0
|
|
270
|
+
score = 1.0 if evaluation_passed else 0.0
|
|
271
|
+
|
|
272
|
+
logger.info(f"Task {task_id}: score={score:.2f}, files={files_generated}, passed={evaluation_passed}")
|
|
273
|
+
|
|
274
|
+
except Exception as e:
|
|
275
|
+
error_message = f"Error during DataSciBenchmark evaluation of {task_id}: {e}"
|
|
276
|
+
logger.error(error_message, exc_info=True)
|
|
277
|
+
|
|
278
|
+
# Build result tuple
|
|
279
|
+
csv_tuple = (
|
|
280
|
+
task_id,
|
|
281
|
+
str(output_dir),
|
|
282
|
+
str(gt_dir),
|
|
283
|
+
score,
|
|
284
|
+
cost,
|
|
285
|
+
files_generated,
|
|
286
|
+
evaluation_passed,
|
|
287
|
+
error_message,
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
return csv_tuple, {"score": score, "passed": evaluation_passed}, error_message
|
|
291
|
+
|