tunacode-cli 0.0.71__py3-none-any.whl → 0.0.73__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tunacode-cli might be problematic. Click here for more details.
- tunacode/cli/commands/implementations/model.py +332 -32
- tunacode/cli/repl.py +2 -1
- tunacode/constants.py +1 -1
- tunacode/core/agents/agent_components/agent_config.py +32 -19
- tunacode/core/agents/agent_components/agent_helpers.py +1 -1
- tunacode/core/agents/agent_components/node_processor.py +35 -3
- tunacode/core/agents/agent_components/response_state.py +109 -6
- tunacode/core/agents/agent_components/state_transition.py +116 -0
- tunacode/core/agents/agent_components/task_completion.py +10 -6
- tunacode/core/agents/main.py +4 -4
- tunacode/prompts/system.md +11 -1
- tunacode/types.py +9 -0
- tunacode/ui/completers.py +211 -9
- tunacode/ui/input.py +7 -1
- tunacode/ui/model_selector.py +394 -0
- tunacode/utils/models_registry.py +563 -0
- {tunacode_cli-0.0.71.dist-info → tunacode_cli-0.0.73.dist-info}/METADATA +1 -1
- {tunacode_cli-0.0.71.dist-info → tunacode_cli-0.0.73.dist-info}/RECORD +21 -18
- {tunacode_cli-0.0.71.dist-info → tunacode_cli-0.0.73.dist-info}/WHEEL +0 -0
- {tunacode_cli-0.0.71.dist-info → tunacode_cli-0.0.73.dist-info}/entry_points.txt +0 -0
- {tunacode_cli-0.0.71.dist-info → tunacode_cli-0.0.73.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,13 +1,116 @@
|
|
|
1
1
|
"""Response state management for tracking agent processing state."""
|
|
2
2
|
|
|
3
|
-
from dataclasses import dataclass
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from tunacode.types import AgentState
|
|
7
|
+
|
|
8
|
+
from .state_transition import AGENT_TRANSITION_RULES, AgentStateMachine
|
|
4
9
|
|
|
5
10
|
|
|
6
11
|
@dataclass
|
|
7
12
|
class ResponseState:
|
|
8
|
-
"""
|
|
13
|
+
"""Enhanced response state using enum-based state machine."""
|
|
14
|
+
|
|
15
|
+
# Internal state machine
|
|
16
|
+
_state_machine: AgentStateMachine = field(
|
|
17
|
+
default_factory=lambda: AgentStateMachine(AgentState.USER_INPUT, AGENT_TRANSITION_RULES)
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
# Backward compatibility boolean flags (derived from enum state)
|
|
21
|
+
_has_user_response: bool = False
|
|
22
|
+
_task_completed: bool = False
|
|
23
|
+
_awaiting_user_guidance: bool = False
|
|
24
|
+
_has_final_synthesis: bool = False
|
|
25
|
+
|
|
26
|
+
def __post_init__(self):
|
|
27
|
+
"""Initialize the state machine."""
|
|
28
|
+
if not hasattr(self, "_state_machine"):
|
|
29
|
+
self._state_machine = AgentStateMachine(AgentState.USER_INPUT, AGENT_TRANSITION_RULES)
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def current_state(self) -> AgentState:
|
|
33
|
+
"""Get the current enum state."""
|
|
34
|
+
return self._state_machine.current_state
|
|
35
|
+
|
|
36
|
+
def transition_to(self, new_state: AgentState) -> None:
|
|
37
|
+
"""Transition to a new state."""
|
|
38
|
+
self._state_machine.transition_to(new_state)
|
|
39
|
+
|
|
40
|
+
def can_transition_to(self, target_state: AgentState) -> bool:
|
|
41
|
+
"""Check if a transition to the target state is allowed."""
|
|
42
|
+
return self._state_machine.can_transition_to(target_state)
|
|
43
|
+
|
|
44
|
+
# Backward compatibility properties
|
|
45
|
+
@property
|
|
46
|
+
def has_user_response(self) -> bool:
|
|
47
|
+
"""Legacy boolean flag for user response detection."""
|
|
48
|
+
return self._has_user_response
|
|
49
|
+
|
|
50
|
+
@has_user_response.setter
|
|
51
|
+
def has_user_response(self, value: bool) -> None:
|
|
52
|
+
"""Set the legacy has_user_response flag."""
|
|
53
|
+
self._has_user_response = value
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def task_completed(self) -> bool:
|
|
57
|
+
"""Legacy boolean flag for task completion (derived from state machine)."""
|
|
58
|
+
# If explicitly set true, honor it; otherwise derive from state machine
|
|
59
|
+
return bool(self._task_completed or self._state_machine.is_completed())
|
|
60
|
+
|
|
61
|
+
@task_completed.setter
|
|
62
|
+
def task_completed(self, value: bool) -> None:
|
|
63
|
+
"""Set the legacy task_completed flag and sync with state machine."""
|
|
64
|
+
self._task_completed = bool(value)
|
|
65
|
+
if value:
|
|
66
|
+
# Ensure state reflects completion in RESPONSE
|
|
67
|
+
try:
|
|
68
|
+
if (
|
|
69
|
+
self._state_machine.current_state != AgentState.RESPONSE
|
|
70
|
+
and self._state_machine.can_transition_to(AgentState.RESPONSE)
|
|
71
|
+
):
|
|
72
|
+
self._state_machine.transition_to(AgentState.RESPONSE)
|
|
73
|
+
except Exception:
|
|
74
|
+
# Best-effort: ignore invalid transition in legacy paths
|
|
75
|
+
pass
|
|
76
|
+
self._state_machine.set_completion_detected(True)
|
|
77
|
+
else:
|
|
78
|
+
self._state_machine.set_completion_detected(False)
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def awaiting_user_guidance(self) -> bool:
|
|
82
|
+
"""Legacy boolean flag for awaiting user guidance."""
|
|
83
|
+
return self._awaiting_user_guidance
|
|
84
|
+
|
|
85
|
+
@awaiting_user_guidance.setter
|
|
86
|
+
def awaiting_user_guidance(self, value: bool) -> None:
|
|
87
|
+
"""Set the legacy awaiting_user_guidance flag."""
|
|
88
|
+
self._awaiting_user_guidance = value
|
|
89
|
+
|
|
90
|
+
@property
|
|
91
|
+
def has_final_synthesis(self) -> bool:
|
|
92
|
+
"""Legacy boolean flag for final synthesis."""
|
|
93
|
+
return self._has_final_synthesis
|
|
94
|
+
|
|
95
|
+
@has_final_synthesis.setter
|
|
96
|
+
def has_final_synthesis(self, value: bool) -> None:
|
|
97
|
+
"""Set the legacy has_final_synthesis flag."""
|
|
98
|
+
self._has_final_synthesis = value
|
|
99
|
+
|
|
100
|
+
# Enhanced state management methods
|
|
101
|
+
def set_completion_detected(self, detected: bool = True) -> None:
|
|
102
|
+
"""Mark that completion has been detected in the RESPONSE state."""
|
|
103
|
+
self._state_machine.set_completion_detected(detected)
|
|
104
|
+
|
|
105
|
+
def is_completed(self) -> bool:
|
|
106
|
+
"""Check if the task is completed according to the state machine."""
|
|
107
|
+
return self._state_machine.is_completed()
|
|
9
108
|
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
109
|
+
def reset_state(self, initial_state: Optional[AgentState] = None) -> None:
|
|
110
|
+
"""Reset the state machine to initial state."""
|
|
111
|
+
self._state_machine.reset(initial_state)
|
|
112
|
+
# Reset legacy flags
|
|
113
|
+
self._has_user_response = False
|
|
114
|
+
self._task_completed = False
|
|
115
|
+
self._awaiting_user_guidance = False
|
|
116
|
+
self._has_final_synthesis = False
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
"""State transition management for agent response processing."""
|
|
2
|
+
|
|
3
|
+
import threading
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from typing import TYPE_CHECKING, Dict, Set
|
|
7
|
+
|
|
8
|
+
from tunacode.types import AgentState
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
pass
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class InvalidStateTransitionError(Exception):
|
|
15
|
+
"""Raised when an invalid state transition is attempted."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, from_state: Enum, to_state: Enum, message: str = None):
|
|
18
|
+
self.from_state = from_state
|
|
19
|
+
self.to_state = to_state
|
|
20
|
+
self.message = message or f"Invalid state transition: {from_state.value} → {to_state.value}"
|
|
21
|
+
super().__init__(self.message)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class StateTransitionRules:
|
|
26
|
+
"""Defines valid state transitions for the agent state machine."""
|
|
27
|
+
|
|
28
|
+
# Valid transitions for each state
|
|
29
|
+
valid_transitions: Dict[Enum, Set[Enum]]
|
|
30
|
+
|
|
31
|
+
def is_valid_transition(self, from_state: Enum, to_state: Enum) -> bool:
|
|
32
|
+
"""Check if a transition between states is valid."""
|
|
33
|
+
return to_state in self.valid_transitions.get(from_state, set())
|
|
34
|
+
|
|
35
|
+
def get_valid_next_states(self, current_state: Enum) -> Set[Enum]:
|
|
36
|
+
"""Get all valid next states from the current state."""
|
|
37
|
+
return self.valid_transitions.get(current_state, set())
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class AgentStateMachine:
|
|
41
|
+
"""Thread-safe state machine for agent response processing."""
|
|
42
|
+
|
|
43
|
+
def __init__(self, initial_state: "AgentState", rules: StateTransitionRules):
|
|
44
|
+
"""
|
|
45
|
+
Initialize the state machine.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
initial_state: The starting state
|
|
49
|
+
rules: Transition rules defining valid state changes
|
|
50
|
+
"""
|
|
51
|
+
self._state = initial_state
|
|
52
|
+
self._rules = rules
|
|
53
|
+
self._lock = threading.RLock() # Reentrant lock for thread safety
|
|
54
|
+
self._completion_detected = False
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def current_state(self) -> "AgentState":
|
|
58
|
+
"""Get the current state."""
|
|
59
|
+
with self._lock:
|
|
60
|
+
return self._state
|
|
61
|
+
|
|
62
|
+
def transition_to(self, new_state: "AgentState") -> None:
|
|
63
|
+
"""
|
|
64
|
+
Transition to a new state.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
new_state: The state to transition to
|
|
68
|
+
|
|
69
|
+
Raises:
|
|
70
|
+
InvalidStateTransitionError: If the transition is not valid
|
|
71
|
+
"""
|
|
72
|
+
with self._lock:
|
|
73
|
+
if not self._rules.is_valid_transition(self._state, new_state):
|
|
74
|
+
raise InvalidStateTransitionError(
|
|
75
|
+
self._state,
|
|
76
|
+
new_state,
|
|
77
|
+
f"Invalid state transition: {self._state.value} → {new_state.value}",
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Handle self-transitions as no-ops
|
|
81
|
+
if self._state == new_state:
|
|
82
|
+
return
|
|
83
|
+
|
|
84
|
+
self._state = new_state
|
|
85
|
+
|
|
86
|
+
def can_transition_to(self, target_state: "AgentState") -> bool:
|
|
87
|
+
"""Check if a transition to the target state is allowed."""
|
|
88
|
+
with self._lock:
|
|
89
|
+
return self._rules.is_valid_transition(self._state, target_state)
|
|
90
|
+
|
|
91
|
+
def set_completion_detected(self, detected: bool = True) -> None:
|
|
92
|
+
"""Mark that completion has been detected in the RESPONSE state."""
|
|
93
|
+
with self._lock:
|
|
94
|
+
self._completion_detected = detected
|
|
95
|
+
|
|
96
|
+
def is_completed(self) -> bool:
|
|
97
|
+
"""Check if the task is completed (only valid in RESPONSE state)."""
|
|
98
|
+
with self._lock:
|
|
99
|
+
return self._state == AgentState.RESPONSE and self._completion_detected
|
|
100
|
+
|
|
101
|
+
def reset(self, initial_state: "AgentState" = None) -> None:
|
|
102
|
+
"""Reset the state machine to initial state."""
|
|
103
|
+
with self._lock:
|
|
104
|
+
self._state = initial_state or AgentState.USER_INPUT
|
|
105
|
+
self._completion_detected = False
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
# Define the transition rules for the agent state machine
|
|
109
|
+
AGENT_TRANSITION_RULES = StateTransitionRules(
|
|
110
|
+
valid_transitions={
|
|
111
|
+
AgentState.USER_INPUT: {AgentState.ASSISTANT},
|
|
112
|
+
AgentState.ASSISTANT: {AgentState.TOOL_EXECUTION, AgentState.RESPONSE},
|
|
113
|
+
AgentState.TOOL_EXECUTION: {AgentState.RESPONSE},
|
|
114
|
+
AgentState.RESPONSE: {AgentState.ASSISTANT}, # Can transition back to continue
|
|
115
|
+
}
|
|
116
|
+
)
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Task completion detection utilities."""
|
|
2
2
|
|
|
3
|
+
import re
|
|
3
4
|
from typing import Tuple
|
|
4
5
|
|
|
5
6
|
|
|
@@ -18,11 +19,14 @@ def check_task_completion(content: str) -> Tuple[bool, str]:
|
|
|
18
19
|
if not content:
|
|
19
20
|
return False, content
|
|
20
21
|
|
|
21
|
-
lines = content.
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
22
|
+
lines = content.split("\n")
|
|
23
|
+
|
|
24
|
+
# New marker: any line starting with "TUNACODE DONE:" (case-insensitive, allow leading whitespace)
|
|
25
|
+
done_pattern = re.compile(r"^\s*TUNACODE\s+DONE:\s*", re.IGNORECASE)
|
|
26
|
+
for idx, line in enumerate(lines):
|
|
27
|
+
if done_pattern.match(line):
|
|
28
|
+
# Remove the marker line and return remaining content
|
|
29
|
+
cleaned = "\n".join(lines[:idx] + lines[idx + 1 :]).strip()
|
|
30
|
+
return True, cleaned
|
|
27
31
|
|
|
28
32
|
return False, content
|
tunacode/core/agents/main.py
CHANGED
|
@@ -97,7 +97,7 @@ async def check_query_satisfaction(
|
|
|
97
97
|
state_manager: StateManager,
|
|
98
98
|
) -> bool:
|
|
99
99
|
"""Check if the response satisfies the original query."""
|
|
100
|
-
return True #
|
|
100
|
+
return True # Completion decided via DONE marker in RESPONSE
|
|
101
101
|
|
|
102
102
|
|
|
103
103
|
async def process_request(
|
|
@@ -251,7 +251,7 @@ Task: {message[:200]}...
|
|
|
251
251
|
|
|
252
252
|
You're describing actions but not executing them. You MUST:
|
|
253
253
|
|
|
254
|
-
1. If task is COMPLETE: Start response with
|
|
254
|
+
1. If task is COMPLETE: Start response with TUNACODE DONE:
|
|
255
255
|
2. If task needs work: Execute a tool RIGHT NOW (grep, read_file, bash, etc.)
|
|
256
256
|
3. If stuck: Explain the specific blocker
|
|
257
257
|
|
|
@@ -269,7 +269,7 @@ NO MORE DESCRIPTIONS. Take ACTION or mark COMPLETE."""
|
|
|
269
269
|
unproductive_iterations = 0
|
|
270
270
|
|
|
271
271
|
# REMOVED: Recursive satisfaction check that caused empty responses
|
|
272
|
-
# The agent now decides completion using
|
|
272
|
+
# The agent now decides completion using a DONE marker
|
|
273
273
|
# This eliminates recursive agent calls and gives control back to the agent
|
|
274
274
|
|
|
275
275
|
# Store original query for reference
|
|
@@ -302,7 +302,7 @@ Progress so far:
|
|
|
302
302
|
- Iterations: {i}
|
|
303
303
|
- Tools used: {tools_used_str}
|
|
304
304
|
|
|
305
|
-
If the task is complete, I should respond with
|
|
305
|
+
If the task is complete, I should respond with TUNACODE DONE:
|
|
306
306
|
Otherwise, please provide specific guidance on what to do next."""
|
|
307
307
|
|
|
308
308
|
create_user_message(clarification_content, state_manager)
|
tunacode/prompts/system.md
CHANGED
|
@@ -10,12 +10,22 @@ CRITICAL BEHAVIOR RULES:
|
|
|
10
10
|
1. ALWAYS ANNOUNCE YOUR INTENTIONS FIRST: Before executing any tools, briefly state what you're about to do (e.g., "I'll search for the main agent implementation" or "Let me examine the file structure")
|
|
11
11
|
2. When you say "Let me..." or "I will..." you MUST execute the corresponding tool in THE SAME RESPONSE
|
|
12
12
|
3. Never describe what you'll do without doing it ALWAYS execute tools when discussing actions
|
|
13
|
-
4. When a task is COMPLETE, start your response with:
|
|
13
|
+
4. When a task is COMPLETE, start your response with: TUNACODE DONE:
|
|
14
14
|
5. If your response is cut off or truncated, you'll be prompted to continue complete your action
|
|
15
15
|
6. YOU MUST NOT USE ANY EMOJIS, YOU WILL BE PUNISHED FOR EMOJI USE
|
|
16
16
|
|
|
17
17
|
You MUST follow these rules:
|
|
18
18
|
|
|
19
|
+
### Completion Signaling
|
|
20
|
+
|
|
21
|
+
When you have fully completed the user’s task:
|
|
22
|
+
|
|
23
|
+
- Start your response with a single line: `TUNACODE DONE:` followed by a brief outcome summary.
|
|
24
|
+
- Do not add explanations before the DONE line; keep it as the first line.
|
|
25
|
+
- Do NOT mark DONE if you have queued tools in the same response — execute tools first, then mark DONE.
|
|
26
|
+
- Example:
|
|
27
|
+
- `TUNACODE DONE: Implemented enum state machine and updated completion logic`
|
|
28
|
+
|
|
19
29
|
###Tool Access Rules###
|
|
20
30
|
|
|
21
31
|
You have 9 powerful tools at your disposal. Understanding their categories is CRITICAL for performance:
|
tunacode/types.py
CHANGED
|
@@ -209,6 +209,15 @@ class PlanPhase(Enum):
|
|
|
209
209
|
REVIEW_DECISION = "review"
|
|
210
210
|
|
|
211
211
|
|
|
212
|
+
class AgentState(Enum):
|
|
213
|
+
"""Agent loop states for enhanced completion detection."""
|
|
214
|
+
|
|
215
|
+
USER_INPUT = "user_input" # Initial: user prompt received
|
|
216
|
+
ASSISTANT = "assistant" # Reasoning/deciding phase
|
|
217
|
+
TOOL_EXECUTION = "tool_execution" # Tool execution phase
|
|
218
|
+
RESPONSE = "response" # Handling results, may complete or loop
|
|
219
|
+
|
|
220
|
+
|
|
212
221
|
@dataclass
|
|
213
222
|
class PlanDoc:
|
|
214
223
|
"""Structured plan document with all required sections."""
|
tunacode/ui/completers.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Completers for file references and commands."""
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
|
-
from typing import TYPE_CHECKING, Iterable, Optional
|
|
4
|
+
from typing import TYPE_CHECKING, Iterable, List, Optional
|
|
5
5
|
|
|
6
6
|
from prompt_toolkit.completion import (
|
|
7
7
|
CompleteEvent,
|
|
@@ -13,6 +13,7 @@ from prompt_toolkit.document import Document
|
|
|
13
13
|
|
|
14
14
|
if TYPE_CHECKING:
|
|
15
15
|
from ..cli.commands import CommandRegistry
|
|
16
|
+
from ..utils.models_registry import ModelInfo, ModelsRegistry
|
|
16
17
|
|
|
17
18
|
|
|
18
19
|
class CommandCompleter(Completer):
|
|
@@ -126,11 +127,212 @@ class FileReferenceCompleter(Completer):
|
|
|
126
127
|
pass
|
|
127
128
|
|
|
128
129
|
|
|
129
|
-
|
|
130
|
-
"""
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
]
|
|
136
|
-
|
|
130
|
+
class ModelCompleter(Completer):
|
|
131
|
+
"""Completer for model names in /model command."""
|
|
132
|
+
|
|
133
|
+
def __init__(self, registry: Optional["ModelsRegistry"] = None):
|
|
134
|
+
"""Initialize the model completer."""
|
|
135
|
+
self.registry = registry
|
|
136
|
+
self._models_cache: Optional[List[ModelInfo]] = None
|
|
137
|
+
self._registry_loaded = False
|
|
138
|
+
|
|
139
|
+
async def _ensure_registry_loaded(self):
|
|
140
|
+
"""Ensure the models registry is loaded."""
|
|
141
|
+
if self.registry and not self._registry_loaded:
|
|
142
|
+
try:
|
|
143
|
+
# Try to load models (this will be fast if already loaded)
|
|
144
|
+
await self.registry.load()
|
|
145
|
+
self._registry_loaded = True
|
|
146
|
+
self._models_cache = (
|
|
147
|
+
list(self.registry.models.values()) if self.registry.models else []
|
|
148
|
+
)
|
|
149
|
+
except Exception:
|
|
150
|
+
# If loading fails, use empty cache
|
|
151
|
+
self._models_cache = []
|
|
152
|
+
self._registry_loaded = True
|
|
153
|
+
|
|
154
|
+
def get_completions(
|
|
155
|
+
self, document: Document, _complete_event: CompleteEvent
|
|
156
|
+
) -> Iterable[Completion]:
|
|
157
|
+
"""Get completions for model names."""
|
|
158
|
+
if not self.registry:
|
|
159
|
+
return
|
|
160
|
+
|
|
161
|
+
text = document.text_before_cursor
|
|
162
|
+
|
|
163
|
+
# Check if we're in a /model command context
|
|
164
|
+
lines = text.split("\n")
|
|
165
|
+
current_line = lines[-1].strip()
|
|
166
|
+
|
|
167
|
+
# Must start with /model
|
|
168
|
+
if not current_line.startswith("/model"):
|
|
169
|
+
return
|
|
170
|
+
|
|
171
|
+
# Try to load registry synchronously if not loaded
|
|
172
|
+
# Note: This is a compromise - ideally we'd use async completion
|
|
173
|
+
if not self._registry_loaded:
|
|
174
|
+
try:
|
|
175
|
+
# Quick attempt to load cached data only
|
|
176
|
+
if self.registry._is_cache_valid() and self.registry._load_from_cache():
|
|
177
|
+
self._registry_loaded = True
|
|
178
|
+
self._models_cache = list(self.registry.models.values())
|
|
179
|
+
elif not self._models_cache:
|
|
180
|
+
# Use fallback models for immediate completion
|
|
181
|
+
self.registry._load_fallback_models()
|
|
182
|
+
self._registry_loaded = True
|
|
183
|
+
self._models_cache = list(self.registry.models.values())
|
|
184
|
+
except Exception:
|
|
185
|
+
return # Skip completion if we can't load models
|
|
186
|
+
|
|
187
|
+
# Get the part after /model
|
|
188
|
+
parts = current_line.split()
|
|
189
|
+
if len(parts) < 2:
|
|
190
|
+
# Just "/model" - suggest popular searches and top models
|
|
191
|
+
popular_searches = ["claude", "gpt", "gemini", "openai", "anthropic"]
|
|
192
|
+
for search_term in popular_searches:
|
|
193
|
+
yield Completion(
|
|
194
|
+
text=search_term, display=f"{search_term} (search)", display_meta="search term"
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
# Also show top 3 most popular models if we have them
|
|
198
|
+
if self._models_cache:
|
|
199
|
+
popular_models = []
|
|
200
|
+
# Look for common popular models
|
|
201
|
+
for model in self._models_cache:
|
|
202
|
+
if any(pop in model.id.lower() for pop in ["gpt-4o", "claude-3", "gemini-2"]):
|
|
203
|
+
popular_models.append(model)
|
|
204
|
+
if len(popular_models) >= 3:
|
|
205
|
+
break
|
|
206
|
+
|
|
207
|
+
for model in popular_models:
|
|
208
|
+
display = f"{model.full_id} - {model.name}"
|
|
209
|
+
if model.cost.input is not None:
|
|
210
|
+
display += f" (${model.cost.input}/{model.cost.output})"
|
|
211
|
+
|
|
212
|
+
yield Completion(
|
|
213
|
+
text=model.full_id, display=display, display_meta=f"{model.provider} model"
|
|
214
|
+
)
|
|
215
|
+
return
|
|
216
|
+
|
|
217
|
+
# Get the current word being typed
|
|
218
|
+
word_before_cursor = document.get_word_before_cursor(WORD=True)
|
|
219
|
+
if not word_before_cursor or not self._models_cache:
|
|
220
|
+
return
|
|
221
|
+
|
|
222
|
+
query = word_before_cursor.lower()
|
|
223
|
+
|
|
224
|
+
# Use the new grouped approach to find base models with variants
|
|
225
|
+
base_models = self.registry.find_base_models(query)
|
|
226
|
+
|
|
227
|
+
if not base_models:
|
|
228
|
+
return
|
|
229
|
+
|
|
230
|
+
results = []
|
|
231
|
+
shown_base_models = 0
|
|
232
|
+
|
|
233
|
+
# Sort base models by popularity/relevance
|
|
234
|
+
sorted_base_models = sorted(
|
|
235
|
+
base_models.items(),
|
|
236
|
+
key=lambda x: (
|
|
237
|
+
# Popular models first
|
|
238
|
+
-1
|
|
239
|
+
if any(
|
|
240
|
+
pop in x[0] for pop in ["gpt-4o", "gpt-4", "claude-3", "gemini-2", "o3", "o1"]
|
|
241
|
+
)
|
|
242
|
+
else 0,
|
|
243
|
+
# Then by name
|
|
244
|
+
x[0],
|
|
245
|
+
),
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
for base_model_name, variants in sorted_base_models:
|
|
249
|
+
if shown_base_models >= 5: # Limit to top 5 base models
|
|
250
|
+
break
|
|
251
|
+
|
|
252
|
+
shown_variants = 0
|
|
253
|
+
for i, model in enumerate(variants):
|
|
254
|
+
if shown_variants >= 3: # Show max 3 variants per base model
|
|
255
|
+
break
|
|
256
|
+
|
|
257
|
+
# Calculate start position for replacement
|
|
258
|
+
start_pos = -len(word_before_cursor)
|
|
259
|
+
|
|
260
|
+
# Build display text with enhanced info
|
|
261
|
+
cost_str = ""
|
|
262
|
+
if model.cost.input is not None:
|
|
263
|
+
if model.cost.input == 0:
|
|
264
|
+
cost_str = " (FREE)"
|
|
265
|
+
else:
|
|
266
|
+
cost_str = f" (${model.cost.input}/{model.cost.output})"
|
|
267
|
+
|
|
268
|
+
# Format provider info
|
|
269
|
+
provider_display = self._get_provider_display_name(model.provider)
|
|
270
|
+
|
|
271
|
+
# Primary variant gets the bullet, others get indentation
|
|
272
|
+
if i == 0:
|
|
273
|
+
# First variant - primary option with bullet
|
|
274
|
+
display = f"● {model.full_id} - {model.name}{cost_str}"
|
|
275
|
+
if model.cost.input == 0:
|
|
276
|
+
display += " ⭐" # Star for free models
|
|
277
|
+
else:
|
|
278
|
+
# Additional variants - indented
|
|
279
|
+
display = f" {model.full_id} - {model.name}{cost_str}"
|
|
280
|
+
if model.cost.input == 0:
|
|
281
|
+
display += " ⭐"
|
|
282
|
+
|
|
283
|
+
meta_info = f"{provider_display}"
|
|
284
|
+
if len(variants) > 1:
|
|
285
|
+
meta_info += f" ({len(variants)} sources)"
|
|
286
|
+
|
|
287
|
+
results.append(
|
|
288
|
+
Completion(
|
|
289
|
+
text=model.full_id,
|
|
290
|
+
start_position=start_pos,
|
|
291
|
+
display=display,
|
|
292
|
+
display_meta=meta_info,
|
|
293
|
+
)
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
shown_variants += 1
|
|
297
|
+
|
|
298
|
+
shown_base_models += 1
|
|
299
|
+
|
|
300
|
+
# Limit total results for readability
|
|
301
|
+
for completion in results[:20]:
|
|
302
|
+
yield completion
|
|
303
|
+
|
|
304
|
+
def _get_provider_display_name(self, provider: str) -> str:
|
|
305
|
+
"""Get a user-friendly provider display name."""
|
|
306
|
+
provider_names = {
|
|
307
|
+
"openai": "OpenAI Direct",
|
|
308
|
+
"anthropic": "Anthropic Direct",
|
|
309
|
+
"google": "Google Direct",
|
|
310
|
+
"google-gla": "Google Labs",
|
|
311
|
+
"openrouter": "OpenRouter",
|
|
312
|
+
"github-models": "GitHub Models (FREE)",
|
|
313
|
+
"azure": "Azure OpenAI",
|
|
314
|
+
"fastrouter": "FastRouter",
|
|
315
|
+
"requesty": "Requesty",
|
|
316
|
+
"cloudflare-workers-ai": "Cloudflare",
|
|
317
|
+
"amazon-bedrock": "AWS Bedrock",
|
|
318
|
+
"chutes": "Chutes AI",
|
|
319
|
+
"deepinfra": "DeepInfra",
|
|
320
|
+
"venice": "Venice AI",
|
|
321
|
+
}
|
|
322
|
+
return provider_names.get(provider, provider.title())
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
def create_completer(
|
|
326
|
+
command_registry: Optional["CommandRegistry"] = None,
|
|
327
|
+
models_registry: Optional["ModelsRegistry"] = None,
|
|
328
|
+
) -> Completer:
|
|
329
|
+
"""Create a merged completer for commands, file references, and models."""
|
|
330
|
+
completers = [
|
|
331
|
+
CommandCompleter(command_registry),
|
|
332
|
+
FileReferenceCompleter(),
|
|
333
|
+
]
|
|
334
|
+
|
|
335
|
+
if models_registry:
|
|
336
|
+
completers.append(ModelCompleter(models_registry))
|
|
337
|
+
|
|
338
|
+
return merge_completers(completers)
|
tunacode/ui/input.py
CHANGED
|
@@ -95,6 +95,12 @@ async def multiline_input(
|
|
|
95
95
|
)
|
|
96
96
|
)
|
|
97
97
|
|
|
98
|
+
# Create models registry for auto-completion (lazy loaded)
|
|
99
|
+
from ..utils.models_registry import ModelsRegistry
|
|
100
|
+
|
|
101
|
+
models_registry = ModelsRegistry()
|
|
102
|
+
# Note: Registry will be loaded lazily by the completer when needed
|
|
103
|
+
|
|
98
104
|
# Display input area (Plan Mode indicator is handled dynamically in prompt manager)
|
|
99
105
|
result = await input(
|
|
100
106
|
"multiline",
|
|
@@ -102,7 +108,7 @@ async def multiline_input(
|
|
|
102
108
|
key_bindings=kb,
|
|
103
109
|
multiline=True,
|
|
104
110
|
placeholder=placeholder,
|
|
105
|
-
completer=create_completer(command_registry),
|
|
111
|
+
completer=create_completer(command_registry, models_registry),
|
|
106
112
|
lexer=FileReferenceLexer(),
|
|
107
113
|
state_manager=state_manager,
|
|
108
114
|
)
|