tunacode-cli 0.0.41__py3-none-any.whl → 0.0.43__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tunacode-cli might be problematic. Click here for more details.
- tunacode/cli/repl.py +8 -4
- tunacode/configuration/defaults.py +1 -0
- tunacode/constants.py +6 -1
- tunacode/core/agents/dspy_integration.py +223 -0
- tunacode/core/agents/dspy_tunacode.py +458 -0
- tunacode/core/agents/main.py +156 -27
- tunacode/core/agents/utils.py +54 -6
- tunacode/core/recursive/__init__.py +18 -0
- tunacode/core/recursive/aggregator.py +467 -0
- tunacode/core/recursive/budget.py +414 -0
- tunacode/core/recursive/decomposer.py +398 -0
- tunacode/core/recursive/executor.py +467 -0
- tunacode/core/recursive/hierarchy.py +487 -0
- tunacode/core/state.py +41 -0
- tunacode/exceptions.py +23 -0
- tunacode/prompts/dspy_task_planning.md +45 -0
- tunacode/prompts/dspy_tool_selection.md +58 -0
- tunacode/ui/console.py +1 -1
- tunacode/ui/output.py +2 -1
- tunacode/ui/panels.py +4 -1
- tunacode/ui/recursive_progress.py +380 -0
- tunacode/ui/tool_ui.py +24 -6
- tunacode/ui/utils.py +1 -1
- tunacode/utils/retry.py +163 -0
- {tunacode_cli-0.0.41.dist-info → tunacode_cli-0.0.43.dist-info}/METADATA +3 -1
- {tunacode_cli-0.0.41.dist-info → tunacode_cli-0.0.43.dist-info}/RECORD +30 -18
- {tunacode_cli-0.0.41.dist-info → tunacode_cli-0.0.43.dist-info}/WHEEL +0 -0
- {tunacode_cli-0.0.41.dist-info → tunacode_cli-0.0.43.dist-info}/entry_points.txt +0 -0
- {tunacode_cli-0.0.41.dist-info → tunacode_cli-0.0.43.dist-info}/licenses/LICENSE +0 -0
- {tunacode_cli-0.0.41.dist-info → tunacode_cli-0.0.43.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,458 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
TunaCode DSPy Production Module
|
|
4
|
+
|
|
5
|
+
Optimizes tool selection and task planning for TunaCode using DSPy.
|
|
6
|
+
Includes 3-4 tool batching optimization for 3x performance gains.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import logging
|
|
11
|
+
import os
|
|
12
|
+
import re
|
|
13
|
+
from typing import Dict, List
|
|
14
|
+
|
|
15
|
+
import dspy
|
|
16
|
+
from dotenv import load_dotenv
|
|
17
|
+
|
|
18
|
+
# Load environment variables
|
|
19
|
+
load_dotenv()
|
|
20
|
+
|
|
21
|
+
# Configure logging
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
# Tool Categories for TunaCode
|
|
25
|
+
TOOL_CATEGORIES = {
|
|
26
|
+
"read_only": ["read_file", "grep", "list_dir", "glob"], # Parallel-executable
|
|
27
|
+
"task_management": ["todo"], # Fast, sequential
|
|
28
|
+
"write_execute": [
|
|
29
|
+
"write_file",
|
|
30
|
+
"update_file",
|
|
31
|
+
"run_command",
|
|
32
|
+
"bash",
|
|
33
|
+
], # Sequential, needs confirmation
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
ALL_TOOLS = [tool for category in TOOL_CATEGORIES.values() for tool in category]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class ToolSelectionSignature(dspy.Signature):
|
|
40
|
+
"""Select optimal tools with batching awareness."""
|
|
41
|
+
|
|
42
|
+
user_request: str = dspy.InputField(desc="The user's request or task")
|
|
43
|
+
current_directory: str = dspy.InputField(desc="Current working directory context")
|
|
44
|
+
tools_json: str = dspy.OutputField(
|
|
45
|
+
desc="JSON array of tool calls with batch grouping, e.g. [[tool1, tool2, tool3], [tool4]]"
|
|
46
|
+
)
|
|
47
|
+
requires_confirmation: bool = dspy.OutputField(
|
|
48
|
+
desc="Whether any tools require user confirmation"
|
|
49
|
+
)
|
|
50
|
+
reasoning: str = dspy.OutputField(desc="Explanation of tool choice and batching strategy")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class TaskPlanningSignature(dspy.Signature):
|
|
54
|
+
"""Break down complex tasks with tool hints."""
|
|
55
|
+
|
|
56
|
+
complex_request: str = dspy.InputField(desc="A complex task that needs breakdown")
|
|
57
|
+
subtasks_with_tools: str = dspy.OutputField(
|
|
58
|
+
desc="JSON array of {task, tools, priority} objects"
|
|
59
|
+
)
|
|
60
|
+
total_tool_calls: int = dspy.OutputField(desc="Estimated total number of tool calls")
|
|
61
|
+
requires_todo: bool = dspy.OutputField(desc="Whether todo tool should be used")
|
|
62
|
+
parallelization_opportunities: int = dspy.OutputField(
|
|
63
|
+
desc="Number of parallel execution opportunities"
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class PathValidationSignature(dspy.Signature):
|
|
68
|
+
"""Validate and convert paths to relative format."""
|
|
69
|
+
|
|
70
|
+
path: str = dspy.InputField(desc="Path to validate")
|
|
71
|
+
current_directory: str = dspy.InputField(desc="Current working directory")
|
|
72
|
+
is_valid: bool = dspy.OutputField(desc="Whether path is valid relative path")
|
|
73
|
+
relative_path: str = dspy.OutputField(desc="Converted relative path")
|
|
74
|
+
reason: str = dspy.OutputField(desc="Validation result explanation")
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class OptimizedToolSelector(dspy.Module):
|
|
78
|
+
"""Tool selection with batching optimization"""
|
|
79
|
+
|
|
80
|
+
def __init__(self):
|
|
81
|
+
self.predict = dspy.ChainOfThought(ToolSelectionSignature)
|
|
82
|
+
|
|
83
|
+
def forward(self, user_request: str, current_directory: str = "."):
|
|
84
|
+
logger.debug(f"Tool Selection for: {user_request}")
|
|
85
|
+
result = self.predict(user_request=user_request, current_directory=current_directory)
|
|
86
|
+
|
|
87
|
+
# Parse and validate tool batches
|
|
88
|
+
try:
|
|
89
|
+
tool_batches = json.loads(result.tools_json)
|
|
90
|
+
validated_batches = self._validate_batches(tool_batches)
|
|
91
|
+
result.tool_batches = validated_batches
|
|
92
|
+
except Exception as e:
|
|
93
|
+
logger.error(f"Failed to parse tool batches: {e}")
|
|
94
|
+
result.tool_batches = []
|
|
95
|
+
|
|
96
|
+
return result
|
|
97
|
+
|
|
98
|
+
def _validate_batches(self, batches: List[List[str]]) -> List[List[Dict]]:
|
|
99
|
+
"""Validate and optimize tool batches for 3-4 tool rule"""
|
|
100
|
+
validated = []
|
|
101
|
+
|
|
102
|
+
for batch in batches:
|
|
103
|
+
# Check if batch contains only read-only tools
|
|
104
|
+
all_read_only = all(
|
|
105
|
+
any(tool_name in tool for tool_name in TOOL_CATEGORIES["read_only"])
|
|
106
|
+
for tool in batch
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# Optimize batch size (3-4 tools is optimal)
|
|
110
|
+
if all_read_only and len(batch) > 4:
|
|
111
|
+
# Split large batches
|
|
112
|
+
for i in range(0, len(batch), 4):
|
|
113
|
+
sub_batch = batch[i : i + 4]
|
|
114
|
+
validated.append(self._parse_tools(sub_batch))
|
|
115
|
+
else:
|
|
116
|
+
validated_batch = self._parse_tools(batch)
|
|
117
|
+
if validated_batch:
|
|
118
|
+
validated.append(validated_batch)
|
|
119
|
+
|
|
120
|
+
return validated
|
|
121
|
+
|
|
122
|
+
def _parse_tools(self, tools: List[str]) -> List[Dict]:
|
|
123
|
+
"""Parse tool strings into proper format"""
|
|
124
|
+
parsed = []
|
|
125
|
+
for tool_str in tools:
|
|
126
|
+
# Extract tool name and args from string like "read_file('main.py')"
|
|
127
|
+
match = re.match(r"(\w+)\((.*)\)", tool_str)
|
|
128
|
+
if match:
|
|
129
|
+
tool_name = match.group(1)
|
|
130
|
+
args_str = match.group(2)
|
|
131
|
+
|
|
132
|
+
tool_dict = {"tool": tool_name, "args": {}}
|
|
133
|
+
if args_str:
|
|
134
|
+
# Handle simple cases like 'file.py' or "pattern", "dir"
|
|
135
|
+
args = [arg.strip().strip("\"'") for arg in args_str.split(",")]
|
|
136
|
+
if tool_name == "read_file":
|
|
137
|
+
tool_dict["args"]["filepath"] = args[0]
|
|
138
|
+
elif tool_name == "grep":
|
|
139
|
+
tool_dict["args"]["pattern"] = args[0]
|
|
140
|
+
if len(args) > 1:
|
|
141
|
+
tool_dict["args"]["directory"] = args[1]
|
|
142
|
+
elif tool_name == "list_dir":
|
|
143
|
+
tool_dict["args"]["directory"] = args[0] if args else "."
|
|
144
|
+
elif tool_name == "glob":
|
|
145
|
+
tool_dict["args"]["pattern"] = args[0]
|
|
146
|
+
if len(args) > 1:
|
|
147
|
+
tool_dict["args"]["directory"] = args[1]
|
|
148
|
+
elif tool_name == "todo":
|
|
149
|
+
tool_dict["args"]["action"] = args[0] if args else "list"
|
|
150
|
+
elif tool_name in ["write_file", "update_file"]:
|
|
151
|
+
if args:
|
|
152
|
+
tool_dict["args"]["filepath"] = args[0]
|
|
153
|
+
elif tool_name in ["run_command", "bash"]:
|
|
154
|
+
if args:
|
|
155
|
+
tool_dict["args"]["command"] = args[0]
|
|
156
|
+
|
|
157
|
+
parsed.append(tool_dict)
|
|
158
|
+
|
|
159
|
+
return parsed
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
class EnhancedTaskPlanner(dspy.Module):
|
|
163
|
+
"""Task planning with tool awareness"""
|
|
164
|
+
|
|
165
|
+
def __init__(self):
|
|
166
|
+
self.predict = dspy.ChainOfThought(TaskPlanningSignature)
|
|
167
|
+
|
|
168
|
+
def forward(self, complex_request: str):
|
|
169
|
+
logger.debug(f"Task Planning for: {complex_request}")
|
|
170
|
+
result = self.predict(complex_request=complex_request)
|
|
171
|
+
|
|
172
|
+
# Parse subtasks
|
|
173
|
+
try:
|
|
174
|
+
result.subtasks = json.loads(result.subtasks_with_tools)
|
|
175
|
+
except Exception as e:
|
|
176
|
+
logger.error(f"Failed to parse subtasks: {e}")
|
|
177
|
+
result.subtasks = []
|
|
178
|
+
|
|
179
|
+
return result
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class PathValidator(dspy.Module):
|
|
183
|
+
"""Ensure all paths are relative"""
|
|
184
|
+
|
|
185
|
+
def __init__(self):
|
|
186
|
+
self.predict = dspy.Predict(PathValidationSignature)
|
|
187
|
+
|
|
188
|
+
def forward(self, path: str, current_directory: str = "."):
|
|
189
|
+
return self.predict(path=path, current_directory=current_directory)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class TunaCodeDSPy(dspy.Module):
|
|
193
|
+
"""Main TunaCode DSPy agent"""
|
|
194
|
+
|
|
195
|
+
def __init__(self):
|
|
196
|
+
self.tool_selector = OptimizedToolSelector()
|
|
197
|
+
self.task_planner = EnhancedTaskPlanner()
|
|
198
|
+
self.path_validator = PathValidator()
|
|
199
|
+
|
|
200
|
+
def forward(self, user_request: str, current_directory: str = "."):
|
|
201
|
+
"""Process request with optimization"""
|
|
202
|
+
|
|
203
|
+
# Detect request complexity
|
|
204
|
+
is_complex = self._is_complex_task(user_request)
|
|
205
|
+
|
|
206
|
+
result = {
|
|
207
|
+
"request": user_request,
|
|
208
|
+
"is_complex": is_complex,
|
|
209
|
+
"current_directory": current_directory,
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
if is_complex:
|
|
213
|
+
# Use task planner for complex tasks
|
|
214
|
+
task_plan = self.task_planner(complex_request=user_request)
|
|
215
|
+
result["subtasks"] = task_plan.subtasks
|
|
216
|
+
result["total_tool_calls"] = task_plan.total_tool_calls
|
|
217
|
+
result["requires_todo"] = task_plan.requires_todo
|
|
218
|
+
result["parallelization_opportunities"] = task_plan.parallelization_opportunities
|
|
219
|
+
|
|
220
|
+
if task_plan.requires_todo:
|
|
221
|
+
result["initial_action"] = "Use todo tool to create task list"
|
|
222
|
+
else:
|
|
223
|
+
# Direct tool selection with batching
|
|
224
|
+
tool_selection = self.tool_selector(
|
|
225
|
+
user_request=user_request, current_directory=current_directory
|
|
226
|
+
)
|
|
227
|
+
result["tool_batches"] = tool_selection.tool_batches
|
|
228
|
+
result["requires_confirmation"] = tool_selection.requires_confirmation
|
|
229
|
+
result["reasoning"] = tool_selection.reasoning
|
|
230
|
+
|
|
231
|
+
return result
|
|
232
|
+
|
|
233
|
+
def _is_complex_task(self, request: str) -> bool:
|
|
234
|
+
"""Detect if task is complex based on keywords and patterns"""
|
|
235
|
+
complex_indicators = [
|
|
236
|
+
"implement",
|
|
237
|
+
"create",
|
|
238
|
+
"build",
|
|
239
|
+
"refactor",
|
|
240
|
+
"add feature",
|
|
241
|
+
"fix all",
|
|
242
|
+
"update multiple",
|
|
243
|
+
"migrate",
|
|
244
|
+
"integrate",
|
|
245
|
+
"debug",
|
|
246
|
+
"optimize performance",
|
|
247
|
+
"add authentication",
|
|
248
|
+
"setup",
|
|
249
|
+
"configure",
|
|
250
|
+
"test suite",
|
|
251
|
+
]
|
|
252
|
+
|
|
253
|
+
request_lower = request.lower()
|
|
254
|
+
|
|
255
|
+
# Check for multiple files mentioned
|
|
256
|
+
file_pattern = r"\b\w+\.\w+\b"
|
|
257
|
+
files_mentioned = len(re.findall(file_pattern, request)) > 2
|
|
258
|
+
|
|
259
|
+
# Check for complex keywords
|
|
260
|
+
has_complex_keyword = any(indicator in request_lower for indicator in complex_indicators)
|
|
261
|
+
|
|
262
|
+
# Check for multiple operations
|
|
263
|
+
operation_words = ["and", "then", "also", "after", "before", "plus"]
|
|
264
|
+
has_multiple_ops = sum(1 for word in operation_words if word in request_lower) >= 2
|
|
265
|
+
|
|
266
|
+
return files_mentioned or has_complex_keyword or has_multiple_ops
|
|
267
|
+
|
|
268
|
+
def validate_path(self, path: str, current_directory: str = ".") -> Dict:
|
|
269
|
+
"""Validate a path is relative and safe"""
|
|
270
|
+
return self.path_validator(path, current_directory)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def get_tool_selection_examples():
|
|
274
|
+
"""Training examples for tool selection"""
|
|
275
|
+
return [
|
|
276
|
+
dspy.Example(
|
|
277
|
+
user_request="Show me the authentication system implementation",
|
|
278
|
+
current_directory=".",
|
|
279
|
+
tools_json='[["grep(\\"auth\\", \\"src/\\")", "list_dir(\\"src/auth/\\")", "glob(\\"**/*auth*.py\\")"]]',
|
|
280
|
+
requires_confirmation=False,
|
|
281
|
+
reasoning="Batch 3 read-only tools for parallel search - optimal performance",
|
|
282
|
+
).with_inputs("user_request", "current_directory"),
|
|
283
|
+
dspy.Example(
|
|
284
|
+
user_request="Read all config files and the main module",
|
|
285
|
+
current_directory=".",
|
|
286
|
+
tools_json='[["read_file(\\"config.json\\")", "read_file(\\"settings.py\\")", "read_file(\\".env\\")", "read_file(\\"main.py\\")"]]',
|
|
287
|
+
requires_confirmation=False,
|
|
288
|
+
reasoning="Batch 4 file reads together - maximum optimal batch size",
|
|
289
|
+
).with_inputs("user_request", "current_directory"),
|
|
290
|
+
dspy.Example(
|
|
291
|
+
user_request="Find the bug in validation and fix it",
|
|
292
|
+
current_directory=".",
|
|
293
|
+
tools_json='[["grep(\\"error\\", \\"logs/\\")", "grep(\\"validation\\", \\"src/\\")", "list_dir(\\"src/validators/\\")"], ["read_file(\\"src/validators/user.py\\")"], ["update_file(\\"src/validators/user.py\\", \\"old\\", \\"new\\")"]]',
|
|
294
|
+
requires_confirmation=True,
|
|
295
|
+
reasoning="Search tools batched, then read, then sequential write operation",
|
|
296
|
+
).with_inputs("user_request", "current_directory"),
|
|
297
|
+
]
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def get_task_planning_examples():
|
|
301
|
+
"""Training examples for task planning"""
|
|
302
|
+
return [
|
|
303
|
+
dspy.Example(
|
|
304
|
+
complex_request="Implement user authentication system with JWT tokens",
|
|
305
|
+
subtasks_with_tools='[{"task": "Analyze current app structure", "tools": ["list_dir", "grep", "read_file"], "priority": "high"}, {"task": "Design user model", "tools": ["write_file"], "priority": "high"}, {"task": "Create auth endpoints", "tools": ["write_file", "update_file"], "priority": "high"}, {"task": "Add JWT tokens", "tools": ["write_file", "grep"], "priority": "high"}, {"task": "Write tests", "tools": ["write_file", "run_command"], "priority": "medium"}]',
|
|
306
|
+
total_tool_calls=15,
|
|
307
|
+
requires_todo=True,
|
|
308
|
+
parallelization_opportunities=3,
|
|
309
|
+
).with_inputs("complex_request"),
|
|
310
|
+
]
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def tool_selection_metric(example, prediction):
|
|
314
|
+
"""Metric for tool selection evaluation"""
|
|
315
|
+
score = 0.0
|
|
316
|
+
|
|
317
|
+
# Tool accuracy (40%)
|
|
318
|
+
if hasattr(prediction, "tool_batches") and hasattr(example, "tools_json"):
|
|
319
|
+
try:
|
|
320
|
+
expected = json.loads(example.tools_json)
|
|
321
|
+
predicted = prediction.tool_batches
|
|
322
|
+
|
|
323
|
+
# Check tool selection accuracy
|
|
324
|
+
expected_tools = set()
|
|
325
|
+
predicted_tools = set()
|
|
326
|
+
|
|
327
|
+
for batch in expected:
|
|
328
|
+
for tool in batch:
|
|
329
|
+
tool_name = re.match(r"(\w+)\(", tool)
|
|
330
|
+
if tool_name:
|
|
331
|
+
expected_tools.add(tool_name.group(1))
|
|
332
|
+
|
|
333
|
+
for batch in predicted:
|
|
334
|
+
for tool in batch:
|
|
335
|
+
if isinstance(tool, dict):
|
|
336
|
+
predicted_tools.add(tool.get("tool", ""))
|
|
337
|
+
|
|
338
|
+
if expected_tools == predicted_tools:
|
|
339
|
+
score += 0.4
|
|
340
|
+
else:
|
|
341
|
+
overlap = len(expected_tools & predicted_tools)
|
|
342
|
+
total = len(expected_tools | predicted_tools)
|
|
343
|
+
if total > 0:
|
|
344
|
+
score += 0.4 * (overlap / total)
|
|
345
|
+
except Exception:
|
|
346
|
+
pass
|
|
347
|
+
|
|
348
|
+
# Batching optimization (30%)
|
|
349
|
+
if hasattr(prediction, "tool_batches"):
|
|
350
|
+
batches = prediction.tool_batches
|
|
351
|
+
optimal_batching = True
|
|
352
|
+
|
|
353
|
+
for batch in batches:
|
|
354
|
+
if len(batch) > 0:
|
|
355
|
+
batch_tools = [tool.get("tool", "") for tool in batch if isinstance(tool, dict)]
|
|
356
|
+
if all(tool in TOOL_CATEGORIES["read_only"] for tool in batch_tools):
|
|
357
|
+
if len(batch) < 3 or len(batch) > 4:
|
|
358
|
+
optimal_batching = False
|
|
359
|
+
break
|
|
360
|
+
|
|
361
|
+
if optimal_batching:
|
|
362
|
+
score += 0.3
|
|
363
|
+
|
|
364
|
+
# Confirmation accuracy (20%)
|
|
365
|
+
if hasattr(prediction, "requires_confirmation") and hasattr(example, "requires_confirmation"):
|
|
366
|
+
if prediction.requires_confirmation == example.requires_confirmation:
|
|
367
|
+
score += 0.2
|
|
368
|
+
|
|
369
|
+
# Reasoning quality (10%)
|
|
370
|
+
if hasattr(prediction, "reasoning") and prediction.reasoning and len(prediction.reasoning) > 20:
|
|
371
|
+
score += 0.1
|
|
372
|
+
|
|
373
|
+
return score
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
def task_planning_metric(example, prediction):
|
|
377
|
+
"""Metric for task planning evaluation"""
|
|
378
|
+
score = 0.0
|
|
379
|
+
|
|
380
|
+
# Subtask quality (30%)
|
|
381
|
+
if hasattr(prediction, "subtasks") and hasattr(example, "subtasks_with_tools"):
|
|
382
|
+
try:
|
|
383
|
+
expected = json.loads(example.subtasks_with_tools)
|
|
384
|
+
predicted = prediction.subtasks
|
|
385
|
+
|
|
386
|
+
if abs(len(expected) - len(predicted)) <= 1:
|
|
387
|
+
score += 0.3
|
|
388
|
+
elif abs(len(expected) - len(predicted)) <= 2:
|
|
389
|
+
score += 0.15
|
|
390
|
+
except Exception:
|
|
391
|
+
pass
|
|
392
|
+
|
|
393
|
+
# Tool estimation accuracy (30%)
|
|
394
|
+
if hasattr(prediction, "total_tool_calls") and hasattr(example, "total_tool_calls"):
|
|
395
|
+
if abs(prediction.total_tool_calls - example.total_tool_calls) <= 5:
|
|
396
|
+
score += 0.3
|
|
397
|
+
elif abs(prediction.total_tool_calls - example.total_tool_calls) <= 10:
|
|
398
|
+
score += 0.15
|
|
399
|
+
|
|
400
|
+
# Todo requirement (20%)
|
|
401
|
+
if hasattr(prediction, "requires_todo") and hasattr(example, "requires_todo"):
|
|
402
|
+
if prediction.requires_todo == example.requires_todo:
|
|
403
|
+
score += 0.2
|
|
404
|
+
|
|
405
|
+
# Parallelization awareness (20%)
|
|
406
|
+
if hasattr(prediction, "parallelization_opportunities") and hasattr(
|
|
407
|
+
example, "parallelization_opportunities"
|
|
408
|
+
):
|
|
409
|
+
if (
|
|
410
|
+
abs(prediction.parallelization_opportunities - example.parallelization_opportunities)
|
|
411
|
+
<= 2
|
|
412
|
+
):
|
|
413
|
+
score += 0.2
|
|
414
|
+
|
|
415
|
+
return score
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
def create_optimized_agent(api_key: str = None, model: str = "openrouter/openai/gpt-4.1-mini"):
|
|
419
|
+
"""Create and optimize the TunaCode DSPy agent"""
|
|
420
|
+
|
|
421
|
+
# Configure DSPy
|
|
422
|
+
if not api_key:
|
|
423
|
+
api_key = os.getenv("OPENROUTER_API_KEY")
|
|
424
|
+
if not api_key:
|
|
425
|
+
raise ValueError("Please set OPENROUTER_API_KEY environment variable")
|
|
426
|
+
|
|
427
|
+
lm = dspy.LM(
|
|
428
|
+
model,
|
|
429
|
+
api_base="https://openrouter.ai/api/v1",
|
|
430
|
+
api_key=api_key,
|
|
431
|
+
temperature=0.3,
|
|
432
|
+
)
|
|
433
|
+
dspy.configure(lm=lm)
|
|
434
|
+
|
|
435
|
+
# Create agent
|
|
436
|
+
agent = TunaCodeDSPy()
|
|
437
|
+
|
|
438
|
+
# Optimize tool selector
|
|
439
|
+
tool_examples = get_tool_selection_examples()
|
|
440
|
+
tool_optimizer = dspy.BootstrapFewShot(
|
|
441
|
+
metric=lambda ex, pred, trace: tool_selection_metric(ex, pred),
|
|
442
|
+
max_bootstrapped_demos=3,
|
|
443
|
+
)
|
|
444
|
+
agent.tool_selector = tool_optimizer.compile(agent.tool_selector, trainset=tool_examples)
|
|
445
|
+
|
|
446
|
+
# Optimize task planner
|
|
447
|
+
task_examples = get_task_planning_examples()
|
|
448
|
+
task_optimizer = dspy.BootstrapFewShot(
|
|
449
|
+
metric=lambda ex, pred, trace: task_planning_metric(ex, pred),
|
|
450
|
+
max_bootstrapped_demos=2,
|
|
451
|
+
)
|
|
452
|
+
agent.task_planner = task_optimizer.compile(agent.task_planner, trainset=task_examples)
|
|
453
|
+
|
|
454
|
+
return agent
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
agent = create_optimized_agent()
|
|
458
|
+
result = agent("Show me the authentication implementation", ".")
|
tunacode/core/agents/main.py
CHANGED
|
@@ -30,9 +30,12 @@ except ImportError:
|
|
|
30
30
|
STREAMING_AVAILABLE = False
|
|
31
31
|
|
|
32
32
|
from tunacode.constants import READ_ONLY_TOOLS
|
|
33
|
+
from tunacode.core.agents.dspy_integration import DSPyIntegration
|
|
34
|
+
from tunacode.core.recursive import RecursiveTaskExecutor
|
|
33
35
|
from tunacode.core.state import StateManager
|
|
34
36
|
from tunacode.core.token_usage.api_response_parser import ApiResponseParser
|
|
35
37
|
from tunacode.core.token_usage.cost_calculator import CostCalculator
|
|
38
|
+
from tunacode.exceptions import ToolBatchingJSONError
|
|
36
39
|
from tunacode.services.mcp import get_mcp_servers
|
|
37
40
|
from tunacode.tools.bash import bash
|
|
38
41
|
from tunacode.tools.glob import glob
|
|
@@ -471,9 +474,17 @@ async def _process_node(
|
|
|
471
474
|
if not has_tool_calls and buffering_callback:
|
|
472
475
|
for part in node.model_response.parts:
|
|
473
476
|
if hasattr(part, "content") and isinstance(part.content, str):
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
+
try:
|
|
478
|
+
await extract_and_execute_tool_calls(
|
|
479
|
+
part.content, buffering_callback, state_manager
|
|
480
|
+
)
|
|
481
|
+
except ToolBatchingJSONError as e:
|
|
482
|
+
# Handle JSON parsing failure after retries
|
|
483
|
+
logger.error(f"Tool batching JSON error: {e}")
|
|
484
|
+
if state_manager.session.show_thoughts:
|
|
485
|
+
await ui.error(str(e))
|
|
486
|
+
# Continue processing other parts instead of failing completely
|
|
487
|
+
continue
|
|
477
488
|
|
|
478
489
|
# Final flush: disabled temporarily while fixing the parallel execution design
|
|
479
490
|
# The buffer is not being used in the current implementation
|
|
@@ -508,6 +519,18 @@ def get_or_create_agent(model: ModelName, state_manager: StateManager) -> Pydant
|
|
|
508
519
|
# Use a default system prompt if neither file exists
|
|
509
520
|
system_prompt = "You are a helpful AI assistant for software development tasks."
|
|
510
521
|
|
|
522
|
+
# Enhance with DSPy optimization if enabled
|
|
523
|
+
use_dspy = state_manager.session.user_config.get("settings", {}).get(
|
|
524
|
+
"use_dspy_optimization", True
|
|
525
|
+
)
|
|
526
|
+
if use_dspy:
|
|
527
|
+
try:
|
|
528
|
+
dspy_integration = DSPyIntegration(state_manager)
|
|
529
|
+
system_prompt = dspy_integration.enhance_system_prompt(system_prompt)
|
|
530
|
+
logger.info("Enhanced system prompt with DSPy optimizations")
|
|
531
|
+
except Exception as e:
|
|
532
|
+
logger.warning(f"Failed to enhance prompt with DSPy: {e}")
|
|
533
|
+
|
|
511
534
|
# Load TUNACODE.md context
|
|
512
535
|
# Use sync version of get_code_style to avoid nested event loop issues
|
|
513
536
|
try:
|
|
@@ -749,6 +772,110 @@ async def process_request(
|
|
|
749
772
|
fallback_enabled = state_manager.session.user_config.get("settings", {}).get(
|
|
750
773
|
"fallback_response", True
|
|
751
774
|
)
|
|
775
|
+
|
|
776
|
+
# Check if DSPy optimization is enabled and if this is a complex task
|
|
777
|
+
use_dspy = state_manager.session.user_config.get("settings", {}).get(
|
|
778
|
+
"use_dspy_optimization", True
|
|
779
|
+
)
|
|
780
|
+
dspy_integration = None
|
|
781
|
+
task_breakdown = None
|
|
782
|
+
|
|
783
|
+
# Check if recursive execution is enabled
|
|
784
|
+
use_recursive = state_manager.session.user_config.get("settings", {}).get(
|
|
785
|
+
"use_recursive_execution", True
|
|
786
|
+
)
|
|
787
|
+
recursive_threshold = state_manager.session.user_config.get("settings", {}).get(
|
|
788
|
+
"recursive_complexity_threshold", 0.7
|
|
789
|
+
)
|
|
790
|
+
|
|
791
|
+
if use_dspy:
|
|
792
|
+
try:
|
|
793
|
+
dspy_integration = DSPyIntegration(state_manager)
|
|
794
|
+
|
|
795
|
+
# Check if this is a complex task that needs planning
|
|
796
|
+
if dspy_integration.should_use_task_planner(message):
|
|
797
|
+
task_breakdown = dspy_integration.get_task_breakdown(message)
|
|
798
|
+
if task_breakdown and task_breakdown.get("requires_todo"):
|
|
799
|
+
# Auto-create todos for complex tasks
|
|
800
|
+
from tunacode.tools.todo import TodoTool
|
|
801
|
+
|
|
802
|
+
todo_tool = TodoTool(state_manager=state_manager)
|
|
803
|
+
|
|
804
|
+
if state_manager.session.show_thoughts:
|
|
805
|
+
from tunacode.ui import console as ui
|
|
806
|
+
|
|
807
|
+
await ui.muted("DSPy: Detected complex task - creating todo list")
|
|
808
|
+
|
|
809
|
+
# Create todos from subtasks
|
|
810
|
+
todos = []
|
|
811
|
+
for subtask in task_breakdown["subtasks"][:5]: # Limit to first 5
|
|
812
|
+
todos.append(
|
|
813
|
+
{
|
|
814
|
+
"content": subtask["task"],
|
|
815
|
+
"priority": subtask.get("priority", "medium"),
|
|
816
|
+
}
|
|
817
|
+
)
|
|
818
|
+
|
|
819
|
+
if todos:
|
|
820
|
+
await todo_tool._execute(action="add_multiple", todos=todos)
|
|
821
|
+
except Exception as e:
|
|
822
|
+
logger.warning(f"DSPy task planning failed: {e}")
|
|
823
|
+
|
|
824
|
+
# Check if recursive execution should be used
|
|
825
|
+
if use_recursive and state_manager.session.current_recursion_depth == 0:
|
|
826
|
+
try:
|
|
827
|
+
# Initialize recursive executor
|
|
828
|
+
recursive_executor = RecursiveTaskExecutor(
|
|
829
|
+
state_manager=state_manager,
|
|
830
|
+
max_depth=state_manager.session.max_recursion_depth,
|
|
831
|
+
min_complexity_threshold=recursive_threshold,
|
|
832
|
+
default_iteration_budget=max_iterations,
|
|
833
|
+
)
|
|
834
|
+
|
|
835
|
+
# Analyze task complexity
|
|
836
|
+
complexity_result = await recursive_executor.decomposer.analyze_and_decompose(message)
|
|
837
|
+
|
|
838
|
+
if (
|
|
839
|
+
complexity_result.should_decompose
|
|
840
|
+
and complexity_result.total_complexity >= recursive_threshold
|
|
841
|
+
):
|
|
842
|
+
if state_manager.session.show_thoughts:
|
|
843
|
+
from tunacode.ui import console as ui
|
|
844
|
+
|
|
845
|
+
await ui.muted(
|
|
846
|
+
f"\n🔄 RECURSIVE EXECUTION: Task complexity {complexity_result.total_complexity:.2f} >= {recursive_threshold}"
|
|
847
|
+
)
|
|
848
|
+
await ui.muted(f"Reasoning: {complexity_result.reasoning}")
|
|
849
|
+
await ui.muted(f"Subtasks: {len(complexity_result.subtasks)}")
|
|
850
|
+
|
|
851
|
+
# Execute recursively
|
|
852
|
+
success, result, error = await recursive_executor.execute_task(
|
|
853
|
+
request=message, parent_task_id=None, depth=0
|
|
854
|
+
)
|
|
855
|
+
|
|
856
|
+
# Create AgentRun response
|
|
857
|
+
from datetime import datetime
|
|
858
|
+
|
|
859
|
+
if success:
|
|
860
|
+
return AgentRun(
|
|
861
|
+
messages=[{"role": "assistant", "content": str(result)}],
|
|
862
|
+
timestamp=datetime.now(),
|
|
863
|
+
model=model,
|
|
864
|
+
iterations=1,
|
|
865
|
+
status="success",
|
|
866
|
+
)
|
|
867
|
+
else:
|
|
868
|
+
return AgentRun(
|
|
869
|
+
messages=[{"role": "assistant", "content": f"Task failed: {error}"}],
|
|
870
|
+
timestamp=datetime.now(),
|
|
871
|
+
model=model,
|
|
872
|
+
iterations=1,
|
|
873
|
+
status="error",
|
|
874
|
+
)
|
|
875
|
+
except Exception as e:
|
|
876
|
+
logger.warning(f"Recursive execution failed, falling back to normal: {e}")
|
|
877
|
+
# Continue with normal execution
|
|
878
|
+
|
|
752
879
|
from tunacode.configuration.models import ModelRegistry
|
|
753
880
|
from tunacode.core.token_usage.usage_tracker import UsageTracker
|
|
754
881
|
|
|
@@ -847,27 +974,28 @@ async def process_request(
|
|
|
847
974
|
buffered_tasks = tool_buffer.flush()
|
|
848
975
|
start_time = time.time()
|
|
849
976
|
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
977
|
+
if state_manager.session.show_thoughts:
|
|
978
|
+
await ui.muted("\n" + "=" * 60)
|
|
979
|
+
await ui.muted(
|
|
980
|
+
f"🚀 FINAL BATCH: Executing {len(buffered_tasks)} buffered read-only tools"
|
|
981
|
+
)
|
|
982
|
+
await ui.muted("=" * 60)
|
|
855
983
|
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
984
|
+
for idx, (part, node) in enumerate(buffered_tasks, 1):
|
|
985
|
+
tool_desc = f" [{idx}] {part.tool_name}"
|
|
986
|
+
if hasattr(part, "args") and isinstance(part.args, dict):
|
|
987
|
+
if part.tool_name == "read_file" and "file_path" in part.args:
|
|
988
|
+
tool_desc += f" → {part.args['file_path']}"
|
|
989
|
+
elif part.tool_name == "grep" and "pattern" in part.args:
|
|
990
|
+
tool_desc += f" → pattern: '{part.args['pattern']}'"
|
|
991
|
+
if "include_files" in part.args:
|
|
992
|
+
tool_desc += f", files: '{part.args['include_files']}'"
|
|
993
|
+
elif part.tool_name == "list_dir" and "directory" in part.args:
|
|
994
|
+
tool_desc += f" → {part.args['directory']}"
|
|
995
|
+
elif part.tool_name == "glob" and "pattern" in part.args:
|
|
996
|
+
tool_desc += f" → pattern: '{part.args['pattern']}'"
|
|
997
|
+
await ui.muted(tool_desc)
|
|
998
|
+
await ui.muted("=" * 60)
|
|
871
999
|
|
|
872
1000
|
await execute_tools_parallel(buffered_tasks, tool_callback)
|
|
873
1001
|
|
|
@@ -875,10 +1003,11 @@ async def process_request(
|
|
|
875
1003
|
sequential_estimate = len(buffered_tasks) * 100
|
|
876
1004
|
speedup = sequential_estimate / elapsed_time if elapsed_time > 0 else 1.0
|
|
877
1005
|
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
1006
|
+
if state_manager.session.show_thoughts:
|
|
1007
|
+
await ui.muted(
|
|
1008
|
+
f"✅ Final batch completed in {elapsed_time:.0f}ms "
|
|
1009
|
+
f"(~{speedup:.1f}x faster than sequential)\n"
|
|
1010
|
+
)
|
|
882
1011
|
|
|
883
1012
|
# If we need to add a fallback response, create a wrapper
|
|
884
1013
|
if not response_state.has_user_response and i >= max_iterations and fallback_enabled:
|