shotgun-sh 0.1.16.dev2__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of shotgun-sh might be problematic. Click here for more details.
- shotgun/agents/common.py +4 -5
- shotgun/agents/config/constants.py +21 -5
- shotgun/agents/config/manager.py +171 -63
- shotgun/agents/config/models.py +65 -84
- shotgun/agents/config/provider.py +174 -85
- shotgun/agents/history/compaction.py +1 -1
- shotgun/agents/history/history_processors.py +18 -9
- shotgun/agents/history/token_counting/__init__.py +31 -0
- shotgun/agents/history/token_counting/anthropic.py +89 -0
- shotgun/agents/history/token_counting/base.py +67 -0
- shotgun/agents/history/token_counting/openai.py +80 -0
- shotgun/agents/history/token_counting/sentencepiece_counter.py +119 -0
- shotgun/agents/history/token_counting/tokenizer_cache.py +90 -0
- shotgun/agents/history/token_counting/utils.py +147 -0
- shotgun/agents/history/token_estimation.py +12 -12
- shotgun/agents/llm.py +62 -0
- shotgun/agents/models.py +2 -2
- shotgun/agents/tools/web_search/__init__.py +42 -15
- shotgun/agents/tools/web_search/anthropic.py +54 -50
- shotgun/agents/tools/web_search/gemini.py +31 -20
- shotgun/agents/tools/web_search/openai.py +4 -4
- shotgun/build_constants.py +2 -2
- shotgun/cli/config.py +28 -57
- shotgun/cli/models.py +2 -2
- shotgun/codebase/models.py +4 -4
- shotgun/llm_proxy/__init__.py +16 -0
- shotgun/llm_proxy/clients.py +39 -0
- shotgun/llm_proxy/constants.py +8 -0
- shotgun/main.py +6 -0
- shotgun/posthog_telemetry.py +5 -3
- shotgun/tui/app.py +7 -3
- shotgun/tui/screens/chat.py +2 -8
- shotgun/tui/screens/chat_screen/command_providers.py +118 -11
- shotgun/tui/screens/chat_screen/history.py +3 -1
- shotgun/tui/screens/model_picker.py +327 -0
- shotgun/tui/screens/provider_config.py +57 -26
- shotgun/utils/env_utils.py +12 -0
- {shotgun_sh-0.1.16.dev2.dist-info → shotgun_sh-0.2.0.dist-info}/METADATA +2 -2
- {shotgun_sh-0.1.16.dev2.dist-info → shotgun_sh-0.2.0.dist-info}/RECORD +42 -31
- shotgun/agents/history/token_counting.py +0 -429
- {shotgun_sh-0.1.16.dev2.dist-info → shotgun_sh-0.2.0.dist-info}/WHEEL +0 -0
- {shotgun_sh-0.1.16.dev2.dist-info → shotgun_sh-0.2.0.dist-info}/entry_points.txt +0 -0
- {shotgun_sh-0.1.16.dev2.dist-info → shotgun_sh-0.2.0.dist-info}/licenses/LICENSE +0 -0
shotgun/tui/screens/chat.py
CHANGED
|
@@ -54,11 +54,8 @@ from ..components.prompt_input import PromptInput
|
|
|
54
54
|
from ..components.spinner import Spinner
|
|
55
55
|
from ..utils.mode_progress import PlaceholderHints
|
|
56
56
|
from .chat_screen.command_providers import (
|
|
57
|
-
AgentModeProvider,
|
|
58
|
-
CodebaseCommandProvider,
|
|
59
57
|
DeleteCodebasePaletteProvider,
|
|
60
|
-
|
|
61
|
-
UsageProvider,
|
|
58
|
+
UnifiedCommandProvider,
|
|
62
59
|
)
|
|
63
60
|
|
|
64
61
|
logger = logging.getLogger(__name__)
|
|
@@ -233,10 +230,7 @@ class ChatScreen(Screen[None]):
|
|
|
233
230
|
]
|
|
234
231
|
|
|
235
232
|
COMMANDS = {
|
|
236
|
-
|
|
237
|
-
ProviderSetupProvider,
|
|
238
|
-
CodebaseCommandProvider,
|
|
239
|
-
UsageProvider,
|
|
233
|
+
UnifiedCommandProvider,
|
|
240
234
|
}
|
|
241
235
|
|
|
242
236
|
value = reactive("")
|
|
@@ -5,6 +5,8 @@ from textual.command import DiscoveryHit, Hit, Provider
|
|
|
5
5
|
|
|
6
6
|
from shotgun.agents.models import AgentType
|
|
7
7
|
from shotgun.codebase.models import CodebaseGraph
|
|
8
|
+
from shotgun.tui.screens.model_picker import ModelPickerScreen
|
|
9
|
+
from shotgun.tui.screens.provider_config import ProviderConfigScreen
|
|
8
10
|
|
|
9
11
|
if TYPE_CHECKING:
|
|
10
12
|
from shotgun.tui.screens.chat import ChatScreen
|
|
@@ -139,7 +141,11 @@ class ProviderSetupProvider(Provider):
|
|
|
139
141
|
|
|
140
142
|
def open_provider_config(self) -> None:
|
|
141
143
|
"""Show the provider configuration screen."""
|
|
142
|
-
self.chat_screen.app.push_screen(
|
|
144
|
+
self.chat_screen.app.push_screen(ProviderConfigScreen())
|
|
145
|
+
|
|
146
|
+
def open_model_picker(self) -> None:
|
|
147
|
+
"""Show the model picker screen."""
|
|
148
|
+
self.chat_screen.app.push_screen(ModelPickerScreen())
|
|
143
149
|
|
|
144
150
|
async def discover(self) -> AsyncGenerator[DiscoveryHit, None]:
|
|
145
151
|
yield DiscoveryHit(
|
|
@@ -147,9 +153,15 @@ class ProviderSetupProvider(Provider):
|
|
|
147
153
|
self.open_provider_config,
|
|
148
154
|
help="⚙️ Manage API keys for available providers",
|
|
149
155
|
)
|
|
156
|
+
yield DiscoveryHit(
|
|
157
|
+
"Select AI Model",
|
|
158
|
+
self.open_model_picker,
|
|
159
|
+
help="🤖 Choose which AI model to use",
|
|
160
|
+
)
|
|
150
161
|
|
|
151
162
|
async def search(self, query: str) -> AsyncGenerator[Hit, None]:
|
|
152
163
|
matcher = self.matcher(query)
|
|
164
|
+
|
|
153
165
|
title = "Open Provider Setup"
|
|
154
166
|
score = matcher.match(title)
|
|
155
167
|
if score > 0:
|
|
@@ -160,6 +172,16 @@ class ProviderSetupProvider(Provider):
|
|
|
160
172
|
help="⚙️ Manage API keys for available providers",
|
|
161
173
|
)
|
|
162
174
|
|
|
175
|
+
title = "Select AI Model"
|
|
176
|
+
score = matcher.match(title)
|
|
177
|
+
if score > 0:
|
|
178
|
+
yield Hit(
|
|
179
|
+
score,
|
|
180
|
+
matcher.highlight(title),
|
|
181
|
+
self.open_model_picker,
|
|
182
|
+
help="🤖 Choose which AI model to use",
|
|
183
|
+
)
|
|
184
|
+
|
|
163
185
|
|
|
164
186
|
class CodebaseCommandProvider(Provider):
|
|
165
187
|
"""Command palette entries for codebase management."""
|
|
@@ -171,30 +193,30 @@ class CodebaseCommandProvider(Provider):
|
|
|
171
193
|
return cast(ChatScreen, self.screen)
|
|
172
194
|
|
|
173
195
|
async def discover(self) -> AsyncGenerator[DiscoveryHit, None]:
|
|
174
|
-
yield DiscoveryHit(
|
|
175
|
-
"Codebase: Index Codebase",
|
|
176
|
-
self.chat_screen.index_codebase_command,
|
|
177
|
-
help="Index a repository into the codebase graph",
|
|
178
|
-
)
|
|
179
196
|
yield DiscoveryHit(
|
|
180
197
|
"Codebase: Delete Codebase Index",
|
|
181
198
|
self.chat_screen.delete_codebase_command,
|
|
182
199
|
help="Delete an existing codebase index",
|
|
183
200
|
)
|
|
201
|
+
yield DiscoveryHit(
|
|
202
|
+
"Codebase: Index Codebase",
|
|
203
|
+
self.chat_screen.index_codebase_command,
|
|
204
|
+
help="Index a repository into the codebase graph",
|
|
205
|
+
)
|
|
184
206
|
|
|
185
207
|
async def search(self, query: str) -> AsyncGenerator[Hit, None]:
|
|
186
208
|
matcher = self.matcher(query)
|
|
187
209
|
commands = [
|
|
188
|
-
(
|
|
189
|
-
"Codebase: Index Codebase",
|
|
190
|
-
self.chat_screen.index_codebase_command,
|
|
191
|
-
"Index a repository into the codebase graph",
|
|
192
|
-
),
|
|
193
210
|
(
|
|
194
211
|
"Codebase: Delete Codebase Index",
|
|
195
212
|
self.chat_screen.delete_codebase_command,
|
|
196
213
|
"Delete an existing codebase index",
|
|
197
214
|
),
|
|
215
|
+
(
|
|
216
|
+
"Codebase: Index Codebase",
|
|
217
|
+
self.chat_screen.index_codebase_command,
|
|
218
|
+
"Index a repository into the codebase graph",
|
|
219
|
+
),
|
|
198
220
|
]
|
|
199
221
|
for title, callback, help_text in commands:
|
|
200
222
|
score = matcher.match(title)
|
|
@@ -249,3 +271,88 @@ class DeleteCodebasePaletteProvider(Provider):
|
|
|
249
271
|
),
|
|
250
272
|
help=graph.repo_path,
|
|
251
273
|
)
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
class UnifiedCommandProvider(Provider):
|
|
277
|
+
"""Unified command provider with all commands in alphabetical order."""
|
|
278
|
+
|
|
279
|
+
@property
|
|
280
|
+
def chat_screen(self) -> "ChatScreen":
|
|
281
|
+
from shotgun.tui.screens.chat import ChatScreen
|
|
282
|
+
|
|
283
|
+
return cast(ChatScreen, self.screen)
|
|
284
|
+
|
|
285
|
+
def open_provider_config(self) -> None:
|
|
286
|
+
"""Show the provider configuration screen."""
|
|
287
|
+
self.chat_screen.app.push_screen(ProviderConfigScreen())
|
|
288
|
+
|
|
289
|
+
def open_model_picker(self) -> None:
|
|
290
|
+
"""Show the model picker screen."""
|
|
291
|
+
self.chat_screen.app.push_screen(ModelPickerScreen())
|
|
292
|
+
|
|
293
|
+
async def discover(self) -> AsyncGenerator[DiscoveryHit, None]:
|
|
294
|
+
"""Provide commands in alphabetical order when palette opens."""
|
|
295
|
+
# Alphabetically ordered commands
|
|
296
|
+
yield DiscoveryHit(
|
|
297
|
+
"Codebase: Delete Codebase Index",
|
|
298
|
+
self.chat_screen.delete_codebase_command,
|
|
299
|
+
help="Delete an existing codebase index",
|
|
300
|
+
)
|
|
301
|
+
yield DiscoveryHit(
|
|
302
|
+
"Codebase: Index Codebase",
|
|
303
|
+
self.chat_screen.index_codebase_command,
|
|
304
|
+
help="Index a repository into the codebase graph",
|
|
305
|
+
)
|
|
306
|
+
yield DiscoveryHit(
|
|
307
|
+
"Open Provider Setup",
|
|
308
|
+
self.open_provider_config,
|
|
309
|
+
help="⚙️ Manage API keys for available providers",
|
|
310
|
+
)
|
|
311
|
+
yield DiscoveryHit(
|
|
312
|
+
"Select AI Model",
|
|
313
|
+
self.open_model_picker,
|
|
314
|
+
help="🤖 Choose which AI model to use",
|
|
315
|
+
)
|
|
316
|
+
yield DiscoveryHit(
|
|
317
|
+
"Show usage",
|
|
318
|
+
self.chat_screen.action_show_usage,
|
|
319
|
+
help="Display usage information for the current session",
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
async def search(self, query: str) -> AsyncGenerator[Hit, None]:
|
|
323
|
+
"""Search for commands in alphabetical order."""
|
|
324
|
+
matcher = self.matcher(query)
|
|
325
|
+
|
|
326
|
+
# Define all commands in alphabetical order
|
|
327
|
+
commands = [
|
|
328
|
+
(
|
|
329
|
+
"Codebase: Delete Codebase Index",
|
|
330
|
+
self.chat_screen.delete_codebase_command,
|
|
331
|
+
"Delete an existing codebase index",
|
|
332
|
+
),
|
|
333
|
+
(
|
|
334
|
+
"Codebase: Index Codebase",
|
|
335
|
+
self.chat_screen.index_codebase_command,
|
|
336
|
+
"Index a repository into the codebase graph",
|
|
337
|
+
),
|
|
338
|
+
(
|
|
339
|
+
"Open Provider Setup",
|
|
340
|
+
self.open_provider_config,
|
|
341
|
+
"⚙️ Manage API keys for available providers",
|
|
342
|
+
),
|
|
343
|
+
(
|
|
344
|
+
"Select AI Model",
|
|
345
|
+
self.open_model_picker,
|
|
346
|
+
"🤖 Choose which AI model to use",
|
|
347
|
+
),
|
|
348
|
+
(
|
|
349
|
+
"Show usage",
|
|
350
|
+
self.chat_screen.action_show_usage,
|
|
351
|
+
"Display usage information for the current session",
|
|
352
|
+
),
|
|
353
|
+
]
|
|
354
|
+
|
|
355
|
+
for title, callback, help_text in commands:
|
|
356
|
+
score = matcher.match(title)
|
|
357
|
+
if score > 0:
|
|
358
|
+
yield Hit(score, matcher.highlight(title), callback, help=help_text)
|
|
@@ -217,7 +217,9 @@ class AgentResponseWidget(Widget):
|
|
|
217
217
|
return ""
|
|
218
218
|
for idx, part in enumerate(self.item.parts):
|
|
219
219
|
if isinstance(part, TextPart):
|
|
220
|
-
|
|
220
|
+
# Only show the circle prefix if there's actual content
|
|
221
|
+
if part.content and part.content.strip():
|
|
222
|
+
acc += f"**⏺** {part.content}\n\n"
|
|
221
223
|
elif isinstance(part, ToolCallPart):
|
|
222
224
|
parts_str = self._format_tool_call_part(part)
|
|
223
225
|
acc += parts_str + "\n\n"
|
|
@@ -0,0 +1,327 @@
|
|
|
1
|
+
"""Screen for selecting AI model."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, cast
|
|
6
|
+
|
|
7
|
+
from textual import on
|
|
8
|
+
from textual.app import ComposeResult
|
|
9
|
+
from textual.containers import Horizontal, Vertical
|
|
10
|
+
from textual.reactive import reactive
|
|
11
|
+
from textual.screen import Screen
|
|
12
|
+
from textual.widgets import Button, Label, ListItem, ListView, Static
|
|
13
|
+
|
|
14
|
+
from shotgun.agents.config import ConfigManager
|
|
15
|
+
from shotgun.agents.config.models import MODEL_SPECS, ModelName, ShotgunConfig
|
|
16
|
+
from shotgun.logging_config import get_logger
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from ..app import ShotgunApp
|
|
20
|
+
|
|
21
|
+
logger = get_logger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# Available models for selection
|
|
25
|
+
AVAILABLE_MODELS = list(ModelName)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _sanitize_model_name_for_id(model_name: ModelName) -> str:
|
|
29
|
+
"""Convert model name to valid Textual ID by replacing dots with hyphens."""
|
|
30
|
+
return model_name.value.replace(".", "-")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class ModelPickerScreen(Screen[None]):
|
|
34
|
+
"""Select AI model to use."""
|
|
35
|
+
|
|
36
|
+
CSS = """
|
|
37
|
+
ModelPicker {
|
|
38
|
+
layout: vertical;
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
#titlebox {
|
|
42
|
+
height: auto;
|
|
43
|
+
margin: 2 0;
|
|
44
|
+
padding: 1;
|
|
45
|
+
border: hkey $border;
|
|
46
|
+
content-align: center middle;
|
|
47
|
+
|
|
48
|
+
& > * {
|
|
49
|
+
text-align: center;
|
|
50
|
+
}
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
#model-picker-title {
|
|
54
|
+
padding: 1 0;
|
|
55
|
+
text-style: bold;
|
|
56
|
+
color: $text-accent;
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
#model-list {
|
|
60
|
+
margin: 2 0;
|
|
61
|
+
height: auto;
|
|
62
|
+
padding: 1;
|
|
63
|
+
& > * {
|
|
64
|
+
padding: 1 0;
|
|
65
|
+
}
|
|
66
|
+
}
|
|
67
|
+
#model-actions {
|
|
68
|
+
padding: 1;
|
|
69
|
+
}
|
|
70
|
+
#model-actions > * {
|
|
71
|
+
margin-right: 2;
|
|
72
|
+
}
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
BINDINGS = [
|
|
76
|
+
("escape", "done", "Back"),
|
|
77
|
+
]
|
|
78
|
+
|
|
79
|
+
selected_model: reactive[ModelName] = reactive(ModelName.GPT_5)
|
|
80
|
+
|
|
81
|
+
def compose(self) -> ComposeResult:
|
|
82
|
+
with Vertical(id="titlebox"):
|
|
83
|
+
yield Static("Model selection", id="model-picker-title")
|
|
84
|
+
yield Static(
|
|
85
|
+
"Select the AI model you want to use for your tasks.",
|
|
86
|
+
id="model-picker-summary",
|
|
87
|
+
)
|
|
88
|
+
yield ListView(id="model-list")
|
|
89
|
+
with Horizontal(id="model-actions"):
|
|
90
|
+
yield Button("Select \\[ENTER]", variant="primary", id="select")
|
|
91
|
+
yield Button("Done \\[ESC]", id="done")
|
|
92
|
+
|
|
93
|
+
def _rebuild_model_list(self) -> None:
|
|
94
|
+
"""Rebuild the model list from current config.
|
|
95
|
+
|
|
96
|
+
This method is called both on first show and when screen is resumed
|
|
97
|
+
to ensure the list always reflects the current configuration.
|
|
98
|
+
"""
|
|
99
|
+
logger.debug("Rebuilding model list from current config")
|
|
100
|
+
|
|
101
|
+
# Load current config with force_reload to get latest API keys
|
|
102
|
+
config_manager = self.config_manager
|
|
103
|
+
config = config_manager.load(force_reload=True)
|
|
104
|
+
|
|
105
|
+
# Log provider key status
|
|
106
|
+
logger.debug(
|
|
107
|
+
"Provider keys: openai=%s, anthropic=%s, google=%s, shotgun=%s",
|
|
108
|
+
config_manager._provider_has_api_key(config.openai),
|
|
109
|
+
config_manager._provider_has_api_key(config.anthropic),
|
|
110
|
+
config_manager._provider_has_api_key(config.google),
|
|
111
|
+
config_manager._provider_has_api_key(config.shotgun),
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
current_model = config.selected_model or ModelName.CLAUDE_SONNET_4_5
|
|
115
|
+
self.selected_model = current_model
|
|
116
|
+
logger.debug("Current selected model: %s", current_model)
|
|
117
|
+
|
|
118
|
+
# Rebuild the model list with current available models
|
|
119
|
+
list_view = self.query_one(ListView)
|
|
120
|
+
|
|
121
|
+
# Remove all existing items
|
|
122
|
+
old_count = len(list(list_view.children))
|
|
123
|
+
for child in list(list_view.children):
|
|
124
|
+
child.remove()
|
|
125
|
+
logger.debug("Removed %d existing model items from list", old_count)
|
|
126
|
+
|
|
127
|
+
# Add new items (labels already have correct text including current indicator)
|
|
128
|
+
new_items = self._build_model_items(config)
|
|
129
|
+
for item in new_items:
|
|
130
|
+
list_view.append(item)
|
|
131
|
+
logger.debug("Added %d available model items to list", len(new_items))
|
|
132
|
+
|
|
133
|
+
# Find and highlight current selection (if it's in the filtered list)
|
|
134
|
+
if list_view.children:
|
|
135
|
+
for i, child in enumerate(list_view.children):
|
|
136
|
+
if isinstance(child, ListItem) and child.id:
|
|
137
|
+
model_id = child.id.removeprefix("model-")
|
|
138
|
+
# Find the model name
|
|
139
|
+
for model_name in AVAILABLE_MODELS:
|
|
140
|
+
if _sanitize_model_name_for_id(model_name) == model_id:
|
|
141
|
+
if model_name == current_model:
|
|
142
|
+
list_view.index = i
|
|
143
|
+
break
|
|
144
|
+
|
|
145
|
+
def on_show(self) -> None:
|
|
146
|
+
"""Rebuild model list when screen is first shown."""
|
|
147
|
+
logger.debug("ModelPickerScreen.on_show() called")
|
|
148
|
+
self._rebuild_model_list()
|
|
149
|
+
|
|
150
|
+
def on_screenresume(self) -> None:
|
|
151
|
+
"""Rebuild model list when screen is resumed (subsequent visits).
|
|
152
|
+
|
|
153
|
+
This is called when returning to the screen after it was suspended,
|
|
154
|
+
ensuring the model list reflects any config changes made while away.
|
|
155
|
+
"""
|
|
156
|
+
logger.debug("ModelPickerScreen.on_screenresume() called")
|
|
157
|
+
self._rebuild_model_list()
|
|
158
|
+
|
|
159
|
+
def action_done(self) -> None:
|
|
160
|
+
self.dismiss()
|
|
161
|
+
|
|
162
|
+
@on(ListView.Highlighted)
|
|
163
|
+
def _on_model_highlighted(self, event: ListView.Highlighted) -> None:
|
|
164
|
+
model_name = self._model_from_item(event.item)
|
|
165
|
+
if model_name:
|
|
166
|
+
self.selected_model = model_name
|
|
167
|
+
|
|
168
|
+
@on(ListView.Selected)
|
|
169
|
+
def _on_model_selected(self, event: ListView.Selected) -> None:
|
|
170
|
+
model_name = self._model_from_item(event.item)
|
|
171
|
+
if model_name:
|
|
172
|
+
self.selected_model = model_name
|
|
173
|
+
self._select_model()
|
|
174
|
+
|
|
175
|
+
@on(Button.Pressed, "#select")
|
|
176
|
+
def _on_select_pressed(self) -> None:
|
|
177
|
+
self._select_model()
|
|
178
|
+
|
|
179
|
+
@on(Button.Pressed, "#done")
|
|
180
|
+
def _on_done_pressed(self) -> None:
|
|
181
|
+
self.action_done()
|
|
182
|
+
|
|
183
|
+
@property
|
|
184
|
+
def config_manager(self) -> ConfigManager:
|
|
185
|
+
app = cast("ShotgunApp", self.app)
|
|
186
|
+
return app.config_manager
|
|
187
|
+
|
|
188
|
+
def refresh_model_labels(self) -> None:
|
|
189
|
+
"""Update the list view entries to reflect current selection.
|
|
190
|
+
|
|
191
|
+
Note: This method only updates labels for currently displayed models.
|
|
192
|
+
To rebuild the entire list after provider changes, on_show() should be used.
|
|
193
|
+
"""
|
|
194
|
+
# Load config once with force_reload
|
|
195
|
+
config = self.config_manager.load(force_reload=True)
|
|
196
|
+
current_model = config.selected_model or ModelName.CLAUDE_SONNET_4_5
|
|
197
|
+
|
|
198
|
+
# Update labels for available models only
|
|
199
|
+
for model_name in AVAILABLE_MODELS:
|
|
200
|
+
# Pass config to avoid multiple force reloads
|
|
201
|
+
if not self._is_model_available(model_name, config):
|
|
202
|
+
continue
|
|
203
|
+
label = self.query_one(
|
|
204
|
+
f"#label-{_sanitize_model_name_for_id(model_name)}", Label
|
|
205
|
+
)
|
|
206
|
+
label.update(
|
|
207
|
+
self._model_label(model_name, is_current=model_name == current_model)
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
def _build_model_items(self, config: ShotgunConfig | None = None) -> list[ListItem]:
|
|
211
|
+
if config is None:
|
|
212
|
+
config = self.config_manager.load(force_reload=True)
|
|
213
|
+
|
|
214
|
+
items: list[ListItem] = []
|
|
215
|
+
current_model = self.selected_model
|
|
216
|
+
for model_name in AVAILABLE_MODELS:
|
|
217
|
+
# Only add models that are available
|
|
218
|
+
if not self._is_model_available(model_name, config):
|
|
219
|
+
continue
|
|
220
|
+
|
|
221
|
+
label = Label(
|
|
222
|
+
self._model_label(model_name, is_current=model_name == current_model),
|
|
223
|
+
id=f"label-{_sanitize_model_name_for_id(model_name)}",
|
|
224
|
+
)
|
|
225
|
+
items.append(
|
|
226
|
+
ListItem(label, id=f"model-{_sanitize_model_name_for_id(model_name)}")
|
|
227
|
+
)
|
|
228
|
+
return items
|
|
229
|
+
|
|
230
|
+
def _model_from_item(self, item: ListItem | None) -> ModelName | None:
|
|
231
|
+
"""Get ModelName from a ListItem."""
|
|
232
|
+
if item is None or item.id is None:
|
|
233
|
+
return None
|
|
234
|
+
sanitized_id = item.id.removeprefix("model-")
|
|
235
|
+
# Find the original model name by comparing sanitized versions
|
|
236
|
+
for model_name in AVAILABLE_MODELS:
|
|
237
|
+
if _sanitize_model_name_for_id(model_name) == sanitized_id:
|
|
238
|
+
return model_name
|
|
239
|
+
return None
|
|
240
|
+
|
|
241
|
+
def _is_model_available(
|
|
242
|
+
self, model_name: ModelName, config: ShotgunConfig | None = None
|
|
243
|
+
) -> bool:
|
|
244
|
+
"""Check if a model is available based on provider key configuration.
|
|
245
|
+
|
|
246
|
+
A model is available if:
|
|
247
|
+
1. Shotgun Account key is configured (provides access to all models), OR
|
|
248
|
+
2. The model's provider has an API key configured (BYOK mode)
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
model_name: The model to check availability for
|
|
252
|
+
config: Optional pre-loaded config to avoid multiple reloads
|
|
253
|
+
|
|
254
|
+
Returns:
|
|
255
|
+
True if the model can be used, False otherwise
|
|
256
|
+
"""
|
|
257
|
+
if config is None:
|
|
258
|
+
config = self.config_manager.load(force_reload=True)
|
|
259
|
+
|
|
260
|
+
# If Shotgun Account is configured, all models are available
|
|
261
|
+
if self.config_manager._provider_has_api_key(config.shotgun):
|
|
262
|
+
logger.debug("Model %s available (Shotgun Account configured)", model_name)
|
|
263
|
+
return True
|
|
264
|
+
|
|
265
|
+
# In BYOK mode, check if the model's provider has a key
|
|
266
|
+
if model_name not in MODEL_SPECS:
|
|
267
|
+
logger.debug("Model %s not available (not in MODEL_SPECS)", model_name)
|
|
268
|
+
return False
|
|
269
|
+
|
|
270
|
+
spec = MODEL_SPECS[model_name]
|
|
271
|
+
# Check provider key directly using the loaded config to avoid stale cache
|
|
272
|
+
provider_config = self.config_manager._get_provider_config(
|
|
273
|
+
config, spec.provider
|
|
274
|
+
)
|
|
275
|
+
has_key = self.config_manager._provider_has_api_key(provider_config)
|
|
276
|
+
logger.debug(
|
|
277
|
+
"Model %s available=%s (provider=%s, has_key=%s)",
|
|
278
|
+
model_name,
|
|
279
|
+
has_key,
|
|
280
|
+
spec.provider,
|
|
281
|
+
has_key,
|
|
282
|
+
)
|
|
283
|
+
return has_key
|
|
284
|
+
|
|
285
|
+
def _model_label(self, model_name: ModelName, is_current: bool) -> str:
|
|
286
|
+
"""Generate label for model with specs and current indicator."""
|
|
287
|
+
if model_name not in MODEL_SPECS:
|
|
288
|
+
return model_name.value
|
|
289
|
+
|
|
290
|
+
spec = MODEL_SPECS[model_name]
|
|
291
|
+
display_name = self._model_display_name(model_name)
|
|
292
|
+
|
|
293
|
+
# Format context/output tokens in readable format
|
|
294
|
+
input_k = spec.max_input_tokens // 1000
|
|
295
|
+
output_k = spec.max_output_tokens // 1000
|
|
296
|
+
|
|
297
|
+
label = f"{display_name} · {input_k}K context · {output_k}K output"
|
|
298
|
+
|
|
299
|
+
# Add cost indicator for expensive models
|
|
300
|
+
if model_name == ModelName.CLAUDE_OPUS_4_1:
|
|
301
|
+
label += " · Expensive"
|
|
302
|
+
|
|
303
|
+
if is_current:
|
|
304
|
+
label += " · Current"
|
|
305
|
+
|
|
306
|
+
return label
|
|
307
|
+
|
|
308
|
+
def _model_display_name(self, model_name: ModelName) -> str:
|
|
309
|
+
"""Get human-readable model name."""
|
|
310
|
+
names = {
|
|
311
|
+
ModelName.GPT_5: "GPT-5 (OpenAI)",
|
|
312
|
+
ModelName.CLAUDE_OPUS_4_1: "Claude Opus 4.1 (Anthropic)",
|
|
313
|
+
ModelName.CLAUDE_SONNET_4_5: "Claude Sonnet 4.5 (Anthropic)",
|
|
314
|
+
ModelName.GEMINI_2_5_PRO: "Gemini 2.5 Pro (Google)",
|
|
315
|
+
}
|
|
316
|
+
return names.get(model_name, model_name.value)
|
|
317
|
+
|
|
318
|
+
def _select_model(self) -> None:
|
|
319
|
+
"""Save the selected model."""
|
|
320
|
+
try:
|
|
321
|
+
self.config_manager.update_selected_model(self.selected_model)
|
|
322
|
+
self.refresh_model_labels()
|
|
323
|
+
self.notify(
|
|
324
|
+
f"Selected model: {self._model_display_name(self.selected_model)}"
|
|
325
|
+
)
|
|
326
|
+
except Exception as exc: # pragma: no cover - defensive; textual path
|
|
327
|
+
self.notify(f"Failed to select model: {exc}", severity="error")
|
|
@@ -12,11 +12,25 @@ from textual.screen import Screen
|
|
|
12
12
|
from textual.widgets import Button, Input, Label, ListItem, ListView, Markdown, Static
|
|
13
13
|
|
|
14
14
|
from shotgun.agents.config import ConfigManager, ProviderType
|
|
15
|
+
from shotgun.utils.env_utils import is_shotgun_account_enabled
|
|
15
16
|
|
|
16
17
|
if TYPE_CHECKING:
|
|
17
18
|
from ..app import ShotgunApp
|
|
18
19
|
|
|
19
20
|
|
|
21
|
+
def get_configurable_providers() -> list[str]:
|
|
22
|
+
"""Get list of configurable providers based on feature flags.
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
List of provider identifiers that can be configured.
|
|
26
|
+
Includes shotgun only if SHOTGUN_ACCOUNT_ENABLED is set.
|
|
27
|
+
"""
|
|
28
|
+
providers = ["openai", "anthropic", "google"]
|
|
29
|
+
if is_shotgun_account_enabled():
|
|
30
|
+
providers.append("shotgun")
|
|
31
|
+
return providers
|
|
32
|
+
|
|
33
|
+
|
|
20
34
|
class ProviderConfigScreen(Screen[None]):
|
|
21
35
|
"""Collect API keys for available providers."""
|
|
22
36
|
|
|
@@ -73,7 +87,7 @@ class ProviderConfigScreen(Screen[None]):
|
|
|
73
87
|
("escape", "done", "Back"),
|
|
74
88
|
]
|
|
75
89
|
|
|
76
|
-
selected_provider: reactive[
|
|
90
|
+
selected_provider: reactive[str] = reactive("openai")
|
|
77
91
|
|
|
78
92
|
def compose(self) -> ComposeResult:
|
|
79
93
|
with Vertical(id="titlebox"):
|
|
@@ -102,9 +116,16 @@ class ProviderConfigScreen(Screen[None]):
|
|
|
102
116
|
list_view = self.query_one(ListView)
|
|
103
117
|
if list_view.children:
|
|
104
118
|
list_view.index = 0
|
|
105
|
-
self.selected_provider =
|
|
119
|
+
self.selected_provider = "openai"
|
|
106
120
|
self.set_focus(self.query_one("#api-key", Input))
|
|
107
121
|
|
|
122
|
+
def on_screenresume(self) -> None:
|
|
123
|
+
"""Refresh provider status when screen is resumed.
|
|
124
|
+
|
|
125
|
+
This ensures the UI reflects any provider changes made elsewhere.
|
|
126
|
+
"""
|
|
127
|
+
self.refresh_provider_status()
|
|
128
|
+
|
|
108
129
|
def action_done(self) -> None:
|
|
109
130
|
self.dismiss()
|
|
110
131
|
|
|
@@ -152,45 +173,55 @@ class ProviderConfigScreen(Screen[None]):
|
|
|
152
173
|
|
|
153
174
|
def refresh_provider_status(self) -> None:
|
|
154
175
|
"""Update the list view entries to reflect configured providers."""
|
|
155
|
-
for
|
|
156
|
-
label = self.query_one(f"#label-{
|
|
157
|
-
label.update(self._provider_label(
|
|
176
|
+
for provider_id in get_configurable_providers():
|
|
177
|
+
label = self.query_one(f"#label-{provider_id}", Label)
|
|
178
|
+
label.update(self._provider_label(provider_id))
|
|
158
179
|
|
|
159
180
|
def _build_provider_items(self) -> list[ListItem]:
|
|
160
181
|
items: list[ListItem] = []
|
|
161
|
-
for
|
|
162
|
-
label = Label(self._provider_label(
|
|
163
|
-
items.append(ListItem(label, id=f"provider-{
|
|
182
|
+
for provider_id in get_configurable_providers():
|
|
183
|
+
label = Label(self._provider_label(provider_id), id=f"label-{provider_id}")
|
|
184
|
+
items.append(ListItem(label, id=f"provider-{provider_id}"))
|
|
164
185
|
return items
|
|
165
186
|
|
|
166
|
-
def _provider_from_item(self, item: ListItem | None) ->
|
|
187
|
+
def _provider_from_item(self, item: ListItem | None) -> str | None:
|
|
167
188
|
if item is None or item.id is None:
|
|
168
189
|
return None
|
|
169
190
|
provider_id = item.id.removeprefix("provider-")
|
|
170
|
-
|
|
171
|
-
return ProviderType(provider_id)
|
|
172
|
-
except ValueError:
|
|
173
|
-
return None
|
|
191
|
+
return provider_id if provider_id in get_configurable_providers() else None
|
|
174
192
|
|
|
175
|
-
def _provider_label(self,
|
|
176
|
-
display = self._provider_display_name(
|
|
193
|
+
def _provider_label(self, provider_id: str) -> str:
|
|
194
|
+
display = self._provider_display_name(provider_id)
|
|
177
195
|
status = (
|
|
178
|
-
"Configured"
|
|
179
|
-
if self.config_manager.has_provider_key(provider)
|
|
180
|
-
else "Not configured"
|
|
196
|
+
"Configured" if self._has_provider_key(provider_id) else "Not configured"
|
|
181
197
|
)
|
|
182
198
|
return f"{display} · {status}"
|
|
183
199
|
|
|
184
|
-
def _provider_display_name(self,
|
|
200
|
+
def _provider_display_name(self, provider_id: str) -> str:
|
|
185
201
|
names = {
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
202
|
+
"openai": "OpenAI",
|
|
203
|
+
"anthropic": "Anthropic",
|
|
204
|
+
"google": "Google Gemini",
|
|
205
|
+
"shotgun": "Shotgun Account",
|
|
189
206
|
}
|
|
190
|
-
return names.get(
|
|
191
|
-
|
|
192
|
-
def _input_placeholder(self,
|
|
193
|
-
return f"{self._provider_display_name(
|
|
207
|
+
return names.get(provider_id, provider_id.title())
|
|
208
|
+
|
|
209
|
+
def _input_placeholder(self, provider_id: str) -> str:
|
|
210
|
+
return f"{self._provider_display_name(provider_id)} API key"
|
|
211
|
+
|
|
212
|
+
def _has_provider_key(self, provider_id: str) -> bool:
|
|
213
|
+
"""Check if provider has a configured API key."""
|
|
214
|
+
if provider_id == "shotgun":
|
|
215
|
+
# Check shotgun key directly
|
|
216
|
+
config = self.config_manager.load()
|
|
217
|
+
return self.config_manager._provider_has_api_key(config.shotgun)
|
|
218
|
+
else:
|
|
219
|
+
# Check LLM provider key
|
|
220
|
+
try:
|
|
221
|
+
provider = ProviderType(provider_id)
|
|
222
|
+
return self.config_manager.has_provider_key(provider)
|
|
223
|
+
except ValueError:
|
|
224
|
+
return False
|
|
194
225
|
|
|
195
226
|
def _save_api_key(self) -> None:
|
|
196
227
|
input_widget = self.query_one("#api-key", Input)
|