tunacode-cli 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tunacode-cli might be problematic. Click here for more details.
- tunacode/__init__.py +0 -0
- tunacode/cli/textual_repl.tcss +283 -0
- tunacode/configuration/__init__.py +1 -0
- tunacode/configuration/defaults.py +45 -0
- tunacode/configuration/models.py +147 -0
- tunacode/configuration/models_registry.json +1 -0
- tunacode/configuration/pricing.py +74 -0
- tunacode/configuration/settings.py +35 -0
- tunacode/constants.py +227 -0
- tunacode/core/__init__.py +6 -0
- tunacode/core/agents/__init__.py +39 -0
- tunacode/core/agents/agent_components/__init__.py +48 -0
- tunacode/core/agents/agent_components/agent_config.py +441 -0
- tunacode/core/agents/agent_components/agent_helpers.py +290 -0
- tunacode/core/agents/agent_components/message_handler.py +99 -0
- tunacode/core/agents/agent_components/node_processor.py +477 -0
- tunacode/core/agents/agent_components/response_state.py +129 -0
- tunacode/core/agents/agent_components/result_wrapper.py +51 -0
- tunacode/core/agents/agent_components/state_transition.py +112 -0
- tunacode/core/agents/agent_components/streaming.py +271 -0
- tunacode/core/agents/agent_components/task_completion.py +40 -0
- tunacode/core/agents/agent_components/tool_buffer.py +44 -0
- tunacode/core/agents/agent_components/tool_executor.py +101 -0
- tunacode/core/agents/agent_components/truncation_checker.py +37 -0
- tunacode/core/agents/delegation_tools.py +109 -0
- tunacode/core/agents/main.py +545 -0
- tunacode/core/agents/prompts.py +66 -0
- tunacode/core/agents/research_agent.py +231 -0
- tunacode/core/compaction.py +218 -0
- tunacode/core/prompting/__init__.py +27 -0
- tunacode/core/prompting/loader.py +66 -0
- tunacode/core/prompting/prompting_engine.py +98 -0
- tunacode/core/prompting/sections.py +50 -0
- tunacode/core/prompting/templates.py +69 -0
- tunacode/core/state.py +409 -0
- tunacode/exceptions.py +313 -0
- tunacode/indexing/__init__.py +5 -0
- tunacode/indexing/code_index.py +432 -0
- tunacode/indexing/constants.py +86 -0
- tunacode/lsp/__init__.py +112 -0
- tunacode/lsp/client.py +351 -0
- tunacode/lsp/diagnostics.py +19 -0
- tunacode/lsp/servers.py +101 -0
- tunacode/prompts/default_prompt.md +952 -0
- tunacode/prompts/research/sections/agent_role.xml +5 -0
- tunacode/prompts/research/sections/constraints.xml +14 -0
- tunacode/prompts/research/sections/output_format.xml +57 -0
- tunacode/prompts/research/sections/tool_use.xml +23 -0
- tunacode/prompts/sections/advanced_patterns.xml +255 -0
- tunacode/prompts/sections/agent_role.xml +8 -0
- tunacode/prompts/sections/completion.xml +10 -0
- tunacode/prompts/sections/critical_rules.xml +37 -0
- tunacode/prompts/sections/examples.xml +220 -0
- tunacode/prompts/sections/output_style.xml +94 -0
- tunacode/prompts/sections/parallel_exec.xml +105 -0
- tunacode/prompts/sections/search_pattern.xml +100 -0
- tunacode/prompts/sections/system_info.xml +6 -0
- tunacode/prompts/sections/tool_use.xml +84 -0
- tunacode/prompts/sections/user_instructions.xml +3 -0
- tunacode/py.typed +0 -0
- tunacode/templates/__init__.py +5 -0
- tunacode/templates/loader.py +15 -0
- tunacode/tools/__init__.py +10 -0
- tunacode/tools/authorization/__init__.py +29 -0
- tunacode/tools/authorization/context.py +32 -0
- tunacode/tools/authorization/factory.py +20 -0
- tunacode/tools/authorization/handler.py +58 -0
- tunacode/tools/authorization/notifier.py +35 -0
- tunacode/tools/authorization/policy.py +19 -0
- tunacode/tools/authorization/requests.py +119 -0
- tunacode/tools/authorization/rules.py +72 -0
- tunacode/tools/bash.py +222 -0
- tunacode/tools/decorators.py +213 -0
- tunacode/tools/glob.py +353 -0
- tunacode/tools/grep.py +468 -0
- tunacode/tools/grep_components/__init__.py +9 -0
- tunacode/tools/grep_components/file_filter.py +93 -0
- tunacode/tools/grep_components/pattern_matcher.py +158 -0
- tunacode/tools/grep_components/result_formatter.py +87 -0
- tunacode/tools/grep_components/search_result.py +34 -0
- tunacode/tools/list_dir.py +205 -0
- tunacode/tools/prompts/bash_prompt.xml +10 -0
- tunacode/tools/prompts/glob_prompt.xml +7 -0
- tunacode/tools/prompts/grep_prompt.xml +10 -0
- tunacode/tools/prompts/list_dir_prompt.xml +7 -0
- tunacode/tools/prompts/read_file_prompt.xml +9 -0
- tunacode/tools/prompts/todoclear_prompt.xml +12 -0
- tunacode/tools/prompts/todoread_prompt.xml +16 -0
- tunacode/tools/prompts/todowrite_prompt.xml +28 -0
- tunacode/tools/prompts/update_file_prompt.xml +9 -0
- tunacode/tools/prompts/web_fetch_prompt.xml +11 -0
- tunacode/tools/prompts/write_file_prompt.xml +7 -0
- tunacode/tools/react.py +111 -0
- tunacode/tools/read_file.py +68 -0
- tunacode/tools/todo.py +222 -0
- tunacode/tools/update_file.py +62 -0
- tunacode/tools/utils/__init__.py +1 -0
- tunacode/tools/utils/ripgrep.py +311 -0
- tunacode/tools/utils/text_match.py +352 -0
- tunacode/tools/web_fetch.py +245 -0
- tunacode/tools/write_file.py +34 -0
- tunacode/tools/xml_helper.py +34 -0
- tunacode/types/__init__.py +166 -0
- tunacode/types/base.py +94 -0
- tunacode/types/callbacks.py +53 -0
- tunacode/types/dataclasses.py +121 -0
- tunacode/types/pydantic_ai.py +31 -0
- tunacode/types/state.py +122 -0
- tunacode/ui/__init__.py +6 -0
- tunacode/ui/app.py +542 -0
- tunacode/ui/commands/__init__.py +430 -0
- tunacode/ui/components/__init__.py +1 -0
- tunacode/ui/headless/__init__.py +5 -0
- tunacode/ui/headless/output.py +72 -0
- tunacode/ui/main.py +252 -0
- tunacode/ui/renderers/__init__.py +41 -0
- tunacode/ui/renderers/errors.py +197 -0
- tunacode/ui/renderers/panels.py +550 -0
- tunacode/ui/renderers/search.py +314 -0
- tunacode/ui/renderers/tools/__init__.py +21 -0
- tunacode/ui/renderers/tools/bash.py +247 -0
- tunacode/ui/renderers/tools/diagnostics.py +186 -0
- tunacode/ui/renderers/tools/glob.py +226 -0
- tunacode/ui/renderers/tools/grep.py +228 -0
- tunacode/ui/renderers/tools/list_dir.py +198 -0
- tunacode/ui/renderers/tools/read_file.py +226 -0
- tunacode/ui/renderers/tools/research.py +294 -0
- tunacode/ui/renderers/tools/update_file.py +237 -0
- tunacode/ui/renderers/tools/web_fetch.py +182 -0
- tunacode/ui/repl_support.py +226 -0
- tunacode/ui/screens/__init__.py +16 -0
- tunacode/ui/screens/model_picker.py +303 -0
- tunacode/ui/screens/session_picker.py +181 -0
- tunacode/ui/screens/setup.py +218 -0
- tunacode/ui/screens/theme_picker.py +90 -0
- tunacode/ui/screens/update_confirm.py +69 -0
- tunacode/ui/shell_runner.py +129 -0
- tunacode/ui/styles/layout.tcss +98 -0
- tunacode/ui/styles/modals.tcss +38 -0
- tunacode/ui/styles/panels.tcss +81 -0
- tunacode/ui/styles/theme-nextstep.tcss +303 -0
- tunacode/ui/styles/widgets.tcss +33 -0
- tunacode/ui/styles.py +18 -0
- tunacode/ui/widgets/__init__.py +23 -0
- tunacode/ui/widgets/command_autocomplete.py +62 -0
- tunacode/ui/widgets/editor.py +402 -0
- tunacode/ui/widgets/file_autocomplete.py +47 -0
- tunacode/ui/widgets/messages.py +46 -0
- tunacode/ui/widgets/resource_bar.py +182 -0
- tunacode/ui/widgets/status_bar.py +98 -0
- tunacode/utils/__init__.py +0 -0
- tunacode/utils/config/__init__.py +13 -0
- tunacode/utils/config/user_configuration.py +91 -0
- tunacode/utils/messaging/__init__.py +10 -0
- tunacode/utils/messaging/message_utils.py +34 -0
- tunacode/utils/messaging/token_counter.py +77 -0
- tunacode/utils/parsing/__init__.py +13 -0
- tunacode/utils/parsing/command_parser.py +55 -0
- tunacode/utils/parsing/json_utils.py +188 -0
- tunacode/utils/parsing/retry.py +146 -0
- tunacode/utils/parsing/tool_parser.py +267 -0
- tunacode/utils/security/__init__.py +15 -0
- tunacode/utils/security/command.py +106 -0
- tunacode/utils/system/__init__.py +25 -0
- tunacode/utils/system/gitignore.py +155 -0
- tunacode/utils/system/paths.py +190 -0
- tunacode/utils/ui/__init__.py +9 -0
- tunacode/utils/ui/file_filter.py +135 -0
- tunacode/utils/ui/helpers.py +24 -0
- tunacode_cli-0.1.21.dist-info/METADATA +170 -0
- tunacode_cli-0.1.21.dist-info/RECORD +174 -0
- tunacode_cli-0.1.21.dist-info/WHEEL +4 -0
- tunacode_cli-0.1.21.dist-info/entry_points.txt +2 -0
- tunacode_cli-0.1.21.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,226 @@
|
|
|
1
|
+
"""Support helpers for the Textual REPL app.
|
|
2
|
+
|
|
3
|
+
This module exists to keep `tunacode.ui.app` focused on UI composition and lifecycle
|
|
4
|
+
while hosting small, testable helpers and callback builders.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import asyncio
|
|
10
|
+
import logging
|
|
11
|
+
import re
|
|
12
|
+
from collections.abc import Awaitable, Callable
|
|
13
|
+
from dataclasses import dataclass
|
|
14
|
+
from typing import Any, Protocol
|
|
15
|
+
|
|
16
|
+
from rich.console import Console
|
|
17
|
+
from rich.text import Text
|
|
18
|
+
|
|
19
|
+
from tunacode.constants import MAX_CALLBACK_CONTENT
|
|
20
|
+
from tunacode.tools.authorization.handler import ToolHandler
|
|
21
|
+
from tunacode.types import (
|
|
22
|
+
StateManager,
|
|
23
|
+
ToolConfirmationRequest,
|
|
24
|
+
ToolConfirmationResponse,
|
|
25
|
+
ToolProgress,
|
|
26
|
+
ToolProgressCallback,
|
|
27
|
+
)
|
|
28
|
+
from tunacode.ui.widgets import ToolResultDisplay
|
|
29
|
+
|
|
30
|
+
COLLAPSE_THRESHOLD: int = 10
|
|
31
|
+
|
|
32
|
+
FILE_EDIT_TOOLS: frozenset[str] = frozenset({"write_file", "update_file"})
|
|
33
|
+
|
|
34
|
+
USER_MESSAGE_PREFIX: str = "│ "
|
|
35
|
+
DEFAULT_USER_MESSAGE_WIDTH: int = 80
|
|
36
|
+
DIAGNOSTICS_BLOCK_START: str = "<file_diagnostics>"
|
|
37
|
+
DIAGNOSTICS_BLOCK_END: str = "</file_diagnostics>"
|
|
38
|
+
DIAGNOSTICS_BLOCK_PATTERN: str = f"{DIAGNOSTICS_BLOCK_START}.*?{DIAGNOSTICS_BLOCK_END}"
|
|
39
|
+
DIAGNOSTICS_BLOCK_RE = re.compile(DIAGNOSTICS_BLOCK_PATTERN, re.DOTALL)
|
|
40
|
+
CALLBACK_TRUNCATION_NOTICE: str = "\n... [truncated for safety]"
|
|
41
|
+
CALLBACK_TRUNCATION_NOTICE_LEN: int = len(CALLBACK_TRUNCATION_NOTICE)
|
|
42
|
+
|
|
43
|
+
logger = logging.getLogger(__name__)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _format_prefixed_wrapped_lines(
|
|
47
|
+
lines: list[tuple[str, str]],
|
|
48
|
+
*,
|
|
49
|
+
width: int,
|
|
50
|
+
) -> Text:
|
|
51
|
+
effective_width = width if width > 0 else DEFAULT_USER_MESSAGE_WIDTH
|
|
52
|
+
content_width = max(1, effective_width - len(USER_MESSAGE_PREFIX))
|
|
53
|
+
console = Console(width=content_width, color_system=None, force_terminal=False)
|
|
54
|
+
|
|
55
|
+
block = Text()
|
|
56
|
+
for line_text, line_style in lines:
|
|
57
|
+
wrapped_lines = Text(line_text, style=line_style, overflow="fold").wrap(
|
|
58
|
+
console, content_width
|
|
59
|
+
)
|
|
60
|
+
for wrapped in wrapped_lines:
|
|
61
|
+
block.append(USER_MESSAGE_PREFIX, style=line_style)
|
|
62
|
+
block.append_text(wrapped)
|
|
63
|
+
block.append("\n")
|
|
64
|
+
return block
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def format_user_message(text: str, style: str, *, width: int) -> Text:
|
|
68
|
+
"""Format user text with left gutter prefix and hard-wrap for terminal width."""
|
|
69
|
+
lines = text.splitlines() or [""]
|
|
70
|
+
styled_lines = [(line, style) for line in lines]
|
|
71
|
+
return _format_prefixed_wrapped_lines(styled_lines, width=width)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def format_collapsed_message(text: str, style: str, *, width: int) -> Text:
|
|
75
|
+
"""Format long pasted text with a collapsed middle section.
|
|
76
|
+
|
|
77
|
+
Shows first 3 lines, collapse indicator, and last 2 lines.
|
|
78
|
+
"""
|
|
79
|
+
lines = text.splitlines()
|
|
80
|
+
line_count = max(1, len(lines))
|
|
81
|
+
|
|
82
|
+
if line_count <= COLLAPSE_THRESHOLD:
|
|
83
|
+
return format_user_message(text, style, width=width)
|
|
84
|
+
|
|
85
|
+
collapsed = line_count - 5
|
|
86
|
+
|
|
87
|
+
render_lines: list[tuple[str, str]] = [(line, style) for line in lines[:3]]
|
|
88
|
+
render_lines.append((f"[[ {collapsed} more lines ]]", f"dim {style}"))
|
|
89
|
+
render_lines.extend((line, style) for line in lines[-2:])
|
|
90
|
+
return _format_prefixed_wrapped_lines(render_lines, width=width)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@dataclass
|
|
94
|
+
class PendingConfirmationState:
|
|
95
|
+
"""Tracks pending tool confirmation state."""
|
|
96
|
+
|
|
97
|
+
future: asyncio.Future[ToolConfirmationResponse]
|
|
98
|
+
request: ToolConfirmationRequest
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class ConfirmationRequester(Protocol):
|
|
102
|
+
async def request_tool_confirmation(
|
|
103
|
+
self, request: ToolConfirmationRequest
|
|
104
|
+
) -> ToolConfirmationResponse: ...
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class StatusBarLike(Protocol):
|
|
108
|
+
def add_edited_file(self, filepath: str) -> None: ...
|
|
109
|
+
|
|
110
|
+
def update_last_action(self, tool_name: str) -> None: ...
|
|
111
|
+
|
|
112
|
+
def update_running_action(self, tool_name: str) -> None: ...
|
|
113
|
+
|
|
114
|
+
def update_subagent_progress(self, progress: ToolProgress) -> None: ...
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class AppForCallbacks(ConfirmationRequester, Protocol):
|
|
118
|
+
status_bar: StatusBarLike
|
|
119
|
+
|
|
120
|
+
def post_message(self, message: ToolResultDisplay) -> None: ...
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def build_textual_tool_callback(
|
|
124
|
+
app: ConfirmationRequester,
|
|
125
|
+
state_manager: StateManager,
|
|
126
|
+
) -> Callable[[Any, Any | None], Awaitable[None]]:
|
|
127
|
+
async def _callback(part: Any, _node: Any = None) -> None:
|
|
128
|
+
tool_handler = state_manager.tool_handler or ToolHandler(state_manager)
|
|
129
|
+
state_manager.set_tool_handler(tool_handler)
|
|
130
|
+
|
|
131
|
+
if not tool_handler.should_confirm(part.tool_name):
|
|
132
|
+
return
|
|
133
|
+
|
|
134
|
+
from tunacode.exceptions import UserAbortError
|
|
135
|
+
from tunacode.utils.parsing.command_parser import parse_args
|
|
136
|
+
|
|
137
|
+
args = parse_args(part.args)
|
|
138
|
+
request = tool_handler.create_confirmation_request(part.tool_name, args)
|
|
139
|
+
response = await app.request_tool_confirmation(request)
|
|
140
|
+
if not tool_handler.process_confirmation(response, part.tool_name):
|
|
141
|
+
raise UserAbortError("User aborted tool execution")
|
|
142
|
+
|
|
143
|
+
return _callback
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _truncate_for_safety(content: str | None) -> str | None:
|
|
147
|
+
"""Emergency truncation - prevents UI freeze on massive outputs."""
|
|
148
|
+
if content is None:
|
|
149
|
+
return None
|
|
150
|
+
if len(content) <= MAX_CALLBACK_CONTENT:
|
|
151
|
+
return content
|
|
152
|
+
|
|
153
|
+
diagnostics_match = DIAGNOSTICS_BLOCK_RE.match(content)
|
|
154
|
+
if diagnostics_match is None:
|
|
155
|
+
if content.startswith(DIAGNOSTICS_BLOCK_START):
|
|
156
|
+
logger.warning("Diagnostics block missing closing tag; truncating content.")
|
|
157
|
+
truncation_limit = MAX_CALLBACK_CONTENT - CALLBACK_TRUNCATION_NOTICE_LEN
|
|
158
|
+
return content[:truncation_limit] + CALLBACK_TRUNCATION_NOTICE
|
|
159
|
+
|
|
160
|
+
diagnostics_block = diagnostics_match.group(0)
|
|
161
|
+
remaining_content = content[len(diagnostics_block) :]
|
|
162
|
+
diagnostics_len = len(diagnostics_block)
|
|
163
|
+
remaining_budget = MAX_CALLBACK_CONTENT - diagnostics_len - CALLBACK_TRUNCATION_NOTICE_LEN
|
|
164
|
+
|
|
165
|
+
if remaining_budget <= 0:
|
|
166
|
+
logger.warning("Diagnostics block exceeds safety limit; truncating remainder.")
|
|
167
|
+
return diagnostics_block + CALLBACK_TRUNCATION_NOTICE
|
|
168
|
+
|
|
169
|
+
truncated_remainder = remaining_content[:remaining_budget]
|
|
170
|
+
truncated_result = f"{diagnostics_block}{truncated_remainder}{CALLBACK_TRUNCATION_NOTICE}"
|
|
171
|
+
return truncated_result
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def build_tool_result_callback(app: AppForCallbacks) -> Callable[..., None]:
|
|
175
|
+
def _callback(
|
|
176
|
+
tool_name: str,
|
|
177
|
+
status: str,
|
|
178
|
+
args: dict[str, Any],
|
|
179
|
+
result: str | None = None,
|
|
180
|
+
duration_ms: float | None = None,
|
|
181
|
+
) -> None:
|
|
182
|
+
if tool_name in FILE_EDIT_TOOLS and status == "completed":
|
|
183
|
+
filepath = args.get("filepath")
|
|
184
|
+
if filepath:
|
|
185
|
+
app.status_bar.add_edited_file(filepath)
|
|
186
|
+
|
|
187
|
+
app.status_bar.update_last_action(tool_name)
|
|
188
|
+
|
|
189
|
+
safe_result = _truncate_for_safety(result)
|
|
190
|
+
|
|
191
|
+
app.post_message(
|
|
192
|
+
ToolResultDisplay(
|
|
193
|
+
tool_name=tool_name,
|
|
194
|
+
status=status,
|
|
195
|
+
args=args,
|
|
196
|
+
result=safe_result,
|
|
197
|
+
duration_ms=duration_ms,
|
|
198
|
+
)
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
return _callback
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def build_tool_start_callback(app: AppForCallbacks) -> Callable[[str], None]:
|
|
205
|
+
"""Build callback for tool start notifications."""
|
|
206
|
+
|
|
207
|
+
def _callback(tool_name: str) -> None:
|
|
208
|
+
app.status_bar.update_running_action(tool_name)
|
|
209
|
+
|
|
210
|
+
return _callback
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def build_tool_progress_callback(app: AppForCallbacks) -> ToolProgressCallback:
|
|
214
|
+
"""Build callback for subagent tool progress notifications."""
|
|
215
|
+
|
|
216
|
+
def _callback(progress: ToolProgress) -> None:
|
|
217
|
+
app.status_bar.update_subagent_progress(progress)
|
|
218
|
+
|
|
219
|
+
return _callback
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
async def run_textual_repl(state_manager: StateManager, show_setup: bool = False) -> None:
|
|
223
|
+
from tunacode.ui.app import TextualReplApp
|
|
224
|
+
|
|
225
|
+
app = TextualReplApp(state_manager=state_manager, show_setup=show_setup)
|
|
226
|
+
await app.run_async()
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""Textual screens for TunaCode REPL."""
|
|
2
|
+
|
|
3
|
+
from tunacode.ui.screens.model_picker import ModelPickerScreen, ProviderPickerScreen
|
|
4
|
+
from tunacode.ui.screens.session_picker import SessionPickerScreen
|
|
5
|
+
from tunacode.ui.screens.setup import SetupScreen
|
|
6
|
+
from tunacode.ui.screens.theme_picker import ThemePickerScreen
|
|
7
|
+
from tunacode.ui.screens.update_confirm import UpdateConfirmScreen
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"ModelPickerScreen",
|
|
11
|
+
"ProviderPickerScreen",
|
|
12
|
+
"SessionPickerScreen",
|
|
13
|
+
"SetupScreen",
|
|
14
|
+
"ThemePickerScreen",
|
|
15
|
+
"UpdateConfirmScreen",
|
|
16
|
+
]
|
|
@@ -0,0 +1,303 @@
|
|
|
1
|
+
"""Model picker modal screens for TunaCode."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from textual import events
|
|
6
|
+
from textual.app import ComposeResult
|
|
7
|
+
from textual.containers import Vertical
|
|
8
|
+
from textual.screen import Screen
|
|
9
|
+
from textual.widgets import Input, OptionList, Static
|
|
10
|
+
from textual.widgets.option_list import Option
|
|
11
|
+
|
|
12
|
+
from tunacode.configuration.models import get_models_for_provider, get_providers
|
|
13
|
+
from tunacode.constants import MODEL_PICKER_UNFILTERED_LIMIT
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _filter_visible_items(
|
|
17
|
+
items: list[tuple[str, str]],
|
|
18
|
+
filter_query: str,
|
|
19
|
+
limit: int,
|
|
20
|
+
) -> tuple[list[tuple[str, str]], bool]:
|
|
21
|
+
normalized_query = filter_query.strip().lower()
|
|
22
|
+
has_query = bool(normalized_query)
|
|
23
|
+
visible_items: list[tuple[str, str]] = []
|
|
24
|
+
total_matches = 0
|
|
25
|
+
|
|
26
|
+
for display_name, item_id in items:
|
|
27
|
+
display_name_lower = display_name.lower()
|
|
28
|
+
if has_query and normalized_query not in display_name_lower:
|
|
29
|
+
continue
|
|
30
|
+
total_matches += 1
|
|
31
|
+
if not has_query and len(visible_items) >= limit:
|
|
32
|
+
continue
|
|
33
|
+
visible_items.append((display_name, item_id))
|
|
34
|
+
|
|
35
|
+
is_truncated = (not has_query) and (total_matches > len(visible_items))
|
|
36
|
+
return visible_items, is_truncated
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _choose_highlight_index(
|
|
40
|
+
items: list[tuple[str, str]],
|
|
41
|
+
current_id: str,
|
|
42
|
+
) -> int | None:
|
|
43
|
+
if not items:
|
|
44
|
+
return None
|
|
45
|
+
|
|
46
|
+
for index, (_, item_id) in enumerate(items):
|
|
47
|
+
if item_id == current_id:
|
|
48
|
+
return index
|
|
49
|
+
|
|
50
|
+
return 0
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _append_truncation_notice(option_list: OptionList, limit: int, label: str) -> None:
|
|
54
|
+
message = f"Showing first {limit} {label}. Type to filter to see more."
|
|
55
|
+
option_list.add_option(Option(message, disabled=True))
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class ProviderPickerScreen(Screen[str | None]):
|
|
59
|
+
"""Modal screen for provider selection (step 1 of model picker)."""
|
|
60
|
+
|
|
61
|
+
CSS = """
|
|
62
|
+
ProviderPickerScreen {
|
|
63
|
+
align: center middle;
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
#provider-container {
|
|
67
|
+
width: 50;
|
|
68
|
+
height: auto;
|
|
69
|
+
max-height: 28;
|
|
70
|
+
border: solid $primary;
|
|
71
|
+
background: $surface;
|
|
72
|
+
padding: 1 2;
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
#provider-title {
|
|
76
|
+
text-style: bold;
|
|
77
|
+
color: $accent;
|
|
78
|
+
text-align: center;
|
|
79
|
+
margin-bottom: 1;
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
#provider-filter {
|
|
83
|
+
height: 3;
|
|
84
|
+
margin-bottom: 1;
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
#provider-list {
|
|
88
|
+
height: auto;
|
|
89
|
+
max-height: 18;
|
|
90
|
+
}
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
BINDINGS = [
|
|
94
|
+
("escape", "cancel", "Cancel"),
|
|
95
|
+
]
|
|
96
|
+
|
|
97
|
+
def __init__(self, current_model: str) -> None:
|
|
98
|
+
super().__init__()
|
|
99
|
+
self._current_model = current_model
|
|
100
|
+
self._current_provider = current_model.split(":")[0] if ":" in current_model else ""
|
|
101
|
+
self._all_providers: list[tuple[str, str]] = []
|
|
102
|
+
self._filter_query: str = ""
|
|
103
|
+
|
|
104
|
+
def compose(self) -> ComposeResult:
|
|
105
|
+
self._all_providers = get_providers()
|
|
106
|
+
|
|
107
|
+
with Vertical(id="provider-container"):
|
|
108
|
+
yield Static("Select Provider", id="provider-title")
|
|
109
|
+
yield Input(placeholder="Filter providers...", id="provider-filter")
|
|
110
|
+
yield OptionList(id="provider-list")
|
|
111
|
+
|
|
112
|
+
self.call_after_refresh(self._rebuild_options)
|
|
113
|
+
|
|
114
|
+
def _rebuild_options(self) -> None:
|
|
115
|
+
"""Rebuild OptionList with filtered items."""
|
|
116
|
+
option_list = self.query_one("#provider-list", OptionList)
|
|
117
|
+
option_list.clear_options()
|
|
118
|
+
|
|
119
|
+
filter_query = self._filter_query
|
|
120
|
+
visible_providers, is_truncated = _filter_visible_items(
|
|
121
|
+
self._all_providers,
|
|
122
|
+
filter_query,
|
|
123
|
+
MODEL_PICKER_UNFILTERED_LIMIT,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
highlight_index = _choose_highlight_index(visible_providers, self._current_provider)
|
|
127
|
+
|
|
128
|
+
for display_name, provider_id in visible_providers:
|
|
129
|
+
option_list.add_option(Option(display_name, id=provider_id))
|
|
130
|
+
|
|
131
|
+
if highlight_index is not None:
|
|
132
|
+
option_list.highlighted = highlight_index
|
|
133
|
+
|
|
134
|
+
if is_truncated:
|
|
135
|
+
_append_truncation_notice(
|
|
136
|
+
option_list,
|
|
137
|
+
MODEL_PICKER_UNFILTERED_LIMIT,
|
|
138
|
+
"providers",
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
def on_input_changed(self, event: Input.Changed) -> None:
|
|
142
|
+
"""Filter options as user types."""
|
|
143
|
+
if event.input.id != "provider-filter":
|
|
144
|
+
return
|
|
145
|
+
self._filter_query = event.value
|
|
146
|
+
self._rebuild_options()
|
|
147
|
+
|
|
148
|
+
def on_key(self, event: events.Key) -> None:
|
|
149
|
+
"""Handle focus transitions between Input and OptionList."""
|
|
150
|
+
filter_input = self.query_one("#provider-filter", Input)
|
|
151
|
+
option_list = self.query_one("#provider-list", OptionList)
|
|
152
|
+
|
|
153
|
+
if event.key == "down" and self.focused == filter_input:
|
|
154
|
+
self.set_focus(option_list)
|
|
155
|
+
event.stop()
|
|
156
|
+
elif event.key == "up" and self.focused == option_list:
|
|
157
|
+
if option_list.highlighted == 0:
|
|
158
|
+
self.set_focus(filter_input)
|
|
159
|
+
event.stop()
|
|
160
|
+
elif event.key == "escape" and self._filter_query:
|
|
161
|
+
filter_input.value = ""
|
|
162
|
+
self._filter_query = ""
|
|
163
|
+
self._rebuild_options()
|
|
164
|
+
event.stop()
|
|
165
|
+
|
|
166
|
+
def on_option_list_option_selected(self, event: OptionList.OptionSelected) -> None:
|
|
167
|
+
"""Confirm selection and dismiss with provider ID."""
|
|
168
|
+
if event.option and event.option.id:
|
|
169
|
+
self.dismiss(str(event.option.id))
|
|
170
|
+
|
|
171
|
+
def action_cancel(self) -> None:
|
|
172
|
+
"""Cancel selection."""
|
|
173
|
+
self.dismiss(None)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class ModelPickerScreen(Screen[str | None]):
|
|
177
|
+
"""Modal screen for model selection (step 2 of model picker)."""
|
|
178
|
+
|
|
179
|
+
CSS = """
|
|
180
|
+
ModelPickerScreen {
|
|
181
|
+
align: center middle;
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
#model-container {
|
|
185
|
+
width: 60;
|
|
186
|
+
height: auto;
|
|
187
|
+
max-height: 28;
|
|
188
|
+
border: solid $primary;
|
|
189
|
+
background: $surface;
|
|
190
|
+
padding: 1 2;
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
#model-title {
|
|
194
|
+
text-style: bold;
|
|
195
|
+
color: $accent;
|
|
196
|
+
text-align: center;
|
|
197
|
+
margin-bottom: 1;
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
#model-filter {
|
|
201
|
+
height: 3;
|
|
202
|
+
margin-bottom: 1;
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
#model-list {
|
|
206
|
+
height: auto;
|
|
207
|
+
max-height: 18;
|
|
208
|
+
}
|
|
209
|
+
"""
|
|
210
|
+
|
|
211
|
+
BINDINGS = [
|
|
212
|
+
("escape", "cancel", "Cancel"),
|
|
213
|
+
]
|
|
214
|
+
|
|
215
|
+
def __init__(self, provider_id: str, current_model: str) -> None:
|
|
216
|
+
super().__init__()
|
|
217
|
+
self._provider_id = provider_id
|
|
218
|
+
self._current_model = current_model
|
|
219
|
+
current_model_id = current_model.split(":", 1)[1] if ":" in current_model else ""
|
|
220
|
+
self._current_model_id = current_model_id
|
|
221
|
+
self._all_models: list[tuple[str, str]] = []
|
|
222
|
+
self._filter_query: str = ""
|
|
223
|
+
|
|
224
|
+
def compose(self) -> ComposeResult:
|
|
225
|
+
self._all_models = get_models_for_provider(self._provider_id)
|
|
226
|
+
|
|
227
|
+
with Vertical(id="model-container"):
|
|
228
|
+
yield Static(f"Select Model ({self._provider_id})", id="model-title")
|
|
229
|
+
yield Input(placeholder="Filter models...", id="model-filter")
|
|
230
|
+
yield OptionList(id="model-list")
|
|
231
|
+
|
|
232
|
+
self.call_after_refresh(self._rebuild_options)
|
|
233
|
+
|
|
234
|
+
def _rebuild_options(self) -> None:
|
|
235
|
+
"""Rebuild OptionList with filtered items and pricing."""
|
|
236
|
+
from tunacode.configuration.pricing import format_pricing_display, get_model_pricing
|
|
237
|
+
|
|
238
|
+
option_list = self.query_one("#model-list", OptionList)
|
|
239
|
+
option_list.clear_options()
|
|
240
|
+
|
|
241
|
+
filter_query = self._filter_query
|
|
242
|
+
visible_models, is_truncated = _filter_visible_items(
|
|
243
|
+
self._all_models,
|
|
244
|
+
filter_query,
|
|
245
|
+
MODEL_PICKER_UNFILTERED_LIMIT,
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
highlight_index = _choose_highlight_index(visible_models, self._current_model_id)
|
|
249
|
+
|
|
250
|
+
for display_name, model_id in visible_models:
|
|
251
|
+
full_model = f"{self._provider_id}:{model_id}"
|
|
252
|
+
pricing = get_model_pricing(full_model)
|
|
253
|
+
if pricing is not None:
|
|
254
|
+
label = f"{display_name} {format_pricing_display(pricing)}"
|
|
255
|
+
else:
|
|
256
|
+
label = display_name
|
|
257
|
+
|
|
258
|
+
option_list.add_option(Option(label, id=model_id))
|
|
259
|
+
|
|
260
|
+
if highlight_index is not None:
|
|
261
|
+
option_list.highlighted = highlight_index
|
|
262
|
+
|
|
263
|
+
if is_truncated:
|
|
264
|
+
_append_truncation_notice(
|
|
265
|
+
option_list,
|
|
266
|
+
MODEL_PICKER_UNFILTERED_LIMIT,
|
|
267
|
+
"models",
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
def on_input_changed(self, event: Input.Changed) -> None:
|
|
271
|
+
"""Filter options as user types."""
|
|
272
|
+
if event.input.id != "model-filter":
|
|
273
|
+
return
|
|
274
|
+
self._filter_query = event.value
|
|
275
|
+
self._rebuild_options()
|
|
276
|
+
|
|
277
|
+
def on_key(self, event: events.Key) -> None:
|
|
278
|
+
"""Handle focus transitions between Input and OptionList."""
|
|
279
|
+
filter_input = self.query_one("#model-filter", Input)
|
|
280
|
+
option_list = self.query_one("#model-list", OptionList)
|
|
281
|
+
|
|
282
|
+
if event.key == "down" and self.focused == filter_input:
|
|
283
|
+
self.set_focus(option_list)
|
|
284
|
+
event.stop()
|
|
285
|
+
elif event.key == "up" and self.focused == option_list:
|
|
286
|
+
if option_list.highlighted == 0:
|
|
287
|
+
self.set_focus(filter_input)
|
|
288
|
+
event.stop()
|
|
289
|
+
elif event.key == "escape" and self._filter_query:
|
|
290
|
+
filter_input.value = ""
|
|
291
|
+
self._filter_query = ""
|
|
292
|
+
self._rebuild_options()
|
|
293
|
+
event.stop()
|
|
294
|
+
|
|
295
|
+
def on_option_list_option_selected(self, event: OptionList.OptionSelected) -> None:
|
|
296
|
+
"""Confirm selection and dismiss with full model string."""
|
|
297
|
+
if event.option and event.option.id:
|
|
298
|
+
full_model = f"{self._provider_id}:{event.option.id}"
|
|
299
|
+
self.dismiss(full_model)
|
|
300
|
+
|
|
301
|
+
def action_cancel(self) -> None:
|
|
302
|
+
"""Cancel selection."""
|
|
303
|
+
self.dismiss(None)
|