dslighting 1.3.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dsat/__init__.py +3 -0
- dsat/benchmark/__init__.py +1 -0
- dsat/benchmark/benchmark.py +168 -0
- dsat/benchmark/datasci.py +291 -0
- dsat/benchmark/mle.py +777 -0
- dsat/benchmark/sciencebench.py +304 -0
- dsat/common/__init__.py +0 -0
- dsat/common/constants.py +11 -0
- dsat/common/exceptions.py +48 -0
- dsat/common/typing.py +19 -0
- dsat/config.py +79 -0
- dsat/models/__init__.py +3 -0
- dsat/models/candidates.py +16 -0
- dsat/models/formats.py +52 -0
- dsat/models/task.py +64 -0
- dsat/operators/__init__.py +0 -0
- dsat/operators/aflow_ops.py +90 -0
- dsat/operators/autokaggle_ops.py +170 -0
- dsat/operators/automind_ops.py +38 -0
- dsat/operators/base.py +22 -0
- dsat/operators/code.py +45 -0
- dsat/operators/dsagent_ops.py +123 -0
- dsat/operators/llm_basic.py +84 -0
- dsat/prompts/__init__.py +0 -0
- dsat/prompts/aflow_prompt.py +76 -0
- dsat/prompts/aide_prompt.py +52 -0
- dsat/prompts/autokaggle_prompt.py +290 -0
- dsat/prompts/automind_prompt.py +29 -0
- dsat/prompts/common.py +51 -0
- dsat/prompts/data_interpreter_prompt.py +82 -0
- dsat/prompts/dsagent_prompt.py +88 -0
- dsat/runner.py +554 -0
- dsat/services/__init__.py +0 -0
- dsat/services/data_analyzer.py +387 -0
- dsat/services/llm.py +486 -0
- dsat/services/llm_single.py +421 -0
- dsat/services/sandbox.py +386 -0
- dsat/services/states/__init__.py +0 -0
- dsat/services/states/autokaggle_state.py +43 -0
- dsat/services/states/base.py +14 -0
- dsat/services/states/dsa_log.py +13 -0
- dsat/services/states/experience.py +237 -0
- dsat/services/states/journal.py +153 -0
- dsat/services/states/operator_library.py +290 -0
- dsat/services/vdb.py +76 -0
- dsat/services/workspace.py +178 -0
- dsat/tasks/__init__.py +3 -0
- dsat/tasks/handlers.py +376 -0
- dsat/templates/open_ended/grade_template.py +107 -0
- dsat/tools/__init__.py +4 -0
- dsat/utils/__init__.py +0 -0
- dsat/utils/context.py +172 -0
- dsat/utils/dynamic_import.py +71 -0
- dsat/utils/parsing.py +33 -0
- dsat/workflows/__init__.py +12 -0
- dsat/workflows/base.py +53 -0
- dsat/workflows/factory.py +439 -0
- dsat/workflows/manual/__init__.py +0 -0
- dsat/workflows/manual/autokaggle_workflow.py +148 -0
- dsat/workflows/manual/data_interpreter_workflow.py +153 -0
- dsat/workflows/manual/deepanalyze_workflow.py +484 -0
- dsat/workflows/manual/dsagent_workflow.py +76 -0
- dsat/workflows/search/__init__.py +0 -0
- dsat/workflows/search/aflow_workflow.py +344 -0
- dsat/workflows/search/aide_workflow.py +283 -0
- dsat/workflows/search/automind_workflow.py +237 -0
- dsat/workflows/templates/__init__.py +0 -0
- dsat/workflows/templates/basic_kaggle_loop.py +71 -0
- dslighting/__init__.py +170 -0
- dslighting/core/__init__.py +13 -0
- dslighting/core/agent.py +646 -0
- dslighting/core/config_builder.py +318 -0
- dslighting/core/data_loader.py +422 -0
- dslighting/core/task_detector.py +422 -0
- dslighting/utils/__init__.py +19 -0
- dslighting/utils/defaults.py +151 -0
- dslighting-1.3.9.dist-info/METADATA +554 -0
- dslighting-1.3.9.dist-info/RECORD +80 -0
- dslighting-1.3.9.dist-info/WHEEL +5 -0
- dslighting-1.3.9.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Any
|
|
4
|
+
import pandas as pd
|
|
5
|
+
|
|
6
|
+
# This is a generic LLM-based grader for open-ended tasks.
|
|
7
|
+
# It reads 'rubric.md' from the task directory and evaluates the submission.
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
from dsat.services.llm import LLMService
|
|
11
|
+
from dsat.config import LLMConfig
|
|
12
|
+
except ImportError:
|
|
13
|
+
# Fallback for when running outside of dsat package context
|
|
14
|
+
import sys
|
|
15
|
+
sys.path.append(str(Path(__file__).resolve().parent.parent.parent.parent))
|
|
16
|
+
from dsat.services.llm import LLMService
|
|
17
|
+
from dsat.config import LLMConfig
|
|
18
|
+
|
|
19
|
+
class Report:
|
|
20
|
+
def __init__(self, score, feedback):
|
|
21
|
+
self.score = score
|
|
22
|
+
self.feedback = feedback
|
|
23
|
+
# Standard fields expected by the framework
|
|
24
|
+
self.is_lower_better = False
|
|
25
|
+
self.submission_exists = True
|
|
26
|
+
self.valid_submission = True
|
|
27
|
+
self.gold_medal = score >= 0.9
|
|
28
|
+
self.silver_medal = score >= 0.7
|
|
29
|
+
self.bronze_medal = score >= 0.5
|
|
30
|
+
self.above_median = score >= 0.5
|
|
31
|
+
self.submission_path = ""
|
|
32
|
+
self.competition_id = "open_ended_task"
|
|
33
|
+
|
|
34
|
+
def grade(submission_path: Path, competition: Any) -> Report:
|
|
35
|
+
"""
|
|
36
|
+
Grades the submission using an LLM Judge against rubric.md.
|
|
37
|
+
"""
|
|
38
|
+
# 1. Load the Rubric
|
|
39
|
+
task_dir = competition.raw_dir.parent
|
|
40
|
+
rubric_path = task_dir / "rubric.md"
|
|
41
|
+
|
|
42
|
+
if not rubric_path.exists():
|
|
43
|
+
# Fallback if no rubric exists
|
|
44
|
+
print(f"Warning: Rubric not found at {rubric_path}. Returning default score.")
|
|
45
|
+
return Report(0.5, "No grading rubric defined.")
|
|
46
|
+
|
|
47
|
+
rubric_content = rubric_path.read_text(encoding="utf-8")
|
|
48
|
+
|
|
49
|
+
# 2. Load the Submission Content (Preview)
|
|
50
|
+
# Since it's open-ended, the 'submission_path' might be a CSV, code, or just a marker.
|
|
51
|
+
# We'll try to peek at the output artifacts if possible, or assume the agent's recent work
|
|
52
|
+
# is what we are grading. Ideally, AIDE produces a submission file.
|
|
53
|
+
|
|
54
|
+
submission_content = "No submission content readable."
|
|
55
|
+
if submission_path.exists():
|
|
56
|
+
try:
|
|
57
|
+
if submission_path.suffix == '.csv':
|
|
58
|
+
df = pd.read_csv(submission_path)
|
|
59
|
+
submission_content = f"CSV Submission Preview:\n{df.head().to_markdown()}"
|
|
60
|
+
else:
|
|
61
|
+
submission_content = submission_path.read_text(encoding="utf-8")[:2000]
|
|
62
|
+
except Exception as e:
|
|
63
|
+
submission_content = f"Error reading submission: {e}"
|
|
64
|
+
|
|
65
|
+
# 3. Setup LLM for Judging
|
|
66
|
+
# Note: In a real run, we might want to inject the API key securely.
|
|
67
|
+
# Here we assume environment variables are set (which they are in DSATRunner).
|
|
68
|
+
try:
|
|
69
|
+
api_key = os.getenv("API_KEY", "EMPTY")
|
|
70
|
+
base_url = os.getenv("API_BASE", "https://api.openai.com/v1")
|
|
71
|
+
model = os.getenv("LLM_MODEL", "gpt-4o")
|
|
72
|
+
|
|
73
|
+
llm = LLMService(LLMConfig(api_key=api_key, api_base=base_url, model=model))
|
|
74
|
+
|
|
75
|
+
prompt = f"""You are an impartial Judge. Evaluate the following submission against the provided Rubric.
|
|
76
|
+
|
|
77
|
+
# RUBRIC
|
|
78
|
+
{rubric_content}
|
|
79
|
+
|
|
80
|
+
# SUBMISSION CONTENT
|
|
81
|
+
{submission_content}
|
|
82
|
+
|
|
83
|
+
# INSTRUCTION
|
|
84
|
+
Assess the submission.
|
|
85
|
+
Output ONLY a float number between 0.0 and 1.0 on the first line.
|
|
86
|
+
On subsequent lines, provide brief feedback.
|
|
87
|
+
"""
|
|
88
|
+
# Synchronous call wrapper or direct call if possible.
|
|
89
|
+
# Since grade() is synchronous in standard mlebench, we need a way to run async code.
|
|
90
|
+
import asyncio
|
|
91
|
+
response = asyncio.run(llm.achat([{"role": "user", "content": prompt}]))
|
|
92
|
+
|
|
93
|
+
lines = response.strip().split('\n')
|
|
94
|
+
try:
|
|
95
|
+
score = float(lines[0].strip())
|
|
96
|
+
except ValueError:
|
|
97
|
+
# Fallback if LLM is chatty
|
|
98
|
+
import re
|
|
99
|
+
match = re.search(r"(\d+(\.\d+)?)", lines[0])
|
|
100
|
+
score = float(match.group(1)) if match else 0.5
|
|
101
|
+
|
|
102
|
+
feedback = "\n".join(lines[1:])
|
|
103
|
+
return Report(score, feedback)
|
|
104
|
+
|
|
105
|
+
except Exception as e:
|
|
106
|
+
print(f"LLM Judging failed: {e}")
|
|
107
|
+
return Report(0.0, f"Judging failed: {e}")
|
dsat/tools/__init__.py
ADDED
dsat/utils/__init__.py
ADDED
|
File without changes
|
dsat/utils/context.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
from itertools import groupby
|
|
2
|
+
import logging
|
|
3
|
+
from typing import List, Dict, Optional, Any
|
|
4
|
+
|
|
5
|
+
from dsat.services.llm import LLMService
|
|
6
|
+
|
|
7
|
+
logger = logging.getLogger(__name__)
|
|
8
|
+
|
|
9
|
+
# Use character count as a simple, fast proxy for token count.
|
|
10
|
+
# In a production system, you would use a real tokenizer.
|
|
11
|
+
MAX_HISTORY_CHARS = 32000
|
|
12
|
+
MAX_ERROR_CHARS = 8000
|
|
13
|
+
MAX_KNOWLEDGE_CHARS = 8000
|
|
14
|
+
MAX_OUTPUT_CHARS = 16000 # Maximum characters for execution output
|
|
15
|
+
|
|
16
|
+
class ContextManager:
|
|
17
|
+
"""
|
|
18
|
+
A utility for intelligently building prompt context to prevent overflow.
|
|
19
|
+
It uses summarization, windowing, and truncation to manage context size.
|
|
20
|
+
"""
|
|
21
|
+
def __init__(self, llm_service: Optional[LLMService] = None):
|
|
22
|
+
self.llm_service = llm_service
|
|
23
|
+
|
|
24
|
+
def build_history_context(self, history: List[Dict[str, Any]], key_order: List[str]) -> str:
|
|
25
|
+
"""
|
|
26
|
+
Builds a context string from a list of historical events using a windowing strategy.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
history: A list of dictionary-like objects representing historical steps.
|
|
30
|
+
key_order: The keys to extract from each history object and their order.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
A formatted string of the most recent history that fits within the budget.
|
|
34
|
+
"""
|
|
35
|
+
if not history:
|
|
36
|
+
return "No history available."
|
|
37
|
+
|
|
38
|
+
context_parts = []
|
|
39
|
+
total_chars = 0
|
|
40
|
+
|
|
41
|
+
# Iterate backwards from the most recent history
|
|
42
|
+
for item in reversed(history):
|
|
43
|
+
part = "\n".join([f"{key.capitalize()}: {item.get(key, 'N/A')}" for key in key_order])
|
|
44
|
+
part_len = len(part)
|
|
45
|
+
|
|
46
|
+
if total_chars + part_len > MAX_HISTORY_CHARS:
|
|
47
|
+
logger.warning("History context truncated to fit budget.")
|
|
48
|
+
break
|
|
49
|
+
|
|
50
|
+
context_parts.append(part)
|
|
51
|
+
total_chars += part_len
|
|
52
|
+
|
|
53
|
+
# Reverse again to restore chronological order
|
|
54
|
+
final_context = "\n---\n".join(reversed(context_parts))
|
|
55
|
+
|
|
56
|
+
if len(history) > len(context_parts):
|
|
57
|
+
final_context = f"[... {len(history) - len(context_parts)} older steps summarized ...]\n\n{final_context}"
|
|
58
|
+
|
|
59
|
+
return final_context
|
|
60
|
+
|
|
61
|
+
def summarize_error(self, stderr: str, exc_type: Optional[str] = None) -> str:
|
|
62
|
+
"""
|
|
63
|
+
Extracts the most relevant parts of a long error message.
|
|
64
|
+
"""
|
|
65
|
+
if not stderr:
|
|
66
|
+
return "No error output."
|
|
67
|
+
|
|
68
|
+
# Prioritize the exception type if available
|
|
69
|
+
summary = f"Exception Type: {exc_type}\n\n" if exc_type else ""
|
|
70
|
+
|
|
71
|
+
if len(stderr) > MAX_ERROR_CHARS:
|
|
72
|
+
# Keep the beginning and the end of the traceback
|
|
73
|
+
head = stderr[:MAX_ERROR_CHARS // 2]
|
|
74
|
+
tail = stderr[-MAX_ERROR_CHARS // 2:]
|
|
75
|
+
summary += f"Traceback (truncated):\n{head}\n[...]\n{tail}"
|
|
76
|
+
logger.warning("Error context truncated to fit budget.")
|
|
77
|
+
else:
|
|
78
|
+
summary += f"Traceback:\n{stderr}"
|
|
79
|
+
|
|
80
|
+
return summary
|
|
81
|
+
|
|
82
|
+
async def summarize_knowledge(self, knowledge_docs: List[str], task_goal: str) -> str:
|
|
83
|
+
"""
|
|
84
|
+
Uses an LLM to summarize a list of retrieved documents into a concise
|
|
85
|
+
knowledge block relevant to the task.
|
|
86
|
+
"""
|
|
87
|
+
if not knowledge_docs:
|
|
88
|
+
return "No relevant knowledge was retrieved for this task."
|
|
89
|
+
if not self.llm_service:
|
|
90
|
+
logger.warning("LLMService not provided to ContextManager; returning raw knowledge.")
|
|
91
|
+
return "\n\n".join(knowledge_docs)[:MAX_KNOWLEDGE_CHARS]
|
|
92
|
+
|
|
93
|
+
full_knowledge = "\n\n---\n\n".join(knowledge_docs)
|
|
94
|
+
|
|
95
|
+
if len(full_knowledge) < MAX_KNOWLEDGE_CHARS:
|
|
96
|
+
return full_knowledge
|
|
97
|
+
|
|
98
|
+
logger.info("Retrieved knowledge is too long; summarizing with LLM...")
|
|
99
|
+
prompt = (
|
|
100
|
+
f"The user is trying to achieve the following goal: '{task_goal}'.\n\n"
|
|
101
|
+
"The following are retrieved documents that might be helpful. Summarize the most critical "
|
|
102
|
+
"patterns, code snippets, and strategies from these documents that are directly relevant "
|
|
103
|
+
"to the user's goal. Be concise."
|
|
104
|
+
f"\n\n# DOCUMENTS\n{full_knowledge}"
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
summary = await self.llm_service.call(prompt)
|
|
108
|
+
return summary
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def summarize_repetitive_logs(log_text: str, min_repeats: int = 3) -> str:
|
|
112
|
+
"""
|
|
113
|
+
Summarizes consecutive, identical lines in a log string.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
log_text: The raw log output.
|
|
117
|
+
min_repeats: The minimum number of consecutive repeats to summarize.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
A cleaned log string with repeated lines summarized.
|
|
121
|
+
"""
|
|
122
|
+
if not log_text:
|
|
123
|
+
return ""
|
|
124
|
+
|
|
125
|
+
lines = log_text.strip().split('\n')
|
|
126
|
+
summarized_lines = []
|
|
127
|
+
|
|
128
|
+
# Use groupby to find consecutive identical elements
|
|
129
|
+
for line, group in groupby(lines):
|
|
130
|
+
count = len(list(group))
|
|
131
|
+
if count >= min_repeats:
|
|
132
|
+
# Add a summary line for the repeated block
|
|
133
|
+
summarized_lines.append(f"<{line.strip()} (repeated {count} times)>")
|
|
134
|
+
else:
|
|
135
|
+
# If not repeated enough, add the lines back as they were
|
|
136
|
+
summarized_lines.extend([line] * count)
|
|
137
|
+
|
|
138
|
+
return '\n'.join(summarized_lines)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def truncate_output(output: str, max_chars: int = MAX_OUTPUT_CHARS) -> str:
|
|
142
|
+
"""
|
|
143
|
+
Truncates long output by keeping the beginning and end portions.
|
|
144
|
+
This is useful when execution output exceeds max_seq_len limits.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
output: The raw execution output.
|
|
148
|
+
max_chars: Maximum allowed characters (default: MAX_OUTPUT_CHARS).
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
Truncated output with truncation notice if needed.
|
|
152
|
+
"""
|
|
153
|
+
if not output:
|
|
154
|
+
return ""
|
|
155
|
+
|
|
156
|
+
if len(output) <= max_chars:
|
|
157
|
+
return output
|
|
158
|
+
|
|
159
|
+
# Keep the first and last portions
|
|
160
|
+
head_size = max_chars // 2
|
|
161
|
+
tail_size = max_chars - head_size
|
|
162
|
+
|
|
163
|
+
head = output[:head_size]
|
|
164
|
+
tail = output[-tail_size:]
|
|
165
|
+
|
|
166
|
+
truncation_notice = (
|
|
167
|
+
f"\n\n... [TRUNCATED: {len(output) - max_chars} characters omitted "
|
|
168
|
+
f"to prevent context overflow] ...\n\n"
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
logger.warning(f"Output truncated from {len(output)} to {max_chars} characters.")
|
|
172
|
+
return head + truncation_notice + tail
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utility for dynamically importing and instantiating classes from code strings.
|
|
3
|
+
Ported from the AFlow project for use in the meta-optimization evaluation step.
|
|
4
|
+
"""
|
|
5
|
+
import importlib.util
|
|
6
|
+
import sys
|
|
7
|
+
from typing import Any, Optional, Dict
|
|
8
|
+
import logging
|
|
9
|
+
from dsat.common.exceptions import DynamicImportError # Import the new exception
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
def import_workflow_from_string(code_string: str, class_name: str = "Workflow") -> Any:
|
|
14
|
+
"""
|
|
15
|
+
Dynamically imports a workflow class from a code string.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
code_string: The string containing the Python code.
|
|
19
|
+
class_name: The name of the class to import (default: "Workflow").
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
The workflow class.
|
|
23
|
+
|
|
24
|
+
Raises:
|
|
25
|
+
DynamicImportError: If the import fails for any reason.
|
|
26
|
+
"""
|
|
27
|
+
try:
|
|
28
|
+
# Create a temporary, unique module name to avoid conflicts
|
|
29
|
+
module_name = f"dynamic_workflow_module_{hash(code_string)}"
|
|
30
|
+
if module_name in sys.modules:
|
|
31
|
+
del sys.modules[module_name]
|
|
32
|
+
|
|
33
|
+
spec = importlib.util.spec_from_loader(module_name, loader=None)
|
|
34
|
+
module = importlib.util.module_from_spec(spec)
|
|
35
|
+
|
|
36
|
+
# Inject necessary base classes and types into the module's scope
|
|
37
|
+
# to prevent NameError during exec.
|
|
38
|
+
module.__dict__['DSATWorkflow'] = __import__('dsat.workflows.base', fromlist=['DSATWorkflow']).DSATWorkflow
|
|
39
|
+
module.__dict__['Path'] = __import__('pathlib').Path
|
|
40
|
+
module.__dict__['asyncio'] = __import__('asyncio')
|
|
41
|
+
module.__dict__['shutil'] = __import__('shutil')
|
|
42
|
+
module.__dict__['Dict'] = __import__('typing').Dict
|
|
43
|
+
module.__dict__['Any'] = __import__('typing').Any
|
|
44
|
+
module.__dict__['List'] = __import__('typing').List
|
|
45
|
+
module.__dict__['LLMService'] = __import__('dsat.services.llm', fromlist=['LLMService']).LLMService
|
|
46
|
+
module.__dict__['SandboxService'] = __import__('dsat.services.sandbox', fromlist=['SandboxService']).SandboxService
|
|
47
|
+
module.__dict__['parse_plan_and_code'] = __import__('dsat.utils.parsing', fromlist=['parse_plan_and_code']).parse_plan_and_code
|
|
48
|
+
|
|
49
|
+
# Execute the code within the new module's namespace
|
|
50
|
+
exec(code_string, module.__dict__)
|
|
51
|
+
|
|
52
|
+
# Get the class from the module
|
|
53
|
+
WorkflowClass = getattr(module, class_name, None)
|
|
54
|
+
|
|
55
|
+
if WorkflowClass:
|
|
56
|
+
return WorkflowClass
|
|
57
|
+
else:
|
|
58
|
+
error_msg = f"Class '{class_name}' not found in the provided dynamic code."
|
|
59
|
+
logger.error(error_msg)
|
|
60
|
+
raise DynamicImportError(error_msg)
|
|
61
|
+
|
|
62
|
+
except DynamicImportError:
|
|
63
|
+
raise # Re-raise if already caught
|
|
64
|
+
except Exception as e:
|
|
65
|
+
error_msg = f"Error during dynamic class import (e.g., syntax error): {e}"
|
|
66
|
+
logger.error(error_msg, exc_info=True)
|
|
67
|
+
raise DynamicImportError(error_msg) from e
|
|
68
|
+
finally:
|
|
69
|
+
if module_name in sys.modules:
|
|
70
|
+
del sys.modules[module_name]
|
|
71
|
+
|
dsat/utils/parsing.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Helper functions for parsing structured content from raw LLM text responses.
|
|
3
|
+
"""
|
|
4
|
+
import re
|
|
5
|
+
|
|
6
|
+
def parse_plan_and_code(response: str) -> tuple[str, str]:
|
|
7
|
+
"""
|
|
8
|
+
Extracts a natural language plan and a Python code block from an LLM's response.
|
|
9
|
+
Assumes a format where the plan precedes the code block.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
response: The raw text response from the LLM.
|
|
13
|
+
|
|
14
|
+
Returns:
|
|
15
|
+
A tuple containing the extracted plan and code.
|
|
16
|
+
Returns a default error message for code if not found.
|
|
17
|
+
"""
|
|
18
|
+
# Use a non-greedy match for the plan to capture everything before the first code block.
|
|
19
|
+
plan_match = re.search(r"(.*?)```(?:python|py)?", response, re.DOTALL)
|
|
20
|
+
if plan_match:
|
|
21
|
+
plan = plan_match.group(1).strip()
|
|
22
|
+
else:
|
|
23
|
+
# If no code block is found, assume the entire response is the plan.
|
|
24
|
+
plan = response.strip()
|
|
25
|
+
|
|
26
|
+
code_match = re.search(r"```(?:python|py)?\n(.*?)\n```", response, re.DOTALL)
|
|
27
|
+
if code_match:
|
|
28
|
+
code = code_match.group(1).strip()
|
|
29
|
+
else:
|
|
30
|
+
# Fallback if the code block is malformed or missing
|
|
31
|
+
code = "# ERROR: Could not parse code block from LLM response."
|
|
32
|
+
|
|
33
|
+
return plan, code
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
# dsat/workflows/__init__.py
|
|
2
|
+
|
|
3
|
+
# This file makes the 'workflows' directory a Python package.
|
|
4
|
+
from .base import DSATWorkflow
|
|
5
|
+
from .factory import (
|
|
6
|
+
WorkflowFactory,
|
|
7
|
+
AIDEWorkflowFactory,
|
|
8
|
+
AutoMindWorkflowFactory,
|
|
9
|
+
DSAgentWorkflowFactory,
|
|
10
|
+
DataInterpreterWorkflowFactory,
|
|
11
|
+
AutoKaggleWorkflowFactory
|
|
12
|
+
)
|
dsat/workflows/base.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
# dsat/workflows/base.py
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Dict, Any
|
|
6
|
+
import logging
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
# --- New, standardized workflow interface (core) ---
|
|
10
|
+
class DSATWorkflow(ABC):
|
|
11
|
+
"""
|
|
12
|
+
New standardized workflow abstract base class defining the "physical interface contract".
|
|
13
|
+
|
|
14
|
+
Any workflow implementing this interface becomes a generic problem solver
|
|
15
|
+
that is completely decoupled from the specific form of the task (QA, Kaggle, etc.).
|
|
16
|
+
It only understands files and directories.
|
|
17
|
+
"""
|
|
18
|
+
def __init__(self, operators: Dict[str, Any], services: Dict[str, Any], agent_config: Dict[str, Any]):
|
|
19
|
+
"""
|
|
20
|
+
Initialize through dependency injection.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
operators: A dictionary containing all operator instances needed by this workflow.
|
|
24
|
+
services: A dictionary containing required service instances (e.g., LLMService, SandboxService).
|
|
25
|
+
agent_config: A dictionary containing agent behavior-specific configuration.
|
|
26
|
+
"""
|
|
27
|
+
self.operators = operators
|
|
28
|
+
self.services = services
|
|
29
|
+
self.agent_config = agent_config
|
|
30
|
+
|
|
31
|
+
@abstractmethod
|
|
32
|
+
async def solve(
|
|
33
|
+
self,
|
|
34
|
+
description: str,
|
|
35
|
+
io_instructions: str, # NEW ARGUMENT
|
|
36
|
+
data_dir: Path,
|
|
37
|
+
output_path: Path
|
|
38
|
+
) -> None:
|
|
39
|
+
"""
|
|
40
|
+
Solve a given task based on the physical file interface.
|
|
41
|
+
|
|
42
|
+
This is the core method for all standardized workflows. Workflows implementing this method need to:
|
|
43
|
+
1. Treat `data_dir` as their only input source.
|
|
44
|
+
2. Execute their internal logic (e.g., call LLM, run code) to solve the task described in `description`.
|
|
45
|
+
3. Write their final, evaluable answer as a single file to the complete path specified by `output_path`.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
description: Natural language goal description and data analysis report.
|
|
49
|
+
io_instructions: Explicit, standardized instructions for reading input and writing output.
|
|
50
|
+
data_dir: A directory containing all input files needed to solve the task (e.g., `problem.txt`, `train.csv`).
|
|
51
|
+
output_path: The complete path where the final output file (e.g., `answer.txt`, `submission.csv`) must be saved.
|
|
52
|
+
"""
|
|
53
|
+
raise NotImplementedError
|