tunacode-cli 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tunacode-cli might be problematic. Click here for more details.

Files changed (36) hide show
  1. tunacode/cli/commands.py +91 -33
  2. tunacode/cli/model_selector.py +178 -0
  3. tunacode/cli/repl.py +11 -10
  4. tunacode/configuration/models.py +11 -1
  5. tunacode/constants.py +11 -11
  6. tunacode/context.py +1 -3
  7. tunacode/core/agents/main.py +52 -94
  8. tunacode/core/agents/tinyagent_main.py +171 -0
  9. tunacode/core/setup/git_safety_setup.py +39 -51
  10. tunacode/core/setup/optimized_coordinator.py +73 -0
  11. tunacode/exceptions.py +13 -15
  12. tunacode/services/enhanced_undo_service.py +322 -0
  13. tunacode/services/project_undo_service.py +311 -0
  14. tunacode/services/undo_service.py +18 -21
  15. tunacode/tools/base.py +11 -20
  16. tunacode/tools/tinyagent_tools.py +103 -0
  17. tunacode/tools/update_file.py +24 -14
  18. tunacode/tools/write_file.py +9 -7
  19. tunacode/types.py +2 -2
  20. tunacode/ui/completers.py +98 -33
  21. tunacode/ui/input.py +8 -7
  22. tunacode/ui/keybindings.py +1 -3
  23. tunacode/ui/lexers.py +16 -17
  24. tunacode/ui/output.py +9 -3
  25. tunacode/ui/panels.py +4 -4
  26. tunacode/ui/prompt_manager.py +6 -4
  27. tunacode/utils/lazy_imports.py +59 -0
  28. tunacode/utils/regex_cache.py +33 -0
  29. tunacode/utils/system.py +13 -13
  30. tunacode_cli-0.0.6.dist-info/METADATA +235 -0
  31. {tunacode_cli-0.0.4.dist-info → tunacode_cli-0.0.6.dist-info}/RECORD +35 -27
  32. tunacode_cli-0.0.4.dist-info/METADATA +0 -247
  33. {tunacode_cli-0.0.4.dist-info → tunacode_cli-0.0.6.dist-info}/WHEEL +0 -0
  34. {tunacode_cli-0.0.4.dist-info → tunacode_cli-0.0.6.dist-info}/entry_points.txt +0 -0
  35. {tunacode_cli-0.0.4.dist-info → tunacode_cli-0.0.6.dist-info}/licenses/LICENSE +0 -0
  36. {tunacode_cli-0.0.4.dist-info → tunacode_cli-0.0.6.dist-info}/top_level.txt +0 -0
tunacode/cli/commands.py CHANGED
@@ -6,7 +6,6 @@ from enum import Enum
6
6
  from typing import Any, Dict, List, Optional, Type
7
7
 
8
8
  from .. import utils
9
- from ..configuration.models import ModelRegistry
10
9
  from ..exceptions import ValidationError
11
10
  from ..services.undo_service import perform_undo
12
11
  from ..types import CommandArgs, CommandContext, CommandResult, ProcessRequestCallback
@@ -260,7 +259,6 @@ class UndoCommand(SimpleCommand):
260
259
  await ui.muted(" • File operations will still work, but can't be undone")
261
260
 
262
261
 
263
-
264
262
  class BranchCommand(SimpleCommand):
265
263
  """Create and switch to a new git branch."""
266
264
 
@@ -388,46 +386,106 @@ class ModelCommand(SimpleCommand):
388
386
  super().__init__(
389
387
  CommandSpec(
390
388
  name="model",
391
- aliases=["/model"],
392
- description="List models or select a model (e.g., /model 3 or /model 3 default)",
389
+ aliases=["/model", "/m"],
390
+ description="List and select AI models interactively",
393
391
  category=CommandCategory.MODEL,
394
392
  )
395
393
  )
396
394
 
397
395
  async def execute(self, args: CommandArgs, context: CommandContext) -> Optional[str]:
396
+ from tunacode.cli.model_selector import ModelSelector
397
+
398
+ selector = ModelSelector()
399
+
398
400
  if not args:
399
- # No arguments - list models
400
- await ui.models(context.state_manager)
401
+ # No arguments - show enhanced model list
402
+ await self._show_model_list(selector, context.state_manager)
401
403
  return None
402
404
 
403
- # Parse model index
404
- try:
405
- model_index = int(args[0])
406
- except ValueError:
407
- await ui.error(f"Invalid model index: {args[0]}")
408
- return None
405
+ # Find model by query (index, name, or fuzzy match)
406
+ query = args[0]
407
+ model_info = selector.find_model(query)
409
408
 
410
- # Get model list
411
- model_registry = ModelRegistry()
412
- models = list(model_registry.list_models().keys())
413
- if model_index < 0 or model_index >= len(models):
414
- await ui.error(f"Model index {model_index} out of range")
409
+ if not model_info:
410
+ # Try to provide helpful suggestions
411
+ await ui.error(f"Model '{query}' not found")
412
+ await ui.muted(
413
+ "Try: /model (to list all), or use a number 0-18, "
414
+ "or model name like 'opus' or 'gpt-4'"
415
+ )
415
416
  return None
416
417
 
417
418
  # Set the model
418
- model = models[model_index]
419
- context.state_manager.session.current_model = model
419
+ context.state_manager.session.current_model = model_info.id
420
420
 
421
421
  # Check if setting as default
422
422
  if len(args) > 1 and args[1] == "default":
423
- utils.user_configuration.set_default_model(model, context.state_manager)
424
- await ui.muted("Updating default model")
423
+ utils.user_configuration.set_default_model(model_info.id, context.state_manager)
424
+ await ui.success(
425
+ f"Set default model: {model_info.display_name} {model_info.provider.value[2]}"
426
+ )
425
427
  return "restart"
426
428
  else:
427
- # Show success message with the new model
428
- await ui.success(f"Switched to model: {model}")
429
+ # Show success message with model details
430
+ cost_emoji = selector.get_cost_emoji(model_info.cost_tier)
431
+ await ui.success(
432
+ f"Switched to: {model_info.display_name} "
433
+ f"{model_info.provider.value[2]} {cost_emoji}\n"
434
+ f" → {model_info.description}"
435
+ )
429
436
  return None
430
437
 
438
+ async def _show_model_list(self, selector, state_manager) -> None:
439
+ """Show enhanced model list grouped by provider."""
440
+ from rich.table import Table
441
+ from rich.text import Text
442
+
443
+ # Create table
444
+ table = Table(show_header=True, box=None, padding=(0, 2))
445
+ table.add_column("ID", style="dim", width=3)
446
+ table.add_column("Model", style="bold")
447
+ table.add_column("Short", style="cyan")
448
+ table.add_column("Description", style="dim")
449
+ table.add_column("Cost", justify="center", width=4)
450
+
451
+ # Current model
452
+ current_model = state_manager.session.current_model if state_manager else None
453
+
454
+ # Add models grouped by provider
455
+ model_index = 0
456
+ grouped = selector.get_models_by_provider()
457
+
458
+ for provider in [p for p in grouped if grouped[p]]: # Only show providers with models
459
+ # Add provider header
460
+ table.add_row(
461
+ "",
462
+ Text(f"{provider.value[2]} {provider.value[1]}", style="bold magenta"),
463
+ "",
464
+ "",
465
+ "",
466
+ )
467
+
468
+ # Add models for this provider
469
+ for model in grouped[provider]:
470
+ is_current = model.id == current_model
471
+ style = "bold green" if is_current else ""
472
+
473
+ table.add_row(
474
+ str(model_index),
475
+ Text(model.display_name + (" ← current" if is_current else ""), style=style),
476
+ model.short_name,
477
+ model.description,
478
+ selector.get_cost_emoji(model.cost_tier),
479
+ )
480
+ model_index += 1
481
+
482
+ # Show the table
483
+ await ui.panel("Available Models", table, border_style="cyan")
484
+
485
+ # Show usage hints
486
+ await ui.muted("\n💡 Usage: /model <number|name> [default]")
487
+ await ui.muted(" Examples: /model 3, /model opus, /model gpt-4 default")
488
+
431
489
 
432
490
  @dataclass
433
491
  class CommandDependencies:
@@ -488,8 +546,7 @@ class CommandRegistry:
488
546
  category_commands = self._categories[command.category]
489
547
  # Remove any existing instance of this command class
490
548
  self._categories[command.category] = [
491
- cmd for cmd in category_commands
492
- if cmd.__class__ != command.__class__
549
+ cmd for cmd in category_commands if cmd.__class__ != command.__class__
493
550
  ]
494
551
  # Add the new instance
495
552
  self._categories[command.category].append(command)
@@ -533,7 +590,7 @@ class CommandRegistry:
533
590
  # Only update if callback has changed
534
591
  if self._factory.dependencies.process_request_callback == callback:
535
592
  return
536
-
593
+
537
594
  self._factory.update_dependencies(process_request_callback=callback)
538
595
 
539
596
  # Re-register CompactCommand with new dependency if already registered
@@ -568,10 +625,10 @@ class CommandRegistry:
568
625
  if command_name in self._commands:
569
626
  command = self._commands[command_name]
570
627
  return await command.execute(args, context)
571
-
628
+
572
629
  # Try partial matching
573
630
  matches = self.find_matching_commands(command_name)
574
-
631
+
575
632
  if not matches:
576
633
  raise ValidationError(f"Unknown command: {command_name}")
577
634
  elif len(matches) == 1:
@@ -581,16 +638,17 @@ class CommandRegistry:
581
638
  else:
582
639
  # Ambiguous - show possibilities
583
640
  raise ValidationError(
584
- f"Ambiguous command '{command_name}'. Did you mean: {', '.join(sorted(set(matches)))}?"
641
+ f"Ambiguous command '{command_name}'. Did you mean: "
642
+ f"{', '.join(sorted(set(matches)))}?"
585
643
  )
586
644
 
587
645
  def find_matching_commands(self, partial_command: str) -> List[str]:
588
646
  """
589
647
  Find all commands that start with the given partial command.
590
-
648
+
591
649
  Args:
592
650
  partial_command: The partial command to match
593
-
651
+
594
652
  Returns:
595
653
  List of matching command names
596
654
  """
@@ -608,11 +666,11 @@ class CommandRegistry:
608
666
  return False
609
667
 
610
668
  command_name = parts[0].lower()
611
-
669
+
612
670
  # Check exact match first
613
671
  if command_name in self._commands:
614
672
  return True
615
-
673
+
616
674
  # Check partial match
617
675
  return len(self.find_matching_commands(command_name)) > 0
618
676
 
@@ -0,0 +1,178 @@
1
+ """Interactive model selector with modern UI."""
2
+
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+ from typing import Dict, List, Optional
6
+
7
+ from tunacode.configuration.models import ModelRegistry
8
+ from tunacode.types import ModelName
9
+
10
+
11
+ class ModelProvider(Enum):
12
+ """Model providers with their display names."""
13
+
14
+ ANTHROPIC = ("anthropic", "Anthropic", "🤖")
15
+ OPENAI = ("openai", "OpenAI", "🧠")
16
+ GOOGLE = ("google-gla", "Google", "🌐")
17
+ OPENROUTER = ("openrouter", "OpenRouter", "🚀")
18
+
19
+
20
+ @dataclass
21
+ class ModelInfo:
22
+ """Enhanced model information."""
23
+
24
+ id: ModelName
25
+ provider: ModelProvider
26
+ display_name: str
27
+ short_name: str
28
+ description: str
29
+ cost_tier: str # low, medium, high, premium
30
+
31
+
32
+ class ModelSelector:
33
+ """Enhanced model selection with categorization and search."""
34
+
35
+ def __init__(self):
36
+ self.registry = ModelRegistry()
37
+ self.models = self._build_model_info()
38
+
39
+ def _build_model_info(self) -> List[ModelInfo]:
40
+ """Build enhanced model information with metadata."""
41
+ models = []
42
+
43
+ # Model metadata mapping
44
+ model_metadata = {
45
+ # Anthropic models
46
+ "anthropic:claude-opus-4-20250514": (
47
+ "Claude Opus 4",
48
+ "opus-4",
49
+ "Most capable Claude model",
50
+ "high",
51
+ ),
52
+ "anthropic:claude-sonnet-4-20250514": (
53
+ "Claude Sonnet 4",
54
+ "sonnet-4",
55
+ "Balanced performance",
56
+ "medium",
57
+ ),
58
+ "anthropic:claude-3-7-sonnet-latest": (
59
+ "Claude 3.7 Sonnet",
60
+ "sonnet-3.7",
61
+ "Previous generation",
62
+ "medium",
63
+ ),
64
+ # Google models
65
+ "google-gla:gemini-2.0-flash": (
66
+ "Gemini 2.0 Flash",
67
+ "flash-2.0",
68
+ "Fast and efficient",
69
+ "low",
70
+ ),
71
+ "google-gla:gemini-2.5-flash-preview-05-20": (
72
+ "Gemini 2.5 Flash",
73
+ "flash-2.5",
74
+ "Latest preview",
75
+ "low",
76
+ ),
77
+ "google-gla:gemini-2.5-pro-preview-05-06": (
78
+ "Gemini 2.5 Pro",
79
+ "pro-2.5",
80
+ "Most capable Gemini",
81
+ "medium",
82
+ ),
83
+ # OpenAI models
84
+ "openai:gpt-4.1": ("GPT-4.1", "gpt-4.1", "Latest GPT-4", "medium"),
85
+ "openai:gpt-4.1-mini": ("GPT-4.1 Mini", "4.1-mini", "Efficient GPT-4", "low"),
86
+ "openai:gpt-4.1-nano": ("GPT-4.1 Nano", "4.1-nano", "Smallest GPT-4", "low"),
87
+ "openai:gpt-4o": ("GPT-4o", "gpt-4o", "Optimized GPT-4", "medium"),
88
+ "openai:o3": ("O3", "o3", "Advanced reasoning", "premium"),
89
+ "openai:o3-mini": ("O3 Mini", "o3-mini", "Efficient reasoning", "high"),
90
+ # OpenRouter models
91
+ "openrouter:mistralai/devstral-small": (
92
+ "Devstral Small",
93
+ "devstral",
94
+ "Code-focused",
95
+ "low",
96
+ ),
97
+ "openrouter:codex-mini-latest": ("Codex Mini", "codex", "Code completion", "medium"),
98
+ "openrouter:o4-mini-high": ("O4 Mini High", "o4-high", "Enhanced O4", "high"),
99
+ "openrouter:o3": ("O3 (OpenRouter)", "o3-or", "O3 via OpenRouter", "premium"),
100
+ "openrouter:o4-mini": ("O4 Mini", "o4-mini", "Standard O4", "high"),
101
+ "openrouter:openai/gpt-4.1": (
102
+ "GPT-4.1 (OR)",
103
+ "gpt-4.1-or",
104
+ "GPT-4.1 via OpenRouter",
105
+ "medium",
106
+ ),
107
+ "openrouter:openai/gpt-4.1-mini": (
108
+ "GPT-4.1 Mini (OR)",
109
+ "4.1-mini-or",
110
+ "GPT-4.1 Mini via OpenRouter",
111
+ "low",
112
+ ),
113
+ }
114
+
115
+ for model_id in self.registry.list_model_ids():
116
+ provider = self._get_provider(model_id)
117
+ if provider and model_id in model_metadata:
118
+ display_name, short_name, description, cost_tier = model_metadata[model_id]
119
+ models.append(
120
+ ModelInfo(
121
+ id=model_id,
122
+ provider=provider,
123
+ display_name=display_name,
124
+ short_name=short_name,
125
+ description=description,
126
+ cost_tier=cost_tier,
127
+ )
128
+ )
129
+
130
+ return models
131
+
132
+ def _get_provider(self, model_id: str) -> Optional[ModelProvider]:
133
+ """Get provider from model ID."""
134
+ for provider in ModelProvider:
135
+ if model_id.startswith(provider.value[0]):
136
+ return provider
137
+ return None
138
+
139
+ def get_models_by_provider(self) -> Dict[ModelProvider, List[ModelInfo]]:
140
+ """Group models by provider."""
141
+ grouped = {provider: [] for provider in ModelProvider}
142
+ for model in self.models:
143
+ if model.provider:
144
+ grouped[model.provider].append(model)
145
+ return grouped
146
+
147
+ def find_model(self, query: str) -> Optional[ModelInfo]:
148
+ """Find model by index, name, or fuzzy match."""
149
+ query = query.lower().strip()
150
+
151
+ # Try as index first
152
+ try:
153
+ index = int(query)
154
+ if 0 <= index < len(self.models):
155
+ return self.models[index]
156
+ except ValueError:
157
+ pass
158
+
159
+ # Exact match on ID
160
+ for model in self.models:
161
+ if model.id.lower() == query:
162
+ return model
163
+
164
+ # Match on short name
165
+ for model in self.models:
166
+ if model.short_name.lower() == query:
167
+ return model
168
+
169
+ # Fuzzy match on display name or short name
170
+ for model in self.models:
171
+ if query in model.display_name.lower() or query in model.short_name.lower():
172
+ return model
173
+
174
+ return None
175
+
176
+ def get_cost_emoji(self, cost_tier: str) -> str:
177
+ """Get emoji representation of cost tier."""
178
+ return {"low": "💚", "medium": "💛", "high": "🧡", "premium": "❤️"}.get(cost_tier, "⚪")
tunacode/cli/repl.py CHANGED
@@ -10,7 +10,6 @@ from asyncio.exceptions import CancelledError
10
10
 
11
11
  from prompt_toolkit.application import run_in_terminal
12
12
  from prompt_toolkit.application.current import get_app
13
- from pydantic_ai.exceptions import UnexpectedModelBehavior
14
13
 
15
14
  from tunacode.configuration.settings import ApplicationSettings
16
15
  from tunacode.core.agents import main as agent
@@ -183,15 +182,17 @@ async def process_request(text: str, state_manager: StateManager, output: bool =
183
182
  await ui.muted("Request cancelled")
184
183
  except UserAbortError:
185
184
  await ui.muted("Operation aborted.")
186
- except UnexpectedModelBehavior as e:
187
- error_message = str(e)
188
- await ui.muted(error_message)
189
- patch_tool_messages(error_message, state_manager)
190
185
  except Exception as e:
191
- # Wrap unexpected exceptions in AgentError for better tracking
192
- agent_error = AgentError(f"Agent processing failed: {str(e)}")
193
- agent_error.__cause__ = e # Preserve the original exception chain
194
- await ui.error(str(e))
186
+ # Check if this is a model behavior error from tinyAgent
187
+ if "model" in str(e).lower() or "unexpected" in str(e).lower():
188
+ error_message = str(e)
189
+ await ui.muted(error_message)
190
+ patch_tool_messages(error_message, state_manager)
191
+ else:
192
+ # Wrap unexpected exceptions in AgentError for better tracking
193
+ agent_error = AgentError(f"Agent processing failed: {str(e)}")
194
+ agent_error.__cause__ = e # Preserve the original exception chain
195
+ await ui.error(str(e))
195
196
  finally:
196
197
  await ui.spinner(False, state_manager.session.spinner, state_manager)
197
198
  state_manager.session.current_task = None
@@ -214,7 +215,7 @@ async def repl(state_manager: StateManager):
214
215
  await ui.line()
215
216
  await ui.success("ready to hack...")
216
217
  await ui.line()
217
-
218
+
218
219
  instance = agent.get_or_create_agent(state_manager.session.current_model, state_manager)
219
220
 
220
221
  async with instance.run_mcp_servers():
@@ -10,8 +10,18 @@ from tunacode.types import ModelRegistry as ModelRegistryType
10
10
 
11
11
 
12
12
  class ModelRegistry:
13
+ _instance = None
14
+ _models_cache = None
15
+
16
+ def __new__(cls):
17
+ if cls._instance is None:
18
+ cls._instance = super(ModelRegistry, cls).__new__(cls)
19
+ return cls._instance
20
+
13
21
  def __init__(self):
14
- self._models = self._load_default_models()
22
+ if ModelRegistry._models_cache is None:
23
+ ModelRegistry._models_cache = self._load_default_models()
24
+ self._models = ModelRegistry._models_cache
15
25
 
16
26
  def _load_default_models(self) -> ModelRegistryType:
17
27
  return {
tunacode/constants.py CHANGED
@@ -7,7 +7,7 @@ Centralizes all magic strings, UI text, error messages, and application constant
7
7
 
8
8
  # Application info
9
9
  APP_NAME = "TunaCode"
10
- APP_VERSION = "0.5.1"
10
+ APP_VERSION = "0.1.0"
11
11
 
12
12
  # File patterns
13
13
  GUIDE_FILE_PATTERN = "{name}.md"
@@ -63,22 +63,22 @@ COMMAND_CATEGORIES = {
63
63
  }
64
64
 
65
65
  # System paths
66
- SIDEKICK_HOME_DIR = ".tunacode"
66
+ TUNACODE_HOME_DIR = ".tunacode"
67
67
  SESSIONS_SUBDIR = "sessions"
68
68
  DEVICE_ID_FILE = "device_id"
69
69
 
70
70
  # UI colors - Modern sleek color scheme
71
71
  UI_COLORS = {
72
- "primary": "#00d7ff", # Bright cyan
73
- "secondary": "#64748b", # Slate gray
74
- "accent": "#7c3aed", # Purple accent
75
- "success": "#10b981", # Emerald green
76
- "warning": "#f59e0b", # Amber
77
- "error": "#ef4444", # Red
78
- "muted": "#94a3b8", # Light slate
79
- "file_ref": "#00d7ff", # Bright cyan
72
+ "primary": "#00d7ff", # Bright cyan
73
+ "secondary": "#64748b", # Slate gray
74
+ "accent": "#7c3aed", # Purple accent
75
+ "success": "#10b981", # Emerald green
76
+ "warning": "#f59e0b", # Amber
77
+ "error": "#ef4444", # Red
78
+ "muted": "#94a3b8", # Light slate
79
+ "file_ref": "#00d7ff", # Bright cyan
80
80
  "background": "#0f172a", # Dark slate
81
- "border": "#334155", # Slate border
81
+ "border": "#334155", # Slate border
82
82
  }
83
83
 
84
84
  # UI text and formatting
tunacode/context.py CHANGED
@@ -1,11 +1,9 @@
1
- import json
2
- import os
3
1
  import subprocess
4
2
  from pathlib import Path
5
3
  from typing import Dict, List
6
4
 
7
- from tunacode.utils.system import list_cwd
8
5
  from tunacode.utils.ripgrep import ripgrep
6
+ from tunacode.utils.system import list_cwd
9
7
 
10
8
 
11
9
  async def get_git_status() -> Dict[str, object]:
@@ -2,49 +2,28 @@
2
2
 
3
3
  Main agent functionality and coordination for the Sidekick CLI.
4
4
  Provides agent creation, message processing, and tool call management.
5
+ Now using tinyAgent instead of pydantic-ai.
5
6
  """
6
7
 
7
- from datetime import datetime, timezone
8
8
  from typing import Optional
9
9
 
10
- from pydantic_ai import Agent, Tool
11
- from pydantic_ai.messages import ModelRequest, ToolReturnPart
12
-
13
10
  from tunacode.core.state import StateManager
14
- from tunacode.services.mcp import get_mcp_servers
15
- from tunacode.tools.read_file import read_file
16
- from tunacode.tools.run_command import run_command
17
- from tunacode.tools.update_file import update_file
18
- from tunacode.tools.write_file import write_file
19
- from tunacode.types import (AgentRun, ErrorMessage, ModelName, PydanticAgent, ToolCallback,
20
- ToolCallId, ToolName)
21
-
22
-
23
- async def _process_node(node, tool_callback: Optional[ToolCallback], state_manager: StateManager):
24
- if hasattr(node, "request"):
25
- state_manager.session.messages.append(node.request)
26
-
27
- if hasattr(node, "model_response"):
28
- state_manager.session.messages.append(node.model_response)
29
- for part in node.model_response.parts:
30
- if part.part_kind == "tool-call" and tool_callback:
31
- await tool_callback(part, node)
32
-
33
-
34
- def get_or_create_agent(model: ModelName, state_manager: StateManager) -> PydanticAgent:
35
- if model not in state_manager.session.agents:
36
- max_retries = state_manager.session.user_config["settings"]["max_retries"]
37
- state_manager.session.agents[model] = Agent(
38
- model=model,
39
- tools=[
40
- Tool(read_file, max_retries=max_retries),
41
- Tool(run_command, max_retries=max_retries),
42
- Tool(update_file, max_retries=max_retries),
43
- Tool(write_file, max_retries=max_retries),
44
- ],
45
- mcp_servers=get_mcp_servers(state_manager),
46
- )
47
- return state_manager.session.agents[model]
11
+ from tunacode.types import AgentRun, ErrorMessage, ModelName, ToolCallback
12
+
13
+ # Import tinyAgent implementation
14
+ from .tinyagent_main import get_or_create_react_agent
15
+ from .tinyagent_main import patch_tool_messages as tinyagent_patch_tool_messages
16
+ from .tinyagent_main import process_request_with_tinyagent
17
+
18
+ # Wrapper functions for backward compatibility with pydantic-ai interface
19
+
20
+
21
+ def get_or_create_agent(model: ModelName, state_manager: StateManager):
22
+ """
23
+ Wrapper for backward compatibility.
24
+ Returns the ReactAgent instance from tinyAgent.
25
+ """
26
+ return get_or_create_react_agent(model, state_manager)
48
27
 
49
28
 
50
29
  def patch_tool_messages(
@@ -52,57 +31,10 @@ def patch_tool_messages(
52
31
  state_manager: StateManager = None,
53
32
  ):
54
33
  """
55
- Find any tool calls without responses and add synthetic error responses for them.
56
- Takes an error message to use in the synthesized tool response.
57
-
58
- Ignores tools that have corresponding retry prompts as the model is already
59
- addressing them.
34
+ Wrapper for backward compatibility.
35
+ TinyAgent handles tool errors internally, so this is mostly a no-op.
60
36
  """
61
- if state_manager is None:
62
- raise ValueError("state_manager is required for patch_tool_messages")
63
-
64
- messages = state_manager.session.messages
65
-
66
- if not messages:
67
- return
68
-
69
- # Map tool calls to their tool returns
70
- tool_calls: dict[ToolCallId, ToolName] = {} # tool_call_id -> tool_name
71
- tool_returns: set[ToolCallId] = set() # set of tool_call_ids with returns
72
- retry_prompts: set[ToolCallId] = set() # set of tool_call_ids with retry prompts
73
-
74
- for message in messages:
75
- if hasattr(message, "parts"):
76
- for part in message.parts:
77
- if (
78
- hasattr(part, "part_kind")
79
- and hasattr(part, "tool_call_id")
80
- and part.tool_call_id
81
- ):
82
- if part.part_kind == "tool-call":
83
- tool_calls[part.tool_call_id] = part.tool_name
84
- elif part.part_kind == "tool-return":
85
- tool_returns.add(part.tool_call_id)
86
- elif part.part_kind == "retry-prompt":
87
- retry_prompts.add(part.tool_call_id)
88
-
89
- # Identify orphaned tools (those without responses and not being retried)
90
- for tool_call_id, tool_name in list(tool_calls.items()):
91
- if tool_call_id not in tool_returns and tool_call_id not in retry_prompts:
92
- messages.append(
93
- ModelRequest(
94
- parts=[
95
- ToolReturnPart(
96
- tool_name=tool_name,
97
- content=error_message,
98
- tool_call_id=tool_call_id,
99
- timestamp=datetime.now(timezone.utc),
100
- part_kind="tool-return",
101
- )
102
- ],
103
- kind="request",
104
- )
105
- )
37
+ tinyagent_patch_tool_messages(error_message, state_manager)
106
38
 
107
39
 
108
40
  async def process_request(
@@ -111,9 +43,35 @@ async def process_request(
111
43
  state_manager: StateManager,
112
44
  tool_callback: Optional[ToolCallback] = None,
113
45
  ) -> AgentRun:
114
- agent = get_or_create_agent(model, state_manager)
115
- mh = state_manager.session.messages.copy()
116
- async with agent.iter(message, message_history=mh) as agent_run:
117
- async for node in agent_run:
118
- await _process_node(node, tool_callback, state_manager)
119
- return agent_run
46
+ """
47
+ Process a request using tinyAgent.
48
+ Returns a result that mimics the pydantic-ai AgentRun structure.
49
+ """
50
+ result = await process_request_with_tinyagent(model, message, state_manager, tool_callback)
51
+
52
+ # Create a mock AgentRun object for compatibility
53
+ class MockAgentRun:
54
+ def __init__(self, result_dict):
55
+ self._result = result_dict
56
+
57
+ @property
58
+ def result(self):
59
+ class MockResult:
60
+ def __init__(self, content):
61
+ self._content = content
62
+
63
+ @property
64
+ def output(self):
65
+ return self._content
66
+
67
+ return MockResult(self._result.get("result", ""))
68
+
69
+ @property
70
+ def messages(self):
71
+ return state_manager.session.messages
72
+
73
+ @property
74
+ def model(self):
75
+ return self._result.get("model", model)
76
+
77
+ return MockAgentRun(result)