tunacode-cli 0.0.70__py3-none-any.whl → 0.0.72__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tunacode-cli might be problematic. Click here for more details.

@@ -1,61 +1,361 @@
1
1
  """Model management commands for TunaCode CLI."""
2
2
 
3
- from typing import Optional
3
+ from typing import Dict, List, Optional
4
4
 
5
- from .... import utils
6
5
  from ....exceptions import ConfigurationError
7
6
  from ....types import CommandArgs, CommandContext
8
7
  from ....ui import console as ui
8
+ from ....ui.model_selector import select_model_interactive
9
+ from ....utils import user_configuration
10
+ from ....utils.models_registry import ModelInfo, ModelsRegistry
9
11
  from ..base import CommandCategory, CommandSpec, SimpleCommand
10
12
 
11
13
 
12
14
  class ModelCommand(SimpleCommand):
13
- """Manage model selection."""
15
+ """Manage model selection with models.dev integration."""
14
16
 
15
17
  spec = CommandSpec(
16
18
  name="model",
17
19
  aliases=["/model"],
18
- description="Switch model (e.g., /model gpt-4 or /model openai:gpt-4)",
20
+ description="Switch model with interactive selection or search",
19
21
  category=CommandCategory.MODEL,
20
22
  )
21
23
 
24
+ def __init__(self):
25
+ """Initialize the model command."""
26
+ super().__init__()
27
+ self.registry = ModelsRegistry()
28
+ self._registry_loaded = False
29
+
30
+ async def _ensure_registry(self) -> bool:
31
+ """Ensure the models registry is loaded."""
32
+ if not self._registry_loaded:
33
+ self._registry_loaded = await self.registry.load()
34
+ return self._registry_loaded
35
+
22
36
  async def execute(self, args: CommandArgs, context: CommandContext) -> Optional[str]:
23
- # No arguments - show current model
37
+ # Handle special flags
38
+ if args and args[0] in ["--list", "-l"]:
39
+ return await self._list_models()
40
+
41
+ if args and args[0] in ["--info", "-i"]:
42
+ if len(args) < 2:
43
+ await ui.error("Usage: /model --info <model-id>")
44
+ return None
45
+ return await self._show_model_info(args[1])
46
+
47
+ # No arguments - show interactive selector
24
48
  if not args:
25
- current_model = context.state_manager.session.current_model
26
- await ui.info(f"Current model: {current_model}")
27
- await ui.muted("Usage: /model <provider:model-name> [default]")
28
- await ui.muted("Example: /model openai:gpt-4.1")
49
+ return await self._interactive_select(context)
50
+
51
+ # Single argument - could be search query or model ID
52
+ model_query = args[0]
53
+
54
+ # Check for flags
55
+ if model_query in ["--search", "-s"]:
56
+ search_query = " ".join(args[1:]) if len(args) > 1 else ""
57
+ return await self._interactive_select(context, search_query)
58
+
59
+ # Direct model specification
60
+ return await self._set_model(model_query, args[1:], context)
61
+
62
+ async def _interactive_select(
63
+ self, context: CommandContext, initial_query: str = ""
64
+ ) -> Optional[str]:
65
+ """Show interactive model selector."""
66
+ await self._ensure_registry()
67
+
68
+ # Show current model
69
+ current_model = context.state_manager.session.current_model
70
+ await ui.info(f"Current model: {current_model}")
71
+
72
+ # Check if we have models loaded
73
+ if not self.registry.models:
74
+ await ui.error("No models available. Try /model --list to see if models can be loaded.")
29
75
  return None
30
76
 
31
- # Get the model name from args
32
- model_name = args[0]
77
+ # For now, use a simple text-based approach instead of complex UI
78
+ # This avoids prompt_toolkit compatibility issues
79
+ if initial_query:
80
+ models = self.registry.search_models(initial_query)
81
+ if not models:
82
+ await ui.error(f"No models found matching '{initial_query}'")
83
+ return None
84
+ else:
85
+ # Show popular models for quick selection
86
+ popular_searches = ["gpt", "claude", "gemini"]
87
+ await ui.info("Popular model searches:")
88
+ for search in popular_searches:
89
+ models = self.registry.search_models(search)[:3] # Top 3
90
+ if models:
91
+ await ui.info(f"\n{search.upper()} models:")
92
+ for model in models:
93
+ await ui.muted(f" • {model.full_id} - {model.name}")
94
+
95
+ await ui.info("\nUsage:")
96
+ await ui.muted(" /model <search-term> - Search for models")
97
+ await ui.muted(" /model --list - Show all models")
98
+ await ui.muted(" /model --info <id> - Show model details")
99
+ await ui.muted(" /model <provider:id> - Set model directly")
100
+ return None
33
101
 
34
- # Check if provider prefix is present
35
- if ":" not in model_name:
36
- await ui.error("Model name must include provider prefix")
37
- await ui.muted("Format: provider:model-name")
38
- await ui.muted(
39
- "Examples: openai:gpt-4.1, anthropic:claude-3-opus, google-gla:gemini-2.0-flash"
40
- )
102
+ # Show search results
103
+ if len(models) == 1:
104
+ # Auto-select single result
105
+ model = models[0]
106
+ context.state_manager.session.current_model = model.full_id
107
+ await ui.success(f"Switched to model: {model.full_id} - {model.name}")
108
+ return None
109
+
110
+ # Show multiple results
111
+ await ui.info(f"Found {len(models)} models:")
112
+ for i, model in enumerate(models[:10], 1): # Show top 10
113
+ details = []
114
+ if model.cost.input is not None:
115
+ details.append(f"${model.cost.input}/{model.cost.output}")
116
+ if model.limits.context:
117
+ details.append(f"{model.limits.context // 1000}k")
118
+ detail_str = f" ({', '.join(details)})" if details else ""
119
+
120
+ await ui.info(f"{i:2d}. {model.full_id} - {model.name}{detail_str}")
121
+
122
+ if len(models) > 10:
123
+ await ui.muted(f"... and {len(models) - 10} more")
124
+
125
+ await ui.muted("Use '/model <provider:model-id>' to select a specific model")
126
+ return None
127
+
128
+ async def _set_model(
129
+ self, model_name: str, extra_args: CommandArgs, context: CommandContext
130
+ ) -> Optional[str]:
131
+ """Set model directly or by search."""
132
+ # Load registry for validation
133
+ await self._ensure_registry()
134
+
135
+ # Check if it's a direct model ID
136
+ if ":" in model_name:
137
+ # Validate against registry if loaded
138
+ if self._registry_loaded:
139
+ model_info = self.registry.get_model(model_name)
140
+ if not model_info:
141
+ # Search for similar models
142
+ similar = self.registry.search_models(model_name.split(":")[-1])
143
+ if similar:
144
+ await ui.warning(f"Model '{model_name}' not found in registry")
145
+ await ui.muted("Did you mean one of these?")
146
+ for model in similar[:5]:
147
+ await ui.muted(f" • {model.full_id} - {model.name}")
148
+ return None
149
+ else:
150
+ await ui.warning("Model not found in registry - setting anyway")
151
+ else:
152
+ # Show model info
153
+ await ui.info(f"Selected: {model_info.name}")
154
+ if model_info.cost.input is not None:
155
+ await ui.muted(f" Pricing: {model_info.cost.format_cost()}")
156
+ if model_info.limits.context:
157
+ await ui.muted(f" Limits: {model_info.limits.format_limits()}")
158
+
159
+ # Set the model
160
+ context.state_manager.session.current_model = model_name
161
+
162
+ # Check if setting as default
163
+ if extra_args and extra_args[0] == "default":
164
+ try:
165
+ user_configuration.set_default_model(model_name, context.state_manager)
166
+ await ui.muted("Updating default model")
167
+ return "restart"
168
+ except ConfigurationError as e:
169
+ await ui.error(str(e))
170
+ return None
171
+
172
+ await ui.success(f"Switched to model: {model_name}")
173
+ return None
174
+
175
+ # No colon - treat as search query
176
+ models = self.registry.search_models(model_name)
177
+
178
+ if not models:
179
+ await ui.error(f"No models found matching '{model_name}'")
180
+ await ui.muted("Try /model --list to see all available models")
181
+ return None
182
+
183
+ if len(models) == 1:
184
+ # Single match - use it
185
+ model = models[0]
186
+ context.state_manager.session.current_model = model.full_id
187
+ await ui.success(f"Switched to model: {model.full_id} - {model.name}")
41
188
  return None
42
189
 
43
- # No validation - user is responsible for correct model names
44
- await ui.warning("Model set without validation - verify the model name is correct")
190
+ # Multiple matches - show interactive selector with results
191
+ await ui.info(f"Found {len(models)} models matching '{model_name}'")
192
+ selected_model = await select_model_interactive(self.registry, model_name)
45
193
 
46
- # Set the model
47
- context.state_manager.session.current_model = model_name
194
+ if selected_model:
195
+ context.state_manager.session.current_model = selected_model
196
+ await ui.success(f"Switched to model: {selected_model}")
197
+ else:
198
+ await ui.info("Model selection cancelled")
48
199
 
49
- # Check if setting as default
50
- if len(args) > 1 and args[1] == "default":
51
- try:
52
- utils.user_configuration.set_default_model(model_name, context.state_manager)
53
- await ui.muted("Updating default model")
54
- return "restart"
55
- except ConfigurationError as e:
56
- await ui.error(str(e))
200
+ return None
201
+
202
+ async def _list_models(self) -> Optional[str]:
203
+ """List all available models."""
204
+ await self._ensure_registry()
205
+
206
+ if not self.registry.models:
207
+ await ui.error("No models available")
208
+ return None
209
+
210
+ # Group by provider
211
+ providers: Dict[str, List[ModelInfo]] = {}
212
+ for model in self.registry.models.values():
213
+ if model.provider not in providers:
214
+ providers[model.provider] = []
215
+ providers[model.provider].append(model)
216
+
217
+ # Display models
218
+ await ui.info(f"Available models ({len(self.registry.models)} total):")
219
+
220
+ for provider_id in sorted(providers.keys()):
221
+ provider_info = self.registry.providers.get(provider_id)
222
+ provider_name = provider_info.name if provider_info else provider_id
223
+
224
+ await ui.print(f"\n{provider_name}:")
225
+
226
+ for model in sorted(providers[provider_id], key=lambda m: m.name):
227
+ line = f" • {model.id}"
228
+ if model.cost.input is not None:
229
+ line += f" (${model.cost.input}/{model.cost.output})"
230
+ if model.limits.context:
231
+ line += f" [{model.limits.context // 1000}k]"
232
+ await ui.muted(line)
233
+
234
+ return None
235
+
236
+ async def _show_model_info(self, model_id: str) -> Optional[str]:
237
+ """Show detailed information about a model."""
238
+ await self._ensure_registry()
239
+
240
+ model = self.registry.get_model(model_id)
241
+ if not model:
242
+ # Try to find similar models or routing options
243
+ base_name = self.registry._extract_base_model_name(model_id)
244
+ variants = self.registry.get_model_variants(base_name)
245
+ if variants:
246
+ await ui.warning(f"Model '{model_id}' not found directly")
247
+ await ui.info(f"Found routing options for '{base_name}':")
248
+
249
+ # Sort variants by cost (FREE first)
250
+ sorted_variants = sorted(
251
+ variants,
252
+ key=lambda m: (
253
+ 0 if m.cost.input == 0 else 1, # FREE first
254
+ m.cost.input or float("inf"), # Then by cost
255
+ m.provider, # Then by provider name
256
+ ),
257
+ )
258
+
259
+ for variant in sorted_variants:
260
+ cost_display = (
261
+ "FREE"
262
+ if variant.cost.input == 0
263
+ else f"${variant.cost.input}/{variant.cost.output}"
264
+ )
265
+ provider_name = self._get_provider_display_name(variant.provider)
266
+
267
+ await ui.muted(f" • {variant.full_id} - {provider_name} ({cost_display})")
268
+
269
+ await ui.muted(
270
+ "\nUse '/model <provider:model-id>' to select a specific routing option"
271
+ )
272
+ return None
273
+ else:
274
+ await ui.error(f"Model '{model_id}' not found")
57
275
  return None
58
276
 
59
- # Show success message with the new model
60
- await ui.success(f"Switched to model: {model_name}")
277
+ # Display model information
278
+ await ui.info(f"{model.name}")
279
+ await ui.muted(f"ID: {model.full_id}")
280
+
281
+ # Show routing alternatives for this base model
282
+ base_name = self.registry._extract_base_model_name(model)
283
+ variants = self.registry.get_model_variants(base_name)
284
+ if len(variants) > 1:
285
+ await ui.print("\nRouting Options:")
286
+
287
+ # Sort variants by cost (FREE first)
288
+ sorted_variants = sorted(
289
+ variants,
290
+ key=lambda m: (
291
+ 0 if m.cost.input == 0 else 1, # FREE first
292
+ m.cost.input or float("inf"), # Then by cost
293
+ m.provider, # Then by provider name
294
+ ),
295
+ )
296
+
297
+ for variant in sorted_variants:
298
+ cost_display = (
299
+ "FREE"
300
+ if variant.cost.input == 0
301
+ else f"${variant.cost.input}/{variant.cost.output}"
302
+ )
303
+ provider_name = self._get_provider_display_name(variant.provider)
304
+
305
+ # Highlight current selection
306
+ prefix = "→ " if variant.full_id == model.full_id else " "
307
+ free_indicator = " ⭐" if variant.cost.input == 0 else ""
308
+
309
+ await ui.muted(
310
+ f"{prefix}{variant.full_id} - {provider_name} ({cost_display}){free_indicator}"
311
+ )
312
+
313
+ if model.cost.input is not None:
314
+ await ui.print("\nPricing:")
315
+ await ui.muted(f" Input: ${model.cost.input} per 1M tokens")
316
+ await ui.muted(f" Output: ${model.cost.output} per 1M tokens")
317
+
318
+ if model.limits.context or model.limits.output:
319
+ await ui.print("\nLimits:")
320
+ if model.limits.context:
321
+ await ui.muted(f" Context: {model.limits.context:,} tokens")
322
+ if model.limits.output:
323
+ await ui.muted(f" Output: {model.limits.output:,} tokens")
324
+
325
+ caps = []
326
+ if model.capabilities.attachment:
327
+ caps.append("Attachments")
328
+ if model.capabilities.reasoning:
329
+ caps.append("Reasoning")
330
+ if model.capabilities.tool_call:
331
+ caps.append("Tool calling")
332
+
333
+ if caps:
334
+ await ui.print("\nCapabilities:")
335
+ for cap in caps:
336
+ await ui.muted(f" ✓ {cap}")
337
+
338
+ if model.capabilities.knowledge:
339
+ await ui.print(f"\nKnowledge cutoff: {model.capabilities.knowledge}")
340
+
61
341
  return None
342
+
343
+ def _get_provider_display_name(self, provider: str) -> str:
344
+ """Get a user-friendly provider display name."""
345
+ provider_names = {
346
+ "openai": "OpenAI Direct",
347
+ "anthropic": "Anthropic Direct",
348
+ "google": "Google Direct",
349
+ "google-gla": "Google Labs",
350
+ "openrouter": "OpenRouter",
351
+ "github-models": "GitHub Models (FREE)",
352
+ "azure": "Azure OpenAI",
353
+ "fastrouter": "FastRouter",
354
+ "requesty": "Requesty",
355
+ "cloudflare-workers-ai": "Cloudflare",
356
+ "amazon-bedrock": "AWS Bedrock",
357
+ "chutes": "Chutes AI",
358
+ "deepinfra": "DeepInfra",
359
+ "venice": "Venice AI",
360
+ }
361
+ return provider_names.get(provider, provider.title())
tunacode/cli/main.py CHANGED
@@ -77,7 +77,8 @@ def main(
77
77
  from tunacode.exceptions import ConfigurationError
78
78
 
79
79
  if isinstance(e, ConfigurationError):
80
- # ConfigurationError already printed helpful message, just exit cleanly
80
+ # Display the configuration error message
81
+ await ui.error(str(e))
81
82
  update_task.cancel() # Cancel the update check
82
83
  return
83
84
  import traceback
tunacode/constants.py CHANGED
@@ -9,7 +9,7 @@ from enum import Enum
9
9
 
10
10
  # Application info
11
11
  APP_NAME = "TunaCode"
12
- APP_VERSION = "0.0.70"
12
+ APP_VERSION = "0.0.72"
13
13
 
14
14
 
15
15
  # File patterns
@@ -240,29 +240,34 @@ YOU MUST EXECUTE present_plan TOOL TO COMPLETE ANY PLANNING TASK.
240
240
  except Exception as e:
241
241
  logger.warning(f"Warning: Failed to load todos: {e}")
242
242
 
243
+ # Get tool strict validation setting from config (default to False for backward compatibility)
244
+ tool_strict_validation = state_manager.session.user_config.get("settings", {}).get(
245
+ "tool_strict_validation", False
246
+ )
247
+
243
248
  # Create tool list based on mode
244
249
  if state_manager.is_plan_mode():
245
250
  # Plan mode: Only read-only tools + present_plan
246
251
  tools_list = [
247
- Tool(present_plan, max_retries=max_retries),
248
- Tool(glob, max_retries=max_retries),
249
- Tool(grep, max_retries=max_retries),
250
- Tool(list_dir, max_retries=max_retries),
251
- Tool(read_file, max_retries=max_retries),
252
+ Tool(present_plan, max_retries=max_retries, strict=tool_strict_validation),
253
+ Tool(glob, max_retries=max_retries, strict=tool_strict_validation),
254
+ Tool(grep, max_retries=max_retries, strict=tool_strict_validation),
255
+ Tool(list_dir, max_retries=max_retries, strict=tool_strict_validation),
256
+ Tool(read_file, max_retries=max_retries, strict=tool_strict_validation),
252
257
  ]
253
258
  else:
254
259
  # Normal mode: All tools
255
260
  tools_list = [
256
- Tool(bash, max_retries=max_retries),
257
- Tool(present_plan, max_retries=max_retries),
258
- Tool(glob, max_retries=max_retries),
259
- Tool(grep, max_retries=max_retries),
260
- Tool(list_dir, max_retries=max_retries),
261
- Tool(read_file, max_retries=max_retries),
262
- Tool(run_command, max_retries=max_retries),
263
- Tool(todo_tool._execute, max_retries=max_retries),
264
- Tool(update_file, max_retries=max_retries),
265
- Tool(write_file, max_retries=max_retries),
261
+ Tool(bash, max_retries=max_retries, strict=tool_strict_validation),
262
+ Tool(present_plan, max_retries=max_retries, strict=tool_strict_validation),
263
+ Tool(glob, max_retries=max_retries, strict=tool_strict_validation),
264
+ Tool(grep, max_retries=max_retries, strict=tool_strict_validation),
265
+ Tool(list_dir, max_retries=max_retries, strict=tool_strict_validation),
266
+ Tool(read_file, max_retries=max_retries, strict=tool_strict_validation),
267
+ Tool(run_command, max_retries=max_retries, strict=tool_strict_validation),
268
+ Tool(todo_tool._execute, max_retries=max_retries, strict=tool_strict_validation),
269
+ Tool(update_file, max_retries=max_retries, strict=tool_strict_validation),
270
+ Tool(write_file, max_retries=max_retries, strict=tool_strict_validation),
266
271
  ]
267
272
 
268
273
  # Log which tools are being registered
@@ -291,6 +296,11 @@ YOU MUST EXECUTE present_plan TOOL TO COMPLETE ANY PLANNING TASK.
291
296
  (
292
297
  state_manager.is_plan_mode(),
293
298
  str(state_manager.session.user_config.get("settings", {}).get("max_retries", 3)),
299
+ str(
300
+ state_manager.session.user_config.get("settings", {}).get(
301
+ "tool_strict_validation", False
302
+ )
303
+ ),
294
304
  str(state_manager.session.user_config.get("mcpServers", {})),
295
305
  )
296
306
  )
@@ -16,6 +16,7 @@ from tunacode.exceptions import ConfigurationError
16
16
  from tunacode.types import ConfigFile, ConfigPath, UserConfig
17
17
  from tunacode.ui import console as ui
18
18
  from tunacode.utils import system, user_configuration
19
+ from tunacode.utils.api_key_validation import validate_api_key_for_model
19
20
  from tunacode.utils.text_utils import key_to_title
20
21
 
21
22
 
@@ -138,11 +139,15 @@ class ConfigSetup(BaseSetup):
138
139
  )
139
140
  )
140
141
 
141
- # No model validation - trust user's model choice
142
+ # Validate API key exists for the selected model
143
+ model = self.state_manager.session.user_config["default_model"]
144
+ is_valid, error_msg = validate_api_key_for_model(
145
+ model, self.state_manager.session.user_config
146
+ )
147
+ if not is_valid:
148
+ raise ConfigurationError(error_msg)
142
149
 
143
- self.state_manager.session.current_model = self.state_manager.session.user_config[
144
- "default_model"
145
- ]
150
+ self.state_manager.session.current_model = model
146
151
 
147
152
  async def validate(self) -> bool:
148
153
  """Validate that configuration is properly set up."""
@@ -152,6 +157,17 @@ class ConfigSetup(BaseSetup):
152
157
  valid = False
153
158
  elif not self.state_manager.session.user_config.get("default_model"):
154
159
  valid = False
160
+ else:
161
+ # Validate API key exists for the selected model
162
+ model = self.state_manager.session.user_config.get("default_model")
163
+ is_valid, error_msg = validate_api_key_for_model(
164
+ model, self.state_manager.session.user_config
165
+ )
166
+ if not is_valid:
167
+ valid = False
168
+ # Store error message for later use
169
+ setattr(self.state_manager, "_config_error", error_msg)
170
+
155
171
  # Cache result for fastpath
156
172
  if valid:
157
173
  setattr(self.state_manager, "_config_valid", True)