tunacode-cli 0.0.55__py3-none-any.whl → 0.0.78.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tunacode-cli might be problematic. Click here for more details.
- tunacode/cli/commands/__init__.py +2 -2
- tunacode/cli/commands/implementations/__init__.py +2 -3
- tunacode/cli/commands/implementations/command_reload.py +48 -0
- tunacode/cli/commands/implementations/debug.py +2 -2
- tunacode/cli/commands/implementations/development.py +10 -8
- tunacode/cli/commands/implementations/model.py +357 -29
- tunacode/cli/commands/implementations/quickstart.py +43 -0
- tunacode/cli/commands/implementations/system.py +96 -3
- tunacode/cli/commands/implementations/template.py +0 -2
- tunacode/cli/commands/registry.py +139 -5
- tunacode/cli/commands/slash/__init__.py +32 -0
- tunacode/cli/commands/slash/command.py +157 -0
- tunacode/cli/commands/slash/loader.py +135 -0
- tunacode/cli/commands/slash/processor.py +294 -0
- tunacode/cli/commands/slash/types.py +93 -0
- tunacode/cli/commands/slash/validator.py +400 -0
- tunacode/cli/main.py +23 -2
- tunacode/cli/repl.py +217 -190
- tunacode/cli/repl_components/command_parser.py +38 -4
- tunacode/cli/repl_components/error_recovery.py +85 -4
- tunacode/cli/repl_components/output_display.py +12 -1
- tunacode/cli/repl_components/tool_executor.py +1 -1
- tunacode/configuration/defaults.py +12 -3
- tunacode/configuration/key_descriptions.py +284 -0
- tunacode/configuration/settings.py +0 -1
- tunacode/constants.py +12 -40
- tunacode/core/agents/__init__.py +43 -2
- tunacode/core/agents/agent_components/__init__.py +7 -0
- tunacode/core/agents/agent_components/agent_config.py +249 -55
- tunacode/core/agents/agent_components/agent_helpers.py +43 -13
- tunacode/core/agents/agent_components/node_processor.py +179 -139
- tunacode/core/agents/agent_components/response_state.py +123 -6
- tunacode/core/agents/agent_components/state_transition.py +116 -0
- tunacode/core/agents/agent_components/streaming.py +296 -0
- tunacode/core/agents/agent_components/task_completion.py +19 -6
- tunacode/core/agents/agent_components/tool_buffer.py +21 -1
- tunacode/core/agents/agent_components/tool_executor.py +10 -0
- tunacode/core/agents/main.py +522 -370
- tunacode/core/agents/main_legact.py +538 -0
- tunacode/core/agents/prompts.py +66 -0
- tunacode/core/agents/utils.py +29 -121
- tunacode/core/code_index.py +83 -29
- tunacode/core/setup/__init__.py +0 -2
- tunacode/core/setup/config_setup.py +110 -20
- tunacode/core/setup/config_wizard.py +230 -0
- tunacode/core/setup/coordinator.py +14 -5
- tunacode/core/state.py +16 -20
- tunacode/core/token_usage/usage_tracker.py +5 -3
- tunacode/core/tool_authorization.py +352 -0
- tunacode/core/tool_handler.py +67 -40
- tunacode/exceptions.py +119 -5
- tunacode/prompts/system.xml +751 -0
- tunacode/services/mcp.py +125 -7
- tunacode/setup.py +5 -25
- tunacode/tools/base.py +163 -0
- tunacode/tools/bash.py +110 -1
- tunacode/tools/glob.py +332 -34
- tunacode/tools/grep.py +179 -82
- tunacode/tools/grep_components/result_formatter.py +98 -4
- tunacode/tools/list_dir.py +132 -2
- tunacode/tools/prompts/bash_prompt.xml +72 -0
- tunacode/tools/prompts/glob_prompt.xml +45 -0
- tunacode/tools/prompts/grep_prompt.xml +98 -0
- tunacode/tools/prompts/list_dir_prompt.xml +31 -0
- tunacode/tools/prompts/react_prompt.xml +23 -0
- tunacode/tools/prompts/read_file_prompt.xml +54 -0
- tunacode/tools/prompts/run_command_prompt.xml +64 -0
- tunacode/tools/prompts/update_file_prompt.xml +53 -0
- tunacode/tools/prompts/write_file_prompt.xml +37 -0
- tunacode/tools/react.py +153 -0
- tunacode/tools/read_file.py +91 -0
- tunacode/tools/run_command.py +114 -0
- tunacode/tools/schema_assembler.py +167 -0
- tunacode/tools/update_file.py +94 -0
- tunacode/tools/write_file.py +86 -0
- tunacode/tools/xml_helper.py +83 -0
- tunacode/tutorial/__init__.py +9 -0
- tunacode/tutorial/content.py +98 -0
- tunacode/tutorial/manager.py +182 -0
- tunacode/tutorial/steps.py +124 -0
- tunacode/types.py +20 -27
- tunacode/ui/completers.py +434 -50
- tunacode/ui/config_dashboard.py +585 -0
- tunacode/ui/console.py +63 -11
- tunacode/ui/input.py +20 -3
- tunacode/ui/keybindings.py +7 -4
- tunacode/ui/model_selector.py +395 -0
- tunacode/ui/output.py +40 -19
- tunacode/ui/panels.py +212 -43
- tunacode/ui/path_heuristics.py +91 -0
- tunacode/ui/prompt_manager.py +5 -1
- tunacode/ui/tool_ui.py +33 -10
- tunacode/utils/api_key_validation.py +93 -0
- tunacode/utils/config_comparator.py +340 -0
- tunacode/utils/json_utils.py +206 -0
- tunacode/utils/message_utils.py +14 -4
- tunacode/utils/models_registry.py +593 -0
- tunacode/utils/ripgrep.py +332 -9
- tunacode/utils/text_utils.py +18 -1
- tunacode/utils/user_configuration.py +45 -0
- tunacode_cli-0.0.78.6.dist-info/METADATA +260 -0
- tunacode_cli-0.0.78.6.dist-info/RECORD +158 -0
- {tunacode_cli-0.0.55.dist-info → tunacode_cli-0.0.78.6.dist-info}/WHEEL +1 -2
- tunacode/cli/commands/implementations/todo.py +0 -217
- tunacode/context.py +0 -71
- tunacode/core/setup/git_safety_setup.py +0 -182
- tunacode/prompts/system.md +0 -731
- tunacode/tools/read_file_async_poc.py +0 -196
- tunacode/tools/todo.py +0 -349
- tunacode_cli-0.0.55.dist-info/METADATA +0 -322
- tunacode_cli-0.0.55.dist-info/RECORD +0 -126
- tunacode_cli-0.0.55.dist-info/top_level.txt +0 -1
- {tunacode_cli-0.0.55.dist-info → tunacode_cli-0.0.78.6.dist-info}/entry_points.txt +0 -0
- {tunacode_cli-0.0.55.dist-info → tunacode_cli-0.0.78.6.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,593 @@
|
|
|
1
|
+
"""Models.dev integration for model discovery and validation."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from datetime import datetime, timedelta
|
|
5
|
+
from difflib import SequenceMatcher
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Dict, List, Optional
|
|
8
|
+
from urllib.error import URLError
|
|
9
|
+
from urllib.request import urlopen
|
|
10
|
+
|
|
11
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ModelCapabilities(BaseModel):
|
|
15
|
+
"""Model capabilities and features."""
|
|
16
|
+
|
|
17
|
+
model_config = ConfigDict(extra="ignore")
|
|
18
|
+
|
|
19
|
+
attachment: bool = False
|
|
20
|
+
reasoning: bool = False
|
|
21
|
+
tool_call: bool = False
|
|
22
|
+
temperature: bool = True
|
|
23
|
+
knowledge: Optional[str] = None
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ModelCost(BaseModel):
|
|
27
|
+
"""Model pricing information."""
|
|
28
|
+
|
|
29
|
+
model_config = ConfigDict(extra="ignore")
|
|
30
|
+
|
|
31
|
+
input: Optional[float] = None
|
|
32
|
+
output: Optional[float] = None
|
|
33
|
+
cache: Optional[float] = None
|
|
34
|
+
|
|
35
|
+
@field_validator("input", "output", "cache")
|
|
36
|
+
@classmethod
|
|
37
|
+
def _non_negative(cls, v: Optional[float]) -> Optional[float]:
|
|
38
|
+
if v is None:
|
|
39
|
+
return v
|
|
40
|
+
if v < 0:
|
|
41
|
+
raise ValueError("cost values must be non-negative")
|
|
42
|
+
return float(v)
|
|
43
|
+
|
|
44
|
+
def format_cost(self) -> str:
|
|
45
|
+
"""Format cost as a readable string."""
|
|
46
|
+
if self.input is None or self.output is None:
|
|
47
|
+
return "Pricing not available"
|
|
48
|
+
return f"${self.input}/{self.output} per 1M tokens"
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class ModelLimits(BaseModel):
|
|
52
|
+
"""Model context and output limits."""
|
|
53
|
+
|
|
54
|
+
model_config = ConfigDict(extra="ignore")
|
|
55
|
+
|
|
56
|
+
context: Optional[int] = None
|
|
57
|
+
output: Optional[int] = None
|
|
58
|
+
|
|
59
|
+
@field_validator("context", "output")
|
|
60
|
+
@classmethod
|
|
61
|
+
def _positive_int(cls, v: Optional[int]) -> Optional[int]:
|
|
62
|
+
if v is None:
|
|
63
|
+
return v
|
|
64
|
+
iv = int(v)
|
|
65
|
+
if iv < 0:
|
|
66
|
+
raise ValueError("limits must be non-negative integers")
|
|
67
|
+
return iv if iv > 0 else None
|
|
68
|
+
|
|
69
|
+
def format_limits(self) -> str:
|
|
70
|
+
"""Format limits as a readable string."""
|
|
71
|
+
parts: List[str] = []
|
|
72
|
+
if self.context:
|
|
73
|
+
parts.append(f"{self.context:,} context")
|
|
74
|
+
if self.output:
|
|
75
|
+
parts.append(f"{self.output:,} output")
|
|
76
|
+
return ", ".join(parts) if parts else "Limits not specified"
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class ModelInfo(BaseModel):
|
|
80
|
+
"""Complete model information."""
|
|
81
|
+
|
|
82
|
+
model_config = ConfigDict(extra="ignore")
|
|
83
|
+
|
|
84
|
+
id: str
|
|
85
|
+
name: str
|
|
86
|
+
provider: str
|
|
87
|
+
capabilities: ModelCapabilities = Field(default_factory=ModelCapabilities)
|
|
88
|
+
cost: ModelCost = Field(default_factory=ModelCost)
|
|
89
|
+
limits: ModelLimits = Field(default_factory=ModelLimits)
|
|
90
|
+
release_date: Optional[str] = None
|
|
91
|
+
last_updated: Optional[str] = None
|
|
92
|
+
open_weights: bool = False
|
|
93
|
+
modalities: Dict[str, List[str]] = Field(default_factory=dict)
|
|
94
|
+
|
|
95
|
+
@property
|
|
96
|
+
def full_id(self) -> str:
|
|
97
|
+
"""Get the full model identifier with provider prefix."""
|
|
98
|
+
return f"{self.provider}:{self.id}"
|
|
99
|
+
|
|
100
|
+
def format_display(self, include_details: bool = True) -> str:
|
|
101
|
+
"""Format model for display."""
|
|
102
|
+
display = f"{self.full_id} - {self.name}"
|
|
103
|
+
if include_details:
|
|
104
|
+
details: List[str] = []
|
|
105
|
+
if self.cost.input is not None:
|
|
106
|
+
details.append(self.cost.format_cost())
|
|
107
|
+
if self.limits.context:
|
|
108
|
+
details.append(f"{self.limits.context // 1000}k context")
|
|
109
|
+
if details:
|
|
110
|
+
display += f" ({', '.join(details)})"
|
|
111
|
+
return display
|
|
112
|
+
|
|
113
|
+
# we need to make this lighter for future dev, low priority
|
|
114
|
+
def matches_search(self, query: str) -> float:
|
|
115
|
+
"""Calculate match score for search query (0-1)."""
|
|
116
|
+
query_lower = query.lower()
|
|
117
|
+
|
|
118
|
+
# Exact match in ID or name
|
|
119
|
+
if query_lower in self.id.lower():
|
|
120
|
+
return 1.0
|
|
121
|
+
if query_lower in self.name.lower():
|
|
122
|
+
return 0.9
|
|
123
|
+
if query_lower in self.provider.lower():
|
|
124
|
+
return 0.8
|
|
125
|
+
|
|
126
|
+
# Fuzzy match
|
|
127
|
+
best_ratio = 0.0
|
|
128
|
+
for field_value in [self.id, self.name, self.provider]:
|
|
129
|
+
ratio = SequenceMatcher(None, query_lower, field_value.lower()).ratio()
|
|
130
|
+
best_ratio = max(best_ratio, ratio)
|
|
131
|
+
|
|
132
|
+
return best_ratio
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class ProviderInfo(BaseModel):
|
|
136
|
+
"""Provider information."""
|
|
137
|
+
|
|
138
|
+
model_config = ConfigDict(extra="ignore")
|
|
139
|
+
|
|
140
|
+
id: str
|
|
141
|
+
name: str
|
|
142
|
+
env: List[str] = Field(default_factory=list)
|
|
143
|
+
npm: Optional[str] = None
|
|
144
|
+
doc: Optional[str] = None
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class ModelsRegistry:
|
|
148
|
+
"""Registry for managing models from models.dev."""
|
|
149
|
+
|
|
150
|
+
API_URL = "https://models.dev/api.json"
|
|
151
|
+
CACHE_FILE = "models_cache.json"
|
|
152
|
+
CACHE_TTL = timedelta(hours=24)
|
|
153
|
+
|
|
154
|
+
def __init__(self, cache_dir: Optional[Path] = None):
|
|
155
|
+
"""Initialize the models registry."""
|
|
156
|
+
if cache_dir is None:
|
|
157
|
+
cache_dir = Path.home() / ".tunacode" / "cache"
|
|
158
|
+
self.cache_dir = cache_dir
|
|
159
|
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
160
|
+
self.cache_file = self.cache_dir / self.CACHE_FILE
|
|
161
|
+
|
|
162
|
+
self.models: Dict[str, ModelInfo] = {}
|
|
163
|
+
self.providers: Dict[str, ProviderInfo] = {}
|
|
164
|
+
self._loaded = False
|
|
165
|
+
|
|
166
|
+
def _is_cache_valid(self) -> bool:
|
|
167
|
+
"""Check if cache file exists and is still valid."""
|
|
168
|
+
if not self.cache_file.exists():
|
|
169
|
+
return False
|
|
170
|
+
|
|
171
|
+
# Check cache age
|
|
172
|
+
cache_age = datetime.now() - datetime.fromtimestamp(self.cache_file.stat().st_mtime)
|
|
173
|
+
return cache_age < self.CACHE_TTL
|
|
174
|
+
|
|
175
|
+
def _load_from_cache(self) -> bool:
|
|
176
|
+
"""Load models from cache file."""
|
|
177
|
+
try:
|
|
178
|
+
with open(self.cache_file, "r") as f:
|
|
179
|
+
data = json.load(f)
|
|
180
|
+
self._parse_data(data)
|
|
181
|
+
return True
|
|
182
|
+
except (json.JSONDecodeError, OSError, KeyError):
|
|
183
|
+
return False
|
|
184
|
+
|
|
185
|
+
def _fetch_from_api(self) -> bool:
|
|
186
|
+
"""Fetch models from models.dev API."""
|
|
187
|
+
try:
|
|
188
|
+
# Add User-Agent header to avoid blocking
|
|
189
|
+
import urllib.request
|
|
190
|
+
|
|
191
|
+
req = urllib.request.Request(self.API_URL, headers={"User-Agent": "TunaCode-CLI/1.0"})
|
|
192
|
+
with urlopen(req, timeout=10) as response: # nosec B310 - Using trusted models.dev API
|
|
193
|
+
data = json.loads(response.read())
|
|
194
|
+
|
|
195
|
+
# Save to cache
|
|
196
|
+
with open(self.cache_file, "w") as f:
|
|
197
|
+
json.dump(data, f, indent=2)
|
|
198
|
+
|
|
199
|
+
self._parse_data(data)
|
|
200
|
+
return True
|
|
201
|
+
except (URLError, json.JSONDecodeError, OSError):
|
|
202
|
+
# Log error but don't fail
|
|
203
|
+
return False
|
|
204
|
+
|
|
205
|
+
def _load_fallback_models(self) -> None:
|
|
206
|
+
"""Load hardcoded popular models as fallback."""
|
|
207
|
+
fallback_data = {
|
|
208
|
+
"openai": {
|
|
209
|
+
"name": "OpenAI",
|
|
210
|
+
"env": ["OPENAI_API_KEY"],
|
|
211
|
+
"npm": "@ai-sdk/openai",
|
|
212
|
+
"doc": "https://platform.openai.com/docs",
|
|
213
|
+
"models": {
|
|
214
|
+
"gpt-4": {
|
|
215
|
+
"name": "GPT-4",
|
|
216
|
+
"attachment": True,
|
|
217
|
+
"reasoning": True,
|
|
218
|
+
"tool_call": True,
|
|
219
|
+
"temperature": True,
|
|
220
|
+
"knowledge": "2024-04",
|
|
221
|
+
"cost": {"input": 30.0, "output": 60.0},
|
|
222
|
+
"limit": {"context": 128000, "output": 4096},
|
|
223
|
+
},
|
|
224
|
+
"gpt-4-turbo": {
|
|
225
|
+
"name": "GPT-4 Turbo",
|
|
226
|
+
"attachment": True,
|
|
227
|
+
"reasoning": True,
|
|
228
|
+
"tool_call": True,
|
|
229
|
+
"temperature": True,
|
|
230
|
+
"knowledge": "2024-04",
|
|
231
|
+
"cost": {"input": 10.0, "output": 30.0},
|
|
232
|
+
"limit": {"context": 128000, "output": 4096},
|
|
233
|
+
},
|
|
234
|
+
"gpt-3.5-turbo": {
|
|
235
|
+
"name": "GPT-3.5 Turbo",
|
|
236
|
+
"attachment": False,
|
|
237
|
+
"reasoning": False,
|
|
238
|
+
"tool_call": True,
|
|
239
|
+
"temperature": True,
|
|
240
|
+
"knowledge": "2024-01",
|
|
241
|
+
"cost": {"input": 0.5, "output": 1.5},
|
|
242
|
+
"limit": {"context": 16000, "output": 4096},
|
|
243
|
+
},
|
|
244
|
+
},
|
|
245
|
+
},
|
|
246
|
+
"anthropic": {
|
|
247
|
+
"name": "Anthropic",
|
|
248
|
+
"env": ["ANTHROPIC_API_KEY"],
|
|
249
|
+
"npm": "@ai-sdk/anthropic",
|
|
250
|
+
"doc": "https://docs.anthropic.com",
|
|
251
|
+
"models": {
|
|
252
|
+
"claude-3-opus-20240229": {
|
|
253
|
+
"name": "Claude 3 Opus",
|
|
254
|
+
"attachment": True,
|
|
255
|
+
"reasoning": True,
|
|
256
|
+
"tool_call": True,
|
|
257
|
+
"temperature": True,
|
|
258
|
+
"knowledge": "2024-04",
|
|
259
|
+
"cost": {"input": 15.0, "output": 75.0},
|
|
260
|
+
"limit": {"context": 200000, "output": 4096},
|
|
261
|
+
},
|
|
262
|
+
"claude-3-sonnet-20240229": {
|
|
263
|
+
"name": "Claude 3 Sonnet",
|
|
264
|
+
"attachment": True,
|
|
265
|
+
"reasoning": True,
|
|
266
|
+
"tool_call": True,
|
|
267
|
+
"temperature": True,
|
|
268
|
+
"knowledge": "2024-04",
|
|
269
|
+
"cost": {"input": 3.0, "output": 15.0},
|
|
270
|
+
"limit": {"context": 200000, "output": 4096},
|
|
271
|
+
},
|
|
272
|
+
"claude-3-haiku-20240307": {
|
|
273
|
+
"name": "Claude 3 Haiku",
|
|
274
|
+
"attachment": True,
|
|
275
|
+
"reasoning": False,
|
|
276
|
+
"tool_call": True,
|
|
277
|
+
"temperature": True,
|
|
278
|
+
"knowledge": "2024-04",
|
|
279
|
+
"cost": {"input": 0.25, "output": 1.25},
|
|
280
|
+
"limit": {"context": 200000, "output": 4096},
|
|
281
|
+
},
|
|
282
|
+
},
|
|
283
|
+
},
|
|
284
|
+
"google": {
|
|
285
|
+
"name": "Google",
|
|
286
|
+
"env": ["GOOGLE_API_KEY"],
|
|
287
|
+
"npm": "@ai-sdk/google",
|
|
288
|
+
"doc": "https://ai.google.dev",
|
|
289
|
+
"models": {
|
|
290
|
+
"gemini-1.5-pro": {
|
|
291
|
+
"name": "Gemini 1.5 Pro",
|
|
292
|
+
"attachment": True,
|
|
293
|
+
"reasoning": True,
|
|
294
|
+
"tool_call": True,
|
|
295
|
+
"temperature": True,
|
|
296
|
+
"knowledge": "2024-04",
|
|
297
|
+
"cost": {"input": 3.5, "output": 10.5},
|
|
298
|
+
"limit": {"context": 2000000, "output": 8192},
|
|
299
|
+
},
|
|
300
|
+
"gemini-1.5-flash": {
|
|
301
|
+
"name": "Gemini 1.5 Flash",
|
|
302
|
+
"attachment": True,
|
|
303
|
+
"reasoning": True,
|
|
304
|
+
"tool_call": True,
|
|
305
|
+
"temperature": True,
|
|
306
|
+
"knowledge": "2024-04",
|
|
307
|
+
"cost": {"input": 0.075, "output": 0.3},
|
|
308
|
+
"limit": {"context": 1000000, "output": 8192},
|
|
309
|
+
},
|
|
310
|
+
},
|
|
311
|
+
},
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
self._parse_data(fallback_data)
|
|
315
|
+
|
|
316
|
+
def _parse_data(self, data: Dict[str, Any]) -> None:
|
|
317
|
+
"""Parse models data from API response."""
|
|
318
|
+
self.models.clear()
|
|
319
|
+
self.providers.clear()
|
|
320
|
+
|
|
321
|
+
for provider_id, provider_data in data.items():
|
|
322
|
+
# Skip non-provider keys
|
|
323
|
+
if not isinstance(provider_data, dict) or "models" not in provider_data:
|
|
324
|
+
continue
|
|
325
|
+
|
|
326
|
+
# Parse provider info
|
|
327
|
+
provider = ProviderInfo(
|
|
328
|
+
id=provider_id,
|
|
329
|
+
name=provider_data.get("name", provider_id),
|
|
330
|
+
env=provider_data.get("env", []),
|
|
331
|
+
npm=provider_data.get("npm"),
|
|
332
|
+
doc=provider_data.get("doc"),
|
|
333
|
+
)
|
|
334
|
+
self.providers[provider_id] = provider
|
|
335
|
+
|
|
336
|
+
# Parse models
|
|
337
|
+
models_data = provider_data.get("models", {})
|
|
338
|
+
for model_id, model_data in models_data.items():
|
|
339
|
+
if not isinstance(model_data, dict):
|
|
340
|
+
continue
|
|
341
|
+
|
|
342
|
+
# Parse capabilities
|
|
343
|
+
capabilities = ModelCapabilities(
|
|
344
|
+
attachment=bool(model_data.get("attachment", False)),
|
|
345
|
+
reasoning=bool(model_data.get("reasoning", False)),
|
|
346
|
+
tool_call=bool(model_data.get("tool_call", False)),
|
|
347
|
+
temperature=bool(model_data.get("temperature", True)),
|
|
348
|
+
knowledge=model_data.get("knowledge"),
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
# Parse cost
|
|
352
|
+
cost_data = model_data.get("cost", {})
|
|
353
|
+
cost = ModelCost(
|
|
354
|
+
input=(cost_data.get("input") if isinstance(cost_data, dict) else None),
|
|
355
|
+
output=(cost_data.get("output") if isinstance(cost_data, dict) else None),
|
|
356
|
+
cache=(cost_data.get("cache") if isinstance(cost_data, dict) else None),
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
# Parse limits
|
|
360
|
+
limit_data = model_data.get("limit", {})
|
|
361
|
+
limits = ModelLimits(
|
|
362
|
+
context=(limit_data.get("context") if isinstance(limit_data, dict) else None),
|
|
363
|
+
output=(limit_data.get("output") if isinstance(limit_data, dict) else None),
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
# Create model info
|
|
367
|
+
model = ModelInfo(
|
|
368
|
+
id=model_id,
|
|
369
|
+
name=model_data.get("name", model_id),
|
|
370
|
+
provider=provider_id,
|
|
371
|
+
capabilities=capabilities,
|
|
372
|
+
cost=cost,
|
|
373
|
+
limits=limits,
|
|
374
|
+
release_date=model_data.get("release_date"),
|
|
375
|
+
last_updated=model_data.get("last_updated"),
|
|
376
|
+
open_weights=bool(model_data.get("open_weights", False)),
|
|
377
|
+
modalities=(
|
|
378
|
+
model_data.get("modalities", {})
|
|
379
|
+
if isinstance(model_data.get("modalities", {}), dict)
|
|
380
|
+
else {}
|
|
381
|
+
),
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
# Store with full ID as key
|
|
385
|
+
self.models[model.full_id] = model
|
|
386
|
+
|
|
387
|
+
async def load(self, force_refresh: bool = False) -> bool:
|
|
388
|
+
"""Load models data, using cache if available."""
|
|
389
|
+
if self._loaded and not force_refresh:
|
|
390
|
+
return True
|
|
391
|
+
|
|
392
|
+
# Try cache first
|
|
393
|
+
if not force_refresh and self._is_cache_valid():
|
|
394
|
+
if self._load_from_cache():
|
|
395
|
+
self._loaded = True
|
|
396
|
+
return True
|
|
397
|
+
|
|
398
|
+
# Fetch from API
|
|
399
|
+
if self._fetch_from_api():
|
|
400
|
+
self._loaded = True
|
|
401
|
+
return True
|
|
402
|
+
|
|
403
|
+
# Try cache as fallback even if expired
|
|
404
|
+
if self._load_from_cache():
|
|
405
|
+
self._loaded = True
|
|
406
|
+
# Import ui locally to avoid circular imports
|
|
407
|
+
from ..ui import console as ui
|
|
408
|
+
|
|
409
|
+
await ui.warning("Using cached models data (API unavailable)")
|
|
410
|
+
return True
|
|
411
|
+
|
|
412
|
+
# Use fallback models as last resort
|
|
413
|
+
from ..ui import console as ui
|
|
414
|
+
|
|
415
|
+
await ui.warning("models.dev API unavailable, using fallback model list")
|
|
416
|
+
self._load_fallback_models()
|
|
417
|
+
self._loaded = True
|
|
418
|
+
return True
|
|
419
|
+
|
|
420
|
+
def get_model(self, model_id: str) -> Optional[ModelInfo]:
|
|
421
|
+
"""Get a specific model by ID."""
|
|
422
|
+
# Try exact match first
|
|
423
|
+
if model_id in self.models:
|
|
424
|
+
return self.models[model_id]
|
|
425
|
+
|
|
426
|
+
# Try without provider prefix
|
|
427
|
+
for full_id, model in self.models.items():
|
|
428
|
+
if model.id == model_id:
|
|
429
|
+
return model
|
|
430
|
+
|
|
431
|
+
return None
|
|
432
|
+
|
|
433
|
+
def validate_model(self, model_id: str) -> bool:
|
|
434
|
+
"""Check if a model ID is valid."""
|
|
435
|
+
return self.get_model(model_id) is not None
|
|
436
|
+
|
|
437
|
+
def search_models(
|
|
438
|
+
self, query: str = "", provider: Optional[str] = None, min_score: float = 0.3
|
|
439
|
+
) -> List[ModelInfo]:
|
|
440
|
+
"""Search for models matching query."""
|
|
441
|
+
results = []
|
|
442
|
+
|
|
443
|
+
for model in self.models.values():
|
|
444
|
+
# Filter by provider if specified
|
|
445
|
+
if provider and model.provider != provider:
|
|
446
|
+
continue
|
|
447
|
+
|
|
448
|
+
# Calculate match score
|
|
449
|
+
if query:
|
|
450
|
+
score = model.matches_search(query)
|
|
451
|
+
if score < min_score:
|
|
452
|
+
continue
|
|
453
|
+
results.append((score, model))
|
|
454
|
+
else:
|
|
455
|
+
# No query, include all
|
|
456
|
+
results.append((1.0, model))
|
|
457
|
+
|
|
458
|
+
# Sort by score (descending) and name
|
|
459
|
+
results.sort(key=lambda x: (-x[0], x[1].name))
|
|
460
|
+
|
|
461
|
+
return [model for _, model in results]
|
|
462
|
+
|
|
463
|
+
def get_providers(self) -> List[ProviderInfo]:
|
|
464
|
+
"""Get list of all providers."""
|
|
465
|
+
return sorted(self.providers.values(), key=lambda p: p.name)
|
|
466
|
+
|
|
467
|
+
def get_models_by_provider(self, provider: str) -> List[ModelInfo]:
|
|
468
|
+
"""Get all models for a specific provider."""
|
|
469
|
+
return [m for m in self.models.values() if m.provider == provider]
|
|
470
|
+
|
|
471
|
+
def _extract_base_model_name(self, model: ModelInfo) -> str:
|
|
472
|
+
"""Extract the base model name from a model (e.g., 'gpt-4o' from 'openai:gpt-4o')."""
|
|
473
|
+
model_id = model.id.lower()
|
|
474
|
+
|
|
475
|
+
# Handle common patterns
|
|
476
|
+
base_name = model_id
|
|
477
|
+
|
|
478
|
+
# Remove common suffixes
|
|
479
|
+
suffixes_to_remove = [
|
|
480
|
+
"-latest",
|
|
481
|
+
"-preview",
|
|
482
|
+
"-turbo",
|
|
483
|
+
"-instruct",
|
|
484
|
+
"-chat",
|
|
485
|
+
"-base",
|
|
486
|
+
"-20240229",
|
|
487
|
+
"-20240307",
|
|
488
|
+
"-20240620",
|
|
489
|
+
"-20241022",
|
|
490
|
+
"-20250514",
|
|
491
|
+
"-0613",
|
|
492
|
+
"-0125",
|
|
493
|
+
"-0301",
|
|
494
|
+
"-1106",
|
|
495
|
+
"-2024",
|
|
496
|
+
"-2025",
|
|
497
|
+
]
|
|
498
|
+
|
|
499
|
+
for suffix in suffixes_to_remove:
|
|
500
|
+
if base_name.endswith(suffix):
|
|
501
|
+
base_name = base_name[: -len(suffix)]
|
|
502
|
+
|
|
503
|
+
# Handle versioned models (e.g., 'claude-3-5-sonnet' -> 'claude-3-sonnet')
|
|
504
|
+
if "claude-3-5" in base_name:
|
|
505
|
+
base_name = base_name.replace("claude-3-5", "claude-3")
|
|
506
|
+
elif "claude-3-7" in base_name:
|
|
507
|
+
base_name = base_name.replace("claude-3-7", "claude-3")
|
|
508
|
+
|
|
509
|
+
# Handle OpenRouter nested paths (e.g., 'openai/gpt-4o' -> 'gpt-4o')
|
|
510
|
+
if "/" in base_name:
|
|
511
|
+
base_name = base_name.split("/")[-1]
|
|
512
|
+
|
|
513
|
+
return base_name
|
|
514
|
+
|
|
515
|
+
def get_model_variants(self, base_model_name: str) -> List[ModelInfo]:
|
|
516
|
+
"""Get all variants of a base model across different providers."""
|
|
517
|
+
base_name = base_model_name.lower()
|
|
518
|
+
variants = []
|
|
519
|
+
|
|
520
|
+
for model in self.models.values():
|
|
521
|
+
model_base = self._extract_base_model_name(model)
|
|
522
|
+
if model_base == base_name or base_name in model_base:
|
|
523
|
+
variants.append(model)
|
|
524
|
+
|
|
525
|
+
# Sort by cost (free first, then ascending cost)
|
|
526
|
+
def sort_key(model: ModelInfo) -> tuple:
|
|
527
|
+
cost = model.cost.input or 999
|
|
528
|
+
is_free = cost == 0
|
|
529
|
+
return (not is_free, cost, model.provider, model.id)
|
|
530
|
+
|
|
531
|
+
variants.sort(key=sort_key)
|
|
532
|
+
return variants
|
|
533
|
+
|
|
534
|
+
def find_base_models(self, query: str) -> Dict[str, List[ModelInfo]]:
|
|
535
|
+
"""Find base models and group their variants by routing source."""
|
|
536
|
+
query_lower = query.lower()
|
|
537
|
+
base_models: Dict[str, List[ModelInfo]] = {}
|
|
538
|
+
|
|
539
|
+
# Find all models matching the query
|
|
540
|
+
matching_models = []
|
|
541
|
+
for model in self.models.values():
|
|
542
|
+
base_name = self._extract_base_model_name(model)
|
|
543
|
+
if (
|
|
544
|
+
query_lower in model.id.lower()
|
|
545
|
+
or query_lower in model.name.lower()
|
|
546
|
+
or query_lower in base_name
|
|
547
|
+
or query_lower in model.provider.lower()
|
|
548
|
+
):
|
|
549
|
+
matching_models.append((base_name, model))
|
|
550
|
+
|
|
551
|
+
# Group by base model name
|
|
552
|
+
for base_name, model in matching_models:
|
|
553
|
+
if base_name not in base_models:
|
|
554
|
+
base_models[base_name] = []
|
|
555
|
+
base_models[base_name].append(model)
|
|
556
|
+
|
|
557
|
+
# Sort variants within each base model
|
|
558
|
+
for base_name in base_models:
|
|
559
|
+
|
|
560
|
+
def sort_key(model: ModelInfo) -> tuple:
|
|
561
|
+
cost = model.cost.input or 999
|
|
562
|
+
is_free = cost == 0
|
|
563
|
+
return (not is_free, cost, model.provider, model.id)
|
|
564
|
+
|
|
565
|
+
base_models[base_name].sort(key=sort_key)
|
|
566
|
+
|
|
567
|
+
return base_models
|
|
568
|
+
|
|
569
|
+
def get_popular_base_models(self) -> List[str]:
|
|
570
|
+
"""Get list of popular base model names for suggestions."""
|
|
571
|
+
popular_patterns = [
|
|
572
|
+
"gpt-4o",
|
|
573
|
+
"gpt-4",
|
|
574
|
+
"gpt-3.5-turbo",
|
|
575
|
+
"claude-3-opus",
|
|
576
|
+
"claude-3-sonnet",
|
|
577
|
+
"claude-3-haiku",
|
|
578
|
+
"gemini-2",
|
|
579
|
+
"gemini-1.5-pro",
|
|
580
|
+
"gemini-1.5-flash",
|
|
581
|
+
"o1-preview",
|
|
582
|
+
"o1-mini",
|
|
583
|
+
"o3",
|
|
584
|
+
"o3-mini",
|
|
585
|
+
]
|
|
586
|
+
|
|
587
|
+
available_base_models = []
|
|
588
|
+
for pattern in popular_patterns:
|
|
589
|
+
variants = self.get_model_variants(pattern)
|
|
590
|
+
if variants:
|
|
591
|
+
available_base_models.append(pattern)
|
|
592
|
+
|
|
593
|
+
return available_base_models
|