dslighting 1.3.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dsat/__init__.py +3 -0
- dsat/benchmark/__init__.py +1 -0
- dsat/benchmark/benchmark.py +168 -0
- dsat/benchmark/datasci.py +291 -0
- dsat/benchmark/mle.py +777 -0
- dsat/benchmark/sciencebench.py +304 -0
- dsat/common/__init__.py +0 -0
- dsat/common/constants.py +11 -0
- dsat/common/exceptions.py +48 -0
- dsat/common/typing.py +19 -0
- dsat/config.py +79 -0
- dsat/models/__init__.py +3 -0
- dsat/models/candidates.py +16 -0
- dsat/models/formats.py +52 -0
- dsat/models/task.py +64 -0
- dsat/operators/__init__.py +0 -0
- dsat/operators/aflow_ops.py +90 -0
- dsat/operators/autokaggle_ops.py +170 -0
- dsat/operators/automind_ops.py +38 -0
- dsat/operators/base.py +22 -0
- dsat/operators/code.py +45 -0
- dsat/operators/dsagent_ops.py +123 -0
- dsat/operators/llm_basic.py +84 -0
- dsat/prompts/__init__.py +0 -0
- dsat/prompts/aflow_prompt.py +76 -0
- dsat/prompts/aide_prompt.py +52 -0
- dsat/prompts/autokaggle_prompt.py +290 -0
- dsat/prompts/automind_prompt.py +29 -0
- dsat/prompts/common.py +51 -0
- dsat/prompts/data_interpreter_prompt.py +82 -0
- dsat/prompts/dsagent_prompt.py +88 -0
- dsat/runner.py +554 -0
- dsat/services/__init__.py +0 -0
- dsat/services/data_analyzer.py +387 -0
- dsat/services/llm.py +486 -0
- dsat/services/llm_single.py +421 -0
- dsat/services/sandbox.py +386 -0
- dsat/services/states/__init__.py +0 -0
- dsat/services/states/autokaggle_state.py +43 -0
- dsat/services/states/base.py +14 -0
- dsat/services/states/dsa_log.py +13 -0
- dsat/services/states/experience.py +237 -0
- dsat/services/states/journal.py +153 -0
- dsat/services/states/operator_library.py +290 -0
- dsat/services/vdb.py +76 -0
- dsat/services/workspace.py +178 -0
- dsat/tasks/__init__.py +3 -0
- dsat/tasks/handlers.py +376 -0
- dsat/templates/open_ended/grade_template.py +107 -0
- dsat/tools/__init__.py +4 -0
- dsat/utils/__init__.py +0 -0
- dsat/utils/context.py +172 -0
- dsat/utils/dynamic_import.py +71 -0
- dsat/utils/parsing.py +33 -0
- dsat/workflows/__init__.py +12 -0
- dsat/workflows/base.py +53 -0
- dsat/workflows/factory.py +439 -0
- dsat/workflows/manual/__init__.py +0 -0
- dsat/workflows/manual/autokaggle_workflow.py +148 -0
- dsat/workflows/manual/data_interpreter_workflow.py +153 -0
- dsat/workflows/manual/deepanalyze_workflow.py +484 -0
- dsat/workflows/manual/dsagent_workflow.py +76 -0
- dsat/workflows/search/__init__.py +0 -0
- dsat/workflows/search/aflow_workflow.py +344 -0
- dsat/workflows/search/aide_workflow.py +283 -0
- dsat/workflows/search/automind_workflow.py +237 -0
- dsat/workflows/templates/__init__.py +0 -0
- dsat/workflows/templates/basic_kaggle_loop.py +71 -0
- dslighting/__init__.py +170 -0
- dslighting/core/__init__.py +13 -0
- dslighting/core/agent.py +646 -0
- dslighting/core/config_builder.py +318 -0
- dslighting/core/data_loader.py +422 -0
- dslighting/core/task_detector.py +422 -0
- dslighting/utils/__init__.py +19 -0
- dslighting/utils/defaults.py +151 -0
- dslighting-1.3.9.dist-info/METADATA +554 -0
- dslighting-1.3.9.dist-info/RECORD +80 -0
- dslighting-1.3.9.dist-info/WHEEL +5 -0
- dslighting-1.3.9.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
from typing import List, Dict, Any
|
|
2
|
+
from pydantic import BaseModel, Field
|
|
3
|
+
|
|
4
|
+
from dsat.operators.base import Operator
|
|
5
|
+
from dsat.services.llm import LLMService
|
|
6
|
+
|
|
7
|
+
# --- Pydantic Models for Structured I/O ---
|
|
8
|
+
|
|
9
|
+
class ScEnsembleResponse(BaseModel):
|
|
10
|
+
"""Structured response for the Self-Consistency Ensemble operator."""
|
|
11
|
+
thought: str = Field(description="The step-by-step thinking process to determine the most consistent solution.")
|
|
12
|
+
solution_letter: str = Field(description="The single letter (A, B, C, etc.) of the most consistent solution.")
|
|
13
|
+
|
|
14
|
+
class ReviewResponse(BaseModel):
|
|
15
|
+
"""Structured response for the Review operator."""
|
|
16
|
+
is_correct: bool = Field(description="True if the solution is very likely correct, False otherwise.")
|
|
17
|
+
feedback: str = Field(description="If incorrect, detailed feedback for revision. If correct, a brief justification.")
|
|
18
|
+
|
|
19
|
+
class ReviseResponse(BaseModel):
|
|
20
|
+
"""Structured response for the Revise operator."""
|
|
21
|
+
solution: str = Field(description="The complete, revised solution based on the provided feedback.")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# --- Operator Implementations ---
|
|
25
|
+
|
|
26
|
+
class ScEnsembleOperator(Operator):
|
|
27
|
+
"""
|
|
28
|
+
Performs a self-consistency check by asking the LLM to vote on the most
|
|
29
|
+
frequent or logical answer from a list of candidate solutions.
|
|
30
|
+
"""
|
|
31
|
+
async def __call__(self, solutions: List[str], problem: str) -> str:
|
|
32
|
+
if not self.llm_service:
|
|
33
|
+
raise ValueError("LLMService is required for this operator.")
|
|
34
|
+
|
|
35
|
+
solution_text = ""
|
|
36
|
+
solution_map = {}
|
|
37
|
+
for i, solution in enumerate(solutions):
|
|
38
|
+
letter = chr(65 + i)
|
|
39
|
+
solution_map[letter] = solution
|
|
40
|
+
solution_text += f"{letter}: \n{solution}\n\n"
|
|
41
|
+
|
|
42
|
+
prompt = (
|
|
43
|
+
f"Given the problem: '{problem}'\n\n"
|
|
44
|
+
f"Several solutions have been generated:\n{solution_text}\n"
|
|
45
|
+
"Carefully evaluate these solutions and identify the answer that appears most frequently or is most logical. "
|
|
46
|
+
"Respond with a JSON object containing your thought process and the letter of the most consistent solution."
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
response_model = await self.llm_service.call_with_json(prompt, output_model=ScEnsembleResponse)
|
|
50
|
+
|
|
51
|
+
chosen_letter = response_model.solution_letter.strip().upper()
|
|
52
|
+
return solution_map.get(chosen_letter, solutions[0]) # Default to first solution on failure
|
|
53
|
+
|
|
54
|
+
class ReviewOperator(Operator):
|
|
55
|
+
"""
|
|
56
|
+
Critically reviews a solution for correctness and provides structured feedback.
|
|
57
|
+
"""
|
|
58
|
+
async def __call__(self, problem: str, solution: str) -> ReviewResponse:
|
|
59
|
+
if not self.llm_service:
|
|
60
|
+
raise ValueError("LLMService is required for this operator.")
|
|
61
|
+
|
|
62
|
+
prompt = (
|
|
63
|
+
"You are a meticulous reviewer. Given a problem and a solution, your task is to critically evaluate the solution's correctness. "
|
|
64
|
+
"If you are more than 95% confident the solution is incorrect, provide feedback for fixing it. Otherwise, confirm its correctness.\n\n"
|
|
65
|
+
f"# PROBLEM\n{problem}\n\n"
|
|
66
|
+
f"# SOLUTION\n{solution}\n\n"
|
|
67
|
+
"Respond with a JSON object containing your evaluation."
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
return await self.llm_service.call_with_json(prompt, output_model=ReviewResponse)
|
|
71
|
+
|
|
72
|
+
class ReviseOperator(Operator):
|
|
73
|
+
"""
|
|
74
|
+
Revises a solution based on feedback from the Review operator.
|
|
75
|
+
"""
|
|
76
|
+
async def __call__(self, problem: str, solution: str, feedback: str) -> str:
|
|
77
|
+
if not self.llm_service:
|
|
78
|
+
raise ValueError("LLMService is required for this operator.")
|
|
79
|
+
|
|
80
|
+
prompt = (
|
|
81
|
+
"You are an expert programmer. A previous solution was found to be incorrect. "
|
|
82
|
+
"Your task is to revise the solution based on the provided feedback.\n\n"
|
|
83
|
+
f"# PROBLEM\n{problem}\n\n"
|
|
84
|
+
f"# INCORRECT SOLUTION\n{solution}\n\n"
|
|
85
|
+
f"# FEEDBACK\n{feedback}\n\n"
|
|
86
|
+
"Provide a JSON object containing the complete, revised solution."
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
response_model = await self.llm_service.call_with_json(prompt, output_model=ReviseResponse)
|
|
90
|
+
return response_model.solution
|
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import re
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Dict, Any, List
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, Field
|
|
8
|
+
|
|
9
|
+
from dsat.operators.base import Operator
|
|
10
|
+
from dsat.services.llm import LLMService
|
|
11
|
+
from dsat.services.sandbox import SandboxService
|
|
12
|
+
from dsat.services.states.autokaggle_state import AutoKaggleState, TaskContract, PhaseMemory
|
|
13
|
+
from dsat.prompts.autokaggle_prompt import (
|
|
14
|
+
get_deconstructor_prompt,
|
|
15
|
+
get_phase_planner_prompt,
|
|
16
|
+
get_step_planner_prompt,
|
|
17
|
+
get_developer_prompt,
|
|
18
|
+
get_validator_prompt,
|
|
19
|
+
get_reviewer_prompt,
|
|
20
|
+
get_summarizer_prompt,
|
|
21
|
+
)
|
|
22
|
+
from dsat.models.formats import StepPlan
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class PhasePlanningResponse(BaseModel):
|
|
28
|
+
"""Response model for phase planning."""
|
|
29
|
+
phases: List[str]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ValidationResponse(BaseModel):
|
|
33
|
+
"""Response model for file validation."""
|
|
34
|
+
passed: bool
|
|
35
|
+
reason: str = ""
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class ReviewResponse(BaseModel):
|
|
39
|
+
"""Response model for code review."""
|
|
40
|
+
score: int
|
|
41
|
+
suggestion: str = Field(default="", description="Constructive feedback or a suggestion for improvement.")
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class TaskDeconstructionOperator(Operator):
|
|
45
|
+
"""Parses the natural language description into a structured TaskContract."""
|
|
46
|
+
|
|
47
|
+
async def __call__(self, description: str) -> TaskContract:
|
|
48
|
+
logger.info("Deconstructing task description into a structured contract...")
|
|
49
|
+
prompt = get_deconstructor_prompt(description, TaskContract.model_json_schema())
|
|
50
|
+
contract = await self.llm_service.call_with_json(prompt, output_model=TaskContract)
|
|
51
|
+
logger.info(f"Task deconstructed. Goal: {contract.task_goal}")
|
|
52
|
+
return contract
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class AutoKagglePlannerOperator(Operator):
|
|
56
|
+
"""Handles high-level phase planning and low-level step planning."""
|
|
57
|
+
|
|
58
|
+
async def __call__(self, *args, **kwargs) -> Any:
|
|
59
|
+
"""
|
|
60
|
+
Main entry point for the planner operator.
|
|
61
|
+
Can be called with different arguments for different planning tasks.
|
|
62
|
+
"""
|
|
63
|
+
if len(args) == 1 and isinstance(args[0], TaskContract):
|
|
64
|
+
# Called for phase planning
|
|
65
|
+
return await self.plan_phases(args[0])
|
|
66
|
+
elif len(args) == 2 and isinstance(args[0], AutoKaggleState) and isinstance(args[1], str):
|
|
67
|
+
# Called for step planning
|
|
68
|
+
return await self.plan_step_details(args[0], args[1])
|
|
69
|
+
else:
|
|
70
|
+
raise ValueError(f"AutoKagglePlannerOperator called with unexpected arguments: {args}, {kwargs}")
|
|
71
|
+
|
|
72
|
+
async def plan_phases(self, contract: TaskContract) -> List[str]:
|
|
73
|
+
logger.info("Planning dynamic phases for the workflow...")
|
|
74
|
+
prompt = get_phase_planner_prompt(contract)
|
|
75
|
+
response = await self.llm_service.call_with_json(prompt, output_model=PhasePlanningResponse)
|
|
76
|
+
phases = response.phases
|
|
77
|
+
logger.info(f"Dynamic phases planned: {phases}")
|
|
78
|
+
return phases
|
|
79
|
+
|
|
80
|
+
async def plan_step_details(self, state: AutoKaggleState, phase_goal: str) -> StepPlan:
|
|
81
|
+
logger.info(f"Planning detailed steps for phase: '{phase_goal}'...")
|
|
82
|
+
prompt = get_step_planner_prompt(state, phase_goal)
|
|
83
|
+
step_plan = await self.llm_service.call_with_json(prompt, output_model=StepPlan)
|
|
84
|
+
return step_plan
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class DynamicValidationOperator(Operator):
|
|
88
|
+
"""Dynamically validates generated files against the TaskContract."""
|
|
89
|
+
|
|
90
|
+
async def __call__(self, contract: TaskContract, workspace_dir: Path) -> Dict[str, Any]:
|
|
91
|
+
logger.info("Performing dynamic validation of output files...")
|
|
92
|
+
results = {}
|
|
93
|
+
for output_file in contract.output_files:
|
|
94
|
+
file_path = workspace_dir / output_file.filename
|
|
95
|
+
if not file_path.exists():
|
|
96
|
+
results[output_file.filename] = {"passed": False, "reason": "File was not generated."}
|
|
97
|
+
continue
|
|
98
|
+
|
|
99
|
+
try:
|
|
100
|
+
content_snippet = "\\n".join(file_path.read_text().splitlines()[:20])
|
|
101
|
+
except Exception as e:
|
|
102
|
+
results[output_file.filename] = {"passed": False, "reason": f"Could not read file: {str(e)}"}
|
|
103
|
+
continue
|
|
104
|
+
|
|
105
|
+
prompt = get_validator_prompt(contract, output_file.filename, content_snippet)
|
|
106
|
+
validation = await self.llm_service.call_with_json(prompt, output_model=ValidationResponse)
|
|
107
|
+
results[output_file.filename] = {"passed": validation.passed, "reason": validation.reason}
|
|
108
|
+
logger.info(f"Validation results: {results}")
|
|
109
|
+
return results
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class AutoKaggleDeveloperOperator(Operator):
|
|
113
|
+
"""Writes, executes, and validates code."""
|
|
114
|
+
|
|
115
|
+
def __init__(self, llm_service: LLMService, sandbox_service: SandboxService, validator: DynamicValidationOperator):
|
|
116
|
+
super().__init__(llm_service, name="AutoKaggleDeveloper")
|
|
117
|
+
self.sandbox = sandbox_service
|
|
118
|
+
self.validator = validator
|
|
119
|
+
|
|
120
|
+
async def __call__(self, state: AutoKaggleState, phase_goal: str, plan: str, attempt_history: List) -> Dict:
|
|
121
|
+
logger.info(f"Developer starting work for phase: '{phase_goal}'")
|
|
122
|
+
prompt = get_developer_prompt(state, phase_goal, plan, attempt_history)
|
|
123
|
+
|
|
124
|
+
raw_reply = await self.llm_service.call(prompt)
|
|
125
|
+
match = re.search(r"```(?:python|py)?\s*([\s\S]*?)\s*```", raw_reply, re.DOTALL)
|
|
126
|
+
code = match.group(1).strip() if match else ""
|
|
127
|
+
|
|
128
|
+
if not code:
|
|
129
|
+
return {"code": "", "status": False, "output": "", "error": "No code was generated.", "validation_result": {}}
|
|
130
|
+
|
|
131
|
+
exec_result = self.sandbox.run_script(code)
|
|
132
|
+
|
|
133
|
+
validation_result = {}
|
|
134
|
+
if exec_result.success:
|
|
135
|
+
# Note: This still validates against the *final* contract outputs. The reviewer logic handles this.
|
|
136
|
+
validation_result = await self.validator(state.contract, self.sandbox.workspace.run_dir)
|
|
137
|
+
|
|
138
|
+
return {
|
|
139
|
+
"code": code,
|
|
140
|
+
"status": exec_result.success,
|
|
141
|
+
"output": exec_result.stdout,
|
|
142
|
+
"error": exec_result.stderr,
|
|
143
|
+
"validation_result": validation_result
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class AutoKaggleReviewerOperator(Operator):
|
|
148
|
+
"""Reviews the developer's work and provides a score and suggestions."""
|
|
149
|
+
|
|
150
|
+
async def __call__(self, state: AutoKaggleState, phase_goal: str, dev_result: Dict, plan: str = "") -> Dict:
|
|
151
|
+
logger.info("Reviewer assessing developer's work...")
|
|
152
|
+
prompt = get_reviewer_prompt(phase_goal, dev_result, plan)
|
|
153
|
+
review = await self.llm_service.call_with_json(prompt, output_model=ReviewResponse)
|
|
154
|
+
review_dict = {
|
|
155
|
+
"score": review.score,
|
|
156
|
+
"suggestion": review.suggestion
|
|
157
|
+
}
|
|
158
|
+
logger.info(f"Review complete. Score: {review.score}")
|
|
159
|
+
return review_dict
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
class AutoKaggleSummarizerOperator(Operator):
|
|
163
|
+
"""Summarizes a successful phase into a report."""
|
|
164
|
+
|
|
165
|
+
async def __call__(self, state: AutoKaggleState, phase_memory: PhaseMemory) -> str:
|
|
166
|
+
logger.info(f"Summarizer creating report for phase: '{phase_memory.phase_goal}'")
|
|
167
|
+
prompt = get_summarizer_prompt(state, phase_memory)
|
|
168
|
+
report = await self.llm_service.call(prompt)
|
|
169
|
+
logger.info("Report created.")
|
|
170
|
+
return report
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Dict
|
|
3
|
+
|
|
4
|
+
from dsat.models.formats import ComplexityScore, DecomposedPlan
|
|
5
|
+
from dsat.operators.base import Operator
|
|
6
|
+
from dsat.services.llm import LLMService
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
class ComplexityScorerOperator(Operator):
|
|
11
|
+
"""Scores the complexity of a natural language plan."""
|
|
12
|
+
async def __call__(self, plan: str, task_goal: str) -> ComplexityScore:
|
|
13
|
+
logger.info("Scoring plan complexity...")
|
|
14
|
+
prompt = (
|
|
15
|
+
"You are an expert project manager. On a scale of 1 to 5, where 1 is trivial "
|
|
16
|
+
"(e.g., load a file and report basic properties) and 5 is highly complex (e.g., multi-stage pipeline involving custom algorithms, extensive preprocessing, or complex simulations), how complex is the following plan?\n\n"
|
|
17
|
+
f"## Task Goal:\n{task_goal}\n\n"
|
|
18
|
+
f"## Proposed Plan:\n{plan}\n\n"
|
|
19
|
+
"Respond with a JSON object containing your score and a brief justification."
|
|
20
|
+
)
|
|
21
|
+
score = await self.llm_service.call_with_json(prompt, output_model=ComplexityScore)
|
|
22
|
+
logger.info(f"Plan complexity scored at {score.complexity}/5.")
|
|
23
|
+
return score
|
|
24
|
+
|
|
25
|
+
class PlanDecomposerOperator(Operator):
|
|
26
|
+
"""Decomposes a complex plan into a structured list of sequential steps."""
|
|
27
|
+
async def __call__(self, plan: str, task_goal: str) -> DecomposedPlan:
|
|
28
|
+
logger.info("Decomposing complex plan into steps...")
|
|
29
|
+
prompt = (
|
|
30
|
+
"You are an expert data scientist. Decompose the following high-level plan into a sequence of "
|
|
31
|
+
"small, logical, and executable steps. Each step should represent a single cell in a data science notebook.\n\n"
|
|
32
|
+
f"## Task Goal:\n{task_goal}\n\n"
|
|
33
|
+
f"## High-Level Plan:\n{plan}\n\n"
|
|
34
|
+
"Respond with a JSON object containing a list of tasks. Each task must have a unique `task_id` and an `instruction`."
|
|
35
|
+
)
|
|
36
|
+
decomposed_plan = await self.llm_service.call_with_json(prompt, output_model=DecomposedPlan)
|
|
37
|
+
logger.info(f"Plan decomposed into {len(decomposed_plan.tasks)} steps.")
|
|
38
|
+
return decomposed_plan
|
dsat/operators/base.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Optional, Any
|
|
3
|
+
|
|
4
|
+
from dsat.services.llm import LLMService
|
|
5
|
+
|
|
6
|
+
class Operator(ABC):
|
|
7
|
+
"""
|
|
8
|
+
Abstract base class for a self-contained capability.
|
|
9
|
+
|
|
10
|
+
Operators are the "verbs" of the agent framework, representing discrete
|
|
11
|
+
actions like generating code, executing it, or reviewing results.
|
|
12
|
+
"""
|
|
13
|
+
def __init__(self, llm_service: Optional[LLMService] = None, name: Optional[str] = None):
|
|
14
|
+
self.name = name or self.__class__.__name__
|
|
15
|
+
self.llm_service = llm_service
|
|
16
|
+
|
|
17
|
+
@abstractmethod
|
|
18
|
+
async def __call__(self, *args, **kwargs) -> Any:
|
|
19
|
+
"""
|
|
20
|
+
Executes the operator's logic. All operator calls are asynchronous.
|
|
21
|
+
"""
|
|
22
|
+
raise NotImplementedError
|
dsat/operators/code.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from dsat.common.typing import ExecutionResult
|
|
5
|
+
from dsat.operators.base import Operator
|
|
6
|
+
from dsat.services.sandbox import SandboxService, ProcessIsolatedNotebookExecutor
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
class ExecuteAndTestOperator(Operator):
|
|
11
|
+
"""
|
|
12
|
+
An operator that acts as a clean interface to the SandboxService,
|
|
13
|
+
handling both script and notebook execution modes.
|
|
14
|
+
"""
|
|
15
|
+
def __init__(self, sandbox_service: SandboxService):
|
|
16
|
+
super().__init__(name="ExecuteAndTest")
|
|
17
|
+
self.sandbox = sandbox_service
|
|
18
|
+
|
|
19
|
+
async def __call__(self, code: str, mode: str = "script", executor_context: Any = None) -> ExecutionResult:
|
|
20
|
+
"""
|
|
21
|
+
Executes code using the configured sandbox.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
code (str): The Python code to execute.
|
|
25
|
+
mode (str): The execution mode, either 'script' or 'notebook'.
|
|
26
|
+
executor_context (Any): For notebook mode, this must be the active ProcessIsolatedNotebookExecutor instance.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
ExecutionResult: The outcome of the execution.
|
|
30
|
+
"""
|
|
31
|
+
if mode == "script":
|
|
32
|
+
logger.info("Executing code as a script...")
|
|
33
|
+
# run_script is synchronous in the sandbox service, but we call it from an async operator
|
|
34
|
+
# A fully async sandbox would use asyncio.to_thread
|
|
35
|
+
return self.sandbox.run_script(code)
|
|
36
|
+
|
|
37
|
+
elif mode == "notebook":
|
|
38
|
+
if not isinstance(executor_context, ProcessIsolatedNotebookExecutor):
|
|
39
|
+
raise TypeError("Notebook mode requires a valid ProcessIsolatedNotebookExecutor instance passed via executor_context.")
|
|
40
|
+
|
|
41
|
+
logger.info("Executing notebook cell...")
|
|
42
|
+
return await executor_context.execute_cell(code)
|
|
43
|
+
|
|
44
|
+
else:
|
|
45
|
+
raise ValueError(f"Unknown execution mode: {mode}")
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Optional, Tuple
|
|
3
|
+
|
|
4
|
+
from dsat.operators.base import Operator
|
|
5
|
+
from dsat.prompts.dsagent_prompt import (
|
|
6
|
+
PLAN_PROMPT_TEMPLATE, PROGRAMMER_PROMPT_TEMPLATE,
|
|
7
|
+
DEBUGGER_PROMPT_TEMPLATE, LOGGER_PROMPT_TEMPLATE
|
|
8
|
+
)
|
|
9
|
+
from dsat.services.llm import LLMService
|
|
10
|
+
from dsat.services.sandbox import SandboxService
|
|
11
|
+
from dsat.services.vdb import VDBService
|
|
12
|
+
from dsat.common.typing import ExecutionResult
|
|
13
|
+
from dsat.utils.parsing import parse_plan_and_code
|
|
14
|
+
from dsat.utils.context import MAX_HISTORY_CHARS, MAX_OUTPUT_CHARS, truncate_output
|
|
15
|
+
|
|
16
|
+
# Define how much recent context to keep verbatim during summarization
|
|
17
|
+
RECENT_CONTEXT_VERBATIM = 8000 # Increased from 2000
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
class DevelopPlanOperator(Operator):
|
|
22
|
+
"""Retrieves cases and generates an experiment plan."""
|
|
23
|
+
def __init__(self, llm_service: LLMService, vdb_service: Optional[VDBService] = None):
|
|
24
|
+
super().__init__(llm_service, name="DevelopPlan")
|
|
25
|
+
self.vdb = vdb_service
|
|
26
|
+
|
|
27
|
+
async def __call__(self, research_problem: str, io_instructions: str, running_log: str) -> str:
|
|
28
|
+
safe_running_log = truncate_output(running_log, MAX_HISTORY_CHARS)
|
|
29
|
+
query = f"{research_problem}\n{safe_running_log}"
|
|
30
|
+
if self.vdb:
|
|
31
|
+
retrieved_cases = self.vdb.retrieve(query, top_k=1)
|
|
32
|
+
case = retrieved_cases[0] if retrieved_cases else "No relevant cases found."
|
|
33
|
+
else:
|
|
34
|
+
case = "No relevant cases found."
|
|
35
|
+
|
|
36
|
+
prompt = PLAN_PROMPT_TEMPLATE.format(
|
|
37
|
+
research_problem=research_problem,
|
|
38
|
+
io_instructions=io_instructions,
|
|
39
|
+
running_log=safe_running_log,
|
|
40
|
+
case=case
|
|
41
|
+
)
|
|
42
|
+
plan = await self.llm_service.call(prompt)
|
|
43
|
+
return plan.strip()
|
|
44
|
+
|
|
45
|
+
class ExecutePlanOperator(Operator):
|
|
46
|
+
"""Manages the programmer-debugger loop to implement a plan."""
|
|
47
|
+
def __init__(self, llm_service: LLMService, sandbox_service: SandboxService, max_retries: int = 10):
|
|
48
|
+
super().__init__(llm_service, name="ExecutePlan")
|
|
49
|
+
self.sandbox = sandbox_service
|
|
50
|
+
self.max_retries = max_retries
|
|
51
|
+
|
|
52
|
+
async def __call__(self, initial_code: str, plan: str, research_problem: str, io_instructions: str, running_log: str = "") -> Tuple[ExecutionResult, str]:
|
|
53
|
+
safe_running_log = truncate_output(running_log, MAX_HISTORY_CHARS)
|
|
54
|
+
current_code = initial_code
|
|
55
|
+
for attempt in range(self.max_retries):
|
|
56
|
+
if attempt == 0:
|
|
57
|
+
prompt = PROGRAMMER_PROMPT_TEMPLATE.format(
|
|
58
|
+
code=current_code, plan=plan, research_problem=research_problem,
|
|
59
|
+
io_instructions=io_instructions, running_log=safe_running_log
|
|
60
|
+
)
|
|
61
|
+
else:
|
|
62
|
+
safe_error_log = truncate_output(exec_result.stderr, MAX_OUTPUT_CHARS)
|
|
63
|
+
prompt = DEBUGGER_PROMPT_TEMPLATE.format(
|
|
64
|
+
plan=plan, code=current_code, error_log=safe_error_log,
|
|
65
|
+
research_problem=research_problem, io_instructions=io_instructions, running_log=safe_running_log
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
response = await self.llm_service.call(prompt)
|
|
69
|
+
_, current_code = parse_plan_and_code(response)
|
|
70
|
+
|
|
71
|
+
exec_result = self.sandbox.run_script(current_code)
|
|
72
|
+
if exec_result.success:
|
|
73
|
+
return exec_result, current_code
|
|
74
|
+
|
|
75
|
+
logger.error(f"Execution failed after {self.max_retries} attempts.")
|
|
76
|
+
return exec_result, current_code # Return last failed result
|
|
77
|
+
|
|
78
|
+
class ReviseLogOperator(Operator):
|
|
79
|
+
"""Summarizes the results of an experiment step."""
|
|
80
|
+
async def __call__(self, running_log: str, plan: str, exec_result: ExecutionResult, diff: str) -> str:
|
|
81
|
+
safe_running_log = truncate_output(running_log, MAX_HISTORY_CHARS)
|
|
82
|
+
execution_log = exec_result.stdout or exec_result.stderr
|
|
83
|
+
safe_execution_log = truncate_output(execution_log, MAX_OUTPUT_CHARS)
|
|
84
|
+
safe_diff = truncate_output(diff, MAX_OUTPUT_CHARS)
|
|
85
|
+
prompt = LOGGER_PROMPT_TEMPLATE.format(
|
|
86
|
+
plan=plan,
|
|
87
|
+
execution_log=safe_execution_log,
|
|
88
|
+
diff=safe_diff,
|
|
89
|
+
running_log=safe_running_log
|
|
90
|
+
)
|
|
91
|
+
new_summary = await self.llm_service.call(prompt)
|
|
92
|
+
|
|
93
|
+
updated_log = running_log + "\n\n---\n" + new_summary.strip()
|
|
94
|
+
|
|
95
|
+
if len(updated_log) > MAX_HISTORY_CHARS:
|
|
96
|
+
logger.warning("Running log exceeds character limit; summarizing older history...")
|
|
97
|
+
|
|
98
|
+
# Split the log: recent history (verbatim) and older history (to be summarized)
|
|
99
|
+
split_point = max(0, len(updated_log) - RECENT_CONTEXT_VERBATIM)
|
|
100
|
+
older_history = updated_log[:split_point]
|
|
101
|
+
recent_history = updated_log[split_point:]
|
|
102
|
+
|
|
103
|
+
if not older_history:
|
|
104
|
+
return updated_log
|
|
105
|
+
|
|
106
|
+
# Summarize only the older part
|
|
107
|
+
summarize_prompt = (
|
|
108
|
+
"The following is the older history of a data science project. "
|
|
109
|
+
"Summarize this history concisely. You MUST preserve the following elements if present:\n"
|
|
110
|
+
"1. Key findings and model performance metrics.\n"
|
|
111
|
+
"2. Any explicit data format constraints (e.g., required CSV columns).\n"
|
|
112
|
+
"3. Any explicit I/O details (e.g., filenames used, successful data loading patterns).\n"
|
|
113
|
+
"4. Specific error messages or tracebacks from failed attempts, as they are crucial context for avoiding repeated mistakes.\n"
|
|
114
|
+
f"\n\n# OLDER HISTORY\n{older_history}"
|
|
115
|
+
)
|
|
116
|
+
summarized_older_history = await self.llm_service.call(summarize_prompt)
|
|
117
|
+
|
|
118
|
+
# Combine summarized older history with verbatim recent history
|
|
119
|
+
updated_log = f"--- SUMMARIZED HISTORY ---\n{summarized_older_history}\n\n--- RECENT HISTORY (VERBATIM) ---\n{recent_history}"
|
|
120
|
+
|
|
121
|
+
logger.info("Running log has been updated with summarized history.")
|
|
122
|
+
|
|
123
|
+
return updated_log
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Dict
|
|
3
|
+
|
|
4
|
+
from dsat.models.formats import Plan, ReviewResult, Task
|
|
5
|
+
from dsat.operators.base import Operator
|
|
6
|
+
from dsat.services.llm import LLMService
|
|
7
|
+
from dsat.utils.parsing import parse_plan_and_code
|
|
8
|
+
from dsat.utils.context import summarize_repetitive_logs
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
class GenerateCodeAndPlanOperator(Operator):
|
|
13
|
+
"""Generates a plan and corresponding code based on a prompt."""
|
|
14
|
+
async def __call__(self, system_prompt: str, user_prompt: str = "") -> tuple[str, str]:
|
|
15
|
+
if not self.llm_service:
|
|
16
|
+
raise ValueError("LLMService is required for this operator.")
|
|
17
|
+
|
|
18
|
+
logger.info("Generating new code and plan...")
|
|
19
|
+
full_prompt = f"{system_prompt}\n\n{user_prompt}"
|
|
20
|
+
|
|
21
|
+
# Use the new standard call method
|
|
22
|
+
response = await self.llm_service.call(full_prompt)
|
|
23
|
+
plan, code = parse_plan_and_code(response)
|
|
24
|
+
|
|
25
|
+
if "# ERROR" in code:
|
|
26
|
+
logger.warning("Failed to parse a valid code block from the LLM response.")
|
|
27
|
+
else:
|
|
28
|
+
logger.info("Successfully generated code and plan.")
|
|
29
|
+
|
|
30
|
+
return plan, code
|
|
31
|
+
|
|
32
|
+
class PlanOperator(Operator):
|
|
33
|
+
"""Creates a structured, multi-step plan based on a user request."""
|
|
34
|
+
async def __call__(self, user_request: str) -> Plan:
|
|
35
|
+
if not self.llm_service:
|
|
36
|
+
raise ValueError("LLMService is required for this operator.")
|
|
37
|
+
|
|
38
|
+
logger.info(f"Generating a plan for request: '{user_request[:100]}...'")
|
|
39
|
+
|
|
40
|
+
prompt = f"Create a structured JSON plan for this user request: {user_request}"
|
|
41
|
+
# No more placeholder! This is a real structured call.
|
|
42
|
+
try:
|
|
43
|
+
plan_model = await self.llm_service.call_with_json(prompt, output_model=Plan)
|
|
44
|
+
except Exception as e:
|
|
45
|
+
logger.warning(f"Structured plan failed ({e}); falling back to text plan.")
|
|
46
|
+
text = await self.llm_service.call(prompt)
|
|
47
|
+
plan_model = Plan(tasks=[Task(task_id="1", instruction=text.strip(), dependent_task_ids=[])])
|
|
48
|
+
logger.info(f"Successfully generated a plan with {len(plan_model.tasks)} tasks.")
|
|
49
|
+
return plan_model
|
|
50
|
+
|
|
51
|
+
class ReviewOperator(Operator):
|
|
52
|
+
"""Reviews code execution output and provides a structured score and analysis."""
|
|
53
|
+
async def __call__(self, prompt_context: Dict) -> ReviewResult:
|
|
54
|
+
if not self.llm_service:
|
|
55
|
+
raise ValueError("LLMService is required for this operator.")
|
|
56
|
+
|
|
57
|
+
logger.info("Reviewing execution output...")
|
|
58
|
+
|
|
59
|
+
raw_output = prompt_context.get('output', '# N/A')
|
|
60
|
+
processed_output = summarize_repetitive_logs(raw_output)
|
|
61
|
+
|
|
62
|
+
prompt = (
|
|
63
|
+
"You are a data science judge. Review the following code and its output.\n\n"
|
|
64
|
+
f"# TASK\n{prompt_context.get('task', 'N/A')}\n\n"
|
|
65
|
+
f"# CODE\n```python\n{prompt_context.get('code', '# N/A')}\n```\n\n"
|
|
66
|
+
f"# OUTPUT\n```\n{processed_output}\n```\n\n"
|
|
67
|
+
"Respond with a JSON object containing your evaluation."
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# No more simulation! This is a real structured call.
|
|
71
|
+
review_model = await self.llm_service.call_with_json(prompt, output_model=ReviewResult)
|
|
72
|
+
return review_model
|
|
73
|
+
|
|
74
|
+
class SummarizeOperator(Operator):
|
|
75
|
+
"""Generates a concise summary of a completed phase or task."""
|
|
76
|
+
async def __call__(self, context: str) -> str:
|
|
77
|
+
if not self.llm_service:
|
|
78
|
+
raise ValueError("LLMService is required for this operator.")
|
|
79
|
+
|
|
80
|
+
logger.info("Generating summary...")
|
|
81
|
+
prompt = f"Please provide a concise summary of the following events:\n\n{context}"
|
|
82
|
+
summary = await self.llm_service.call(prompt)
|
|
83
|
+
logger.info("Summary generated successfully.")
|
|
84
|
+
return summary
|
dsat/prompts/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
from pydantic import BaseModel, Field
|
|
2
|
+
|
|
3
|
+
# 1. Pydantic model for the optimizer's structured output
|
|
4
|
+
class GraphOptimize(BaseModel):
|
|
5
|
+
modification: str = Field(description="A brief, one-sentence summary of the change made to the workflow code.")
|
|
6
|
+
graph: str = Field(description="The complete, runnable Python code for the new 'Workflow' class.")
|
|
7
|
+
|
|
8
|
+
# 2. Prompt templates
|
|
9
|
+
# --- FIX START: MAKE PROMPT MORE GENERIC AND STRATEGY-FOCUSED ---
|
|
10
|
+
WORKFLOW_OPTIMIZE_PROMPT = """
|
|
11
|
+
You are an expert AI workflow engineer. Your task is to iteratively optimize a Python-based AI workflow to improve its problem-solving score.
|
|
12
|
+
|
|
13
|
+
You will be given the code of a parent workflow, its performance score, and a history of modifications.
|
|
14
|
+
Your goal is to propose a single, small, logical modification to the workflow code. The new code must be a complete and runnable Python class named `Workflow`.
|
|
15
|
+
|
|
16
|
+
RULES:
|
|
17
|
+
1. **Focus on Logic**: Your modifications should improve the problem-solving STRATEGY. Examples: add a data cleaning step, try a different model, change a prompt, add a self-correction loop, etc.
|
|
18
|
+
2. **Adhere to the Standard Interface**: The workflow you generate MUST correctly implement the `solve` method: `async def solve(self, description: str, io_instructions: str, data_dir: Path, output_path: Path)`.
|
|
19
|
+
3. **Use the Arguments Correctly**:
|
|
20
|
+
- `description`: Contains the task goal and a COMPLETE data analysis report.
|
|
21
|
+
- `io_instructions`: Contains CRITICAL I/O requirements.
|
|
22
|
+
Your code's logic MUST use both arguments effectively.
|
|
23
|
+
4. **DO NOT Hardcode Filenames**: Your generated workflow code should extract the required output filename from the `output_path.name` attribute provided to the `solve` method, or ensure that any internal prompts correctly instruct the final code-generation step to do so.
|
|
24
|
+
5. **Make only ONE logical change.** (e.g., add one new operator call, change a prompt, add a loop).
|
|
25
|
+
6. **Analyze the experience log.** Avoid modifications that have failed in the past. Learn from successful ones.
|
|
26
|
+
7. **Ensure the `graph` output is the complete, final Python code**, including all necessary imports and the full class definition.
|
|
27
|
+
8. **Inherit from DSATWorkflow**: The generated class MUST be `class Workflow(DSATWorkflow):`.
|
|
28
|
+
|
|
29
|
+
Your response MUST be a JSON object that adheres to the provided schema. Do not include any other text, markdown, or explanations.
|
|
30
|
+
"""
|
|
31
|
+
# --- FIX END ---
|
|
32
|
+
|
|
33
|
+
WORKFLOW_INPUT_TEMPLATE = """
|
|
34
|
+
# PARENT WORKFLOW CONTEXT
|
|
35
|
+
|
|
36
|
+
## Experience Log
|
|
37
|
+
{experience}
|
|
38
|
+
|
|
39
|
+
## Parent Score
|
|
40
|
+
{score:.4f}
|
|
41
|
+
|
|
42
|
+
## Parent Code
|
|
43
|
+
```python
|
|
44
|
+
{graph_code}
|
|
45
|
+
```
|
|
46
|
+
|
|
47
|
+
## Available Operators
|
|
48
|
+
{operator_description}
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
# 3. Helper functions to assemble the final prompt
|
|
52
|
+
def get_operator_description() -> str:
|
|
53
|
+
"""Returns a formatted string of available operators for the prompt."""
|
|
54
|
+
return """
|
|
55
|
+
You can call operators from the `self.operators` dictionary inside the `solve` method. For example: `final_answer = await self.operators['ScEnsemble'](...)`
|
|
56
|
+
|
|
57
|
+
- `ScEnsemble`: Performs self-consistency voting to find the best solution from a list.
|
|
58
|
+
- `__call__(self, solutions: List[str], problem: str) -> str`
|
|
59
|
+
- `Review`: Critically reviews a solution and returns structured feedback.
|
|
60
|
+
- `__call__(self, problem: str, solution: str) -> ReviewResponse(is_correct: bool, feedback: str)`
|
|
61
|
+
- `Revise`: Revises a solution based on feedback.
|
|
62
|
+
- `__call__(self, problem: str, solution: str, feedback: str) -> str`
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
# --- FIX: Update function signature ---
|
|
66
|
+
def get_graph_optimize_prompt(experience: str, score: float, graph_code: str) -> str:
|
|
67
|
+
"""Assembles the full prompt for the optimizer LLM."""
|
|
68
|
+
main_prompt = WORKFLOW_OPTIMIZE_PROMPT # Use the updated prompt directly
|
|
69
|
+
|
|
70
|
+
inputs = WORKFLOW_INPUT_TEMPLATE.format(
|
|
71
|
+
experience=experience,
|
|
72
|
+
score=score,
|
|
73
|
+
graph_code=graph_code,
|
|
74
|
+
operator_description=get_operator_description()
|
|
75
|
+
)
|
|
76
|
+
return main_prompt + "\n" + inputs
|