dslighting 1.3.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dsat/__init__.py +3 -0
- dsat/benchmark/__init__.py +1 -0
- dsat/benchmark/benchmark.py +168 -0
- dsat/benchmark/datasci.py +291 -0
- dsat/benchmark/mle.py +777 -0
- dsat/benchmark/sciencebench.py +304 -0
- dsat/common/__init__.py +0 -0
- dsat/common/constants.py +11 -0
- dsat/common/exceptions.py +48 -0
- dsat/common/typing.py +19 -0
- dsat/config.py +79 -0
- dsat/models/__init__.py +3 -0
- dsat/models/candidates.py +16 -0
- dsat/models/formats.py +52 -0
- dsat/models/task.py +64 -0
- dsat/operators/__init__.py +0 -0
- dsat/operators/aflow_ops.py +90 -0
- dsat/operators/autokaggle_ops.py +170 -0
- dsat/operators/automind_ops.py +38 -0
- dsat/operators/base.py +22 -0
- dsat/operators/code.py +45 -0
- dsat/operators/dsagent_ops.py +123 -0
- dsat/operators/llm_basic.py +84 -0
- dsat/prompts/__init__.py +0 -0
- dsat/prompts/aflow_prompt.py +76 -0
- dsat/prompts/aide_prompt.py +52 -0
- dsat/prompts/autokaggle_prompt.py +290 -0
- dsat/prompts/automind_prompt.py +29 -0
- dsat/prompts/common.py +51 -0
- dsat/prompts/data_interpreter_prompt.py +82 -0
- dsat/prompts/dsagent_prompt.py +88 -0
- dsat/runner.py +554 -0
- dsat/services/__init__.py +0 -0
- dsat/services/data_analyzer.py +387 -0
- dsat/services/llm.py +486 -0
- dsat/services/llm_single.py +421 -0
- dsat/services/sandbox.py +386 -0
- dsat/services/states/__init__.py +0 -0
- dsat/services/states/autokaggle_state.py +43 -0
- dsat/services/states/base.py +14 -0
- dsat/services/states/dsa_log.py +13 -0
- dsat/services/states/experience.py +237 -0
- dsat/services/states/journal.py +153 -0
- dsat/services/states/operator_library.py +290 -0
- dsat/services/vdb.py +76 -0
- dsat/services/workspace.py +178 -0
- dsat/tasks/__init__.py +3 -0
- dsat/tasks/handlers.py +376 -0
- dsat/templates/open_ended/grade_template.py +107 -0
- dsat/tools/__init__.py +4 -0
- dsat/utils/__init__.py +0 -0
- dsat/utils/context.py +172 -0
- dsat/utils/dynamic_import.py +71 -0
- dsat/utils/parsing.py +33 -0
- dsat/workflows/__init__.py +12 -0
- dsat/workflows/base.py +53 -0
- dsat/workflows/factory.py +439 -0
- dsat/workflows/manual/__init__.py +0 -0
- dsat/workflows/manual/autokaggle_workflow.py +148 -0
- dsat/workflows/manual/data_interpreter_workflow.py +153 -0
- dsat/workflows/manual/deepanalyze_workflow.py +484 -0
- dsat/workflows/manual/dsagent_workflow.py +76 -0
- dsat/workflows/search/__init__.py +0 -0
- dsat/workflows/search/aflow_workflow.py +344 -0
- dsat/workflows/search/aide_workflow.py +283 -0
- dsat/workflows/search/automind_workflow.py +237 -0
- dsat/workflows/templates/__init__.py +0 -0
- dsat/workflows/templates/basic_kaggle_loop.py +71 -0
- dslighting/__init__.py +170 -0
- dslighting/core/__init__.py +13 -0
- dslighting/core/agent.py +646 -0
- dslighting/core/config_builder.py +318 -0
- dslighting/core/data_loader.py +422 -0
- dslighting/core/task_detector.py +422 -0
- dslighting/utils/__init__.py +19 -0
- dslighting/utils/defaults.py +151 -0
- dslighting-1.3.9.dist-info/METADATA +554 -0
- dslighting-1.3.9.dist-info/RECORD +80 -0
- dslighting-1.3.9.dist-info/WHEEL +5 -0
- dslighting-1.3.9.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
# dsat/workflows/manual/data_interpreter_workflow.py
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Dict, Any
|
|
6
|
+
|
|
7
|
+
from dsat.workflows.base import DSATWorkflow
|
|
8
|
+
from dsat.models.formats import Plan
|
|
9
|
+
|
|
10
|
+
from dsat.services.sandbox import SandboxService
|
|
11
|
+
from dsat.operators.base import Operator
|
|
12
|
+
|
|
13
|
+
from dsat.prompts.data_interpreter_prompt import (
|
|
14
|
+
PLAN_PROMPT, GENERATE_CODE_PROMPT, REFLECT_AND_DEBUG_PROMPT, FINALIZE_OUTPUT_PROMPT
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
from dsat.utils.context import ContextManager, MAX_OUTPUT_CHARS, truncate_output
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class DataInterpreterWorkflow(DSATWorkflow):
|
|
23
|
+
"""
|
|
24
|
+
Implements the DataInterpreter workflow, now conforming to the DSATWorkflow interface.
|
|
25
|
+
|
|
26
|
+
It uses a plan -> (write -> execute -> reflect) loop to solve problems,
|
|
27
|
+
and then executes a final step to generate the required output file.
|
|
28
|
+
It also saves a detailed interaction report as a separate log file.
|
|
29
|
+
"""
|
|
30
|
+
def __init__(self, operators: Dict[str, Operator], services: Dict[str, Any], agent_config: Dict[str, Any]):
|
|
31
|
+
super().__init__(operators, services, agent_config)
|
|
32
|
+
self.sandbox_service: SandboxService = services["sandbox"]
|
|
33
|
+
self.planner_op = self.operators["planner"]
|
|
34
|
+
self.generator_op = self.operators["generator"]
|
|
35
|
+
self.debugger_op = self.operators["debugger"]
|
|
36
|
+
self.executor_op = self.operators["executor"]
|
|
37
|
+
self.max_retries = self.agent_config.get("max_retries", 3)
|
|
38
|
+
self.context_manager = ContextManager()
|
|
39
|
+
|
|
40
|
+
async def solve(self, description: str, io_instructions: str, data_dir: Path, output_path: Path) -> None:
|
|
41
|
+
"""
|
|
42
|
+
Use Data Interpreter's plan-and-execute loop...
|
|
43
|
+
"""
|
|
44
|
+
logger.info(f"DataInterpreterWorkflow starting to solve task. Target output: {output_path}")
|
|
45
|
+
|
|
46
|
+
# The planner needs the full context to make the plan.
|
|
47
|
+
full_context_for_planner = f"{description}\n\n{io_instructions}"
|
|
48
|
+
|
|
49
|
+
# 1. Create a plan
|
|
50
|
+
logger.info("Step 1: Planning...")
|
|
51
|
+
plan: Plan = await self.planner_op(user_request=full_context_for_planner)
|
|
52
|
+
logger.info(f"Plan generated with {len(plan.tasks)} tasks.")
|
|
53
|
+
|
|
54
|
+
report_lines = [f"# Data Interpretation Report for: {description}\n"]
|
|
55
|
+
report_lines.append("## Execution Plan")
|
|
56
|
+
for task in plan.tasks:
|
|
57
|
+
report_lines.append(f"- **Task {task.task_id}**: {task.instruction}")
|
|
58
|
+
report_lines.append("\n---\n")
|
|
59
|
+
|
|
60
|
+
# 2. Execute tasks in Notebook context
|
|
61
|
+
logger.info("Step 2: Executing tasks...")
|
|
62
|
+
history_steps = []
|
|
63
|
+
async with self.sandbox_service.notebook_executor() as notebook:
|
|
64
|
+
for task in plan.tasks:
|
|
65
|
+
logger.info(f"Executing Task {task.task_id}: {task.instruction}")
|
|
66
|
+
report_lines.append(f"## Task {task.task_id}: {task.instruction}")
|
|
67
|
+
|
|
68
|
+
current_code = ""
|
|
69
|
+
exec_result = None
|
|
70
|
+
|
|
71
|
+
for attempt in range(self.max_retries):
|
|
72
|
+
logger.info(f"Attempt {attempt + 1}/{self.max_retries} for task {task.task_id}")
|
|
73
|
+
|
|
74
|
+
# Build history context before every attempt (needed by both generate and debug)
|
|
75
|
+
history_context = self.context_manager.build_history_context(
|
|
76
|
+
history_steps,
|
|
77
|
+
key_order=["task_id", "status", "code", "output"]
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
if attempt == 0:
|
|
81
|
+
prompt = GENERATE_CODE_PROMPT.format(
|
|
82
|
+
user_requirement=description,
|
|
83
|
+
io_instructions=io_instructions, # Pass separately
|
|
84
|
+
plan_status=plan.model_dump_json(),
|
|
85
|
+
current_task=task.instruction, history=history_context
|
|
86
|
+
)
|
|
87
|
+
_, current_code = await self.generator_op(system_prompt=prompt)
|
|
88
|
+
else:
|
|
89
|
+
safe_error_output = truncate_output(exec_result.stderr, MAX_OUTPUT_CHARS)
|
|
90
|
+
prompt = REFLECT_AND_DEBUG_PROMPT.format(
|
|
91
|
+
user_requirement=description,
|
|
92
|
+
io_instructions=io_instructions, # Pass separately
|
|
93
|
+
current_task=task.instruction,
|
|
94
|
+
failed_code=current_code, error_output=safe_error_output,
|
|
95
|
+
history=history_context
|
|
96
|
+
)
|
|
97
|
+
_, current_code = await self.debugger_op(system_prompt=prompt)
|
|
98
|
+
|
|
99
|
+
report_lines.append(f"\n**Attempt {attempt + 1} Code:**\n```python\n{current_code}\n```")
|
|
100
|
+
exec_result = await self.executor_op(code=current_code, mode="notebook", executor_context=notebook)
|
|
101
|
+
|
|
102
|
+
if exec_result.success:
|
|
103
|
+
logger.info(f"Task {task.task_id} succeeded.")
|
|
104
|
+
report_lines.append(f"**Result:** Success\n**Output:**\n```\n{exec_result.stdout}\n```")
|
|
105
|
+
history_steps.append({
|
|
106
|
+
"task_id": task.task_id,
|
|
107
|
+
"status": "Success",
|
|
108
|
+
"code": truncate_output(current_code, MAX_OUTPUT_CHARS),
|
|
109
|
+
"output": truncate_output(exec_result.stdout, MAX_OUTPUT_CHARS),
|
|
110
|
+
})
|
|
111
|
+
break
|
|
112
|
+
else:
|
|
113
|
+
logger.warning(f"Task {task.task_id} failed on attempt {attempt + 1}.")
|
|
114
|
+
report_lines.append(f"**Result:** Failure\n**Error:**\n```\n{exec_result.stderr}\n```")
|
|
115
|
+
history_steps.append({
|
|
116
|
+
"task_id": task.task_id,
|
|
117
|
+
"status": "Failure",
|
|
118
|
+
"code": truncate_output(current_code, MAX_OUTPUT_CHARS),
|
|
119
|
+
"output": truncate_output(exec_result.stderr, MAX_OUTPUT_CHARS),
|
|
120
|
+
})
|
|
121
|
+
|
|
122
|
+
report_lines.append("\n---\n")
|
|
123
|
+
|
|
124
|
+
logger.info("Step 3: Generating final output file...")
|
|
125
|
+
report_lines.append("## Final Output Generation")
|
|
126
|
+
|
|
127
|
+
final_history_context = self.context_manager.build_history_context(
|
|
128
|
+
history_steps,
|
|
129
|
+
key_order=["task_id", "status", "code", "output"]
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
finalize_prompt = FINALIZE_OUTPUT_PROMPT.format(
|
|
133
|
+
user_requirement=description,
|
|
134
|
+
io_instructions=io_instructions, # Pass separately
|
|
135
|
+
history=final_history_context,
|
|
136
|
+
output_filename=output_path.name
|
|
137
|
+
)
|
|
138
|
+
_, final_code = await self.generator_op(system_prompt=finalize_prompt)
|
|
139
|
+
report_lines.append(f"**Finalization Code:**\n```python\n{final_code}\n```")
|
|
140
|
+
|
|
141
|
+
final_exec_result = await self.executor_op(code=final_code, mode="notebook", executor_context=notebook)
|
|
142
|
+
|
|
143
|
+
if final_exec_result.success:
|
|
144
|
+
logger.info("Finalization code executed successfully.")
|
|
145
|
+
report_lines.append("**Result:** Success")
|
|
146
|
+
else:
|
|
147
|
+
logger.error(f"Finalization code failed to execute!\n{final_exec_result.stderr}")
|
|
148
|
+
report_lines.append(f"**Result:** Failure\n**Error:**\n```\n{final_exec_result.stderr}\n```")
|
|
149
|
+
|
|
150
|
+
# 4. Save the execution report as a separate log file
|
|
151
|
+
report_path = output_path.parent / "execution_report.md"
|
|
152
|
+
report_path.write_text("\n".join(report_lines), encoding='utf-8')
|
|
153
|
+
logger.info(f"Execution report saved to {report_path}")
|
|
@@ -0,0 +1,484 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import re
|
|
3
|
+
import shutil
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Dict, List, Optional, Set
|
|
6
|
+
|
|
7
|
+
from dsat.benchmark.benchmark import BaseBenchmark
|
|
8
|
+
from dsat.services.llm import LLMService
|
|
9
|
+
from dsat.services.sandbox import SandboxService
|
|
10
|
+
from dsat.services.workspace import WorkspaceService
|
|
11
|
+
from dsat.utils.context import MAX_HISTORY_CHARS, MAX_OUTPUT_CHARS, truncate_output
|
|
12
|
+
from dsat.workflows.base import DSATWorkflow
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class DeepAnalyzeWorkflow(DSATWorkflow):
|
|
18
|
+
"""
|
|
19
|
+
Workflow implementation for DeepAnalyze-8B.
|
|
20
|
+
|
|
21
|
+
DeepAnalyze uses structured tags (<Analyze>, <Code>, <Execute>, <Answer>) and requires
|
|
22
|
+
multi-round dialog where the system injects code execution results into <Execute> tags.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
# Common output file extensions to look for
|
|
26
|
+
OUTPUT_EXTENSIONS = {'.csv', '.txt', '.json', '.xlsx', '.xls', '.png', '.jpg', '.jpeg', '.pdf', '.html', '.py', '.pkl', '.pickle', '.npy', '.npz', '.h5', '.hdf5', '.parquet'}
|
|
27
|
+
|
|
28
|
+
# Files to ignore when scanning for outputs
|
|
29
|
+
IGNORE_FILES = {'prompt.json', '.gitkeep', '.DS_Store', 'thumbs.db'}
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
operators: Dict[str, Any],
|
|
34
|
+
services: Dict[str, Any],
|
|
35
|
+
agent_config: Dict[str, Any],
|
|
36
|
+
benchmark: Optional[BaseBenchmark] = None,
|
|
37
|
+
):
|
|
38
|
+
super().__init__(operators, services, agent_config)
|
|
39
|
+
self.llm_service: LLMService = services["llm"]
|
|
40
|
+
self.sandbox_service: SandboxService = services["sandbox"]
|
|
41
|
+
self.workspace_service: WorkspaceService = services.get("workspace")
|
|
42
|
+
self.execute_op = operators.get("execute")
|
|
43
|
+
if not self.execute_op:
|
|
44
|
+
raise ValueError("DeepAnalyzeWorkflow requires an 'execute' operator.")
|
|
45
|
+
|
|
46
|
+
self.max_iterations = agent_config.get("max_iterations", 10)
|
|
47
|
+
self.benchmark = benchmark
|
|
48
|
+
|
|
49
|
+
@staticmethod
|
|
50
|
+
def _extract_code_from_segment(segment: str) -> Optional[str]:
|
|
51
|
+
"""Extract python code between <Code>...</Code>, optionally fenced by ```python ... ```."""
|
|
52
|
+
code_match = re.search(r"<Code>(.*?)</Code>", segment, re.DOTALL)
|
|
53
|
+
if not code_match:
|
|
54
|
+
return None
|
|
55
|
+
code_content = code_match.group(1).strip()
|
|
56
|
+
md_match = re.search(r"```(?:python)?(.*?)```", code_content, re.DOTALL)
|
|
57
|
+
return (md_match.group(1).strip() if md_match else code_content)
|
|
58
|
+
|
|
59
|
+
@staticmethod
|
|
60
|
+
def _extract_answer_from_segment(segment: str) -> Optional[str]:
|
|
61
|
+
"""Extract answer content between <Answer>...</Answer> tags."""
|
|
62
|
+
# Try case-insensitive match for <Answer> or <answer>
|
|
63
|
+
answer_match = re.search(r"<Answer>(.*?)</Answer>", segment, re.DOTALL | re.IGNORECASE)
|
|
64
|
+
if not answer_match:
|
|
65
|
+
return None
|
|
66
|
+
answer_content = answer_match.group(1).strip()
|
|
67
|
+
# Remove "Answer:" prefix if present
|
|
68
|
+
if "Answer:" in answer_content:
|
|
69
|
+
answer_content = answer_content.split("Answer:", 1)[-1].strip()
|
|
70
|
+
return answer_content
|
|
71
|
+
|
|
72
|
+
def _should_terminate(self, response: str, iteration: int) -> bool:
|
|
73
|
+
"""Stop when an <Answer> block appears or max iterations reached."""
|
|
74
|
+
# Some models emit a full <Answer>...</Answer> in one turn; others may emit
|
|
75
|
+
# a bare </Answer> later. Prefer terminating as soon as any Answer block
|
|
76
|
+
# is detected to avoid an extra empty turn.
|
|
77
|
+
if re.search(r"<Answer>.*?</Answer>", response, re.DOTALL | re.IGNORECASE) or "</Answer>" in response:
|
|
78
|
+
logger.info("Detected <Answer> block, stopping iterations.")
|
|
79
|
+
return True
|
|
80
|
+
if iteration + 1 >= self.max_iterations:
|
|
81
|
+
logger.info("Reached maximum iterations, stopping workflow.")
|
|
82
|
+
return True
|
|
83
|
+
return False
|
|
84
|
+
|
|
85
|
+
def _build_llm_messages(self, conversation_history: List[Dict[str, str]]) -> List[Dict[str, str]]:
|
|
86
|
+
if not conversation_history:
|
|
87
|
+
return []
|
|
88
|
+
|
|
89
|
+
max_chars = MAX_HISTORY_CHARS
|
|
90
|
+
first = conversation_history[0]
|
|
91
|
+
first_content = truncate_output(first.get("content", ""), max_chars // 2)
|
|
92
|
+
messages = [{"role": first.get("role", "user"), "content": first_content}]
|
|
93
|
+
total_chars = len(first_content)
|
|
94
|
+
|
|
95
|
+
recent_messages = []
|
|
96
|
+
for msg in reversed(conversation_history[1:]):
|
|
97
|
+
content = truncate_output(msg.get("content", ""), MAX_OUTPUT_CHARS)
|
|
98
|
+
msg_len = len(content)
|
|
99
|
+
if total_chars + msg_len > max_chars:
|
|
100
|
+
break
|
|
101
|
+
recent_messages.append({"role": msg.get("role", "user"), "content": content})
|
|
102
|
+
total_chars += msg_len
|
|
103
|
+
|
|
104
|
+
messages.extend(reversed(recent_messages))
|
|
105
|
+
return messages
|
|
106
|
+
|
|
107
|
+
@staticmethod
|
|
108
|
+
def _extract_output_filenames_from_description(description: str) -> List[str]:
|
|
109
|
+
"""
|
|
110
|
+
Extract all potential output filenames from task description.
|
|
111
|
+
Returns a list of filenames (without path) that might be expected outputs.
|
|
112
|
+
"""
|
|
113
|
+
filenames = []
|
|
114
|
+
|
|
115
|
+
# Patterns to match various ways output files are specified
|
|
116
|
+
patterns = [
|
|
117
|
+
# "saved in/to 'filename.ext'" or "saved in/to \"filename.ext\""
|
|
118
|
+
r'saved?\s+(?:in|to|as)\s+["\']([^"\']+\.\w+)["\']',
|
|
119
|
+
# "output file named 'filename.ext'"
|
|
120
|
+
r'output\s+file\s+(?:named|called)?\s*["\']([^"\']+\.\w+)["\']',
|
|
121
|
+
# "save the results/output to 'filename.ext'"
|
|
122
|
+
r'save\s+(?:the\s+)?(?:results?|output|data|file)\s+(?:to|in|as)\s+["\']([^"\']+\.\w+)["\']',
|
|
123
|
+
# "results should be saved in 'filename.ext'"
|
|
124
|
+
r'(?:results?|output)\s+should\s+be\s+saved\s+(?:in|to|as)\s+["\']([^"\']+\.\w+)["\']',
|
|
125
|
+
# "write to 'filename.ext'"
|
|
126
|
+
r'write\s+(?:to|into)\s+["\']([^"\']+\.\w+)["\']',
|
|
127
|
+
# "export to 'filename.ext'"
|
|
128
|
+
r'export\s+(?:to|as)\s+["\']([^"\']+\.\w+)["\']',
|
|
129
|
+
# "filename.csv" or 'filename.csv' standalone in quotes (common patterns)
|
|
130
|
+
r'["\']([a-zA-Z0-9_\-]+\.(?:csv|txt|json|xlsx|png|jpg|pdf|html|py))["\']',
|
|
131
|
+
]
|
|
132
|
+
|
|
133
|
+
seen = set()
|
|
134
|
+
for pattern in patterns:
|
|
135
|
+
for match in re.finditer(pattern, description, re.IGNORECASE):
|
|
136
|
+
filename = match.group(1)
|
|
137
|
+
# Clean up the filename
|
|
138
|
+
filename = filename.strip()
|
|
139
|
+
# Only add if it looks like a valid filename and not seen before
|
|
140
|
+
if filename and '/' not in filename and '\\' not in filename and filename not in seen:
|
|
141
|
+
filenames.append(filename)
|
|
142
|
+
seen.add(filename)
|
|
143
|
+
|
|
144
|
+
return filenames
|
|
145
|
+
|
|
146
|
+
def _get_initial_sandbox_files(self, sandbox_workdir: Path) -> Set[str]:
|
|
147
|
+
"""Get the set of files initially present in sandbox (to detect new files later)."""
|
|
148
|
+
if not sandbox_workdir.exists():
|
|
149
|
+
return set()
|
|
150
|
+
return {f.name for f in sandbox_workdir.iterdir() if f.is_file()}
|
|
151
|
+
|
|
152
|
+
def _find_new_output_files(self, sandbox_workdir: Path, initial_files: Set[str]) -> List[Path]:
|
|
153
|
+
"""
|
|
154
|
+
Find all new files created in sandbox since initial state.
|
|
155
|
+
Returns list of Path objects for new output files.
|
|
156
|
+
"""
|
|
157
|
+
new_files = []
|
|
158
|
+
if not sandbox_workdir.exists():
|
|
159
|
+
return new_files
|
|
160
|
+
|
|
161
|
+
for f in sandbox_workdir.iterdir():
|
|
162
|
+
if not f.is_file():
|
|
163
|
+
continue
|
|
164
|
+
if f.name in initial_files:
|
|
165
|
+
continue
|
|
166
|
+
if f.name.lower() in self.IGNORE_FILES:
|
|
167
|
+
continue
|
|
168
|
+
if f.name.startswith('_sandbox_script_'):
|
|
169
|
+
continue
|
|
170
|
+
# Check if it's a recognized output type
|
|
171
|
+
if f.suffix.lower() in self.OUTPUT_EXTENSIONS or f.suffix == '':
|
|
172
|
+
new_files.append(f)
|
|
173
|
+
|
|
174
|
+
return new_files
|
|
175
|
+
|
|
176
|
+
def _collect_outputs_to_destination(
|
|
177
|
+
self,
|
|
178
|
+
sandbox_workdir: Path,
|
|
179
|
+
output_path: Path,
|
|
180
|
+
expected_filenames: List[str],
|
|
181
|
+
initial_files: Set[str],
|
|
182
|
+
) -> bool:
|
|
183
|
+
"""
|
|
184
|
+
Collect output files from sandbox to destination directory.
|
|
185
|
+
|
|
186
|
+
Strategy:
|
|
187
|
+
1. First, check if the expected output file (output_path.name) exists
|
|
188
|
+
2. Then, check for any files matching expected_filenames from task description
|
|
189
|
+
3. Finally, collect any new files created during execution
|
|
190
|
+
|
|
191
|
+
All matching files are copied to output_path.parent, preserving original names.
|
|
192
|
+
The primary output is also copied to output_path for compatibility.
|
|
193
|
+
|
|
194
|
+
Returns True if at least one output file was collected.
|
|
195
|
+
"""
|
|
196
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
197
|
+
collected = False
|
|
198
|
+
copied_files = set()
|
|
199
|
+
|
|
200
|
+
# 1. Check for the default expected output file
|
|
201
|
+
default_output = sandbox_workdir / output_path.name
|
|
202
|
+
if default_output.exists():
|
|
203
|
+
shutil.copy(default_output, output_path)
|
|
204
|
+
copied_files.add(output_path.name)
|
|
205
|
+
collected = True
|
|
206
|
+
logger.info("Copied default output file to %s", output_path)
|
|
207
|
+
|
|
208
|
+
# 2. Check for files specified in task description
|
|
209
|
+
for filename in expected_filenames:
|
|
210
|
+
src_file = sandbox_workdir / filename
|
|
211
|
+
if src_file.exists() and filename not in copied_files:
|
|
212
|
+
dst_file = output_path.parent / filename
|
|
213
|
+
shutil.copy(src_file, dst_file)
|
|
214
|
+
copied_files.add(filename)
|
|
215
|
+
collected = True
|
|
216
|
+
logger.info("Copied expected output file '%s' to %s", filename, dst_file)
|
|
217
|
+
|
|
218
|
+
# If no default output was found, also copy first expected file as the default
|
|
219
|
+
if not (output_path.exists()):
|
|
220
|
+
shutil.copy(src_file, output_path)
|
|
221
|
+
logger.info("Also copied '%s' as default output to %s", filename, output_path)
|
|
222
|
+
|
|
223
|
+
# 3. Collect any other new files created during execution
|
|
224
|
+
new_files = self._find_new_output_files(sandbox_workdir, initial_files)
|
|
225
|
+
for src_file in new_files:
|
|
226
|
+
if src_file.name not in copied_files:
|
|
227
|
+
dst_file = output_path.parent / src_file.name
|
|
228
|
+
shutil.copy(src_file, dst_file)
|
|
229
|
+
copied_files.add(src_file.name)
|
|
230
|
+
collected = True
|
|
231
|
+
logger.info("Copied new output file '%s' to %s", src_file.name, dst_file)
|
|
232
|
+
|
|
233
|
+
# If still no default output, use first new file
|
|
234
|
+
if not output_path.exists():
|
|
235
|
+
shutil.copy(src_file, output_path)
|
|
236
|
+
logger.info("Also copied '%s' as default output to %s", src_file.name, output_path)
|
|
237
|
+
|
|
238
|
+
if collected:
|
|
239
|
+
logger.info("Total %d output file(s) collected to %s", len(copied_files), output_path.parent)
|
|
240
|
+
|
|
241
|
+
return collected
|
|
242
|
+
|
|
243
|
+
def _write_answer_to_file(self, output_path: Path, answer_content: str) -> None:
|
|
244
|
+
"""Write answer content to output file based on file extension."""
|
|
245
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
246
|
+
|
|
247
|
+
suffix = output_path.suffix.lower()
|
|
248
|
+
|
|
249
|
+
if suffix == '.txt':
|
|
250
|
+
output_path.write_text(answer_content, encoding='utf-8')
|
|
251
|
+
logger.info("Answer written to text file: %s", output_path)
|
|
252
|
+
|
|
253
|
+
elif suffix == '.csv':
|
|
254
|
+
# For CSV files, check if content looks like CSV data
|
|
255
|
+
if '\n' in answer_content and ',' in answer_content:
|
|
256
|
+
output_path.write_text(answer_content, encoding='utf-8')
|
|
257
|
+
else:
|
|
258
|
+
# Wrap simple answer in CSV format
|
|
259
|
+
output_path.write_text(f"answer\n{answer_content}\n", encoding='utf-8')
|
|
260
|
+
logger.info("Answer written to CSV file: %s", output_path)
|
|
261
|
+
|
|
262
|
+
elif suffix == '.json':
|
|
263
|
+
# Try to format as JSON if possible
|
|
264
|
+
import json
|
|
265
|
+
try:
|
|
266
|
+
# Check if it's already valid JSON
|
|
267
|
+
json.loads(answer_content)
|
|
268
|
+
output_path.write_text(answer_content, encoding='utf-8')
|
|
269
|
+
except json.JSONDecodeError:
|
|
270
|
+
# Wrap in JSON format
|
|
271
|
+
output_path.write_text(json.dumps({"answer": answer_content}), encoding='utf-8')
|
|
272
|
+
logger.info("Answer written to JSON file: %s", output_path)
|
|
273
|
+
|
|
274
|
+
else:
|
|
275
|
+
# Default: write as plain text
|
|
276
|
+
output_path.write_text(answer_content, encoding='utf-8')
|
|
277
|
+
logger.info("Answer written to file: %s", output_path)
|
|
278
|
+
|
|
279
|
+
@staticmethod
|
|
280
|
+
def _build_initial_prompt(
|
|
281
|
+
description: str,
|
|
282
|
+
io_instructions: str,
|
|
283
|
+
data_dir: Path,
|
|
284
|
+
output_path: Path,
|
|
285
|
+
) -> str:
|
|
286
|
+
return f"""You are an expert data scientist using the DeepAnalyze model to solve data science tasks.
|
|
287
|
+
|
|
288
|
+
Task Description:
|
|
289
|
+
{description}
|
|
290
|
+
|
|
291
|
+
Input/Output Requirements:
|
|
292
|
+
{io_instructions}
|
|
293
|
+
|
|
294
|
+
Data Directory: {data_dir}
|
|
295
|
+
Output File: {output_path}
|
|
296
|
+
|
|
297
|
+
Please use the following format to analyze and solve the problem:
|
|
298
|
+
1. Analyze the task in the <Analyze>...</Analyze> tags
|
|
299
|
+
2. Provide the code to execute in the <Code>...</Code> tags (may include ```python``` code blocks)
|
|
300
|
+
3. The system will provide code execution results in the <Execute>...</Execute> tags
|
|
301
|
+
4. When you determine that the code execution results can successfully solve the problem, output the final answer in the <Answer>...</Answer> tags
|
|
302
|
+
|
|
303
|
+
Please begin the first round of analysis."""
|
|
304
|
+
|
|
305
|
+
async def solve(
|
|
306
|
+
self,
|
|
307
|
+
description: str,
|
|
308
|
+
io_instructions: str,
|
|
309
|
+
data_dir: Path,
|
|
310
|
+
output_path: Path,
|
|
311
|
+
) -> None:
|
|
312
|
+
"""
|
|
313
|
+
Execute the DeepAnalyze workflow to solve a data science task.
|
|
314
|
+
|
|
315
|
+
The workflow:
|
|
316
|
+
1. Sends task description to LLM
|
|
317
|
+
2. Extracts and executes code from <Code> blocks
|
|
318
|
+
3. Feeds execution results back via <Execute> blocks
|
|
319
|
+
4. Repeats until <Answer> tag is found or max iterations reached
|
|
320
|
+
5. Collects all output files from sandbox to destination
|
|
321
|
+
"""
|
|
322
|
+
logger.info(
|
|
323
|
+
"DeepAnalyzeWorkflow starting to solve task. Target output: %s",
|
|
324
|
+
output_path,
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
if not self.workspace_service:
|
|
328
|
+
raise ValueError("WorkspaceService is required for DeepAnalyzeWorkflow.")
|
|
329
|
+
|
|
330
|
+
# Extract expected output filenames from task description
|
|
331
|
+
expected_filenames = self._extract_output_filenames_from_description(description)
|
|
332
|
+
if expected_filenames:
|
|
333
|
+
logger.info("Expected output files from task description: %s", expected_filenames)
|
|
334
|
+
|
|
335
|
+
# Get sandbox workdir and record initial files
|
|
336
|
+
sandbox_workdir = self.workspace_service.get_path("sandbox_workdir")
|
|
337
|
+
initial_files = self._get_initial_sandbox_files(sandbox_workdir)
|
|
338
|
+
logger.debug("Initial sandbox files: %s", initial_files)
|
|
339
|
+
|
|
340
|
+
initial_prompt = self._build_initial_prompt(
|
|
341
|
+
description=description,
|
|
342
|
+
io_instructions=io_instructions,
|
|
343
|
+
data_dir=data_dir,
|
|
344
|
+
output_path=output_path,
|
|
345
|
+
)
|
|
346
|
+
conversation_history: List[Dict[str, str]] = [
|
|
347
|
+
{"role": "user", "content": initial_prompt}
|
|
348
|
+
]
|
|
349
|
+
|
|
350
|
+
final_answer_content: Optional[str] = None
|
|
351
|
+
# Keep the most recent assistant segment that contained an <Answer> block,
|
|
352
|
+
# since some models emit a bare </Answer> in a follow-up turn.
|
|
353
|
+
last_answer_segment: Optional[str] = None
|
|
354
|
+
# If the model fails to emit a <Code> block, retry within the same iteration
|
|
355
|
+
# with a short nudge, up to this many times.
|
|
356
|
+
no_code_retries_max: int = int(self.agent_config.get("no_code_retries", 2))
|
|
357
|
+
|
|
358
|
+
for iteration in range(self.max_iterations):
|
|
359
|
+
logger.info(
|
|
360
|
+
"--- DeepAnalyze Iteration %d/%d ---",
|
|
361
|
+
iteration + 1,
|
|
362
|
+
self.max_iterations,
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
no_code_retry_count = 0
|
|
366
|
+
extracted_code: Optional[str] = None
|
|
367
|
+
llm_response = ""
|
|
368
|
+
|
|
369
|
+
while True:
|
|
370
|
+
# LLMService.call() only supports single prompts, so use the underlying multi-turn helper.
|
|
371
|
+
response = await self.llm_service._make_llm_call_with_retries( # type: ignore[attr-defined]
|
|
372
|
+
messages=self._build_llm_messages(conversation_history),
|
|
373
|
+
max_retries=self.llm_service.config.max_retries,
|
|
374
|
+
)
|
|
375
|
+
llm_response = response.choices[0].message.content
|
|
376
|
+
logger.debug("DeepAnalyze raw response: %s", llm_response)
|
|
377
|
+
conversation_history.append({"role": "assistant", "content": llm_response})
|
|
378
|
+
|
|
379
|
+
# Extract answer if present
|
|
380
|
+
extracted_answer = self._extract_answer_from_segment(llm_response)
|
|
381
|
+
if extracted_answer:
|
|
382
|
+
final_answer_content = extracted_answer
|
|
383
|
+
last_answer_segment = llm_response
|
|
384
|
+
logger.info("Extracted answer from <Answer> tag: %s", extracted_answer[:100] + "..." if len(extracted_answer) > 100 else extracted_answer)
|
|
385
|
+
|
|
386
|
+
if self._should_terminate(llm_response, iteration):
|
|
387
|
+
break
|
|
388
|
+
|
|
389
|
+
extracted_code = self._extract_code_from_segment(llm_response)
|
|
390
|
+
if extracted_code:
|
|
391
|
+
break
|
|
392
|
+
|
|
393
|
+
if no_code_retry_count >= no_code_retries_max:
|
|
394
|
+
logger.warning(
|
|
395
|
+
"No <Code> block found after %d retry(ies) in iteration %d, continuing to next iteration.",
|
|
396
|
+
no_code_retry_count,
|
|
397
|
+
iteration + 1,
|
|
398
|
+
)
|
|
399
|
+
break
|
|
400
|
+
|
|
401
|
+
no_code_retry_count += 1
|
|
402
|
+
logger.warning(
|
|
403
|
+
"No <Code> block found in iteration %d, retrying (%d/%d).",
|
|
404
|
+
iteration + 1,
|
|
405
|
+
no_code_retry_count,
|
|
406
|
+
no_code_retries_max,
|
|
407
|
+
)
|
|
408
|
+
conversation_history.append(
|
|
409
|
+
{
|
|
410
|
+
"role": "user",
|
|
411
|
+
"content": (
|
|
412
|
+
"<Feedback>\n"
|
|
413
|
+
"No <Code>...</Code> block was detected. "
|
|
414
|
+
"Please provide executable Python code inside <Code> tags so it can be run.\n"
|
|
415
|
+
"</Feedback>\n"
|
|
416
|
+
),
|
|
417
|
+
}
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
if self._should_terminate(llm_response, iteration):
|
|
421
|
+
break
|
|
422
|
+
|
|
423
|
+
if not extracted_code:
|
|
424
|
+
continue
|
|
425
|
+
|
|
426
|
+
exec_result = await self.execute_op(code=extracted_code, mode="script")
|
|
427
|
+
if exec_result.success:
|
|
428
|
+
logger.info("Code execution succeeded in iteration %d.", iteration + 1)
|
|
429
|
+
execution_output = exec_result.stdout or "Code executed successfully."
|
|
430
|
+
else:
|
|
431
|
+
logger.warning("Code execution failed in iteration %d.", iteration + 1)
|
|
432
|
+
execution_output = f"Error during execution:\n{exec_result.stderr}"
|
|
433
|
+
|
|
434
|
+
safe_execution_output = truncate_output(execution_output, MAX_OUTPUT_CHARS)
|
|
435
|
+
execute_block = f"\n<Execute>\n{safe_execution_output}\n</Execute>\n"
|
|
436
|
+
conversation_history.append({"role": "user", "content": execute_block})
|
|
437
|
+
|
|
438
|
+
# After successful execution, collect intermediate outputs
|
|
439
|
+
if exec_result.success:
|
|
440
|
+
collected = self._collect_outputs_to_destination(
|
|
441
|
+
sandbox_workdir=sandbox_workdir,
|
|
442
|
+
output_path=output_path,
|
|
443
|
+
expected_filenames=expected_filenames,
|
|
444
|
+
initial_files=initial_files,
|
|
445
|
+
)
|
|
446
|
+
if collected:
|
|
447
|
+
logger.info("Intermediate outputs collected in iteration %d", iteration + 1)
|
|
448
|
+
|
|
449
|
+
# Final output collection after all iterations
|
|
450
|
+
logger.info("Workflow iterations complete. Performing final output collection...")
|
|
451
|
+
collected = self._collect_outputs_to_destination(
|
|
452
|
+
sandbox_workdir=sandbox_workdir,
|
|
453
|
+
output_path=output_path,
|
|
454
|
+
expected_filenames=expected_filenames,
|
|
455
|
+
initial_files=initial_files,
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
if not collected:
|
|
459
|
+
# No files were collected from sandbox
|
|
460
|
+
if not final_answer_content:
|
|
461
|
+
# Try to recover answer from the last assistant segment that had it.
|
|
462
|
+
if last_answer_segment:
|
|
463
|
+
final_answer_content = self._extract_answer_from_segment(last_answer_segment)
|
|
464
|
+
# Fallback: scan conversation history for the most recent <Answer>.
|
|
465
|
+
if not final_answer_content:
|
|
466
|
+
for msg in reversed(conversation_history):
|
|
467
|
+
if msg.get("role") != "assistant":
|
|
468
|
+
continue
|
|
469
|
+
recovered = self._extract_answer_from_segment(msg.get("content", ""))
|
|
470
|
+
if recovered:
|
|
471
|
+
final_answer_content = recovered
|
|
472
|
+
break
|
|
473
|
+
|
|
474
|
+
if final_answer_content:
|
|
475
|
+
logger.info(
|
|
476
|
+
"No output files found in sandbox, but answer content recovered from <Answer>. Writing to output file."
|
|
477
|
+
)
|
|
478
|
+
self._write_answer_to_file(output_path, final_answer_content)
|
|
479
|
+
else:
|
|
480
|
+
logger.warning(
|
|
481
|
+
"DeepAnalyzeWorkflow finished but no output files found in sandbox and no <Answer> tag content extracted."
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
logger.info("DeepAnalyzeWorkflow completed. Output directory: %s", output_path.parent)
|