dslighting 1.3.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dsat/__init__.py +3 -0
- dsat/benchmark/__init__.py +1 -0
- dsat/benchmark/benchmark.py +168 -0
- dsat/benchmark/datasci.py +291 -0
- dsat/benchmark/mle.py +777 -0
- dsat/benchmark/sciencebench.py +304 -0
- dsat/common/__init__.py +0 -0
- dsat/common/constants.py +11 -0
- dsat/common/exceptions.py +48 -0
- dsat/common/typing.py +19 -0
- dsat/config.py +79 -0
- dsat/models/__init__.py +3 -0
- dsat/models/candidates.py +16 -0
- dsat/models/formats.py +52 -0
- dsat/models/task.py +64 -0
- dsat/operators/__init__.py +0 -0
- dsat/operators/aflow_ops.py +90 -0
- dsat/operators/autokaggle_ops.py +170 -0
- dsat/operators/automind_ops.py +38 -0
- dsat/operators/base.py +22 -0
- dsat/operators/code.py +45 -0
- dsat/operators/dsagent_ops.py +123 -0
- dsat/operators/llm_basic.py +84 -0
- dsat/prompts/__init__.py +0 -0
- dsat/prompts/aflow_prompt.py +76 -0
- dsat/prompts/aide_prompt.py +52 -0
- dsat/prompts/autokaggle_prompt.py +290 -0
- dsat/prompts/automind_prompt.py +29 -0
- dsat/prompts/common.py +51 -0
- dsat/prompts/data_interpreter_prompt.py +82 -0
- dsat/prompts/dsagent_prompt.py +88 -0
- dsat/runner.py +554 -0
- dsat/services/__init__.py +0 -0
- dsat/services/data_analyzer.py +387 -0
- dsat/services/llm.py +486 -0
- dsat/services/llm_single.py +421 -0
- dsat/services/sandbox.py +386 -0
- dsat/services/states/__init__.py +0 -0
- dsat/services/states/autokaggle_state.py +43 -0
- dsat/services/states/base.py +14 -0
- dsat/services/states/dsa_log.py +13 -0
- dsat/services/states/experience.py +237 -0
- dsat/services/states/journal.py +153 -0
- dsat/services/states/operator_library.py +290 -0
- dsat/services/vdb.py +76 -0
- dsat/services/workspace.py +178 -0
- dsat/tasks/__init__.py +3 -0
- dsat/tasks/handlers.py +376 -0
- dsat/templates/open_ended/grade_template.py +107 -0
- dsat/tools/__init__.py +4 -0
- dsat/utils/__init__.py +0 -0
- dsat/utils/context.py +172 -0
- dsat/utils/dynamic_import.py +71 -0
- dsat/utils/parsing.py +33 -0
- dsat/workflows/__init__.py +12 -0
- dsat/workflows/base.py +53 -0
- dsat/workflows/factory.py +439 -0
- dsat/workflows/manual/__init__.py +0 -0
- dsat/workflows/manual/autokaggle_workflow.py +148 -0
- dsat/workflows/manual/data_interpreter_workflow.py +153 -0
- dsat/workflows/manual/deepanalyze_workflow.py +484 -0
- dsat/workflows/manual/dsagent_workflow.py +76 -0
- dsat/workflows/search/__init__.py +0 -0
- dsat/workflows/search/aflow_workflow.py +344 -0
- dsat/workflows/search/aide_workflow.py +283 -0
- dsat/workflows/search/automind_workflow.py +237 -0
- dsat/workflows/templates/__init__.py +0 -0
- dsat/workflows/templates/basic_kaggle_loop.py +71 -0
- dslighting/__init__.py +170 -0
- dslighting/core/__init__.py +13 -0
- dslighting/core/agent.py +646 -0
- dslighting/core/config_builder.py +318 -0
- dslighting/core/data_loader.py +422 -0
- dslighting/core/task_detector.py +422 -0
- dslighting/utils/__init__.py +19 -0
- dslighting/utils/defaults.py +151 -0
- dslighting-1.3.9.dist-info/METADATA +554 -0
- dslighting-1.3.9.dist-info/RECORD +80 -0
- dslighting-1.3.9.dist-info/WHEEL +5 -0
- dslighting-1.3.9.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,283 @@
|
|
|
1
|
+
# dsat/workflows/search/aide_workflow.py
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import shutil
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Dict, Optional, Any, List
|
|
7
|
+
|
|
8
|
+
from dsat.workflows.base import DSATWorkflow
|
|
9
|
+
from dsat.services.states.journal import JournalState, Node, MetricValue
|
|
10
|
+
from dsat.benchmark.benchmark import BaseBenchmark
|
|
11
|
+
from dsat.services.llm import LLMService
|
|
12
|
+
|
|
13
|
+
from dsat.prompts.aide_prompt import create_improve_prompt, create_debug_prompt
|
|
14
|
+
from dsat.prompts.common import create_draft_prompt
|
|
15
|
+
|
|
16
|
+
from dsat.utils.context import ContextManager, summarize_repetitive_logs
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class AIDEWorkflow(DSATWorkflow):
|
|
22
|
+
"""
|
|
23
|
+
Implements the AIDE iterative search algorithm.
|
|
24
|
+
This class serves as the base for search-based workflows, containing shared
|
|
25
|
+
logic for the main search loop, node selection policy, and final artifact generation.
|
|
26
|
+
"""
|
|
27
|
+
def __init__(self, operators: Dict[str, Any], services: Dict[str, Any], agent_config: Dict[str, Any], benchmark: Optional[BaseBenchmark] = None):
|
|
28
|
+
"""
|
|
29
|
+
Initializes the AIDEWorkflow and its required services and operators.
|
|
30
|
+
"""
|
|
31
|
+
super().__init__(operators, services, agent_config)
|
|
32
|
+
self.state: JournalState = services["state"]
|
|
33
|
+
self.sandbox_service = services["sandbox"]
|
|
34
|
+
self.workspace_service = services.get("workspace")
|
|
35
|
+
self.llm_service: LLMService = services["llm"]
|
|
36
|
+
self.benchmark = benchmark
|
|
37
|
+
|
|
38
|
+
self.execute_op = self.operators["execute"]
|
|
39
|
+
|
|
40
|
+
self.generate_op = self.operators["generate"]
|
|
41
|
+
self.review_op = self.operators["review"]
|
|
42
|
+
|
|
43
|
+
self.context_manager = ContextManager()
|
|
44
|
+
|
|
45
|
+
def _get_error_history(self, node: Node, max_depth: int = 3) -> str:
|
|
46
|
+
"""
|
|
47
|
+
Traverses up a chain of buggy parent nodes to build a concise error history.
|
|
48
|
+
"""
|
|
49
|
+
history = []
|
|
50
|
+
current = node
|
|
51
|
+
depth = 0
|
|
52
|
+
while current and current.is_buggy and depth < max_depth:
|
|
53
|
+
error_summary = self.context_manager.summarize_error(current.term_out, current.exc_type)
|
|
54
|
+
entry = (
|
|
55
|
+
f"--- Failure at Step #{current.step} ---\n"
|
|
56
|
+
f"Plan: {current.plan}\n"
|
|
57
|
+
f"Code:\n```python\n{current.code}\n```\n"
|
|
58
|
+
f"Error:\n```\n{error_summary}\n```"
|
|
59
|
+
)
|
|
60
|
+
history.append(entry)
|
|
61
|
+
depth += 1
|
|
62
|
+
current = self.state.get_node(current.parent_id) if current.parent_id else None
|
|
63
|
+
|
|
64
|
+
if not history:
|
|
65
|
+
return "No error history found."
|
|
66
|
+
|
|
67
|
+
# Reverse to show chronological order (oldest failure first)
|
|
68
|
+
return "\n".join(reversed(history))
|
|
69
|
+
|
|
70
|
+
def _llm_history_length(self) -> int:
|
|
71
|
+
return len(self.llm_service.get_call_history())
|
|
72
|
+
|
|
73
|
+
def _capture_llm_calls_since(self, start_index: int) -> List[Dict[str, Any]]:
|
|
74
|
+
history = self.llm_service.get_call_history()
|
|
75
|
+
if start_index < len(history):
|
|
76
|
+
return history[start_index:]
|
|
77
|
+
return []
|
|
78
|
+
|
|
79
|
+
async def solve(self, description: str, io_instructions: str, data_dir: Path, output_path: Path) -> None:
|
|
80
|
+
"""
|
|
81
|
+
The main entry point for the workflow.
|
|
82
|
+
...
|
|
83
|
+
"""
|
|
84
|
+
logger.info(f"{self.__class__.__name__} starting to solve task. Target output: {output_path}")
|
|
85
|
+
|
|
86
|
+
max_iterations = self.agent_config.get("search", {}).get("max_iterations", 3)
|
|
87
|
+
|
|
88
|
+
for i in range(max_iterations):
|
|
89
|
+
logger.info(f"--- Starting {self.__class__.__name__} Solve Step {i + 1}/{max_iterations} ---")
|
|
90
|
+
|
|
91
|
+
task_context = {
|
|
92
|
+
"goal_and_data": description,
|
|
93
|
+
"io_instructions": io_instructions
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
await self._execute_search_step(task_context, output_path)
|
|
97
|
+
|
|
98
|
+
logger.info("Max iterations reached. Generating final output from the best found solution.")
|
|
99
|
+
best_node = self.state.get_best_node()
|
|
100
|
+
|
|
101
|
+
if best_node:
|
|
102
|
+
logger.info(f"Executing code from best node #{best_node.step} to generate final artifact.")
|
|
103
|
+
final_exec_result = await self.execute_op(code=best_node.code, mode="script")
|
|
104
|
+
if not final_exec_result.success:
|
|
105
|
+
logger.warning(f"Final execution of best node's code failed: {final_exec_result.stderr}")
|
|
106
|
+
final_solution_path = self._write_final_submission(best_node, output_path)
|
|
107
|
+
if final_solution_path:
|
|
108
|
+
logger.info(f"Final solution code saved to {final_solution_path}")
|
|
109
|
+
else:
|
|
110
|
+
logger.warning("No successful solution was found during the search.")
|
|
111
|
+
|
|
112
|
+
async def _execute_search_step(self, task_context: Dict, output_path: Path):
|
|
113
|
+
"""
|
|
114
|
+
Execute a single, concrete step of the AIDE search loop.
|
|
115
|
+
This involves selecting a node to expand, generating new code, executing it,
|
|
116
|
+
and performing **grounded validation** against the benchmark's grading function.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
task_context: A dictionary containing the task goal, data report, and I/O instructions.
|
|
120
|
+
output_path: The path where the final output file is expected. Used for grounded validation.
|
|
121
|
+
"""
|
|
122
|
+
# 1. Select a node
|
|
123
|
+
parent_node = self._select_node_to_expand()
|
|
124
|
+
|
|
125
|
+
# 2. Create a prompt
|
|
126
|
+
if parent_node is None:
|
|
127
|
+
prompt = create_draft_prompt(task_context, self.state.generate_summary())
|
|
128
|
+
elif parent_node.is_buggy:
|
|
129
|
+
error_history = self._get_error_history(parent_node)
|
|
130
|
+
prompt = create_debug_prompt(task_context, parent_node.code, error_history, previous_plan=parent_node.plan, memory_summary=self.state.generate_summary())
|
|
131
|
+
else:
|
|
132
|
+
summarized_output = summarize_repetitive_logs(parent_node.term_out)
|
|
133
|
+
prompt = create_improve_prompt(
|
|
134
|
+
task_context, self.state.generate_summary(), parent_node.code,
|
|
135
|
+
parent_node.analysis, previous_plan=parent_node.plan, previous_output=summarized_output
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
# 3. Generate a new plan and code using the LLM.
|
|
139
|
+
generate_start = self._llm_history_length()
|
|
140
|
+
plan, code = await self.generate_op(system_prompt=prompt)
|
|
141
|
+
new_node = Node(plan=plan, code=code)
|
|
142
|
+
new_node.generate_prompt = prompt
|
|
143
|
+
new_node.task_context = task_context
|
|
144
|
+
new_calls = self._capture_llm_calls_since(generate_start)
|
|
145
|
+
if new_calls:
|
|
146
|
+
new_node.llm_generate = new_calls[-1]
|
|
147
|
+
|
|
148
|
+
# 4. Execute the new code in the sandbox.
|
|
149
|
+
exec_result = await self.execute_op(code=new_node.code, mode="script")
|
|
150
|
+
new_node.absorb_exec_result(exec_result)
|
|
151
|
+
|
|
152
|
+
# Perform grounded validation if code execution was successful.
|
|
153
|
+
if exec_result.success:
|
|
154
|
+
submission_file_in_sandbox = self.sandbox_service.workspace.get_path("sandbox_workdir") / output_path.name
|
|
155
|
+
|
|
156
|
+
if not submission_file_in_sandbox.exists():
|
|
157
|
+
new_node.is_buggy = True
|
|
158
|
+
new_node.analysis = "Code executed without error, but failed to produce the required output file."
|
|
159
|
+
new_node.metric = MetricValue(value=0.0, maximize=True)
|
|
160
|
+
elif self.benchmark and hasattr(self.benchmark, 'grade'):
|
|
161
|
+
logger.info(f"Performing grounded validation using benchmark grader on '{submission_file_in_sandbox}'...")
|
|
162
|
+
score = await self.benchmark.grade(submission_file_in_sandbox)
|
|
163
|
+
|
|
164
|
+
# A score > 0 from the grader is the ground truth for a non-buggy, valid submission.
|
|
165
|
+
if score > 0:
|
|
166
|
+
new_node.is_buggy = False
|
|
167
|
+
new_node.metric = MetricValue(value=score, maximize=True)
|
|
168
|
+
logger.info(f"Grounded validation PASSED. Score: {score:.4f}")
|
|
169
|
+
review_context = {
|
|
170
|
+
"task": task_context,
|
|
171
|
+
"code": new_node.code,
|
|
172
|
+
"output": new_node.term_out
|
|
173
|
+
}
|
|
174
|
+
new_node.review_context = review_context
|
|
175
|
+
review_start = self._llm_history_length()
|
|
176
|
+
review = await self.review_op(prompt_context=review_context)
|
|
177
|
+
review_calls = self._capture_llm_calls_since(review_start)
|
|
178
|
+
if review_calls:
|
|
179
|
+
new_node.llm_review = review_calls[-1]
|
|
180
|
+
new_node.analysis = f"Grounded Score: {score:.4f}. Reviewer Summary: {review.summary}"
|
|
181
|
+
else:
|
|
182
|
+
new_node.is_buggy = True
|
|
183
|
+
new_node.metric = MetricValue(value=score, maximize=True)
|
|
184
|
+
new_node.analysis = "Grounded validation FAILED: The generated submission file was invalid or scored 0.0."
|
|
185
|
+
logger.warning(f"Grounded validation FAILED. Score: {score}")
|
|
186
|
+
else:
|
|
187
|
+
review_context = {
|
|
188
|
+
"task": task_context,
|
|
189
|
+
"code": new_node.code,
|
|
190
|
+
"output": new_node.term_out
|
|
191
|
+
}
|
|
192
|
+
new_node.review_context = review_context
|
|
193
|
+
review_start = self._llm_history_length()
|
|
194
|
+
review = await self.review_op(prompt_context=review_context)
|
|
195
|
+
review_calls = self._capture_llm_calls_since(review_start)
|
|
196
|
+
if review_calls:
|
|
197
|
+
new_node.llm_review = review_calls[-1]
|
|
198
|
+
new_node.analysis = review.summary
|
|
199
|
+
new_node.is_buggy = review.is_buggy
|
|
200
|
+
new_node.metric = MetricValue(value=review.metric_value, maximize=not review.lower_is_better) if review.metric_value is not None else MetricValue(value=None)
|
|
201
|
+
|
|
202
|
+
# 7. Add the new node to the search tree state and persist artifacts.
|
|
203
|
+
self.state.append(new_node, parent=parent_node)
|
|
204
|
+
self._persist_node_artifacts(new_node)
|
|
205
|
+
logger.info(f"Step {new_node.step} complete. Buggy: {new_node.is_buggy}. Metric: {new_node.metric}.")
|
|
206
|
+
|
|
207
|
+
def _persist_node_artifacts(self, node: Node) -> None:
|
|
208
|
+
"""
|
|
209
|
+
保存每个节点生成的代码,以便后续分析和还原完整路径。
|
|
210
|
+
"""
|
|
211
|
+
if not self.workspace_service:
|
|
212
|
+
return
|
|
213
|
+
code_steps_dir = self.workspace_service.get_path("artifacts") / "code_steps"
|
|
214
|
+
code_steps_dir.mkdir(parents=True, exist_ok=True)
|
|
215
|
+
filename = f"step_{node.step:03d}_{node.id}.py"
|
|
216
|
+
file_path = code_steps_dir / filename
|
|
217
|
+
file_path.write_text(node.code, encoding="utf-8")
|
|
218
|
+
node.code_artifact_path = str(file_path)
|
|
219
|
+
|
|
220
|
+
def _write_final_submission(self, node: Node, expected_output: Path) -> Optional[Path]:
|
|
221
|
+
"""
|
|
222
|
+
将最终成功的代码与输出复制到固定位置,便于后续复现。
|
|
223
|
+
"""
|
|
224
|
+
if not self.workspace_service:
|
|
225
|
+
return None
|
|
226
|
+
final_dir = self.workspace_service.get_path("artifacts") / "final_submission"
|
|
227
|
+
final_dir.mkdir(parents=True, exist_ok=True)
|
|
228
|
+
|
|
229
|
+
final_solution_path = final_dir / "final_solution.py"
|
|
230
|
+
final_solution_path.write_text(node.code, encoding="utf-8")
|
|
231
|
+
node.final_submission_path = str(final_solution_path)
|
|
232
|
+
|
|
233
|
+
if expected_output and expected_output.exists():
|
|
234
|
+
try:
|
|
235
|
+
shutil.copy2(expected_output, final_dir / expected_output.name)
|
|
236
|
+
except Exception as copy_error:
|
|
237
|
+
logger.warning(f"Failed to copy final submission artifact '{expected_output}': {copy_error}")
|
|
238
|
+
|
|
239
|
+
return final_solution_path
|
|
240
|
+
|
|
241
|
+
def _select_node_to_expand(self) -> Optional[Node]:
|
|
242
|
+
"""
|
|
243
|
+
Implements the REVISED, DEBUG-FIRST search policy.
|
|
244
|
+
The policy prioritizes:
|
|
245
|
+
1. Debugging failed solutions.
|
|
246
|
+
2. Improving the current best solution if no bugs exist.
|
|
247
|
+
3. Drafting a new solution only as a last resort.
|
|
248
|
+
"""
|
|
249
|
+
search_cfg = self.agent_config.get("search", {})
|
|
250
|
+
# num_drafts is no longer used to block debugging
|
|
251
|
+
max_debug_depth = search_cfg.get("max_debug_depth", 3)
|
|
252
|
+
|
|
253
|
+
buggy_nodes = [n for n in self.state.nodes.values() if n.is_buggy]
|
|
254
|
+
# Find nodes that are leaves in the bug-chain and haven't exceeded debug depth
|
|
255
|
+
debuggable = [
|
|
256
|
+
n for n in buggy_nodes
|
|
257
|
+
if not n.children_ids and self._get_debug_depth(n) < max_debug_depth
|
|
258
|
+
]
|
|
259
|
+
|
|
260
|
+
if debuggable:
|
|
261
|
+
# Select the most recent buggy node to work on
|
|
262
|
+
selected = max(debuggable, key=lambda n: n.step)
|
|
263
|
+
logger.info(f"[Search Policy] Debugging: Prioritizing most recent failed node #{selected.step}.")
|
|
264
|
+
return selected
|
|
265
|
+
|
|
266
|
+
best_node = self.state.get_best_node()
|
|
267
|
+
if best_node:
|
|
268
|
+
logger.info(f"[Search Policy] Improving: No bugs to fix. Selected best node #{best_node.step} with metric {best_node.metric}.")
|
|
269
|
+
return best_node
|
|
270
|
+
|
|
271
|
+
logger.info("[Search Policy] Drafting: No bugs to fix and no successful nodes to improve. Creating new solution.")
|
|
272
|
+
return None
|
|
273
|
+
|
|
274
|
+
def _get_debug_depth(self, node: Node) -> int:
|
|
275
|
+
depth = 0
|
|
276
|
+
current = node
|
|
277
|
+
while current.parent_id:
|
|
278
|
+
parent = self.state.get_node(current.parent_id)
|
|
279
|
+
if not parent or not parent.is_buggy:
|
|
280
|
+
break
|
|
281
|
+
depth += 1
|
|
282
|
+
current = parent
|
|
283
|
+
return depth
|
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import json
|
|
3
|
+
from typing import Dict, Any, Optional
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
from .aide_workflow import AIDEWorkflow
|
|
7
|
+
from dsat.services.states.journal import Node, MetricValue
|
|
8
|
+
from dsat.common.typing import ExecutionResult
|
|
9
|
+
from dsat.benchmark.benchmark import BaseBenchmark
|
|
10
|
+
|
|
11
|
+
from dsat.services.vdb import VDBService
|
|
12
|
+
|
|
13
|
+
from dsat.prompts.common import create_draft_prompt
|
|
14
|
+
from dsat.prompts.automind_prompt import create_stepwise_code_prompt, create_stepwise_debug_prompt
|
|
15
|
+
from dsat.prompts.aide_prompt import create_improve_prompt, create_debug_prompt
|
|
16
|
+
|
|
17
|
+
from dsat.utils.context import (
|
|
18
|
+
ContextManager,
|
|
19
|
+
MAX_HISTORY_CHARS,
|
|
20
|
+
MAX_OUTPUT_CHARS,
|
|
21
|
+
summarize_repetitive_logs,
|
|
22
|
+
truncate_output,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class AutoMindWorkflow(AIDEWorkflow):
|
|
29
|
+
"""
|
|
30
|
+
Implements the AUTOMIND iterative search algorithm.
|
|
31
|
+
This workflow extends AIDE by incorporating a knowledge base (VDB),
|
|
32
|
+
a self-adaptive coding strategy (one-pass vs. stepwise), and more
|
|
33
|
+
sophisticated context management for complex tasks.
|
|
34
|
+
"""
|
|
35
|
+
def __init__(self, operators: Dict[str, Any], services: Dict[str, Any], agent_config: Dict[str, Any], benchmark: Optional[BaseBenchmark] = None):
|
|
36
|
+
"""
|
|
37
|
+
Initializes the AutoMindWorkflow, building upon the AIDE base.
|
|
38
|
+
"""
|
|
39
|
+
# Initialize base AIDE components, now passing the benchmark instance up.
|
|
40
|
+
super().__init__(operators, services, agent_config, benchmark=benchmark)
|
|
41
|
+
|
|
42
|
+
self.vdb_service: VDBService = services.get("vdb")
|
|
43
|
+
|
|
44
|
+
self.complexity_scorer_op = self.operators.get("complexity_scorer")
|
|
45
|
+
self.plan_decomposer_op = self.operators.get("plan_decomposer")
|
|
46
|
+
|
|
47
|
+
# AutoMind's context manager requires an LLM service to summarize knowledge and history.
|
|
48
|
+
self.context_manager = ContextManager(llm_service=services.get("llm"))
|
|
49
|
+
|
|
50
|
+
async def _execute_search_step(self, task_context: Dict, output_path: Path):
|
|
51
|
+
"""
|
|
52
|
+
Execute a single step of the AutoMind search loop.
|
|
53
|
+
This overrides the AIDE implementation to add knowledge retrieval, the
|
|
54
|
+
self-adaptive coding strategy, and now uses **grounded validation**.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
task_context: A dictionary containing the task goal, data report, and I/O instructions.
|
|
58
|
+
output_path: The path where the final output file is expected. Used for grounded validation.
|
|
59
|
+
"""
|
|
60
|
+
# 1. Select a node
|
|
61
|
+
parent_node = self._select_node_to_expand()
|
|
62
|
+
|
|
63
|
+
task_goal = task_context.get('goal_and_data', 'Solve the data science task.')
|
|
64
|
+
io_instructions = task_context.get('io_instructions', 'N/A')
|
|
65
|
+
|
|
66
|
+
# 2. Create a prompt
|
|
67
|
+
if parent_node is None:
|
|
68
|
+
# For new drafts, retrieve similar examples from the knowledge base.
|
|
69
|
+
retrieved_knowledge = ""
|
|
70
|
+
if self.vdb_service:
|
|
71
|
+
cases = self.vdb_service.retrieve(task_goal, top_k=2)
|
|
72
|
+
retrieved_knowledge = await self.context_manager.summarize_knowledge(cases, task_goal)
|
|
73
|
+
|
|
74
|
+
prompt = create_draft_prompt(task_context, self.state.generate_summary(), retrieved_knowledge)
|
|
75
|
+
elif parent_node.is_buggy:
|
|
76
|
+
error_summary = self.context_manager.summarize_error(parent_node.term_out, parent_node.exc_type)
|
|
77
|
+
prompt = create_debug_prompt(task_context, parent_node.code, error_summary, previous_plan=parent_node.plan, memory_summary=self.state.generate_summary())
|
|
78
|
+
else:
|
|
79
|
+
summarized_output = summarize_repetitive_logs(parent_node.term_out)
|
|
80
|
+
prompt = create_improve_prompt(
|
|
81
|
+
task_context, self.state.generate_summary(), parent_node.code,
|
|
82
|
+
parent_node.analysis, previous_plan=parent_node.plan, previous_output=summarized_output
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# 3. Generate initial plan and one-pass code.
|
|
86
|
+
plan, one_pass_code = await self.generate_op(system_prompt=prompt)
|
|
87
|
+
|
|
88
|
+
# 4. Apply the self-adaptive coding strategy for new drafts.
|
|
89
|
+
use_adaptive = self.complexity_scorer_op and self.plan_decomposer_op
|
|
90
|
+
if use_adaptive and parent_node is None:
|
|
91
|
+
final_code, exec_result = await self._execute_step_adaptively(plan, one_pass_code, task_goal, io_instructions)
|
|
92
|
+
else:
|
|
93
|
+
# For simpler tasks, improvements, or debugging, use the one-pass code.
|
|
94
|
+
final_code = one_pass_code
|
|
95
|
+
exec_result = await self.execute_op(code=final_code, mode="script")
|
|
96
|
+
|
|
97
|
+
# 5. Create a new node and absorb the execution result.
|
|
98
|
+
new_node = Node(plan=plan, code=final_code)
|
|
99
|
+
new_node.absorb_exec_result(exec_result)
|
|
100
|
+
|
|
101
|
+
if exec_result.success:
|
|
102
|
+
submission_file_in_sandbox = self.sandbox_service.workspace.get_path("sandbox_workdir") / output_path.name
|
|
103
|
+
|
|
104
|
+
if not submission_file_in_sandbox.exists():
|
|
105
|
+
new_node.is_buggy = True
|
|
106
|
+
new_node.analysis = "Code executed without error, but failed to produce the required output file."
|
|
107
|
+
new_node.metric = MetricValue(value=0.0, maximize=True)
|
|
108
|
+
elif self.benchmark and hasattr(self.benchmark, 'grade'):
|
|
109
|
+
logger.info(f"Performing grounded validation using benchmark grader on '{submission_file_in_sandbox}'...")
|
|
110
|
+
score = await self.benchmark.grade(submission_file_in_sandbox)
|
|
111
|
+
|
|
112
|
+
if score > 0:
|
|
113
|
+
new_node.is_buggy = False
|
|
114
|
+
new_node.metric = MetricValue(value=score, maximize=True)
|
|
115
|
+
logger.info(f"Grounded validation PASSED. Score: {score:.4f}")
|
|
116
|
+
review = await self.review_op(prompt_context={
|
|
117
|
+
"task": task_context, "code": new_node.code, "output": new_node.term_out
|
|
118
|
+
})
|
|
119
|
+
new_node.analysis = f"Grounded Score: {score:.4f}. Reviewer Summary: {review.summary}"
|
|
120
|
+
else:
|
|
121
|
+
new_node.is_buggy = True
|
|
122
|
+
new_node.metric = MetricValue(value=score, maximize=True)
|
|
123
|
+
new_node.analysis = "Grounded validation FAILED: The generated submission file was invalid or scored 0.0."
|
|
124
|
+
logger.warning(f"Grounded validation FAILED. Score: {score}")
|
|
125
|
+
else:
|
|
126
|
+
logger.warning("No benchmark with 'grade' method found. Falling back to unreliable LLM-based review.")
|
|
127
|
+
review = await self.review_op(prompt_context={
|
|
128
|
+
"task": task_context, "code": new_node.code, "output": new_node.term_out
|
|
129
|
+
})
|
|
130
|
+
new_node.analysis = review.summary
|
|
131
|
+
new_node.is_buggy = review.is_buggy
|
|
132
|
+
new_node.metric = MetricValue(value=review.metric_value, maximize=not review.lower_is_better) if review.metric_value is not None else MetricValue(value=None)
|
|
133
|
+
|
|
134
|
+
# 8. Add the new node to the search tree state.
|
|
135
|
+
self.state.append(new_node, parent=parent_node)
|
|
136
|
+
logger.info(f"Step {new_node.step} complete. Buggy: {new_node.is_buggy}. Metric: {new_node.metric}.")
|
|
137
|
+
|
|
138
|
+
async def _execute_step_adaptively(self, plan: str, one_pass_code: str, task_goal: str, io_instructions: str) -> tuple[str, ExecutionResult]:
|
|
139
|
+
"""
|
|
140
|
+
Core of the Self-Adaptive Coding Strategy.
|
|
141
|
+
It scores the complexity of a plan and chooses to either execute the provided
|
|
142
|
+
one-pass code directly or decompose the plan into smaller steps and execute them
|
|
143
|
+
sequentially in a notebook context.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
plan: The overall plan for the task.
|
|
147
|
+
one_pass_code: The single block of code generated for the entire plan.
|
|
148
|
+
task_goal: The user's primary goal.
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
A tuple containing the final code (either one-pass or combined steps) and the
|
|
152
|
+
final ExecutionResult.
|
|
153
|
+
"""
|
|
154
|
+
final_code = one_pass_code
|
|
155
|
+
|
|
156
|
+
# 1. Score plan complexity using the dedicated operator.
|
|
157
|
+
score = await self.complexity_scorer_op(plan=plan, task_goal=task_goal)
|
|
158
|
+
|
|
159
|
+
# 2. Choose strategy based on the complexity score.
|
|
160
|
+
if score.complexity <= 3: # Threshold for one-pass vs stepwise
|
|
161
|
+
logger.info("Plan is simple. Executing in one-pass mode.")
|
|
162
|
+
exec_result = await self.execute_op(code=final_code, mode="script")
|
|
163
|
+
else:
|
|
164
|
+
logger.info("Plan is complex. Decomposing and executing in stepwise mode.")
|
|
165
|
+
# 3. Decompose the complex plan into a sequence of smaller tasks.
|
|
166
|
+
decomposed_plan = await self.plan_decomposer_op(plan=plan, task_goal=task_goal)
|
|
167
|
+
|
|
168
|
+
step_codes = []
|
|
169
|
+
history_steps = []
|
|
170
|
+
final_exec_result = None
|
|
171
|
+
# Get max retries config
|
|
172
|
+
max_step_retries = self.agent_config.get("max_retries", 3)
|
|
173
|
+
|
|
174
|
+
async with self.sandbox_service.notebook_executor() as notebook:
|
|
175
|
+
for task in decomposed_plan.tasks:
|
|
176
|
+
logger.info(f"Executing step {task.task_id}: {task.instruction}")
|
|
177
|
+
|
|
178
|
+
step_succeeded = False
|
|
179
|
+
current_code = ""
|
|
180
|
+
step_failure_history = [] # History for the current step
|
|
181
|
+
|
|
182
|
+
# Implement retry loop for robustness
|
|
183
|
+
for attempt in range(max_step_retries):
|
|
184
|
+
logger.info(f"Step {task.task_id}, Attempt {attempt + 1}/{max_step_retries}")
|
|
185
|
+
|
|
186
|
+
# Build a concise history of recent steps to provide context for the next step.
|
|
187
|
+
recent_history_str = self.context_manager.build_history_context(
|
|
188
|
+
history_steps,
|
|
189
|
+
key_order=["task_id", "code", "output"]
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
if attempt == 0:
|
|
193
|
+
step_prompt = create_stepwise_code_prompt(task_goal, plan, recent_history_str, task.instruction, io_instructions)
|
|
194
|
+
else:
|
|
195
|
+
error_summary = self.context_manager.summarize_error(exec_result.stderr, exec_result.exc_type)
|
|
196
|
+
step_failure_history.append({
|
|
197
|
+
"attempt": attempt,
|
|
198
|
+
"code": truncate_output(current_code, MAX_OUTPUT_CHARS),
|
|
199
|
+
"error": error_summary
|
|
200
|
+
})
|
|
201
|
+
|
|
202
|
+
formatted_failure_history = "\n".join([
|
|
203
|
+
f"--- Attempt {f['attempt']} Failed ---\nCode:\n```python\n{f['code']}\n```\nError: {f['error']}\n---"
|
|
204
|
+
for f in step_failure_history
|
|
205
|
+
])
|
|
206
|
+
safe_failure_history = truncate_output(formatted_failure_history, MAX_HISTORY_CHARS)
|
|
207
|
+
|
|
208
|
+
step_prompt = create_stepwise_debug_prompt(
|
|
209
|
+
task_goal, plan, recent_history_str, task.instruction,
|
|
210
|
+
current_code, safe_failure_history, io_instructions
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
_, current_code = await self.generate_op(system_prompt=step_prompt)
|
|
214
|
+
exec_result = await self.execute_op(code=current_code, mode="notebook", executor_context=notebook)
|
|
215
|
+
|
|
216
|
+
if exec_result.success:
|
|
217
|
+
step_succeeded = True
|
|
218
|
+
break
|
|
219
|
+
|
|
220
|
+
if not step_succeeded:
|
|
221
|
+
logger.error(f"Step {task.task_id} failed after {max_step_retries} attempts. Aborting stepwise execution.")
|
|
222
|
+
final_exec_result = exec_result # Capture the failed result
|
|
223
|
+
break
|
|
224
|
+
|
|
225
|
+
step_codes.append(f"# --- Step {task.task_id}: {task.instruction} ---\n{current_code}")
|
|
226
|
+
# Record the successful step for future context.
|
|
227
|
+
history_steps.append({
|
|
228
|
+
"task_id": task.task_id,
|
|
229
|
+
"code": truncate_output(current_code, MAX_OUTPUT_CHARS),
|
|
230
|
+
"output": truncate_output(exec_result.stdout, MAX_OUTPUT_CHARS),
|
|
231
|
+
})
|
|
232
|
+
final_exec_result = exec_result # Update with the latest successful result
|
|
233
|
+
|
|
234
|
+
final_code = "\n\n".join(step_codes)
|
|
235
|
+
exec_result = final_exec_result
|
|
236
|
+
|
|
237
|
+
return final_code, exec_result
|
|
File without changes
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
"""
|
|
2
|
+
A simple, baseline workflow that serves as the starting point (seed)
|
|
3
|
+
for the meta-optimization evolutionary algorithm.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
def get_initial_workflow_code() -> str:
|
|
7
|
+
"""Returns the source code for a simple workflow using injected operators."""
|
|
8
|
+
return '''
|
|
9
|
+
import shutil
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from dsat.workflows.base import DSATWorkflow
|
|
12
|
+
from dsat.services.llm import LLMService
|
|
13
|
+
from dsat.services.sandbox import SandboxService
|
|
14
|
+
from dsat.utils.parsing import parse_plan_and_code
|
|
15
|
+
from typing import Dict, Any, List
|
|
16
|
+
|
|
17
|
+
class Workflow(DSATWorkflow):
|
|
18
|
+
"""
|
|
19
|
+
An initial workflow that generates Python code to solve a Kaggle-style task,
|
|
20
|
+
executes it in a sandbox, and produces a submission file.
|
|
21
|
+
This is a strong starting point for the optimization process.
|
|
22
|
+
"""
|
|
23
|
+
def __init__(self, operators: Dict[str, Any], services: Dict[str, Any], agent_config: Dict[str, Any]):
|
|
24
|
+
"""
|
|
25
|
+
The driver injects the llm_service and sandbox_service.
|
|
26
|
+
"""
|
|
27
|
+
super().__init__(operators, services, agent_config)
|
|
28
|
+
self.llm_service: LLMService = services["llm"]
|
|
29
|
+
self.sandbox_service: SandboxService = services["sandbox"]
|
|
30
|
+
|
|
31
|
+
async def solve(self, description: str, io_instructions: str, data_dir: Path, output_path: Path):
|
|
32
|
+
"""
|
|
33
|
+
The main entry point for executing the workflow.
|
|
34
|
+
"""
|
|
35
|
+
print(f" Initial Workflow: Starting task. Target: {output_path.name}")
|
|
36
|
+
|
|
37
|
+
prompt = (
|
|
38
|
+
f"You are an expert AI developer and data scientist. Your task is to write a single, complete Python script to solve the following problem. "
|
|
39
|
+
f"Provide only the Python code in a single code block.\\n\\n"
|
|
40
|
+
f"# PROBLEM DESCRIPTION AND DATA REPORT\\n{description}\\n\\n"
|
|
41
|
+
f"# CRITICAL I/O REQUIREMENTS (MUST BE FOLLOWED)\\n{io_instructions}\\n\\n"
|
|
42
|
+
f"Ensure the script strictly follows the CRITICAL I/O REQUIREMENTS."
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
# 1. Generate the code
|
|
46
|
+
llm_response = await self.llm_service.call(prompt)
|
|
47
|
+
_, code_to_execute = parse_plan_and_code(llm_response)
|
|
48
|
+
|
|
49
|
+
if "# ERROR" in code_to_execute:
|
|
50
|
+
print(" ERROR: Failed to generate valid code from LLM.")
|
|
51
|
+
return
|
|
52
|
+
|
|
53
|
+
# 2. Execute the code in the sandbox
|
|
54
|
+
# The sandbox will have its own isolated workspace.
|
|
55
|
+
print(" Initial Workflow: Executing generated script in sandbox...")
|
|
56
|
+
exec_result = self.sandbox_service.run_script(code_to_execute)
|
|
57
|
+
|
|
58
|
+
if not exec_result.success:
|
|
59
|
+
print(f" ERROR: Code execution failed.\\\\n{exec_result.stderr}")
|
|
60
|
+
return
|
|
61
|
+
|
|
62
|
+
# 3. Verify the generated submission file exists.
|
|
63
|
+
sandbox_workdir = self.sandbox_service.workspace.get_path("sandbox_workdir")
|
|
64
|
+
generated_file = sandbox_workdir / output_path.name
|
|
65
|
+
|
|
66
|
+
if generated_file.exists():
|
|
67
|
+
print(f" SUCCESS: Submission file '{output_path.name}' successfully generated in sandbox.")
|
|
68
|
+
else:
|
|
69
|
+
print(f" ERROR: Execution succeeded, but the required output file '{output_path.name}' was not created in {sandbox_workdir}.")
|
|
70
|
+
|
|
71
|
+
'''
|