dslighting 1.3.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. dsat/__init__.py +3 -0
  2. dsat/benchmark/__init__.py +1 -0
  3. dsat/benchmark/benchmark.py +168 -0
  4. dsat/benchmark/datasci.py +291 -0
  5. dsat/benchmark/mle.py +777 -0
  6. dsat/benchmark/sciencebench.py +304 -0
  7. dsat/common/__init__.py +0 -0
  8. dsat/common/constants.py +11 -0
  9. dsat/common/exceptions.py +48 -0
  10. dsat/common/typing.py +19 -0
  11. dsat/config.py +79 -0
  12. dsat/models/__init__.py +3 -0
  13. dsat/models/candidates.py +16 -0
  14. dsat/models/formats.py +52 -0
  15. dsat/models/task.py +64 -0
  16. dsat/operators/__init__.py +0 -0
  17. dsat/operators/aflow_ops.py +90 -0
  18. dsat/operators/autokaggle_ops.py +170 -0
  19. dsat/operators/automind_ops.py +38 -0
  20. dsat/operators/base.py +22 -0
  21. dsat/operators/code.py +45 -0
  22. dsat/operators/dsagent_ops.py +123 -0
  23. dsat/operators/llm_basic.py +84 -0
  24. dsat/prompts/__init__.py +0 -0
  25. dsat/prompts/aflow_prompt.py +76 -0
  26. dsat/prompts/aide_prompt.py +52 -0
  27. dsat/prompts/autokaggle_prompt.py +290 -0
  28. dsat/prompts/automind_prompt.py +29 -0
  29. dsat/prompts/common.py +51 -0
  30. dsat/prompts/data_interpreter_prompt.py +82 -0
  31. dsat/prompts/dsagent_prompt.py +88 -0
  32. dsat/runner.py +554 -0
  33. dsat/services/__init__.py +0 -0
  34. dsat/services/data_analyzer.py +387 -0
  35. dsat/services/llm.py +486 -0
  36. dsat/services/llm_single.py +421 -0
  37. dsat/services/sandbox.py +386 -0
  38. dsat/services/states/__init__.py +0 -0
  39. dsat/services/states/autokaggle_state.py +43 -0
  40. dsat/services/states/base.py +14 -0
  41. dsat/services/states/dsa_log.py +13 -0
  42. dsat/services/states/experience.py +237 -0
  43. dsat/services/states/journal.py +153 -0
  44. dsat/services/states/operator_library.py +290 -0
  45. dsat/services/vdb.py +76 -0
  46. dsat/services/workspace.py +178 -0
  47. dsat/tasks/__init__.py +3 -0
  48. dsat/tasks/handlers.py +376 -0
  49. dsat/templates/open_ended/grade_template.py +107 -0
  50. dsat/tools/__init__.py +4 -0
  51. dsat/utils/__init__.py +0 -0
  52. dsat/utils/context.py +172 -0
  53. dsat/utils/dynamic_import.py +71 -0
  54. dsat/utils/parsing.py +33 -0
  55. dsat/workflows/__init__.py +12 -0
  56. dsat/workflows/base.py +53 -0
  57. dsat/workflows/factory.py +439 -0
  58. dsat/workflows/manual/__init__.py +0 -0
  59. dsat/workflows/manual/autokaggle_workflow.py +148 -0
  60. dsat/workflows/manual/data_interpreter_workflow.py +153 -0
  61. dsat/workflows/manual/deepanalyze_workflow.py +484 -0
  62. dsat/workflows/manual/dsagent_workflow.py +76 -0
  63. dsat/workflows/search/__init__.py +0 -0
  64. dsat/workflows/search/aflow_workflow.py +344 -0
  65. dsat/workflows/search/aide_workflow.py +283 -0
  66. dsat/workflows/search/automind_workflow.py +237 -0
  67. dsat/workflows/templates/__init__.py +0 -0
  68. dsat/workflows/templates/basic_kaggle_loop.py +71 -0
  69. dslighting/__init__.py +170 -0
  70. dslighting/core/__init__.py +13 -0
  71. dslighting/core/agent.py +646 -0
  72. dslighting/core/config_builder.py +318 -0
  73. dslighting/core/data_loader.py +422 -0
  74. dslighting/core/task_detector.py +422 -0
  75. dslighting/utils/__init__.py +19 -0
  76. dslighting/utils/defaults.py +151 -0
  77. dslighting-1.3.9.dist-info/METADATA +554 -0
  78. dslighting-1.3.9.dist-info/RECORD +80 -0
  79. dslighting-1.3.9.dist-info/WHEEL +5 -0
  80. dslighting-1.3.9.dist-info/top_level.txt +2 -0
@@ -0,0 +1,283 @@
1
+ # dsat/workflows/search/aide_workflow.py
2
+
3
+ import logging
4
+ import shutil
5
+ from pathlib import Path
6
+ from typing import Dict, Optional, Any, List
7
+
8
+ from dsat.workflows.base import DSATWorkflow
9
+ from dsat.services.states.journal import JournalState, Node, MetricValue
10
+ from dsat.benchmark.benchmark import BaseBenchmark
11
+ from dsat.services.llm import LLMService
12
+
13
+ from dsat.prompts.aide_prompt import create_improve_prompt, create_debug_prompt
14
+ from dsat.prompts.common import create_draft_prompt
15
+
16
+ from dsat.utils.context import ContextManager, summarize_repetitive_logs
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class AIDEWorkflow(DSATWorkflow):
22
+ """
23
+ Implements the AIDE iterative search algorithm.
24
+ This class serves as the base for search-based workflows, containing shared
25
+ logic for the main search loop, node selection policy, and final artifact generation.
26
+ """
27
+ def __init__(self, operators: Dict[str, Any], services: Dict[str, Any], agent_config: Dict[str, Any], benchmark: Optional[BaseBenchmark] = None):
28
+ """
29
+ Initializes the AIDEWorkflow and its required services and operators.
30
+ """
31
+ super().__init__(operators, services, agent_config)
32
+ self.state: JournalState = services["state"]
33
+ self.sandbox_service = services["sandbox"]
34
+ self.workspace_service = services.get("workspace")
35
+ self.llm_service: LLMService = services["llm"]
36
+ self.benchmark = benchmark
37
+
38
+ self.execute_op = self.operators["execute"]
39
+
40
+ self.generate_op = self.operators["generate"]
41
+ self.review_op = self.operators["review"]
42
+
43
+ self.context_manager = ContextManager()
44
+
45
+ def _get_error_history(self, node: Node, max_depth: int = 3) -> str:
46
+ """
47
+ Traverses up a chain of buggy parent nodes to build a concise error history.
48
+ """
49
+ history = []
50
+ current = node
51
+ depth = 0
52
+ while current and current.is_buggy and depth < max_depth:
53
+ error_summary = self.context_manager.summarize_error(current.term_out, current.exc_type)
54
+ entry = (
55
+ f"--- Failure at Step #{current.step} ---\n"
56
+ f"Plan: {current.plan}\n"
57
+ f"Code:\n```python\n{current.code}\n```\n"
58
+ f"Error:\n```\n{error_summary}\n```"
59
+ )
60
+ history.append(entry)
61
+ depth += 1
62
+ current = self.state.get_node(current.parent_id) if current.parent_id else None
63
+
64
+ if not history:
65
+ return "No error history found."
66
+
67
+ # Reverse to show chronological order (oldest failure first)
68
+ return "\n".join(reversed(history))
69
+
70
+ def _llm_history_length(self) -> int:
71
+ return len(self.llm_service.get_call_history())
72
+
73
+ def _capture_llm_calls_since(self, start_index: int) -> List[Dict[str, Any]]:
74
+ history = self.llm_service.get_call_history()
75
+ if start_index < len(history):
76
+ return history[start_index:]
77
+ return []
78
+
79
+ async def solve(self, description: str, io_instructions: str, data_dir: Path, output_path: Path) -> None:
80
+ """
81
+ The main entry point for the workflow.
82
+ ...
83
+ """
84
+ logger.info(f"{self.__class__.__name__} starting to solve task. Target output: {output_path}")
85
+
86
+ max_iterations = self.agent_config.get("search", {}).get("max_iterations", 3)
87
+
88
+ for i in range(max_iterations):
89
+ logger.info(f"--- Starting {self.__class__.__name__} Solve Step {i + 1}/{max_iterations} ---")
90
+
91
+ task_context = {
92
+ "goal_and_data": description,
93
+ "io_instructions": io_instructions
94
+ }
95
+
96
+ await self._execute_search_step(task_context, output_path)
97
+
98
+ logger.info("Max iterations reached. Generating final output from the best found solution.")
99
+ best_node = self.state.get_best_node()
100
+
101
+ if best_node:
102
+ logger.info(f"Executing code from best node #{best_node.step} to generate final artifact.")
103
+ final_exec_result = await self.execute_op(code=best_node.code, mode="script")
104
+ if not final_exec_result.success:
105
+ logger.warning(f"Final execution of best node's code failed: {final_exec_result.stderr}")
106
+ final_solution_path = self._write_final_submission(best_node, output_path)
107
+ if final_solution_path:
108
+ logger.info(f"Final solution code saved to {final_solution_path}")
109
+ else:
110
+ logger.warning("No successful solution was found during the search.")
111
+
112
+ async def _execute_search_step(self, task_context: Dict, output_path: Path):
113
+ """
114
+ Execute a single, concrete step of the AIDE search loop.
115
+ This involves selecting a node to expand, generating new code, executing it,
116
+ and performing **grounded validation** against the benchmark's grading function.
117
+
118
+ Args:
119
+ task_context: A dictionary containing the task goal, data report, and I/O instructions.
120
+ output_path: The path where the final output file is expected. Used for grounded validation.
121
+ """
122
+ # 1. Select a node
123
+ parent_node = self._select_node_to_expand()
124
+
125
+ # 2. Create a prompt
126
+ if parent_node is None:
127
+ prompt = create_draft_prompt(task_context, self.state.generate_summary())
128
+ elif parent_node.is_buggy:
129
+ error_history = self._get_error_history(parent_node)
130
+ prompt = create_debug_prompt(task_context, parent_node.code, error_history, previous_plan=parent_node.plan, memory_summary=self.state.generate_summary())
131
+ else:
132
+ summarized_output = summarize_repetitive_logs(parent_node.term_out)
133
+ prompt = create_improve_prompt(
134
+ task_context, self.state.generate_summary(), parent_node.code,
135
+ parent_node.analysis, previous_plan=parent_node.plan, previous_output=summarized_output
136
+ )
137
+
138
+ # 3. Generate a new plan and code using the LLM.
139
+ generate_start = self._llm_history_length()
140
+ plan, code = await self.generate_op(system_prompt=prompt)
141
+ new_node = Node(plan=plan, code=code)
142
+ new_node.generate_prompt = prompt
143
+ new_node.task_context = task_context
144
+ new_calls = self._capture_llm_calls_since(generate_start)
145
+ if new_calls:
146
+ new_node.llm_generate = new_calls[-1]
147
+
148
+ # 4. Execute the new code in the sandbox.
149
+ exec_result = await self.execute_op(code=new_node.code, mode="script")
150
+ new_node.absorb_exec_result(exec_result)
151
+
152
+ # Perform grounded validation if code execution was successful.
153
+ if exec_result.success:
154
+ submission_file_in_sandbox = self.sandbox_service.workspace.get_path("sandbox_workdir") / output_path.name
155
+
156
+ if not submission_file_in_sandbox.exists():
157
+ new_node.is_buggy = True
158
+ new_node.analysis = "Code executed without error, but failed to produce the required output file."
159
+ new_node.metric = MetricValue(value=0.0, maximize=True)
160
+ elif self.benchmark and hasattr(self.benchmark, 'grade'):
161
+ logger.info(f"Performing grounded validation using benchmark grader on '{submission_file_in_sandbox}'...")
162
+ score = await self.benchmark.grade(submission_file_in_sandbox)
163
+
164
+ # A score > 0 from the grader is the ground truth for a non-buggy, valid submission.
165
+ if score > 0:
166
+ new_node.is_buggy = False
167
+ new_node.metric = MetricValue(value=score, maximize=True)
168
+ logger.info(f"Grounded validation PASSED. Score: {score:.4f}")
169
+ review_context = {
170
+ "task": task_context,
171
+ "code": new_node.code,
172
+ "output": new_node.term_out
173
+ }
174
+ new_node.review_context = review_context
175
+ review_start = self._llm_history_length()
176
+ review = await self.review_op(prompt_context=review_context)
177
+ review_calls = self._capture_llm_calls_since(review_start)
178
+ if review_calls:
179
+ new_node.llm_review = review_calls[-1]
180
+ new_node.analysis = f"Grounded Score: {score:.4f}. Reviewer Summary: {review.summary}"
181
+ else:
182
+ new_node.is_buggy = True
183
+ new_node.metric = MetricValue(value=score, maximize=True)
184
+ new_node.analysis = "Grounded validation FAILED: The generated submission file was invalid or scored 0.0."
185
+ logger.warning(f"Grounded validation FAILED. Score: {score}")
186
+ else:
187
+ review_context = {
188
+ "task": task_context,
189
+ "code": new_node.code,
190
+ "output": new_node.term_out
191
+ }
192
+ new_node.review_context = review_context
193
+ review_start = self._llm_history_length()
194
+ review = await self.review_op(prompt_context=review_context)
195
+ review_calls = self._capture_llm_calls_since(review_start)
196
+ if review_calls:
197
+ new_node.llm_review = review_calls[-1]
198
+ new_node.analysis = review.summary
199
+ new_node.is_buggy = review.is_buggy
200
+ new_node.metric = MetricValue(value=review.metric_value, maximize=not review.lower_is_better) if review.metric_value is not None else MetricValue(value=None)
201
+
202
+ # 7. Add the new node to the search tree state and persist artifacts.
203
+ self.state.append(new_node, parent=parent_node)
204
+ self._persist_node_artifacts(new_node)
205
+ logger.info(f"Step {new_node.step} complete. Buggy: {new_node.is_buggy}. Metric: {new_node.metric}.")
206
+
207
+ def _persist_node_artifacts(self, node: Node) -> None:
208
+ """
209
+ 保存每个节点生成的代码,以便后续分析和还原完整路径。
210
+ """
211
+ if not self.workspace_service:
212
+ return
213
+ code_steps_dir = self.workspace_service.get_path("artifacts") / "code_steps"
214
+ code_steps_dir.mkdir(parents=True, exist_ok=True)
215
+ filename = f"step_{node.step:03d}_{node.id}.py"
216
+ file_path = code_steps_dir / filename
217
+ file_path.write_text(node.code, encoding="utf-8")
218
+ node.code_artifact_path = str(file_path)
219
+
220
+ def _write_final_submission(self, node: Node, expected_output: Path) -> Optional[Path]:
221
+ """
222
+ 将最终成功的代码与输出复制到固定位置,便于后续复现。
223
+ """
224
+ if not self.workspace_service:
225
+ return None
226
+ final_dir = self.workspace_service.get_path("artifacts") / "final_submission"
227
+ final_dir.mkdir(parents=True, exist_ok=True)
228
+
229
+ final_solution_path = final_dir / "final_solution.py"
230
+ final_solution_path.write_text(node.code, encoding="utf-8")
231
+ node.final_submission_path = str(final_solution_path)
232
+
233
+ if expected_output and expected_output.exists():
234
+ try:
235
+ shutil.copy2(expected_output, final_dir / expected_output.name)
236
+ except Exception as copy_error:
237
+ logger.warning(f"Failed to copy final submission artifact '{expected_output}': {copy_error}")
238
+
239
+ return final_solution_path
240
+
241
+ def _select_node_to_expand(self) -> Optional[Node]:
242
+ """
243
+ Implements the REVISED, DEBUG-FIRST search policy.
244
+ The policy prioritizes:
245
+ 1. Debugging failed solutions.
246
+ 2. Improving the current best solution if no bugs exist.
247
+ 3. Drafting a new solution only as a last resort.
248
+ """
249
+ search_cfg = self.agent_config.get("search", {})
250
+ # num_drafts is no longer used to block debugging
251
+ max_debug_depth = search_cfg.get("max_debug_depth", 3)
252
+
253
+ buggy_nodes = [n for n in self.state.nodes.values() if n.is_buggy]
254
+ # Find nodes that are leaves in the bug-chain and haven't exceeded debug depth
255
+ debuggable = [
256
+ n for n in buggy_nodes
257
+ if not n.children_ids and self._get_debug_depth(n) < max_debug_depth
258
+ ]
259
+
260
+ if debuggable:
261
+ # Select the most recent buggy node to work on
262
+ selected = max(debuggable, key=lambda n: n.step)
263
+ logger.info(f"[Search Policy] Debugging: Prioritizing most recent failed node #{selected.step}.")
264
+ return selected
265
+
266
+ best_node = self.state.get_best_node()
267
+ if best_node:
268
+ logger.info(f"[Search Policy] Improving: No bugs to fix. Selected best node #{best_node.step} with metric {best_node.metric}.")
269
+ return best_node
270
+
271
+ logger.info("[Search Policy] Drafting: No bugs to fix and no successful nodes to improve. Creating new solution.")
272
+ return None
273
+
274
+ def _get_debug_depth(self, node: Node) -> int:
275
+ depth = 0
276
+ current = node
277
+ while current.parent_id:
278
+ parent = self.state.get_node(current.parent_id)
279
+ if not parent or not parent.is_buggy:
280
+ break
281
+ depth += 1
282
+ current = parent
283
+ return depth
@@ -0,0 +1,237 @@
1
+ import logging
2
+ import json
3
+ from typing import Dict, Any, Optional
4
+ from pathlib import Path
5
+
6
+ from .aide_workflow import AIDEWorkflow
7
+ from dsat.services.states.journal import Node, MetricValue
8
+ from dsat.common.typing import ExecutionResult
9
+ from dsat.benchmark.benchmark import BaseBenchmark
10
+
11
+ from dsat.services.vdb import VDBService
12
+
13
+ from dsat.prompts.common import create_draft_prompt
14
+ from dsat.prompts.automind_prompt import create_stepwise_code_prompt, create_stepwise_debug_prompt
15
+ from dsat.prompts.aide_prompt import create_improve_prompt, create_debug_prompt
16
+
17
+ from dsat.utils.context import (
18
+ ContextManager,
19
+ MAX_HISTORY_CHARS,
20
+ MAX_OUTPUT_CHARS,
21
+ summarize_repetitive_logs,
22
+ truncate_output,
23
+ )
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class AutoMindWorkflow(AIDEWorkflow):
29
+ """
30
+ Implements the AUTOMIND iterative search algorithm.
31
+ This workflow extends AIDE by incorporating a knowledge base (VDB),
32
+ a self-adaptive coding strategy (one-pass vs. stepwise), and more
33
+ sophisticated context management for complex tasks.
34
+ """
35
+ def __init__(self, operators: Dict[str, Any], services: Dict[str, Any], agent_config: Dict[str, Any], benchmark: Optional[BaseBenchmark] = None):
36
+ """
37
+ Initializes the AutoMindWorkflow, building upon the AIDE base.
38
+ """
39
+ # Initialize base AIDE components, now passing the benchmark instance up.
40
+ super().__init__(operators, services, agent_config, benchmark=benchmark)
41
+
42
+ self.vdb_service: VDBService = services.get("vdb")
43
+
44
+ self.complexity_scorer_op = self.operators.get("complexity_scorer")
45
+ self.plan_decomposer_op = self.operators.get("plan_decomposer")
46
+
47
+ # AutoMind's context manager requires an LLM service to summarize knowledge and history.
48
+ self.context_manager = ContextManager(llm_service=services.get("llm"))
49
+
50
+ async def _execute_search_step(self, task_context: Dict, output_path: Path):
51
+ """
52
+ Execute a single step of the AutoMind search loop.
53
+ This overrides the AIDE implementation to add knowledge retrieval, the
54
+ self-adaptive coding strategy, and now uses **grounded validation**.
55
+
56
+ Args:
57
+ task_context: A dictionary containing the task goal, data report, and I/O instructions.
58
+ output_path: The path where the final output file is expected. Used for grounded validation.
59
+ """
60
+ # 1. Select a node
61
+ parent_node = self._select_node_to_expand()
62
+
63
+ task_goal = task_context.get('goal_and_data', 'Solve the data science task.')
64
+ io_instructions = task_context.get('io_instructions', 'N/A')
65
+
66
+ # 2. Create a prompt
67
+ if parent_node is None:
68
+ # For new drafts, retrieve similar examples from the knowledge base.
69
+ retrieved_knowledge = ""
70
+ if self.vdb_service:
71
+ cases = self.vdb_service.retrieve(task_goal, top_k=2)
72
+ retrieved_knowledge = await self.context_manager.summarize_knowledge(cases, task_goal)
73
+
74
+ prompt = create_draft_prompt(task_context, self.state.generate_summary(), retrieved_knowledge)
75
+ elif parent_node.is_buggy:
76
+ error_summary = self.context_manager.summarize_error(parent_node.term_out, parent_node.exc_type)
77
+ prompt = create_debug_prompt(task_context, parent_node.code, error_summary, previous_plan=parent_node.plan, memory_summary=self.state.generate_summary())
78
+ else:
79
+ summarized_output = summarize_repetitive_logs(parent_node.term_out)
80
+ prompt = create_improve_prompt(
81
+ task_context, self.state.generate_summary(), parent_node.code,
82
+ parent_node.analysis, previous_plan=parent_node.plan, previous_output=summarized_output
83
+ )
84
+
85
+ # 3. Generate initial plan and one-pass code.
86
+ plan, one_pass_code = await self.generate_op(system_prompt=prompt)
87
+
88
+ # 4. Apply the self-adaptive coding strategy for new drafts.
89
+ use_adaptive = self.complexity_scorer_op and self.plan_decomposer_op
90
+ if use_adaptive and parent_node is None:
91
+ final_code, exec_result = await self._execute_step_adaptively(plan, one_pass_code, task_goal, io_instructions)
92
+ else:
93
+ # For simpler tasks, improvements, or debugging, use the one-pass code.
94
+ final_code = one_pass_code
95
+ exec_result = await self.execute_op(code=final_code, mode="script")
96
+
97
+ # 5. Create a new node and absorb the execution result.
98
+ new_node = Node(plan=plan, code=final_code)
99
+ new_node.absorb_exec_result(exec_result)
100
+
101
+ if exec_result.success:
102
+ submission_file_in_sandbox = self.sandbox_service.workspace.get_path("sandbox_workdir") / output_path.name
103
+
104
+ if not submission_file_in_sandbox.exists():
105
+ new_node.is_buggy = True
106
+ new_node.analysis = "Code executed without error, but failed to produce the required output file."
107
+ new_node.metric = MetricValue(value=0.0, maximize=True)
108
+ elif self.benchmark and hasattr(self.benchmark, 'grade'):
109
+ logger.info(f"Performing grounded validation using benchmark grader on '{submission_file_in_sandbox}'...")
110
+ score = await self.benchmark.grade(submission_file_in_sandbox)
111
+
112
+ if score > 0:
113
+ new_node.is_buggy = False
114
+ new_node.metric = MetricValue(value=score, maximize=True)
115
+ logger.info(f"Grounded validation PASSED. Score: {score:.4f}")
116
+ review = await self.review_op(prompt_context={
117
+ "task": task_context, "code": new_node.code, "output": new_node.term_out
118
+ })
119
+ new_node.analysis = f"Grounded Score: {score:.4f}. Reviewer Summary: {review.summary}"
120
+ else:
121
+ new_node.is_buggy = True
122
+ new_node.metric = MetricValue(value=score, maximize=True)
123
+ new_node.analysis = "Grounded validation FAILED: The generated submission file was invalid or scored 0.0."
124
+ logger.warning(f"Grounded validation FAILED. Score: {score}")
125
+ else:
126
+ logger.warning("No benchmark with 'grade' method found. Falling back to unreliable LLM-based review.")
127
+ review = await self.review_op(prompt_context={
128
+ "task": task_context, "code": new_node.code, "output": new_node.term_out
129
+ })
130
+ new_node.analysis = review.summary
131
+ new_node.is_buggy = review.is_buggy
132
+ new_node.metric = MetricValue(value=review.metric_value, maximize=not review.lower_is_better) if review.metric_value is not None else MetricValue(value=None)
133
+
134
+ # 8. Add the new node to the search tree state.
135
+ self.state.append(new_node, parent=parent_node)
136
+ logger.info(f"Step {new_node.step} complete. Buggy: {new_node.is_buggy}. Metric: {new_node.metric}.")
137
+
138
+ async def _execute_step_adaptively(self, plan: str, one_pass_code: str, task_goal: str, io_instructions: str) -> tuple[str, ExecutionResult]:
139
+ """
140
+ Core of the Self-Adaptive Coding Strategy.
141
+ It scores the complexity of a plan and chooses to either execute the provided
142
+ one-pass code directly or decompose the plan into smaller steps and execute them
143
+ sequentially in a notebook context.
144
+
145
+ Args:
146
+ plan: The overall plan for the task.
147
+ one_pass_code: The single block of code generated for the entire plan.
148
+ task_goal: The user's primary goal.
149
+
150
+ Returns:
151
+ A tuple containing the final code (either one-pass or combined steps) and the
152
+ final ExecutionResult.
153
+ """
154
+ final_code = one_pass_code
155
+
156
+ # 1. Score plan complexity using the dedicated operator.
157
+ score = await self.complexity_scorer_op(plan=plan, task_goal=task_goal)
158
+
159
+ # 2. Choose strategy based on the complexity score.
160
+ if score.complexity <= 3: # Threshold for one-pass vs stepwise
161
+ logger.info("Plan is simple. Executing in one-pass mode.")
162
+ exec_result = await self.execute_op(code=final_code, mode="script")
163
+ else:
164
+ logger.info("Plan is complex. Decomposing and executing in stepwise mode.")
165
+ # 3. Decompose the complex plan into a sequence of smaller tasks.
166
+ decomposed_plan = await self.plan_decomposer_op(plan=plan, task_goal=task_goal)
167
+
168
+ step_codes = []
169
+ history_steps = []
170
+ final_exec_result = None
171
+ # Get max retries config
172
+ max_step_retries = self.agent_config.get("max_retries", 3)
173
+
174
+ async with self.sandbox_service.notebook_executor() as notebook:
175
+ for task in decomposed_plan.tasks:
176
+ logger.info(f"Executing step {task.task_id}: {task.instruction}")
177
+
178
+ step_succeeded = False
179
+ current_code = ""
180
+ step_failure_history = [] # History for the current step
181
+
182
+ # Implement retry loop for robustness
183
+ for attempt in range(max_step_retries):
184
+ logger.info(f"Step {task.task_id}, Attempt {attempt + 1}/{max_step_retries}")
185
+
186
+ # Build a concise history of recent steps to provide context for the next step.
187
+ recent_history_str = self.context_manager.build_history_context(
188
+ history_steps,
189
+ key_order=["task_id", "code", "output"]
190
+ )
191
+
192
+ if attempt == 0:
193
+ step_prompt = create_stepwise_code_prompt(task_goal, plan, recent_history_str, task.instruction, io_instructions)
194
+ else:
195
+ error_summary = self.context_manager.summarize_error(exec_result.stderr, exec_result.exc_type)
196
+ step_failure_history.append({
197
+ "attempt": attempt,
198
+ "code": truncate_output(current_code, MAX_OUTPUT_CHARS),
199
+ "error": error_summary
200
+ })
201
+
202
+ formatted_failure_history = "\n".join([
203
+ f"--- Attempt {f['attempt']} Failed ---\nCode:\n```python\n{f['code']}\n```\nError: {f['error']}\n---"
204
+ for f in step_failure_history
205
+ ])
206
+ safe_failure_history = truncate_output(formatted_failure_history, MAX_HISTORY_CHARS)
207
+
208
+ step_prompt = create_stepwise_debug_prompt(
209
+ task_goal, plan, recent_history_str, task.instruction,
210
+ current_code, safe_failure_history, io_instructions
211
+ )
212
+
213
+ _, current_code = await self.generate_op(system_prompt=step_prompt)
214
+ exec_result = await self.execute_op(code=current_code, mode="notebook", executor_context=notebook)
215
+
216
+ if exec_result.success:
217
+ step_succeeded = True
218
+ break
219
+
220
+ if not step_succeeded:
221
+ logger.error(f"Step {task.task_id} failed after {max_step_retries} attempts. Aborting stepwise execution.")
222
+ final_exec_result = exec_result # Capture the failed result
223
+ break
224
+
225
+ step_codes.append(f"# --- Step {task.task_id}: {task.instruction} ---\n{current_code}")
226
+ # Record the successful step for future context.
227
+ history_steps.append({
228
+ "task_id": task.task_id,
229
+ "code": truncate_output(current_code, MAX_OUTPUT_CHARS),
230
+ "output": truncate_output(exec_result.stdout, MAX_OUTPUT_CHARS),
231
+ })
232
+ final_exec_result = exec_result # Update with the latest successful result
233
+
234
+ final_code = "\n\n".join(step_codes)
235
+ exec_result = final_exec_result
236
+
237
+ return final_code, exec_result
File without changes
@@ -0,0 +1,71 @@
1
+ """
2
+ A simple, baseline workflow that serves as the starting point (seed)
3
+ for the meta-optimization evolutionary algorithm.
4
+ """
5
+
6
+ def get_initial_workflow_code() -> str:
7
+ """Returns the source code for a simple workflow using injected operators."""
8
+ return '''
9
+ import shutil
10
+ from pathlib import Path
11
+ from dsat.workflows.base import DSATWorkflow
12
+ from dsat.services.llm import LLMService
13
+ from dsat.services.sandbox import SandboxService
14
+ from dsat.utils.parsing import parse_plan_and_code
15
+ from typing import Dict, Any, List
16
+
17
+ class Workflow(DSATWorkflow):
18
+ """
19
+ An initial workflow that generates Python code to solve a Kaggle-style task,
20
+ executes it in a sandbox, and produces a submission file.
21
+ This is a strong starting point for the optimization process.
22
+ """
23
+ def __init__(self, operators: Dict[str, Any], services: Dict[str, Any], agent_config: Dict[str, Any]):
24
+ """
25
+ The driver injects the llm_service and sandbox_service.
26
+ """
27
+ super().__init__(operators, services, agent_config)
28
+ self.llm_service: LLMService = services["llm"]
29
+ self.sandbox_service: SandboxService = services["sandbox"]
30
+
31
+ async def solve(self, description: str, io_instructions: str, data_dir: Path, output_path: Path):
32
+ """
33
+ The main entry point for executing the workflow.
34
+ """
35
+ print(f" Initial Workflow: Starting task. Target: {output_path.name}")
36
+
37
+ prompt = (
38
+ f"You are an expert AI developer and data scientist. Your task is to write a single, complete Python script to solve the following problem. "
39
+ f"Provide only the Python code in a single code block.\\n\\n"
40
+ f"# PROBLEM DESCRIPTION AND DATA REPORT\\n{description}\\n\\n"
41
+ f"# CRITICAL I/O REQUIREMENTS (MUST BE FOLLOWED)\\n{io_instructions}\\n\\n"
42
+ f"Ensure the script strictly follows the CRITICAL I/O REQUIREMENTS."
43
+ )
44
+
45
+ # 1. Generate the code
46
+ llm_response = await self.llm_service.call(prompt)
47
+ _, code_to_execute = parse_plan_and_code(llm_response)
48
+
49
+ if "# ERROR" in code_to_execute:
50
+ print(" ERROR: Failed to generate valid code from LLM.")
51
+ return
52
+
53
+ # 2. Execute the code in the sandbox
54
+ # The sandbox will have its own isolated workspace.
55
+ print(" Initial Workflow: Executing generated script in sandbox...")
56
+ exec_result = self.sandbox_service.run_script(code_to_execute)
57
+
58
+ if not exec_result.success:
59
+ print(f" ERROR: Code execution failed.\\\\n{exec_result.stderr}")
60
+ return
61
+
62
+ # 3. Verify the generated submission file exists.
63
+ sandbox_workdir = self.sandbox_service.workspace.get_path("sandbox_workdir")
64
+ generated_file = sandbox_workdir / output_path.name
65
+
66
+ if generated_file.exists():
67
+ print(f" SUCCESS: Submission file '{output_path.name}' successfully generated in sandbox.")
68
+ else:
69
+ print(f" ERROR: Execution succeeded, but the required output file '{output_path.name}' was not created in {sandbox_workdir}.")
70
+
71
+ '''