tunacode-cli 0.0.46__py3-none-any.whl → 0.0.48__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tunacode-cli might be problematic. Click here for more details.

Files changed (32) hide show
  1. tunacode/cli/main.py +0 -4
  2. tunacode/cli/repl.py +7 -14
  3. tunacode/configuration/defaults.py +1 -0
  4. tunacode/constants.py +1 -6
  5. tunacode/core/agents/dspy_integration.py +223 -0
  6. tunacode/core/agents/dspy_tunacode.py +458 -0
  7. tunacode/core/agents/main.py +237 -311
  8. tunacode/core/agents/utils.py +6 -54
  9. tunacode/core/state.py +0 -41
  10. tunacode/core/token_usage/cost_calculator.py +8 -0
  11. tunacode/core/token_usage/usage_tracker.py +17 -1
  12. tunacode/exceptions.py +0 -23
  13. tunacode/prompts/dspy_task_planning.md +45 -0
  14. tunacode/prompts/dspy_tool_selection.md +58 -0
  15. tunacode/ui/input.py +1 -2
  16. tunacode/ui/keybindings.py +1 -17
  17. tunacode/ui/panels.py +2 -9
  18. tunacode/utils/token_counter.py +2 -1
  19. {tunacode_cli-0.0.46.dist-info → tunacode_cli-0.0.48.dist-info}/METADATA +3 -3
  20. {tunacode_cli-0.0.46.dist-info → tunacode_cli-0.0.48.dist-info}/RECORD +24 -28
  21. tunacode/core/recursive/__init__.py +0 -18
  22. tunacode/core/recursive/aggregator.py +0 -467
  23. tunacode/core/recursive/budget.py +0 -414
  24. tunacode/core/recursive/decomposer.py +0 -398
  25. tunacode/core/recursive/executor.py +0 -470
  26. tunacode/core/recursive/hierarchy.py +0 -487
  27. tunacode/ui/recursive_progress.py +0 -380
  28. tunacode/utils/retry.py +0 -163
  29. {tunacode_cli-0.0.46.dist-info → tunacode_cli-0.0.48.dist-info}/WHEEL +0 -0
  30. {tunacode_cli-0.0.46.dist-info → tunacode_cli-0.0.48.dist-info}/entry_points.txt +0 -0
  31. {tunacode_cli-0.0.46.dist-info → tunacode_cli-0.0.48.dist-info}/licenses/LICENSE +0 -0
  32. {tunacode_cli-0.0.46.dist-info → tunacode_cli-0.0.48.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,458 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ TunaCode DSPy Production Module
4
+
5
+ Optimizes tool selection and task planning for TunaCode using DSPy.
6
+ Includes 3-4 tool batching optimization for 3x performance gains.
7
+ """
8
+
9
+ import json
10
+ import logging
11
+ import os
12
+ import re
13
+ from typing import Dict, List
14
+
15
+ import dspy
16
+ from dotenv import load_dotenv
17
+
18
+ # Load environment variables
19
+ load_dotenv()
20
+
21
+ # Configure logging
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # Tool Categories for TunaCode
25
+ TOOL_CATEGORIES = {
26
+ "read_only": ["read_file", "grep", "list_dir", "glob"], # Parallel-executable
27
+ "task_management": ["todo"], # Fast, sequential
28
+ "write_execute": [
29
+ "write_file",
30
+ "update_file",
31
+ "run_command",
32
+ "bash",
33
+ ], # Sequential, needs confirmation
34
+ }
35
+
36
+ ALL_TOOLS = [tool for category in TOOL_CATEGORIES.values() for tool in category]
37
+
38
+
39
+ class ToolSelectionSignature(dspy.Signature):
40
+ """Select optimal tools with batching awareness."""
41
+
42
+ user_request: str = dspy.InputField(desc="The user's request or task")
43
+ current_directory: str = dspy.InputField(desc="Current working directory context")
44
+ tools_json: str = dspy.OutputField(
45
+ desc="JSON array of tool calls with batch grouping, e.g. [[tool1, tool2, tool3], [tool4]]"
46
+ )
47
+ requires_confirmation: bool = dspy.OutputField(
48
+ desc="Whether any tools require user confirmation"
49
+ )
50
+ reasoning: str = dspy.OutputField(desc="Explanation of tool choice and batching strategy")
51
+
52
+
53
+ class TaskPlanningSignature(dspy.Signature):
54
+ """Break down complex tasks with tool hints."""
55
+
56
+ complex_request: str = dspy.InputField(desc="A complex task that needs breakdown")
57
+ subtasks_with_tools: str = dspy.OutputField(
58
+ desc="JSON array of {task, tools, priority} objects"
59
+ )
60
+ total_tool_calls: int = dspy.OutputField(desc="Estimated total number of tool calls")
61
+ requires_todo: bool = dspy.OutputField(desc="Whether todo tool should be used")
62
+ parallelization_opportunities: int = dspy.OutputField(
63
+ desc="Number of parallel execution opportunities"
64
+ )
65
+
66
+
67
+ class PathValidationSignature(dspy.Signature):
68
+ """Validate and convert paths to relative format."""
69
+
70
+ path: str = dspy.InputField(desc="Path to validate")
71
+ current_directory: str = dspy.InputField(desc="Current working directory")
72
+ is_valid: bool = dspy.OutputField(desc="Whether path is valid relative path")
73
+ relative_path: str = dspy.OutputField(desc="Converted relative path")
74
+ reason: str = dspy.OutputField(desc="Validation result explanation")
75
+
76
+
77
+ class OptimizedToolSelector(dspy.Module):
78
+ """Tool selection with batching optimization"""
79
+
80
+ def __init__(self):
81
+ self.predict = dspy.ChainOfThought(ToolSelectionSignature)
82
+
83
+ def forward(self, user_request: str, current_directory: str = "."):
84
+ logger.debug(f"Tool Selection for: {user_request}")
85
+ result = self.predict(user_request=user_request, current_directory=current_directory)
86
+
87
+ # Parse and validate tool batches
88
+ try:
89
+ tool_batches = json.loads(result.tools_json)
90
+ validated_batches = self._validate_batches(tool_batches)
91
+ result.tool_batches = validated_batches
92
+ except Exception as e:
93
+ logger.error(f"Failed to parse tool batches: {e}")
94
+ result.tool_batches = []
95
+
96
+ return result
97
+
98
+ def _validate_batches(self, batches: List[List[str]]) -> List[List[Dict]]:
99
+ """Validate and optimize tool batches for 3-4 tool rule"""
100
+ validated = []
101
+
102
+ for batch in batches:
103
+ # Check if batch contains only read-only tools
104
+ all_read_only = all(
105
+ any(tool_name in tool for tool_name in TOOL_CATEGORIES["read_only"])
106
+ for tool in batch
107
+ )
108
+
109
+ # Optimize batch size (3-4 tools is optimal)
110
+ if all_read_only and len(batch) > 4:
111
+ # Split large batches
112
+ for i in range(0, len(batch), 4):
113
+ sub_batch = batch[i : i + 4]
114
+ validated.append(self._parse_tools(sub_batch))
115
+ else:
116
+ validated_batch = self._parse_tools(batch)
117
+ if validated_batch:
118
+ validated.append(validated_batch)
119
+
120
+ return validated
121
+
122
+ def _parse_tools(self, tools: List[str]) -> List[Dict]:
123
+ """Parse tool strings into proper format"""
124
+ parsed = []
125
+ for tool_str in tools:
126
+ # Extract tool name and args from string like "read_file('main.py')"
127
+ match = re.match(r"(\w+)\((.*)\)", tool_str)
128
+ if match:
129
+ tool_name = match.group(1)
130
+ args_str = match.group(2)
131
+
132
+ tool_dict = {"tool": tool_name, "args": {}}
133
+ if args_str:
134
+ # Handle simple cases like 'file.py' or "pattern", "dir"
135
+ args = [arg.strip().strip("\"'") for arg in args_str.split(",")]
136
+ if tool_name == "read_file":
137
+ tool_dict["args"]["filepath"] = args[0]
138
+ elif tool_name == "grep":
139
+ tool_dict["args"]["pattern"] = args[0]
140
+ if len(args) > 1:
141
+ tool_dict["args"]["directory"] = args[1]
142
+ elif tool_name == "list_dir":
143
+ tool_dict["args"]["directory"] = args[0] if args else "."
144
+ elif tool_name == "glob":
145
+ tool_dict["args"]["pattern"] = args[0]
146
+ if len(args) > 1:
147
+ tool_dict["args"]["directory"] = args[1]
148
+ elif tool_name == "todo":
149
+ tool_dict["args"]["action"] = args[0] if args else "list"
150
+ elif tool_name in ["write_file", "update_file"]:
151
+ if args:
152
+ tool_dict["args"]["filepath"] = args[0]
153
+ elif tool_name in ["run_command", "bash"]:
154
+ if args:
155
+ tool_dict["args"]["command"] = args[0]
156
+
157
+ parsed.append(tool_dict)
158
+
159
+ return parsed
160
+
161
+
162
+ class EnhancedTaskPlanner(dspy.Module):
163
+ """Task planning with tool awareness"""
164
+
165
+ def __init__(self):
166
+ self.predict = dspy.ChainOfThought(TaskPlanningSignature)
167
+
168
+ def forward(self, complex_request: str):
169
+ logger.debug(f"Task Planning for: {complex_request}")
170
+ result = self.predict(complex_request=complex_request)
171
+
172
+ # Parse subtasks
173
+ try:
174
+ result.subtasks = json.loads(result.subtasks_with_tools)
175
+ except Exception as e:
176
+ logger.error(f"Failed to parse subtasks: {e}")
177
+ result.subtasks = []
178
+
179
+ return result
180
+
181
+
182
+ class PathValidator(dspy.Module):
183
+ """Ensure all paths are relative"""
184
+
185
+ def __init__(self):
186
+ self.predict = dspy.Predict(PathValidationSignature)
187
+
188
+ def forward(self, path: str, current_directory: str = "."):
189
+ return self.predict(path=path, current_directory=current_directory)
190
+
191
+
192
+ class TunaCodeDSPy(dspy.Module):
193
+ """Main TunaCode DSPy agent"""
194
+
195
+ def __init__(self):
196
+ self.tool_selector = OptimizedToolSelector()
197
+ self.task_planner = EnhancedTaskPlanner()
198
+ self.path_validator = PathValidator()
199
+
200
+ def forward(self, user_request: str, current_directory: str = "."):
201
+ """Process request with optimization"""
202
+
203
+ # Detect request complexity
204
+ is_complex = self._is_complex_task(user_request)
205
+
206
+ result = {
207
+ "request": user_request,
208
+ "is_complex": is_complex,
209
+ "current_directory": current_directory,
210
+ }
211
+
212
+ if is_complex:
213
+ # Use task planner for complex tasks
214
+ task_plan = self.task_planner(complex_request=user_request)
215
+ result["subtasks"] = task_plan.subtasks
216
+ result["total_tool_calls"] = task_plan.total_tool_calls
217
+ result["requires_todo"] = task_plan.requires_todo
218
+ result["parallelization_opportunities"] = task_plan.parallelization_opportunities
219
+
220
+ if task_plan.requires_todo:
221
+ result["initial_action"] = "Use todo tool to create task list"
222
+ else:
223
+ # Direct tool selection with batching
224
+ tool_selection = self.tool_selector(
225
+ user_request=user_request, current_directory=current_directory
226
+ )
227
+ result["tool_batches"] = tool_selection.tool_batches
228
+ result["requires_confirmation"] = tool_selection.requires_confirmation
229
+ result["reasoning"] = tool_selection.reasoning
230
+
231
+ return result
232
+
233
+ def _is_complex_task(self, request: str) -> bool:
234
+ """Detect if task is complex based on keywords and patterns"""
235
+ complex_indicators = [
236
+ "implement",
237
+ "create",
238
+ "build",
239
+ "refactor",
240
+ "add feature",
241
+ "fix all",
242
+ "update multiple",
243
+ "migrate",
244
+ "integrate",
245
+ "debug",
246
+ "optimize performance",
247
+ "add authentication",
248
+ "setup",
249
+ "configure",
250
+ "test suite",
251
+ ]
252
+
253
+ request_lower = request.lower()
254
+
255
+ # Check for multiple files mentioned
256
+ file_pattern = r"\b\w+\.\w+\b"
257
+ files_mentioned = len(re.findall(file_pattern, request)) > 2
258
+
259
+ # Check for complex keywords
260
+ has_complex_keyword = any(indicator in request_lower for indicator in complex_indicators)
261
+
262
+ # Check for multiple operations
263
+ operation_words = ["and", "then", "also", "after", "before", "plus"]
264
+ has_multiple_ops = sum(1 for word in operation_words if word in request_lower) >= 2
265
+
266
+ return files_mentioned or has_complex_keyword or has_multiple_ops
267
+
268
+ def validate_path(self, path: str, current_directory: str = ".") -> Dict:
269
+ """Validate a path is relative and safe"""
270
+ return self.path_validator(path, current_directory)
271
+
272
+
273
+ def get_tool_selection_examples():
274
+ """Training examples for tool selection"""
275
+ return [
276
+ dspy.Example(
277
+ user_request="Show me the authentication system implementation",
278
+ current_directory=".",
279
+ tools_json='[["grep(\\"auth\\", \\"src/\\")", "list_dir(\\"src/auth/\\")", "glob(\\"**/*auth*.py\\")"]]',
280
+ requires_confirmation=False,
281
+ reasoning="Batch 3 read-only tools for parallel search - optimal performance",
282
+ ).with_inputs("user_request", "current_directory"),
283
+ dspy.Example(
284
+ user_request="Read all config files and the main module",
285
+ current_directory=".",
286
+ tools_json='[["read_file(\\"config.json\\")", "read_file(\\"settings.py\\")", "read_file(\\".env\\")", "read_file(\\"main.py\\")"]]',
287
+ requires_confirmation=False,
288
+ reasoning="Batch 4 file reads together - maximum optimal batch size",
289
+ ).with_inputs("user_request", "current_directory"),
290
+ dspy.Example(
291
+ user_request="Find the bug in validation and fix it",
292
+ current_directory=".",
293
+ tools_json='[["grep(\\"error\\", \\"logs/\\")", "grep(\\"validation\\", \\"src/\\")", "list_dir(\\"src/validators/\\")"], ["read_file(\\"src/validators/user.py\\")"], ["update_file(\\"src/validators/user.py\\", \\"old\\", \\"new\\")"]]',
294
+ requires_confirmation=True,
295
+ reasoning="Search tools batched, then read, then sequential write operation",
296
+ ).with_inputs("user_request", "current_directory"),
297
+ ]
298
+
299
+
300
+ def get_task_planning_examples():
301
+ """Training examples for task planning"""
302
+ return [
303
+ dspy.Example(
304
+ complex_request="Implement user authentication system with JWT tokens",
305
+ subtasks_with_tools='[{"task": "Analyze current app structure", "tools": ["list_dir", "grep", "read_file"], "priority": "high"}, {"task": "Design user model", "tools": ["write_file"], "priority": "high"}, {"task": "Create auth endpoints", "tools": ["write_file", "update_file"], "priority": "high"}, {"task": "Add JWT tokens", "tools": ["write_file", "grep"], "priority": "high"}, {"task": "Write tests", "tools": ["write_file", "run_command"], "priority": "medium"}]',
306
+ total_tool_calls=15,
307
+ requires_todo=True,
308
+ parallelization_opportunities=3,
309
+ ).with_inputs("complex_request"),
310
+ ]
311
+
312
+
313
+ def tool_selection_metric(example, prediction):
314
+ """Metric for tool selection evaluation"""
315
+ score = 0.0
316
+
317
+ # Tool accuracy (40%)
318
+ if hasattr(prediction, "tool_batches") and hasattr(example, "tools_json"):
319
+ try:
320
+ expected = json.loads(example.tools_json)
321
+ predicted = prediction.tool_batches
322
+
323
+ # Check tool selection accuracy
324
+ expected_tools = set()
325
+ predicted_tools = set()
326
+
327
+ for batch in expected:
328
+ for tool in batch:
329
+ tool_name = re.match(r"(\w+)\(", tool)
330
+ if tool_name:
331
+ expected_tools.add(tool_name.group(1))
332
+
333
+ for batch in predicted:
334
+ for tool in batch:
335
+ if isinstance(tool, dict):
336
+ predicted_tools.add(tool.get("tool", ""))
337
+
338
+ if expected_tools == predicted_tools:
339
+ score += 0.4
340
+ else:
341
+ overlap = len(expected_tools & predicted_tools)
342
+ total = len(expected_tools | predicted_tools)
343
+ if total > 0:
344
+ score += 0.4 * (overlap / total)
345
+ except Exception:
346
+ pass
347
+
348
+ # Batching optimization (30%)
349
+ if hasattr(prediction, "tool_batches"):
350
+ batches = prediction.tool_batches
351
+ optimal_batching = True
352
+
353
+ for batch in batches:
354
+ if len(batch) > 0:
355
+ batch_tools = [tool.get("tool", "") for tool in batch if isinstance(tool, dict)]
356
+ if all(tool in TOOL_CATEGORIES["read_only"] for tool in batch_tools):
357
+ if len(batch) < 3 or len(batch) > 4:
358
+ optimal_batching = False
359
+ break
360
+
361
+ if optimal_batching:
362
+ score += 0.3
363
+
364
+ # Confirmation accuracy (20%)
365
+ if hasattr(prediction, "requires_confirmation") and hasattr(example, "requires_confirmation"):
366
+ if prediction.requires_confirmation == example.requires_confirmation:
367
+ score += 0.2
368
+
369
+ # Reasoning quality (10%)
370
+ if hasattr(prediction, "reasoning") and prediction.reasoning and len(prediction.reasoning) > 20:
371
+ score += 0.1
372
+
373
+ return score
374
+
375
+
376
+ def task_planning_metric(example, prediction):
377
+ """Metric for task planning evaluation"""
378
+ score = 0.0
379
+
380
+ # Subtask quality (30%)
381
+ if hasattr(prediction, "subtasks") and hasattr(example, "subtasks_with_tools"):
382
+ try:
383
+ expected = json.loads(example.subtasks_with_tools)
384
+ predicted = prediction.subtasks
385
+
386
+ if abs(len(expected) - len(predicted)) <= 1:
387
+ score += 0.3
388
+ elif abs(len(expected) - len(predicted)) <= 2:
389
+ score += 0.15
390
+ except Exception:
391
+ pass
392
+
393
+ # Tool estimation accuracy (30%)
394
+ if hasattr(prediction, "total_tool_calls") and hasattr(example, "total_tool_calls"):
395
+ if abs(prediction.total_tool_calls - example.total_tool_calls) <= 5:
396
+ score += 0.3
397
+ elif abs(prediction.total_tool_calls - example.total_tool_calls) <= 10:
398
+ score += 0.15
399
+
400
+ # Todo requirement (20%)
401
+ if hasattr(prediction, "requires_todo") and hasattr(example, "requires_todo"):
402
+ if prediction.requires_todo == example.requires_todo:
403
+ score += 0.2
404
+
405
+ # Parallelization awareness (20%)
406
+ if hasattr(prediction, "parallelization_opportunities") and hasattr(
407
+ example, "parallelization_opportunities"
408
+ ):
409
+ if (
410
+ abs(prediction.parallelization_opportunities - example.parallelization_opportunities)
411
+ <= 2
412
+ ):
413
+ score += 0.2
414
+
415
+ return score
416
+
417
+
418
+ def create_optimized_agent(api_key: str = None, model: str = "openrouter/openai/gpt-4.1-mini"):
419
+ """Create and optimize the TunaCode DSPy agent"""
420
+
421
+ # Configure DSPy
422
+ if not api_key:
423
+ api_key = os.getenv("OPENROUTER_API_KEY")
424
+ if not api_key:
425
+ raise ValueError("Please set OPENROUTER_API_KEY environment variable")
426
+
427
+ lm = dspy.LM(
428
+ model,
429
+ api_base="https://openrouter.ai/api/v1",
430
+ api_key=api_key,
431
+ temperature=0.3,
432
+ )
433
+ dspy.configure(lm=lm)
434
+
435
+ # Create agent
436
+ agent = TunaCodeDSPy()
437
+
438
+ # Optimize tool selector
439
+ tool_examples = get_tool_selection_examples()
440
+ tool_optimizer = dspy.BootstrapFewShot(
441
+ metric=lambda ex, pred, trace: tool_selection_metric(ex, pred),
442
+ max_bootstrapped_demos=3,
443
+ )
444
+ agent.tool_selector = tool_optimizer.compile(agent.tool_selector, trainset=tool_examples)
445
+
446
+ # Optimize task planner
447
+ task_examples = get_task_planning_examples()
448
+ task_optimizer = dspy.BootstrapFewShot(
449
+ metric=lambda ex, pred, trace: task_planning_metric(ex, pred),
450
+ max_bootstrapped_demos=2,
451
+ )
452
+ agent.task_planner = task_optimizer.compile(agent.task_planner, trainset=task_examples)
453
+
454
+ return agent
455
+
456
+
457
+ agent = create_optimized_agent()
458
+ result = agent("Show me the authentication implementation", ".")