dslighting 1.3.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dsat/__init__.py +3 -0
- dsat/benchmark/__init__.py +1 -0
- dsat/benchmark/benchmark.py +168 -0
- dsat/benchmark/datasci.py +291 -0
- dsat/benchmark/mle.py +777 -0
- dsat/benchmark/sciencebench.py +304 -0
- dsat/common/__init__.py +0 -0
- dsat/common/constants.py +11 -0
- dsat/common/exceptions.py +48 -0
- dsat/common/typing.py +19 -0
- dsat/config.py +79 -0
- dsat/models/__init__.py +3 -0
- dsat/models/candidates.py +16 -0
- dsat/models/formats.py +52 -0
- dsat/models/task.py +64 -0
- dsat/operators/__init__.py +0 -0
- dsat/operators/aflow_ops.py +90 -0
- dsat/operators/autokaggle_ops.py +170 -0
- dsat/operators/automind_ops.py +38 -0
- dsat/operators/base.py +22 -0
- dsat/operators/code.py +45 -0
- dsat/operators/dsagent_ops.py +123 -0
- dsat/operators/llm_basic.py +84 -0
- dsat/prompts/__init__.py +0 -0
- dsat/prompts/aflow_prompt.py +76 -0
- dsat/prompts/aide_prompt.py +52 -0
- dsat/prompts/autokaggle_prompt.py +290 -0
- dsat/prompts/automind_prompt.py +29 -0
- dsat/prompts/common.py +51 -0
- dsat/prompts/data_interpreter_prompt.py +82 -0
- dsat/prompts/dsagent_prompt.py +88 -0
- dsat/runner.py +554 -0
- dsat/services/__init__.py +0 -0
- dsat/services/data_analyzer.py +387 -0
- dsat/services/llm.py +486 -0
- dsat/services/llm_single.py +421 -0
- dsat/services/sandbox.py +386 -0
- dsat/services/states/__init__.py +0 -0
- dsat/services/states/autokaggle_state.py +43 -0
- dsat/services/states/base.py +14 -0
- dsat/services/states/dsa_log.py +13 -0
- dsat/services/states/experience.py +237 -0
- dsat/services/states/journal.py +153 -0
- dsat/services/states/operator_library.py +290 -0
- dsat/services/vdb.py +76 -0
- dsat/services/workspace.py +178 -0
- dsat/tasks/__init__.py +3 -0
- dsat/tasks/handlers.py +376 -0
- dsat/templates/open_ended/grade_template.py +107 -0
- dsat/tools/__init__.py +4 -0
- dsat/utils/__init__.py +0 -0
- dsat/utils/context.py +172 -0
- dsat/utils/dynamic_import.py +71 -0
- dsat/utils/parsing.py +33 -0
- dsat/workflows/__init__.py +12 -0
- dsat/workflows/base.py +53 -0
- dsat/workflows/factory.py +439 -0
- dsat/workflows/manual/__init__.py +0 -0
- dsat/workflows/manual/autokaggle_workflow.py +148 -0
- dsat/workflows/manual/data_interpreter_workflow.py +153 -0
- dsat/workflows/manual/deepanalyze_workflow.py +484 -0
- dsat/workflows/manual/dsagent_workflow.py +76 -0
- dsat/workflows/search/__init__.py +0 -0
- dsat/workflows/search/aflow_workflow.py +344 -0
- dsat/workflows/search/aide_workflow.py +283 -0
- dsat/workflows/search/automind_workflow.py +237 -0
- dsat/workflows/templates/__init__.py +0 -0
- dsat/workflows/templates/basic_kaggle_loop.py +71 -0
- dslighting/__init__.py +170 -0
- dslighting/core/__init__.py +13 -0
- dslighting/core/agent.py +646 -0
- dslighting/core/config_builder.py +318 -0
- dslighting/core/data_loader.py +422 -0
- dslighting/core/task_detector.py +422 -0
- dslighting/utils/__init__.py +19 -0
- dslighting/utils/defaults.py +151 -0
- dslighting-1.3.9.dist-info/METADATA +554 -0
- dslighting-1.3.9.dist-info/RECORD +80 -0
- dslighting-1.3.9.dist-info/WHEEL +5 -0
- dslighting-1.3.9.dist-info/top_level.txt +2 -0
dslighting/core/agent.py
ADDED
|
@@ -0,0 +1,646 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DSLighting Agent - Simplified API for data science automation.
|
|
3
|
+
|
|
4
|
+
This module provides the main Agent class that wraps the complexity of
|
|
5
|
+
DSAT framework while providing full control when needed.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import asyncio
|
|
9
|
+
import logging
|
|
10
|
+
import time
|
|
11
|
+
import uuid
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import Any, Dict, List, Optional, Union
|
|
15
|
+
|
|
16
|
+
import pandas as pd
|
|
17
|
+
|
|
18
|
+
from dsat.config import DSATConfig
|
|
19
|
+
from dsat.models.task import TaskDefinition, TaskType
|
|
20
|
+
from dsat.runner import DSATRunner
|
|
21
|
+
|
|
22
|
+
from dslighting.core.config_builder import ConfigBuilder
|
|
23
|
+
from dslighting.core.data_loader import DataLoader, LoadedData
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class AgentResult:
|
|
30
|
+
"""
|
|
31
|
+
Result of running an Agent on a data science task.
|
|
32
|
+
|
|
33
|
+
Attributes:
|
|
34
|
+
success: Whether the task completed successfully
|
|
35
|
+
output: Task output (predictions, answer, file path, etc.)
|
|
36
|
+
score: Evaluation score (if available)
|
|
37
|
+
cost: Total LLM cost in USD
|
|
38
|
+
duration: Execution time in seconds
|
|
39
|
+
artifacts_path: Path to generated artifacts
|
|
40
|
+
workspace_path: Path to workspace directory
|
|
41
|
+
error: Error message if failed
|
|
42
|
+
metadata: Additional metadata
|
|
43
|
+
"""
|
|
44
|
+
success: bool
|
|
45
|
+
output: Any
|
|
46
|
+
cost: float = 0.0
|
|
47
|
+
duration: float = 0.0
|
|
48
|
+
score: Optional[float] = None
|
|
49
|
+
artifacts_path: Optional[Path] = None
|
|
50
|
+
workspace_path: Optional[Path] = None
|
|
51
|
+
error: Optional[str] = None
|
|
52
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
53
|
+
|
|
54
|
+
def __repr__(self) -> str:
|
|
55
|
+
if self.success:
|
|
56
|
+
return (
|
|
57
|
+
f"AgentResult(success={self.success}, "
|
|
58
|
+
f"output={self.output}, "
|
|
59
|
+
f"score={self.score}, "
|
|
60
|
+
f"cost=${self.cost:.4f}, "
|
|
61
|
+
f"duration={self.duration:.1f}s)"
|
|
62
|
+
)
|
|
63
|
+
else:
|
|
64
|
+
return (
|
|
65
|
+
f"AgentResult(success={self.success}, "
|
|
66
|
+
f"error={self.error}, "
|
|
67
|
+
f"cost=${self.cost:.4f})"
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class Agent:
|
|
72
|
+
"""
|
|
73
|
+
Simplified interface to DSLighting's data science automation capabilities.
|
|
74
|
+
|
|
75
|
+
The Agent class provides a scikit-learn-like API that handles the complexity
|
|
76
|
+
of workflow selection, configuration, and execution while allowing advanced
|
|
77
|
+
users to access underlying components when needed.
|
|
78
|
+
|
|
79
|
+
Examples:
|
|
80
|
+
Simple usage:
|
|
81
|
+
>>> import dslighting
|
|
82
|
+
>>> agent = dslighting.Agent()
|
|
83
|
+
>>> result = agent.run("data/my-competition")
|
|
84
|
+
|
|
85
|
+
Advanced usage:
|
|
86
|
+
>>> agent = dslighting.Agent(
|
|
87
|
+
... workflow="autokaggle",
|
|
88
|
+
... model="gpt-4o",
|
|
89
|
+
... temperature=0.5,
|
|
90
|
+
... max_iterations=10
|
|
91
|
+
... )
|
|
92
|
+
>>> result = agent.run(data_path)
|
|
93
|
+
|
|
94
|
+
Access underlying components:
|
|
95
|
+
>>> config = agent.get_config()
|
|
96
|
+
>>> runner = agent.get_runner()
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
def __init__(
|
|
100
|
+
self,
|
|
101
|
+
workflow: str = None,
|
|
102
|
+
model: str = None,
|
|
103
|
+
api_key: str = None,
|
|
104
|
+
api_base: str = None,
|
|
105
|
+
provider: str = None,
|
|
106
|
+
temperature: float = None,
|
|
107
|
+
max_iterations: int = None,
|
|
108
|
+
num_drafts: int = None,
|
|
109
|
+
workspace_dir: str = None,
|
|
110
|
+
run_name: str = None,
|
|
111
|
+
keep_workspace: bool = False,
|
|
112
|
+
keep_workspace_on_failure: bool = True,
|
|
113
|
+
verbose: bool = True,
|
|
114
|
+
**kwargs
|
|
115
|
+
):
|
|
116
|
+
"""
|
|
117
|
+
Initialize DSLighting Agent.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
workflow: Workflow name (aide, autokaggle, automind, etc.)
|
|
121
|
+
Defaults to "aide" or auto-detected from data
|
|
122
|
+
model: LLM model name (e.g., "gpt-4o-mini", "deepseek-chat")
|
|
123
|
+
Defaults to LLM_MODEL env var or "gpt-4o-mini"
|
|
124
|
+
api_key: API key for LLM service
|
|
125
|
+
Defaults to API_KEY env var
|
|
126
|
+
api_base: API base URL
|
|
127
|
+
Defaults to API_BASE env var or OpenAI
|
|
128
|
+
provider: LLM provider for LiteLLM (e.g., "siliconflow")
|
|
129
|
+
temperature: LLM temperature (0.0-1.0)
|
|
130
|
+
max_iterations: Maximum agent search iterations
|
|
131
|
+
num_drafts: Number of drafts to generate
|
|
132
|
+
workspace_dir: Custom workspace directory
|
|
133
|
+
run_name: Name for this run
|
|
134
|
+
keep_workspace: Keep workspace after completion (default: False)
|
|
135
|
+
keep_workspace_on_failure: Keep workspace on failure (default: True)
|
|
136
|
+
verbose: Enable verbose logging
|
|
137
|
+
**kwargs: Additional parameters passed to DSATConfig
|
|
138
|
+
"""
|
|
139
|
+
self.verbose = verbose
|
|
140
|
+
self.logger = logger
|
|
141
|
+
|
|
142
|
+
# Build configuration
|
|
143
|
+
self.config_builder = ConfigBuilder()
|
|
144
|
+
self.config = self.config_builder.build_config(
|
|
145
|
+
workflow=workflow,
|
|
146
|
+
model=model,
|
|
147
|
+
api_key=api_key,
|
|
148
|
+
api_base=api_base,
|
|
149
|
+
provider=provider,
|
|
150
|
+
temperature=temperature,
|
|
151
|
+
max_iterations=max_iterations,
|
|
152
|
+
num_drafts=num_drafts,
|
|
153
|
+
workspace_dir=workspace_dir,
|
|
154
|
+
run_name=run_name,
|
|
155
|
+
keep_workspace=keep_workspace,
|
|
156
|
+
keep_workspace_on_failure=keep_workspace_on_failure,
|
|
157
|
+
**kwargs
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
# Create runner (will be created on first use)
|
|
161
|
+
self._runner: Optional[DSATRunner] = None
|
|
162
|
+
|
|
163
|
+
# Track results
|
|
164
|
+
self._results: List[AgentResult] = []
|
|
165
|
+
|
|
166
|
+
self.logger.info(f"Agent initialized with workflow: '{self.config.workflow.name}'")
|
|
167
|
+
|
|
168
|
+
def run(
|
|
169
|
+
self,
|
|
170
|
+
data: Union[str, Path, dict, pd.DataFrame, LoadedData] = None,
|
|
171
|
+
task_id: str = None,
|
|
172
|
+
data_dir: str = None,
|
|
173
|
+
output_path: str = None,
|
|
174
|
+
description: str = None,
|
|
175
|
+
**kwargs
|
|
176
|
+
) -> AgentResult:
|
|
177
|
+
"""
|
|
178
|
+
Run the agent on a data science task.
|
|
179
|
+
|
|
180
|
+
This is the main entry point for executing data science tasks.
|
|
181
|
+
It handles data loading, task creation, workflow execution, and
|
|
182
|
+
result collection.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
data: Optional data source (path, DataFrame, dict, or LoadedData).
|
|
186
|
+
If not provided, use task_id + data_dir pattern.
|
|
187
|
+
task_id: Task/Competition identifier (e.g., "bike-sharing-demand").
|
|
188
|
+
Required when using MLE benchmark format.
|
|
189
|
+
data_dir: Base data directory containing competition data.
|
|
190
|
+
Default: "data/competitions"
|
|
191
|
+
output_path: Custom output path for results
|
|
192
|
+
description: Optional task description (overrides detected)
|
|
193
|
+
**kwargs: Additional task parameters
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
AgentResult with output, metrics, and metadata
|
|
197
|
+
|
|
198
|
+
Examples:
|
|
199
|
+
>>> # Method 1: Recommended - using task_id + data_dir
|
|
200
|
+
>>> result = agent.run(
|
|
201
|
+
... task_id="bike-sharing-demand",
|
|
202
|
+
... data_dir="data/competitions"
|
|
203
|
+
... )
|
|
204
|
+
|
|
205
|
+
>>> # Method 2: Using data path directly
|
|
206
|
+
>>> result = agent.run("path/to/competition")
|
|
207
|
+
|
|
208
|
+
>>> # Method 3: Using DataFrame
|
|
209
|
+
>>> result = agent.run(df, description="Predict price")
|
|
210
|
+
"""
|
|
211
|
+
# Start timing
|
|
212
|
+
start_time = time.time()
|
|
213
|
+
|
|
214
|
+
try:
|
|
215
|
+
# ========== New simplified API: task_id + data_dir ==========
|
|
216
|
+
if task_id:
|
|
217
|
+
# Set default data_dir if not provided
|
|
218
|
+
if data_dir is None:
|
|
219
|
+
data_dir = "data/competitions"
|
|
220
|
+
|
|
221
|
+
self.logger.info(f"Using MLE benchmark format")
|
|
222
|
+
self.logger.info(f" task_id: {task_id}")
|
|
223
|
+
self.logger.info(f" data_dir: {data_dir}")
|
|
224
|
+
|
|
225
|
+
# Resolve paths
|
|
226
|
+
data_dir_path = Path(data_dir).resolve()
|
|
227
|
+
competition_dir = data_dir_path / task_id
|
|
228
|
+
|
|
229
|
+
# Check if task exists in benchmarks registry
|
|
230
|
+
benchmark_dir = self._get_default_benchmark_dir()
|
|
231
|
+
task_registry = benchmark_dir / task_id
|
|
232
|
+
|
|
233
|
+
if not task_registry.exists():
|
|
234
|
+
self.logger.warning(
|
|
235
|
+
f"Task '{task_id}' not found in benchmark registry: {benchmark_dir}"
|
|
236
|
+
)
|
|
237
|
+
self.logger.warning(
|
|
238
|
+
f"This means the task cannot be auto-graded. "
|
|
239
|
+
f"To enable grading, register the task at: {task_registry}"
|
|
240
|
+
)
|
|
241
|
+
else:
|
|
242
|
+
self.logger.info(f" ✓ Task registered: {task_registry}")
|
|
243
|
+
|
|
244
|
+
# Check if data exists
|
|
245
|
+
if not competition_dir.exists():
|
|
246
|
+
raise FileNotFoundError(
|
|
247
|
+
f"Data directory not found: {competition_dir}\n"
|
|
248
|
+
f"Please ensure data is prepared at: {competition_dir}/prepared/"
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
self.logger.info(f" Data directory: {competition_dir}")
|
|
252
|
+
|
|
253
|
+
# Load data
|
|
254
|
+
loader = DataLoader()
|
|
255
|
+
loaded_data = loader.load(competition_dir)
|
|
256
|
+
|
|
257
|
+
# ========== Recommended API: LoadedData with task_id ==========
|
|
258
|
+
elif data is not None:
|
|
259
|
+
# Load data if not already loaded
|
|
260
|
+
if not isinstance(data, LoadedData):
|
|
261
|
+
loader = DataLoader()
|
|
262
|
+
loaded_data = loader.load(data)
|
|
263
|
+
else:
|
|
264
|
+
loaded_data = data
|
|
265
|
+
|
|
266
|
+
# Extract task_id from loaded_data if available
|
|
267
|
+
if loaded_data.task_id:
|
|
268
|
+
# Use task_id from loaded_data for benchmark initialization
|
|
269
|
+
extracted_task_id = loaded_data.task_id
|
|
270
|
+
self.logger.info(f"Detected task_id from data: {extracted_task_id}")
|
|
271
|
+
|
|
272
|
+
# Override task_id parameter for benchmark initialization
|
|
273
|
+
task_id = extracted_task_id
|
|
274
|
+
else:
|
|
275
|
+
raise ValueError(
|
|
276
|
+
"Either 'task_id' or 'data' must be provided. "
|
|
277
|
+
"Example: agent.run(task_id='bike-sharing-demand', data_dir='data/competitions')"
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
# Get task information
|
|
281
|
+
task_detection = loaded_data.task_detection
|
|
282
|
+
|
|
283
|
+
# Determine workflow
|
|
284
|
+
workflow = self._determine_workflow(task_detection)
|
|
285
|
+
self.logger.info(f"Using workflow: {workflow}")
|
|
286
|
+
|
|
287
|
+
# Create task definition
|
|
288
|
+
task = self._create_task_definition(
|
|
289
|
+
loaded_data=loaded_data,
|
|
290
|
+
task_id=task_id,
|
|
291
|
+
description=description,
|
|
292
|
+
output_path=output_path,
|
|
293
|
+
**kwargs
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
# Initialize benchmark for MLE tasks (for grading)
|
|
297
|
+
if task_id and loaded_data.get_task_type() == "kaggle":
|
|
298
|
+
try:
|
|
299
|
+
# Import here to avoid hard dependency
|
|
300
|
+
from dsat.benchmark.mle import MLEBenchmark
|
|
301
|
+
|
|
302
|
+
# Resolve data_dir
|
|
303
|
+
if data_dir is None:
|
|
304
|
+
data_dir = "data/competitions"
|
|
305
|
+
data_dir_path = Path(data_dir).expanduser().resolve()
|
|
306
|
+
|
|
307
|
+
# Create benchmark instance
|
|
308
|
+
benchmark = MLEBenchmark(
|
|
309
|
+
data_dir=str(data_dir_path),
|
|
310
|
+
task_ids=[task_id]
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
# Set benchmark on runner
|
|
314
|
+
runner = self.get_runner()
|
|
315
|
+
runner.benchmark = benchmark
|
|
316
|
+
|
|
317
|
+
self.logger.info(f"✓ Benchmark initialized for grading: {task_id}")
|
|
318
|
+
except ImportError:
|
|
319
|
+
self.logger.warning("MLE-Bench not installed. Grading will be skipped.")
|
|
320
|
+
self.logger.warning("Install with: pip install mlebench")
|
|
321
|
+
except Exception as e:
|
|
322
|
+
self.logger.warning(f"Benchmark initialization failed: {e}")
|
|
323
|
+
self.logger.warning("Grading will be skipped")
|
|
324
|
+
|
|
325
|
+
# Execute task (async wrapper)
|
|
326
|
+
result = asyncio.run(self._execute_task(task, loaded_data))
|
|
327
|
+
|
|
328
|
+
# Calculate duration
|
|
329
|
+
duration = time.time() - start_time
|
|
330
|
+
result.duration = duration
|
|
331
|
+
|
|
332
|
+
# Store result
|
|
333
|
+
self._results.append(result)
|
|
334
|
+
|
|
335
|
+
if self.verbose:
|
|
336
|
+
self._log_result(result)
|
|
337
|
+
|
|
338
|
+
return result
|
|
339
|
+
|
|
340
|
+
except Exception as e:
|
|
341
|
+
self.logger.error(f"Task execution failed: {e}", exc_info=True)
|
|
342
|
+
|
|
343
|
+
duration = time.time() - start_time
|
|
344
|
+
|
|
345
|
+
return AgentResult(
|
|
346
|
+
success=False,
|
|
347
|
+
output=None,
|
|
348
|
+
duration=duration,
|
|
349
|
+
error=str(e),
|
|
350
|
+
metadata={"exception_type": type(e).__name__}
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
def run_batch(
|
|
354
|
+
self,
|
|
355
|
+
data_list: List[Union[str, Path, dict, pd.DataFrame]],
|
|
356
|
+
**kwargs
|
|
357
|
+
) -> List[AgentResult]:
|
|
358
|
+
"""
|
|
359
|
+
Run the agent on multiple data sources sequentially.
|
|
360
|
+
|
|
361
|
+
Args:
|
|
362
|
+
data_list: List of data sources
|
|
363
|
+
**kwargs: Additional parameters passed to run()
|
|
364
|
+
|
|
365
|
+
Returns:
|
|
366
|
+
List of AgentResult objects
|
|
367
|
+
|
|
368
|
+
Examples:
|
|
369
|
+
>>> results = agent.run_batch([
|
|
370
|
+
... "data/titanic",
|
|
371
|
+
... "data/house-prices",
|
|
372
|
+
... "data/fraud-detection"
|
|
373
|
+
... ])
|
|
374
|
+
>>> for r in results:
|
|
375
|
+
... print(f"{r.task_id}: {r.score}")
|
|
376
|
+
"""
|
|
377
|
+
results = []
|
|
378
|
+
|
|
379
|
+
for i, data in enumerate(data_list):
|
|
380
|
+
self.logger.info(f"Running batch task {i+1}/{len(data_list)}")
|
|
381
|
+
|
|
382
|
+
result = self.run(data, **kwargs)
|
|
383
|
+
results.append(result)
|
|
384
|
+
|
|
385
|
+
return results
|
|
386
|
+
|
|
387
|
+
def get_config(self) -> DSATConfig:
|
|
388
|
+
"""
|
|
389
|
+
Get the underlying DSAT configuration.
|
|
390
|
+
|
|
391
|
+
This allows advanced users to modify configuration directly.
|
|
392
|
+
|
|
393
|
+
Returns:
|
|
394
|
+
DSATConfig object
|
|
395
|
+
|
|
396
|
+
Examples:
|
|
397
|
+
>>> config = agent.get_config()
|
|
398
|
+
>>> config.llm.temperature = 0.5
|
|
399
|
+
"""
|
|
400
|
+
return self.config
|
|
401
|
+
|
|
402
|
+
def get_runner(self) -> DSATRunner:
|
|
403
|
+
"""
|
|
404
|
+
Get the underlying DSATRunner instance.
|
|
405
|
+
|
|
406
|
+
This allows advanced users to interact with the runner directly.
|
|
407
|
+
|
|
408
|
+
Returns:
|
|
409
|
+
DSATRunner instance
|
|
410
|
+
|
|
411
|
+
Examples:
|
|
412
|
+
>>> runner = agent.get_runner()
|
|
413
|
+
>>> eval_fn = runner.get_eval_function()
|
|
414
|
+
"""
|
|
415
|
+
if self._runner is None:
|
|
416
|
+
self._runner = DSATRunner(self.config)
|
|
417
|
+
|
|
418
|
+
return self._runner
|
|
419
|
+
|
|
420
|
+
def get_results(self) -> List[AgentResult]:
|
|
421
|
+
"""Get all results from this agent session."""
|
|
422
|
+
return self._results.copy()
|
|
423
|
+
|
|
424
|
+
def _determine_workflow(self, task_detection) -> str:
|
|
425
|
+
"""Determine which workflow to use."""
|
|
426
|
+
# If user specified workflow, use it
|
|
427
|
+
if self.config.workflow and self.config.workflow.name:
|
|
428
|
+
return self.config.workflow.name
|
|
429
|
+
|
|
430
|
+
# Otherwise, use recommended workflow from detection
|
|
431
|
+
if task_detection and task_detection.recommended_workflow:
|
|
432
|
+
return task_detection.recommended_workflow
|
|
433
|
+
|
|
434
|
+
# Fallback to default
|
|
435
|
+
return "aide"
|
|
436
|
+
|
|
437
|
+
def _create_task_definition(
|
|
438
|
+
self,
|
|
439
|
+
loaded_data: LoadedData,
|
|
440
|
+
task_id: str = None,
|
|
441
|
+
description: str = None,
|
|
442
|
+
output_path: str = None,
|
|
443
|
+
**kwargs
|
|
444
|
+
) -> TaskDefinition:
|
|
445
|
+
"""Create TaskDefinition from LoadedData."""
|
|
446
|
+
# Generate task ID if not provided
|
|
447
|
+
if task_id is None:
|
|
448
|
+
safe_name = str(uuid.uuid4())[:8]
|
|
449
|
+
task_id = f"task_{safe_name}"
|
|
450
|
+
|
|
451
|
+
# Get task type
|
|
452
|
+
task_type = loaded_data.get_task_type()
|
|
453
|
+
|
|
454
|
+
# Get description
|
|
455
|
+
if description is None:
|
|
456
|
+
description = loaded_data.get_description()
|
|
457
|
+
|
|
458
|
+
# Get I/O instructions
|
|
459
|
+
io_instructions = loaded_data.get_io_instructions()
|
|
460
|
+
|
|
461
|
+
# Build payload
|
|
462
|
+
payload = kwargs.copy()
|
|
463
|
+
payload["description"] = description
|
|
464
|
+
payload["io_instructions"] = io_instructions
|
|
465
|
+
|
|
466
|
+
# Add data directory based on task type
|
|
467
|
+
if loaded_data.data_dir:
|
|
468
|
+
data_dir = loaded_data.data_dir
|
|
469
|
+
|
|
470
|
+
if task_type == "kaggle":
|
|
471
|
+
# MLE/Kaggle format: needs public_data_dir and output_submission_path
|
|
472
|
+
# Follow MLEBenchmark pattern: {data_dir}/prepared/public
|
|
473
|
+
prepared_dir = data_dir / "prepared"
|
|
474
|
+
public_dir = prepared_dir / "public"
|
|
475
|
+
|
|
476
|
+
# Check if prepared/public exists (MLE format)
|
|
477
|
+
if public_dir.exists():
|
|
478
|
+
payload["public_data_dir"] = str(public_dir.resolve())
|
|
479
|
+
self.logger.info(f"Using MLE prepared data: {public_dir.resolve()}")
|
|
480
|
+
else:
|
|
481
|
+
# Fallback: use data_dir directly
|
|
482
|
+
payload["public_data_dir"] = str(data_dir.resolve())
|
|
483
|
+
self.logger.warning(
|
|
484
|
+
f"Prepared data not found at {public_dir}, using data_dir instead"
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
# Set output path - use simple filename, will be saved in workspace/sandbox
|
|
488
|
+
if output_path is None:
|
|
489
|
+
# Extract competition_id from data_dir path if possible
|
|
490
|
+
competition_id = data_dir.name
|
|
491
|
+
unique_id = str(uuid.uuid4())[:8]
|
|
492
|
+
output_filename = f"submission_{competition_id}_{unique_id}.csv"
|
|
493
|
+
|
|
494
|
+
# Use just the filename - DSAT will save it in workspace/sandbox
|
|
495
|
+
output_path = Path(output_filename)
|
|
496
|
+
|
|
497
|
+
payload["output_submission_path"] = str(output_path)
|
|
498
|
+
self.logger.info(f"Output submission file: {output_path}")
|
|
499
|
+
else:
|
|
500
|
+
# Other task types: use data_dir
|
|
501
|
+
payload["data_dir"] = str(data_dir)
|
|
502
|
+
|
|
503
|
+
# Add output path if specified (for non-kaggle tasks)
|
|
504
|
+
if output_path and task_type != "kaggle":
|
|
505
|
+
payload["output_submission_path"] = str(output_path)
|
|
506
|
+
|
|
507
|
+
# Create TaskDefinition
|
|
508
|
+
task = TaskDefinition(
|
|
509
|
+
task_id=task_id,
|
|
510
|
+
task_type=task_type, # Pass string directly, Pydantic will validate
|
|
511
|
+
payload=payload
|
|
512
|
+
)
|
|
513
|
+
|
|
514
|
+
return task
|
|
515
|
+
|
|
516
|
+
def _get_workspace_dir(self) -> Path:
|
|
517
|
+
"""Get the workspace directory for this agent run."""
|
|
518
|
+
# Try to get from config
|
|
519
|
+
workspace_dir = None
|
|
520
|
+
|
|
521
|
+
if hasattr(self, 'config') and hasattr(self.config, 'run'):
|
|
522
|
+
run_config = self.config.run
|
|
523
|
+
if hasattr(run_config, 'parameters') and run_config.parameters:
|
|
524
|
+
workspace_dir = run_config.parameters.get('workspace_dir')
|
|
525
|
+
|
|
526
|
+
# Fallback to default workspace directory
|
|
527
|
+
if workspace_dir is None:
|
|
528
|
+
from dslighting.utils.defaults import DEFAULT_WORKSPACE_DIR
|
|
529
|
+
workspace_dir = DEFAULT_WORKSPACE_DIR
|
|
530
|
+
|
|
531
|
+
# Ensure workspace exists
|
|
532
|
+
workspace_path = Path(workspace_dir)
|
|
533
|
+
workspace_path.mkdir(parents=True, exist_ok=True)
|
|
534
|
+
|
|
535
|
+
return workspace_path
|
|
536
|
+
|
|
537
|
+
def _get_default_benchmark_dir(self) -> Path:
|
|
538
|
+
"""
|
|
539
|
+
Get the default benchmark registry directory.
|
|
540
|
+
|
|
541
|
+
This is where task registration files (grade.py, description.md, etc.) are stored.
|
|
542
|
+
Default: benchmarks/mlebench/competitions/
|
|
543
|
+
|
|
544
|
+
Returns:
|
|
545
|
+
Path to benchmark registry directory
|
|
546
|
+
"""
|
|
547
|
+
# Try to get from config
|
|
548
|
+
benchmark_dir = None
|
|
549
|
+
|
|
550
|
+
if hasattr(self, 'config') and hasattr(self.config, 'run'):
|
|
551
|
+
run_config = self.config.run
|
|
552
|
+
if hasattr(run_config, 'parameters') and run_config.parameters:
|
|
553
|
+
benchmark_dir = run_config.parameters.get('benchmark_dir')
|
|
554
|
+
|
|
555
|
+
# Fallback to default benchmark directory
|
|
556
|
+
if benchmark_dir is None:
|
|
557
|
+
# Use relative path from current working directory
|
|
558
|
+
# Default: benchmarks/mlebench/competitions/
|
|
559
|
+
benchmark_dir = "benchmarks/mlebench/competitions"
|
|
560
|
+
|
|
561
|
+
benchmark_path = Path(benchmark_dir).resolve()
|
|
562
|
+
|
|
563
|
+
self.logger.debug(f"Benchmark registry directory: {benchmark_path}")
|
|
564
|
+
|
|
565
|
+
return benchmark_path
|
|
566
|
+
|
|
567
|
+
async def _execute_task(
|
|
568
|
+
self,
|
|
569
|
+
task: TaskDefinition,
|
|
570
|
+
loaded_data: LoadedData
|
|
571
|
+
) -> AgentResult:
|
|
572
|
+
"""Execute a single task using DSATRunner."""
|
|
573
|
+
runner = self.get_runner()
|
|
574
|
+
|
|
575
|
+
# Link data directory to workspace if available
|
|
576
|
+
if loaded_data.data_dir and hasattr(runner, 'benchmark') and runner.benchmark:
|
|
577
|
+
workspace_service = None
|
|
578
|
+
# Try to get workspace service from workflow
|
|
579
|
+
if hasattr(runner, 'factory'):
|
|
580
|
+
try:
|
|
581
|
+
workflow = runner.factory.create_workflow(self.config, benchmark=runner.benchmark)
|
|
582
|
+
workspace_service = workflow.services.get("workspace")
|
|
583
|
+
if workspace_service:
|
|
584
|
+
workspace_service.link_data_to_workspace(loaded_data.data_dir)
|
|
585
|
+
except Exception:
|
|
586
|
+
pass
|
|
587
|
+
|
|
588
|
+
# Get evaluation function
|
|
589
|
+
eval_fn = runner.get_eval_function()
|
|
590
|
+
|
|
591
|
+
# Execute task
|
|
592
|
+
output, cost, usage_summary = await eval_fn(task)
|
|
593
|
+
|
|
594
|
+
# Check for errors
|
|
595
|
+
if isinstance(output, str) and output.startswith("[ERROR]"):
|
|
596
|
+
return AgentResult(
|
|
597
|
+
success=False,
|
|
598
|
+
output=output,
|
|
599
|
+
cost=cost,
|
|
600
|
+
error=output
|
|
601
|
+
)
|
|
602
|
+
|
|
603
|
+
# Extract score if available
|
|
604
|
+
score = None
|
|
605
|
+
if isinstance(output, (int, float)):
|
|
606
|
+
score = float(output)
|
|
607
|
+
elif isinstance(output, dict) and "score" in output:
|
|
608
|
+
score = output["score"]
|
|
609
|
+
|
|
610
|
+
# Get workspace path
|
|
611
|
+
workspace_path = None
|
|
612
|
+
if hasattr(runner, 'run_records') and runner.run_records:
|
|
613
|
+
last_record = runner.run_records[-1]
|
|
614
|
+
workspace_path = last_record.get("workspace_dir")
|
|
615
|
+
|
|
616
|
+
return AgentResult(
|
|
617
|
+
success=True,
|
|
618
|
+
output=output,
|
|
619
|
+
cost=cost,
|
|
620
|
+
score=score,
|
|
621
|
+
workspace_path=Path(workspace_path) if workspace_path else None,
|
|
622
|
+
metadata={"usage": usage_summary}
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
def _log_result(self, result: AgentResult):
|
|
626
|
+
"""Log result summary."""
|
|
627
|
+
if result.success:
|
|
628
|
+
self.logger.info(
|
|
629
|
+
f"✓ Task completed successfully | "
|
|
630
|
+
f"Score: {result.score or 'N/A'} | "
|
|
631
|
+
f"Cost: ${result.cost:.4f} | "
|
|
632
|
+
f"Duration: {result.duration:.1f}s"
|
|
633
|
+
)
|
|
634
|
+
else:
|
|
635
|
+
self.logger.error(
|
|
636
|
+
f"✗ Task failed | "
|
|
637
|
+
f"Error: {result.error} | "
|
|
638
|
+
f"Cost: ${result.cost:.4f}"
|
|
639
|
+
)
|
|
640
|
+
|
|
641
|
+
def __repr__(self) -> str:
|
|
642
|
+
return (
|
|
643
|
+
f"Agent(workflow='{self.config.workflow.name}', "
|
|
644
|
+
f"model='{self.config.llm.model}', "
|
|
645
|
+
f"results={len(self._results)})"
|
|
646
|
+
)
|