tunacode-cli 0.0.70__py3-none-any.whl → 0.0.72__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tunacode-cli might be problematic. Click here for more details.
- tunacode/cli/commands/implementations/model.py +332 -32
- tunacode/cli/main.py +2 -1
- tunacode/constants.py +1 -1
- tunacode/core/agents/agent_components/agent_config.py +25 -15
- tunacode/core/setup/config_setup.py +20 -4
- tunacode/ui/completers.py +211 -9
- tunacode/ui/input.py +7 -1
- tunacode/ui/model_selector.py +394 -0
- tunacode/utils/api_key_validation.py +93 -0
- tunacode/utils/models_registry.py +563 -0
- {tunacode_cli-0.0.70.dist-info → tunacode_cli-0.0.72.dist-info}/METADATA +2 -2
- {tunacode_cli-0.0.70.dist-info → tunacode_cli-0.0.72.dist-info}/RECORD +15 -12
- {tunacode_cli-0.0.70.dist-info → tunacode_cli-0.0.72.dist-info}/WHEEL +0 -0
- {tunacode_cli-0.0.70.dist-info → tunacode_cli-0.0.72.dist-info}/entry_points.txt +0 -0
- {tunacode_cli-0.0.70.dist-info → tunacode_cli-0.0.72.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Module: tunacode.utils.api_key_validation
|
|
3
|
+
|
|
4
|
+
Utilities for validating API keys are configured for the selected model.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Optional, Tuple
|
|
8
|
+
|
|
9
|
+
from tunacode.types import UserConfig
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def get_required_api_key_for_model(model: str) -> Tuple[Optional[str], str]:
|
|
13
|
+
"""
|
|
14
|
+
Determine which API key is required for a given model.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
model: Model identifier in format "provider:model-name"
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
Tuple of (api_key_name, provider_name) or (None, "unknown") if no specific key required
|
|
21
|
+
"""
|
|
22
|
+
if not model or ":" not in model:
|
|
23
|
+
return None, "unknown"
|
|
24
|
+
|
|
25
|
+
provider = model.split(":")[0].lower()
|
|
26
|
+
|
|
27
|
+
# Map providers to their required API keys
|
|
28
|
+
provider_key_map = {
|
|
29
|
+
"openrouter": ("OPENROUTER_API_KEY", "OpenRouter"),
|
|
30
|
+
"openai": ("OPENAI_API_KEY", "OpenAI"),
|
|
31
|
+
"anthropic": ("ANTHROPIC_API_KEY", "Anthropic"),
|
|
32
|
+
"google": ("GEMINI_API_KEY", "Google"),
|
|
33
|
+
"google-gla": ("GEMINI_API_KEY", "Google"),
|
|
34
|
+
"gemini": ("GEMINI_API_KEY", "Google"),
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
return provider_key_map.get(provider, (None, provider))
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def validate_api_key_for_model(model: str, user_config: UserConfig) -> Tuple[bool, Optional[str]]:
|
|
41
|
+
"""
|
|
42
|
+
Check if the required API key exists for the given model.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
model: Model identifier in format "provider:model-name"
|
|
46
|
+
user_config: User configuration containing env variables
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
Tuple of (is_valid, error_message)
|
|
50
|
+
"""
|
|
51
|
+
api_key_name, provider_name = get_required_api_key_for_model(model)
|
|
52
|
+
|
|
53
|
+
if not api_key_name:
|
|
54
|
+
# No specific API key required (might be custom endpoint)
|
|
55
|
+
return True, None
|
|
56
|
+
|
|
57
|
+
env_config = user_config.get("env", {})
|
|
58
|
+
api_key = env_config.get(api_key_name, "").strip()
|
|
59
|
+
|
|
60
|
+
if not api_key:
|
|
61
|
+
return False, (
|
|
62
|
+
f"No API key found for {provider_name}.\n"
|
|
63
|
+
f"Please run 'tunacode --setup' to configure your API key."
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
return True, None
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def get_configured_providers(user_config: UserConfig) -> list[str]:
|
|
70
|
+
"""
|
|
71
|
+
Get list of providers that have API keys configured.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
user_config: User configuration containing env variables
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
List of provider names that have API keys set
|
|
78
|
+
"""
|
|
79
|
+
env_config = user_config.get("env", {})
|
|
80
|
+
configured = []
|
|
81
|
+
|
|
82
|
+
provider_map = {
|
|
83
|
+
"OPENROUTER_API_KEY": "openrouter",
|
|
84
|
+
"OPENAI_API_KEY": "openai",
|
|
85
|
+
"ANTHROPIC_API_KEY": "anthropic",
|
|
86
|
+
"GEMINI_API_KEY": "google",
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
for key_name, provider in provider_map.items():
|
|
90
|
+
if env_config.get(key_name, "").strip():
|
|
91
|
+
configured.append(provider)
|
|
92
|
+
|
|
93
|
+
return configured
|
|
@@ -0,0 +1,563 @@
|
|
|
1
|
+
"""Models.dev integration for model discovery and validation."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from datetime import datetime, timedelta
|
|
6
|
+
from difflib import SequenceMatcher
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Dict, List, Optional
|
|
9
|
+
from urllib.error import URLError
|
|
10
|
+
from urllib.request import urlopen
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class ModelCapabilities:
|
|
15
|
+
"""Model capabilities and features."""
|
|
16
|
+
|
|
17
|
+
attachment: bool = False
|
|
18
|
+
reasoning: bool = False
|
|
19
|
+
tool_call: bool = False
|
|
20
|
+
temperature: bool = True
|
|
21
|
+
knowledge: Optional[str] = None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class ModelCost:
|
|
26
|
+
"""Model pricing information."""
|
|
27
|
+
|
|
28
|
+
input: Optional[float] = None
|
|
29
|
+
output: Optional[float] = None
|
|
30
|
+
cache: Optional[float] = None
|
|
31
|
+
|
|
32
|
+
def format_cost(self) -> str:
|
|
33
|
+
"""Format cost as a readable string."""
|
|
34
|
+
if self.input is None or self.output is None:
|
|
35
|
+
return "Pricing not available"
|
|
36
|
+
return f"${self.input}/{self.output} per 1M tokens"
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class ModelLimits:
|
|
41
|
+
"""Model context and output limits."""
|
|
42
|
+
|
|
43
|
+
context: Optional[int] = None
|
|
44
|
+
output: Optional[int] = None
|
|
45
|
+
|
|
46
|
+
def format_limits(self) -> str:
|
|
47
|
+
"""Format limits as a readable string."""
|
|
48
|
+
parts = []
|
|
49
|
+
if self.context:
|
|
50
|
+
parts.append(f"{self.context:,} context")
|
|
51
|
+
if self.output:
|
|
52
|
+
parts.append(f"{self.output:,} output")
|
|
53
|
+
return ", ".join(parts) if parts else "Limits not specified"
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@dataclass
|
|
57
|
+
class ModelInfo:
|
|
58
|
+
"""Complete model information."""
|
|
59
|
+
|
|
60
|
+
id: str
|
|
61
|
+
name: str
|
|
62
|
+
provider: str
|
|
63
|
+
capabilities: ModelCapabilities = field(default_factory=ModelCapabilities)
|
|
64
|
+
cost: ModelCost = field(default_factory=ModelCost)
|
|
65
|
+
limits: ModelLimits = field(default_factory=ModelLimits)
|
|
66
|
+
release_date: Optional[str] = None
|
|
67
|
+
last_updated: Optional[str] = None
|
|
68
|
+
open_weights: bool = False
|
|
69
|
+
modalities: Dict[str, List[str]] = field(default_factory=dict)
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def full_id(self) -> str:
|
|
73
|
+
"""Get the full model identifier with provider prefix."""
|
|
74
|
+
return f"{self.provider}:{self.id}"
|
|
75
|
+
|
|
76
|
+
def format_display(self, include_details: bool = True) -> str:
|
|
77
|
+
"""Format model for display."""
|
|
78
|
+
display = f"{self.full_id} - {self.name}"
|
|
79
|
+
if include_details:
|
|
80
|
+
details = []
|
|
81
|
+
if self.cost.input is not None:
|
|
82
|
+
details.append(self.cost.format_cost())
|
|
83
|
+
if self.limits.context:
|
|
84
|
+
details.append(f"{self.limits.context // 1000}k context")
|
|
85
|
+
if details:
|
|
86
|
+
display += f" ({', '.join(details)})"
|
|
87
|
+
return display
|
|
88
|
+
|
|
89
|
+
def matches_search(self, query: str) -> float:
|
|
90
|
+
"""Calculate match score for search query (0-1)."""
|
|
91
|
+
query_lower = query.lower()
|
|
92
|
+
|
|
93
|
+
# Exact match in ID or name
|
|
94
|
+
if query_lower in self.id.lower():
|
|
95
|
+
return 1.0
|
|
96
|
+
if query_lower in self.name.lower():
|
|
97
|
+
return 0.9
|
|
98
|
+
if query_lower in self.provider.lower():
|
|
99
|
+
return 0.8
|
|
100
|
+
|
|
101
|
+
# Fuzzy match
|
|
102
|
+
best_ratio = 0.0
|
|
103
|
+
for field_value in [self.id, self.name, self.provider]:
|
|
104
|
+
ratio = SequenceMatcher(None, query_lower, field_value.lower()).ratio()
|
|
105
|
+
best_ratio = max(best_ratio, ratio)
|
|
106
|
+
|
|
107
|
+
return best_ratio
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@dataclass
|
|
111
|
+
class ProviderInfo:
|
|
112
|
+
"""Provider information."""
|
|
113
|
+
|
|
114
|
+
id: str
|
|
115
|
+
name: str
|
|
116
|
+
env: List[str] = field(default_factory=list)
|
|
117
|
+
npm: Optional[str] = None
|
|
118
|
+
doc: Optional[str] = None
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class ModelsRegistry:
|
|
122
|
+
"""Registry for managing models from models.dev."""
|
|
123
|
+
|
|
124
|
+
API_URL = "https://models.dev/api.json"
|
|
125
|
+
CACHE_FILE = "models_cache.json"
|
|
126
|
+
CACHE_TTL = timedelta(hours=24)
|
|
127
|
+
|
|
128
|
+
def __init__(self, cache_dir: Optional[Path] = None):
|
|
129
|
+
"""Initialize the models registry."""
|
|
130
|
+
if cache_dir is None:
|
|
131
|
+
cache_dir = Path.home() / ".tunacode" / "cache"
|
|
132
|
+
self.cache_dir = cache_dir
|
|
133
|
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
134
|
+
self.cache_file = self.cache_dir / self.CACHE_FILE
|
|
135
|
+
|
|
136
|
+
self.models: Dict[str, ModelInfo] = {}
|
|
137
|
+
self.providers: Dict[str, ProviderInfo] = {}
|
|
138
|
+
self._loaded = False
|
|
139
|
+
|
|
140
|
+
def _is_cache_valid(self) -> bool:
|
|
141
|
+
"""Check if cache file exists and is still valid."""
|
|
142
|
+
if not self.cache_file.exists():
|
|
143
|
+
return False
|
|
144
|
+
|
|
145
|
+
# Check cache age
|
|
146
|
+
cache_age = datetime.now() - datetime.fromtimestamp(self.cache_file.stat().st_mtime)
|
|
147
|
+
return cache_age < self.CACHE_TTL
|
|
148
|
+
|
|
149
|
+
def _load_from_cache(self) -> bool:
|
|
150
|
+
"""Load models from cache file."""
|
|
151
|
+
try:
|
|
152
|
+
with open(self.cache_file, "r") as f:
|
|
153
|
+
data = json.load(f)
|
|
154
|
+
self._parse_data(data)
|
|
155
|
+
return True
|
|
156
|
+
except (json.JSONDecodeError, OSError, KeyError):
|
|
157
|
+
return False
|
|
158
|
+
|
|
159
|
+
def _fetch_from_api(self) -> bool:
|
|
160
|
+
"""Fetch models from models.dev API."""
|
|
161
|
+
try:
|
|
162
|
+
# Add User-Agent header to avoid blocking
|
|
163
|
+
import urllib.request
|
|
164
|
+
|
|
165
|
+
req = urllib.request.Request(self.API_URL, headers={"User-Agent": "TunaCode-CLI/1.0"})
|
|
166
|
+
with urlopen(req, timeout=10) as response: # nosec B310 - Using trusted models.dev API
|
|
167
|
+
data = json.loads(response.read())
|
|
168
|
+
|
|
169
|
+
# Save to cache
|
|
170
|
+
with open(self.cache_file, "w") as f:
|
|
171
|
+
json.dump(data, f, indent=2)
|
|
172
|
+
|
|
173
|
+
self._parse_data(data)
|
|
174
|
+
return True
|
|
175
|
+
except (URLError, json.JSONDecodeError, OSError):
|
|
176
|
+
# Log error but don't fail
|
|
177
|
+
return False
|
|
178
|
+
|
|
179
|
+
def _load_fallback_models(self) -> None:
|
|
180
|
+
"""Load hardcoded popular models as fallback."""
|
|
181
|
+
fallback_data = {
|
|
182
|
+
"openai": {
|
|
183
|
+
"name": "OpenAI",
|
|
184
|
+
"env": ["OPENAI_API_KEY"],
|
|
185
|
+
"npm": "@ai-sdk/openai",
|
|
186
|
+
"doc": "https://platform.openai.com/docs",
|
|
187
|
+
"models": {
|
|
188
|
+
"gpt-4": {
|
|
189
|
+
"name": "GPT-4",
|
|
190
|
+
"attachment": True,
|
|
191
|
+
"reasoning": True,
|
|
192
|
+
"tool_call": True,
|
|
193
|
+
"temperature": True,
|
|
194
|
+
"knowledge": "2024-04",
|
|
195
|
+
"cost": {"input": 30.0, "output": 60.0},
|
|
196
|
+
"limit": {"context": 128000, "output": 4096},
|
|
197
|
+
},
|
|
198
|
+
"gpt-4-turbo": {
|
|
199
|
+
"name": "GPT-4 Turbo",
|
|
200
|
+
"attachment": True,
|
|
201
|
+
"reasoning": True,
|
|
202
|
+
"tool_call": True,
|
|
203
|
+
"temperature": True,
|
|
204
|
+
"knowledge": "2024-04",
|
|
205
|
+
"cost": {"input": 10.0, "output": 30.0},
|
|
206
|
+
"limit": {"context": 128000, "output": 4096},
|
|
207
|
+
},
|
|
208
|
+
"gpt-3.5-turbo": {
|
|
209
|
+
"name": "GPT-3.5 Turbo",
|
|
210
|
+
"attachment": False,
|
|
211
|
+
"reasoning": False,
|
|
212
|
+
"tool_call": True,
|
|
213
|
+
"temperature": True,
|
|
214
|
+
"knowledge": "2024-01",
|
|
215
|
+
"cost": {"input": 0.5, "output": 1.5},
|
|
216
|
+
"limit": {"context": 16000, "output": 4096},
|
|
217
|
+
},
|
|
218
|
+
},
|
|
219
|
+
},
|
|
220
|
+
"anthropic": {
|
|
221
|
+
"name": "Anthropic",
|
|
222
|
+
"env": ["ANTHROPIC_API_KEY"],
|
|
223
|
+
"npm": "@ai-sdk/anthropic",
|
|
224
|
+
"doc": "https://docs.anthropic.com",
|
|
225
|
+
"models": {
|
|
226
|
+
"claude-3-opus-20240229": {
|
|
227
|
+
"name": "Claude 3 Opus",
|
|
228
|
+
"attachment": True,
|
|
229
|
+
"reasoning": True,
|
|
230
|
+
"tool_call": True,
|
|
231
|
+
"temperature": True,
|
|
232
|
+
"knowledge": "2024-04",
|
|
233
|
+
"cost": {"input": 15.0, "output": 75.0},
|
|
234
|
+
"limit": {"context": 200000, "output": 4096},
|
|
235
|
+
},
|
|
236
|
+
"claude-3-sonnet-20240229": {
|
|
237
|
+
"name": "Claude 3 Sonnet",
|
|
238
|
+
"attachment": True,
|
|
239
|
+
"reasoning": True,
|
|
240
|
+
"tool_call": True,
|
|
241
|
+
"temperature": True,
|
|
242
|
+
"knowledge": "2024-04",
|
|
243
|
+
"cost": {"input": 3.0, "output": 15.0},
|
|
244
|
+
"limit": {"context": 200000, "output": 4096},
|
|
245
|
+
},
|
|
246
|
+
"claude-3-haiku-20240307": {
|
|
247
|
+
"name": "Claude 3 Haiku",
|
|
248
|
+
"attachment": True,
|
|
249
|
+
"reasoning": False,
|
|
250
|
+
"tool_call": True,
|
|
251
|
+
"temperature": True,
|
|
252
|
+
"knowledge": "2024-04",
|
|
253
|
+
"cost": {"input": 0.25, "output": 1.25},
|
|
254
|
+
"limit": {"context": 200000, "output": 4096},
|
|
255
|
+
},
|
|
256
|
+
},
|
|
257
|
+
},
|
|
258
|
+
"google": {
|
|
259
|
+
"name": "Google",
|
|
260
|
+
"env": ["GOOGLE_API_KEY"],
|
|
261
|
+
"npm": "@ai-sdk/google",
|
|
262
|
+
"doc": "https://ai.google.dev",
|
|
263
|
+
"models": {
|
|
264
|
+
"gemini-1.5-pro": {
|
|
265
|
+
"name": "Gemini 1.5 Pro",
|
|
266
|
+
"attachment": True,
|
|
267
|
+
"reasoning": True,
|
|
268
|
+
"tool_call": True,
|
|
269
|
+
"temperature": True,
|
|
270
|
+
"knowledge": "2024-04",
|
|
271
|
+
"cost": {"input": 3.5, "output": 10.5},
|
|
272
|
+
"limit": {"context": 2000000, "output": 8192},
|
|
273
|
+
},
|
|
274
|
+
"gemini-1.5-flash": {
|
|
275
|
+
"name": "Gemini 1.5 Flash",
|
|
276
|
+
"attachment": True,
|
|
277
|
+
"reasoning": True,
|
|
278
|
+
"tool_call": True,
|
|
279
|
+
"temperature": True,
|
|
280
|
+
"knowledge": "2024-04",
|
|
281
|
+
"cost": {"input": 0.075, "output": 0.3},
|
|
282
|
+
"limit": {"context": 1000000, "output": 8192},
|
|
283
|
+
},
|
|
284
|
+
},
|
|
285
|
+
},
|
|
286
|
+
}
|
|
287
|
+
|
|
288
|
+
self._parse_data(fallback_data)
|
|
289
|
+
|
|
290
|
+
def _parse_data(self, data: Dict[str, Any]) -> None:
|
|
291
|
+
"""Parse models data from API response."""
|
|
292
|
+
self.models.clear()
|
|
293
|
+
self.providers.clear()
|
|
294
|
+
|
|
295
|
+
for provider_id, provider_data in data.items():
|
|
296
|
+
# Skip non-provider keys
|
|
297
|
+
if not isinstance(provider_data, dict) or "models" not in provider_data:
|
|
298
|
+
continue
|
|
299
|
+
|
|
300
|
+
# Parse provider info
|
|
301
|
+
provider = ProviderInfo(
|
|
302
|
+
id=provider_id,
|
|
303
|
+
name=provider_data.get("name", provider_id),
|
|
304
|
+
env=provider_data.get("env", []),
|
|
305
|
+
npm=provider_data.get("npm"),
|
|
306
|
+
doc=provider_data.get("doc"),
|
|
307
|
+
)
|
|
308
|
+
self.providers[provider_id] = provider
|
|
309
|
+
|
|
310
|
+
# Parse models
|
|
311
|
+
models_data = provider_data.get("models", {})
|
|
312
|
+
for model_id, model_data in models_data.items():
|
|
313
|
+
if not isinstance(model_data, dict):
|
|
314
|
+
continue
|
|
315
|
+
|
|
316
|
+
# Parse capabilities
|
|
317
|
+
capabilities = ModelCapabilities(
|
|
318
|
+
attachment=model_data.get("attachment", False),
|
|
319
|
+
reasoning=model_data.get("reasoning", False),
|
|
320
|
+
tool_call=model_data.get("tool_call", False),
|
|
321
|
+
temperature=model_data.get("temperature", True),
|
|
322
|
+
knowledge=model_data.get("knowledge"),
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
# Parse cost
|
|
326
|
+
cost_data = model_data.get("cost", {})
|
|
327
|
+
cost = ModelCost(
|
|
328
|
+
input=cost_data.get("input") if isinstance(cost_data, dict) else None,
|
|
329
|
+
output=cost_data.get("output") if isinstance(cost_data, dict) else None,
|
|
330
|
+
cache=cost_data.get("cache") if isinstance(cost_data, dict) else None,
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
# Parse limits
|
|
334
|
+
limit_data = model_data.get("limit", {})
|
|
335
|
+
limits = ModelLimits(
|
|
336
|
+
context=limit_data.get("context") if isinstance(limit_data, dict) else None,
|
|
337
|
+
output=limit_data.get("output") if isinstance(limit_data, dict) else None,
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
# Create model info
|
|
341
|
+
model = ModelInfo(
|
|
342
|
+
id=model_id,
|
|
343
|
+
name=model_data.get("name", model_id),
|
|
344
|
+
provider=provider_id,
|
|
345
|
+
capabilities=capabilities,
|
|
346
|
+
cost=cost,
|
|
347
|
+
limits=limits,
|
|
348
|
+
release_date=model_data.get("release_date"),
|
|
349
|
+
last_updated=model_data.get("last_updated"),
|
|
350
|
+
open_weights=model_data.get("open_weights", False),
|
|
351
|
+
modalities=model_data.get("modalities", {}),
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
# Store with full ID as key
|
|
355
|
+
self.models[model.full_id] = model
|
|
356
|
+
|
|
357
|
+
async def load(self, force_refresh: bool = False) -> bool:
|
|
358
|
+
"""Load models data, using cache if available."""
|
|
359
|
+
if self._loaded and not force_refresh:
|
|
360
|
+
return True
|
|
361
|
+
|
|
362
|
+
# Try cache first
|
|
363
|
+
if not force_refresh and self._is_cache_valid():
|
|
364
|
+
if self._load_from_cache():
|
|
365
|
+
self._loaded = True
|
|
366
|
+
return True
|
|
367
|
+
|
|
368
|
+
# Fetch from API
|
|
369
|
+
if self._fetch_from_api():
|
|
370
|
+
self._loaded = True
|
|
371
|
+
return True
|
|
372
|
+
|
|
373
|
+
# Try cache as fallback even if expired
|
|
374
|
+
if self._load_from_cache():
|
|
375
|
+
self._loaded = True
|
|
376
|
+
# Import ui locally to avoid circular imports
|
|
377
|
+
from ..ui import console as ui
|
|
378
|
+
|
|
379
|
+
await ui.warning("Using cached models data (API unavailable)")
|
|
380
|
+
return True
|
|
381
|
+
|
|
382
|
+
# Use fallback models as last resort
|
|
383
|
+
from ..ui import console as ui
|
|
384
|
+
|
|
385
|
+
await ui.warning("models.dev API unavailable, using fallback model list")
|
|
386
|
+
self._load_fallback_models()
|
|
387
|
+
self._loaded = True
|
|
388
|
+
return True
|
|
389
|
+
|
|
390
|
+
def get_model(self, model_id: str) -> Optional[ModelInfo]:
|
|
391
|
+
"""Get a specific model by ID."""
|
|
392
|
+
# Try exact match first
|
|
393
|
+
if model_id in self.models:
|
|
394
|
+
return self.models[model_id]
|
|
395
|
+
|
|
396
|
+
# Try without provider prefix
|
|
397
|
+
for full_id, model in self.models.items():
|
|
398
|
+
if model.id == model_id:
|
|
399
|
+
return model
|
|
400
|
+
|
|
401
|
+
return None
|
|
402
|
+
|
|
403
|
+
def validate_model(self, model_id: str) -> bool:
|
|
404
|
+
"""Check if a model ID is valid."""
|
|
405
|
+
return self.get_model(model_id) is not None
|
|
406
|
+
|
|
407
|
+
def search_models(
|
|
408
|
+
self, query: str = "", provider: Optional[str] = None, min_score: float = 0.3
|
|
409
|
+
) -> List[ModelInfo]:
|
|
410
|
+
"""Search for models matching query."""
|
|
411
|
+
results = []
|
|
412
|
+
|
|
413
|
+
for model in self.models.values():
|
|
414
|
+
# Filter by provider if specified
|
|
415
|
+
if provider and model.provider != provider:
|
|
416
|
+
continue
|
|
417
|
+
|
|
418
|
+
# Calculate match score
|
|
419
|
+
if query:
|
|
420
|
+
score = model.matches_search(query)
|
|
421
|
+
if score < min_score:
|
|
422
|
+
continue
|
|
423
|
+
results.append((score, model))
|
|
424
|
+
else:
|
|
425
|
+
# No query, include all
|
|
426
|
+
results.append((1.0, model))
|
|
427
|
+
|
|
428
|
+
# Sort by score (descending) and name
|
|
429
|
+
results.sort(key=lambda x: (-x[0], x[1].name))
|
|
430
|
+
|
|
431
|
+
return [model for _, model in results]
|
|
432
|
+
|
|
433
|
+
def get_providers(self) -> List[ProviderInfo]:
|
|
434
|
+
"""Get list of all providers."""
|
|
435
|
+
return sorted(self.providers.values(), key=lambda p: p.name)
|
|
436
|
+
|
|
437
|
+
def get_models_by_provider(self, provider: str) -> List[ModelInfo]:
|
|
438
|
+
"""Get all models for a specific provider."""
|
|
439
|
+
return [m for m in self.models.values() if m.provider == provider]
|
|
440
|
+
|
|
441
|
+
def _extract_base_model_name(self, model: ModelInfo) -> str:
|
|
442
|
+
"""Extract the base model name from a model (e.g., 'gpt-4o' from 'openai:gpt-4o')."""
|
|
443
|
+
model_id = model.id.lower()
|
|
444
|
+
|
|
445
|
+
# Handle common patterns
|
|
446
|
+
base_name = model_id
|
|
447
|
+
|
|
448
|
+
# Remove common suffixes
|
|
449
|
+
suffixes_to_remove = [
|
|
450
|
+
"-latest",
|
|
451
|
+
"-preview",
|
|
452
|
+
"-turbo",
|
|
453
|
+
"-instruct",
|
|
454
|
+
"-chat",
|
|
455
|
+
"-base",
|
|
456
|
+
"-20240229",
|
|
457
|
+
"-20240307",
|
|
458
|
+
"-20240620",
|
|
459
|
+
"-20241022",
|
|
460
|
+
"-20250514",
|
|
461
|
+
"-0613",
|
|
462
|
+
"-0125",
|
|
463
|
+
"-0301",
|
|
464
|
+
"-1106",
|
|
465
|
+
"-2024",
|
|
466
|
+
"-2025",
|
|
467
|
+
]
|
|
468
|
+
|
|
469
|
+
for suffix in suffixes_to_remove:
|
|
470
|
+
if base_name.endswith(suffix):
|
|
471
|
+
base_name = base_name[: -len(suffix)]
|
|
472
|
+
|
|
473
|
+
# Handle versioned models (e.g., 'claude-3-5-sonnet' -> 'claude-3-sonnet')
|
|
474
|
+
if "claude-3-5" in base_name:
|
|
475
|
+
base_name = base_name.replace("claude-3-5", "claude-3")
|
|
476
|
+
elif "claude-3-7" in base_name:
|
|
477
|
+
base_name = base_name.replace("claude-3-7", "claude-3")
|
|
478
|
+
|
|
479
|
+
# Handle OpenRouter nested paths (e.g., 'openai/gpt-4o' -> 'gpt-4o')
|
|
480
|
+
if "/" in base_name:
|
|
481
|
+
base_name = base_name.split("/")[-1]
|
|
482
|
+
|
|
483
|
+
return base_name
|
|
484
|
+
|
|
485
|
+
def get_model_variants(self, base_model_name: str) -> List[ModelInfo]:
|
|
486
|
+
"""Get all variants of a base model across different providers."""
|
|
487
|
+
base_name = base_model_name.lower()
|
|
488
|
+
variants = []
|
|
489
|
+
|
|
490
|
+
for model in self.models.values():
|
|
491
|
+
model_base = self._extract_base_model_name(model)
|
|
492
|
+
if model_base == base_name or base_name in model_base:
|
|
493
|
+
variants.append(model)
|
|
494
|
+
|
|
495
|
+
# Sort by cost (free first, then ascending cost)
|
|
496
|
+
def sort_key(model: ModelInfo) -> tuple:
|
|
497
|
+
cost = model.cost.input or 999
|
|
498
|
+
is_free = cost == 0
|
|
499
|
+
return (not is_free, cost, model.provider, model.id)
|
|
500
|
+
|
|
501
|
+
variants.sort(key=sort_key)
|
|
502
|
+
return variants
|
|
503
|
+
|
|
504
|
+
def find_base_models(self, query: str) -> Dict[str, List[ModelInfo]]:
|
|
505
|
+
"""Find base models and group their variants by routing source."""
|
|
506
|
+
query_lower = query.lower()
|
|
507
|
+
base_models: Dict[str, List[ModelInfo]] = {}
|
|
508
|
+
|
|
509
|
+
# Find all models matching the query
|
|
510
|
+
matching_models = []
|
|
511
|
+
for model in self.models.values():
|
|
512
|
+
base_name = self._extract_base_model_name(model)
|
|
513
|
+
if (
|
|
514
|
+
query_lower in model.id.lower()
|
|
515
|
+
or query_lower in model.name.lower()
|
|
516
|
+
or query_lower in base_name
|
|
517
|
+
or query_lower in model.provider.lower()
|
|
518
|
+
):
|
|
519
|
+
matching_models.append((base_name, model))
|
|
520
|
+
|
|
521
|
+
# Group by base model name
|
|
522
|
+
for base_name, model in matching_models:
|
|
523
|
+
if base_name not in base_models:
|
|
524
|
+
base_models[base_name] = []
|
|
525
|
+
base_models[base_name].append(model)
|
|
526
|
+
|
|
527
|
+
# Sort variants within each base model
|
|
528
|
+
for base_name in base_models:
|
|
529
|
+
|
|
530
|
+
def sort_key(model: ModelInfo) -> tuple:
|
|
531
|
+
cost = model.cost.input or 999
|
|
532
|
+
is_free = cost == 0
|
|
533
|
+
return (not is_free, cost, model.provider, model.id)
|
|
534
|
+
|
|
535
|
+
base_models[base_name].sort(key=sort_key)
|
|
536
|
+
|
|
537
|
+
return base_models
|
|
538
|
+
|
|
539
|
+
def get_popular_base_models(self) -> List[str]:
|
|
540
|
+
"""Get list of popular base model names for suggestions."""
|
|
541
|
+
popular_patterns = [
|
|
542
|
+
"gpt-4o",
|
|
543
|
+
"gpt-4",
|
|
544
|
+
"gpt-3.5-turbo",
|
|
545
|
+
"claude-3-opus",
|
|
546
|
+
"claude-3-sonnet",
|
|
547
|
+
"claude-3-haiku",
|
|
548
|
+
"gemini-2",
|
|
549
|
+
"gemini-1.5-pro",
|
|
550
|
+
"gemini-1.5-flash",
|
|
551
|
+
"o1-preview",
|
|
552
|
+
"o1-mini",
|
|
553
|
+
"o3",
|
|
554
|
+
"o3-mini",
|
|
555
|
+
]
|
|
556
|
+
|
|
557
|
+
available_base_models = []
|
|
558
|
+
for pattern in popular_patterns:
|
|
559
|
+
variants = self.get_model_variants(pattern)
|
|
560
|
+
if variants:
|
|
561
|
+
available_base_models.append(pattern)
|
|
562
|
+
|
|
563
|
+
return available_base_models
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: tunacode-cli
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.72
|
|
4
4
|
Summary: Your agentic CLI developer.
|
|
5
5
|
Project-URL: Homepage, https://tunacode.xyz/
|
|
6
6
|
Project-URL: Repository, https://github.com/alchemiststudiosDOTai/tunacode
|
|
@@ -100,7 +100,7 @@ See the [Hatch Build System Guide](documentation/development/hatch-build-system.
|
|
|
100
100
|
|
|
101
101
|
## Configuration
|
|
102
102
|
|
|
103
|
-
Choose your AI provider and set your API key. For more details, see the [Configuration Section](documentation/user/getting-started.md#2-configuration) in the Getting Started Guide.
|
|
103
|
+
Choose your AI provider and set your API key. For more details, see the [Configuration Section](documentation/user/getting-started.md#2-configuration) in the Getting Started Guide. For local models (LM Studio, Ollama, etc.), see the [Local Models Setup Guide](documentation/configuration/local-models.md).
|
|
104
104
|
|
|
105
105
|
### Recommended Models
|
|
106
106
|
|