tunacode-cli 0.0.55__py3-none-any.whl → 0.0.78.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tunacode-cli might be problematic. Click here for more details.

Files changed (114) hide show
  1. tunacode/cli/commands/__init__.py +2 -2
  2. tunacode/cli/commands/implementations/__init__.py +2 -3
  3. tunacode/cli/commands/implementations/command_reload.py +48 -0
  4. tunacode/cli/commands/implementations/debug.py +2 -2
  5. tunacode/cli/commands/implementations/development.py +10 -8
  6. tunacode/cli/commands/implementations/model.py +357 -29
  7. tunacode/cli/commands/implementations/quickstart.py +43 -0
  8. tunacode/cli/commands/implementations/system.py +96 -3
  9. tunacode/cli/commands/implementations/template.py +0 -2
  10. tunacode/cli/commands/registry.py +139 -5
  11. tunacode/cli/commands/slash/__init__.py +32 -0
  12. tunacode/cli/commands/slash/command.py +157 -0
  13. tunacode/cli/commands/slash/loader.py +135 -0
  14. tunacode/cli/commands/slash/processor.py +294 -0
  15. tunacode/cli/commands/slash/types.py +93 -0
  16. tunacode/cli/commands/slash/validator.py +400 -0
  17. tunacode/cli/main.py +23 -2
  18. tunacode/cli/repl.py +217 -190
  19. tunacode/cli/repl_components/command_parser.py +38 -4
  20. tunacode/cli/repl_components/error_recovery.py +85 -4
  21. tunacode/cli/repl_components/output_display.py +12 -1
  22. tunacode/cli/repl_components/tool_executor.py +1 -1
  23. tunacode/configuration/defaults.py +12 -3
  24. tunacode/configuration/key_descriptions.py +284 -0
  25. tunacode/configuration/settings.py +0 -1
  26. tunacode/constants.py +12 -40
  27. tunacode/core/agents/__init__.py +43 -2
  28. tunacode/core/agents/agent_components/__init__.py +7 -0
  29. tunacode/core/agents/agent_components/agent_config.py +249 -55
  30. tunacode/core/agents/agent_components/agent_helpers.py +43 -13
  31. tunacode/core/agents/agent_components/node_processor.py +179 -139
  32. tunacode/core/agents/agent_components/response_state.py +123 -6
  33. tunacode/core/agents/agent_components/state_transition.py +116 -0
  34. tunacode/core/agents/agent_components/streaming.py +296 -0
  35. tunacode/core/agents/agent_components/task_completion.py +19 -6
  36. tunacode/core/agents/agent_components/tool_buffer.py +21 -1
  37. tunacode/core/agents/agent_components/tool_executor.py +10 -0
  38. tunacode/core/agents/main.py +522 -370
  39. tunacode/core/agents/main_legact.py +538 -0
  40. tunacode/core/agents/prompts.py +66 -0
  41. tunacode/core/agents/utils.py +29 -121
  42. tunacode/core/code_index.py +83 -29
  43. tunacode/core/setup/__init__.py +0 -2
  44. tunacode/core/setup/config_setup.py +110 -20
  45. tunacode/core/setup/config_wizard.py +230 -0
  46. tunacode/core/setup/coordinator.py +14 -5
  47. tunacode/core/state.py +16 -20
  48. tunacode/core/token_usage/usage_tracker.py +5 -3
  49. tunacode/core/tool_authorization.py +352 -0
  50. tunacode/core/tool_handler.py +67 -40
  51. tunacode/exceptions.py +119 -5
  52. tunacode/prompts/system.xml +751 -0
  53. tunacode/services/mcp.py +125 -7
  54. tunacode/setup.py +5 -25
  55. tunacode/tools/base.py +163 -0
  56. tunacode/tools/bash.py +110 -1
  57. tunacode/tools/glob.py +332 -34
  58. tunacode/tools/grep.py +179 -82
  59. tunacode/tools/grep_components/result_formatter.py +98 -4
  60. tunacode/tools/list_dir.py +132 -2
  61. tunacode/tools/prompts/bash_prompt.xml +72 -0
  62. tunacode/tools/prompts/glob_prompt.xml +45 -0
  63. tunacode/tools/prompts/grep_prompt.xml +98 -0
  64. tunacode/tools/prompts/list_dir_prompt.xml +31 -0
  65. tunacode/tools/prompts/react_prompt.xml +23 -0
  66. tunacode/tools/prompts/read_file_prompt.xml +54 -0
  67. tunacode/tools/prompts/run_command_prompt.xml +64 -0
  68. tunacode/tools/prompts/update_file_prompt.xml +53 -0
  69. tunacode/tools/prompts/write_file_prompt.xml +37 -0
  70. tunacode/tools/react.py +153 -0
  71. tunacode/tools/read_file.py +91 -0
  72. tunacode/tools/run_command.py +114 -0
  73. tunacode/tools/schema_assembler.py +167 -0
  74. tunacode/tools/update_file.py +94 -0
  75. tunacode/tools/write_file.py +86 -0
  76. tunacode/tools/xml_helper.py +83 -0
  77. tunacode/tutorial/__init__.py +9 -0
  78. tunacode/tutorial/content.py +98 -0
  79. tunacode/tutorial/manager.py +182 -0
  80. tunacode/tutorial/steps.py +124 -0
  81. tunacode/types.py +20 -27
  82. tunacode/ui/completers.py +434 -50
  83. tunacode/ui/config_dashboard.py +585 -0
  84. tunacode/ui/console.py +63 -11
  85. tunacode/ui/input.py +20 -3
  86. tunacode/ui/keybindings.py +7 -4
  87. tunacode/ui/model_selector.py +395 -0
  88. tunacode/ui/output.py +40 -19
  89. tunacode/ui/panels.py +212 -43
  90. tunacode/ui/path_heuristics.py +91 -0
  91. tunacode/ui/prompt_manager.py +5 -1
  92. tunacode/ui/tool_ui.py +33 -10
  93. tunacode/utils/api_key_validation.py +93 -0
  94. tunacode/utils/config_comparator.py +340 -0
  95. tunacode/utils/json_utils.py +206 -0
  96. tunacode/utils/message_utils.py +14 -4
  97. tunacode/utils/models_registry.py +593 -0
  98. tunacode/utils/ripgrep.py +332 -9
  99. tunacode/utils/text_utils.py +18 -1
  100. tunacode/utils/user_configuration.py +45 -0
  101. tunacode_cli-0.0.78.6.dist-info/METADATA +260 -0
  102. tunacode_cli-0.0.78.6.dist-info/RECORD +158 -0
  103. {tunacode_cli-0.0.55.dist-info → tunacode_cli-0.0.78.6.dist-info}/WHEEL +1 -2
  104. tunacode/cli/commands/implementations/todo.py +0 -217
  105. tunacode/context.py +0 -71
  106. tunacode/core/setup/git_safety_setup.py +0 -182
  107. tunacode/prompts/system.md +0 -731
  108. tunacode/tools/read_file_async_poc.py +0 -196
  109. tunacode/tools/todo.py +0 -349
  110. tunacode_cli-0.0.55.dist-info/METADATA +0 -322
  111. tunacode_cli-0.0.55.dist-info/RECORD +0 -126
  112. tunacode_cli-0.0.55.dist-info/top_level.txt +0 -1
  113. {tunacode_cli-0.0.55.dist-info → tunacode_cli-0.0.78.6.dist-info}/entry_points.txt +0 -0
  114. {tunacode_cli-0.0.55.dist-info → tunacode_cli-0.0.78.6.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,593 @@
1
+ """Models.dev integration for model discovery and validation."""
2
+
3
+ import json
4
+ from datetime import datetime, timedelta
5
+ from difflib import SequenceMatcher
6
+ from pathlib import Path
7
+ from typing import Any, Dict, List, Optional
8
+ from urllib.error import URLError
9
+ from urllib.request import urlopen
10
+
11
+ from pydantic import BaseModel, ConfigDict, Field, field_validator
12
+
13
+
14
+ class ModelCapabilities(BaseModel):
15
+ """Model capabilities and features."""
16
+
17
+ model_config = ConfigDict(extra="ignore")
18
+
19
+ attachment: bool = False
20
+ reasoning: bool = False
21
+ tool_call: bool = False
22
+ temperature: bool = True
23
+ knowledge: Optional[str] = None
24
+
25
+
26
+ class ModelCost(BaseModel):
27
+ """Model pricing information."""
28
+
29
+ model_config = ConfigDict(extra="ignore")
30
+
31
+ input: Optional[float] = None
32
+ output: Optional[float] = None
33
+ cache: Optional[float] = None
34
+
35
+ @field_validator("input", "output", "cache")
36
+ @classmethod
37
+ def _non_negative(cls, v: Optional[float]) -> Optional[float]:
38
+ if v is None:
39
+ return v
40
+ if v < 0:
41
+ raise ValueError("cost values must be non-negative")
42
+ return float(v)
43
+
44
+ def format_cost(self) -> str:
45
+ """Format cost as a readable string."""
46
+ if self.input is None or self.output is None:
47
+ return "Pricing not available"
48
+ return f"${self.input}/{self.output} per 1M tokens"
49
+
50
+
51
+ class ModelLimits(BaseModel):
52
+ """Model context and output limits."""
53
+
54
+ model_config = ConfigDict(extra="ignore")
55
+
56
+ context: Optional[int] = None
57
+ output: Optional[int] = None
58
+
59
+ @field_validator("context", "output")
60
+ @classmethod
61
+ def _positive_int(cls, v: Optional[int]) -> Optional[int]:
62
+ if v is None:
63
+ return v
64
+ iv = int(v)
65
+ if iv < 0:
66
+ raise ValueError("limits must be non-negative integers")
67
+ return iv if iv > 0 else None
68
+
69
+ def format_limits(self) -> str:
70
+ """Format limits as a readable string."""
71
+ parts: List[str] = []
72
+ if self.context:
73
+ parts.append(f"{self.context:,} context")
74
+ if self.output:
75
+ parts.append(f"{self.output:,} output")
76
+ return ", ".join(parts) if parts else "Limits not specified"
77
+
78
+
79
+ class ModelInfo(BaseModel):
80
+ """Complete model information."""
81
+
82
+ model_config = ConfigDict(extra="ignore")
83
+
84
+ id: str
85
+ name: str
86
+ provider: str
87
+ capabilities: ModelCapabilities = Field(default_factory=ModelCapabilities)
88
+ cost: ModelCost = Field(default_factory=ModelCost)
89
+ limits: ModelLimits = Field(default_factory=ModelLimits)
90
+ release_date: Optional[str] = None
91
+ last_updated: Optional[str] = None
92
+ open_weights: bool = False
93
+ modalities: Dict[str, List[str]] = Field(default_factory=dict)
94
+
95
+ @property
96
+ def full_id(self) -> str:
97
+ """Get the full model identifier with provider prefix."""
98
+ return f"{self.provider}:{self.id}"
99
+
100
+ def format_display(self, include_details: bool = True) -> str:
101
+ """Format model for display."""
102
+ display = f"{self.full_id} - {self.name}"
103
+ if include_details:
104
+ details: List[str] = []
105
+ if self.cost.input is not None:
106
+ details.append(self.cost.format_cost())
107
+ if self.limits.context:
108
+ details.append(f"{self.limits.context // 1000}k context")
109
+ if details:
110
+ display += f" ({', '.join(details)})"
111
+ return display
112
+
113
+ # we need to make this lighter for future dev, low priority
114
+ def matches_search(self, query: str) -> float:
115
+ """Calculate match score for search query (0-1)."""
116
+ query_lower = query.lower()
117
+
118
+ # Exact match in ID or name
119
+ if query_lower in self.id.lower():
120
+ return 1.0
121
+ if query_lower in self.name.lower():
122
+ return 0.9
123
+ if query_lower in self.provider.lower():
124
+ return 0.8
125
+
126
+ # Fuzzy match
127
+ best_ratio = 0.0
128
+ for field_value in [self.id, self.name, self.provider]:
129
+ ratio = SequenceMatcher(None, query_lower, field_value.lower()).ratio()
130
+ best_ratio = max(best_ratio, ratio)
131
+
132
+ return best_ratio
133
+
134
+
135
+ class ProviderInfo(BaseModel):
136
+ """Provider information."""
137
+
138
+ model_config = ConfigDict(extra="ignore")
139
+
140
+ id: str
141
+ name: str
142
+ env: List[str] = Field(default_factory=list)
143
+ npm: Optional[str] = None
144
+ doc: Optional[str] = None
145
+
146
+
147
+ class ModelsRegistry:
148
+ """Registry for managing models from models.dev."""
149
+
150
+ API_URL = "https://models.dev/api.json"
151
+ CACHE_FILE = "models_cache.json"
152
+ CACHE_TTL = timedelta(hours=24)
153
+
154
+ def __init__(self, cache_dir: Optional[Path] = None):
155
+ """Initialize the models registry."""
156
+ if cache_dir is None:
157
+ cache_dir = Path.home() / ".tunacode" / "cache"
158
+ self.cache_dir = cache_dir
159
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
160
+ self.cache_file = self.cache_dir / self.CACHE_FILE
161
+
162
+ self.models: Dict[str, ModelInfo] = {}
163
+ self.providers: Dict[str, ProviderInfo] = {}
164
+ self._loaded = False
165
+
166
+ def _is_cache_valid(self) -> bool:
167
+ """Check if cache file exists and is still valid."""
168
+ if not self.cache_file.exists():
169
+ return False
170
+
171
+ # Check cache age
172
+ cache_age = datetime.now() - datetime.fromtimestamp(self.cache_file.stat().st_mtime)
173
+ return cache_age < self.CACHE_TTL
174
+
175
+ def _load_from_cache(self) -> bool:
176
+ """Load models from cache file."""
177
+ try:
178
+ with open(self.cache_file, "r") as f:
179
+ data = json.load(f)
180
+ self._parse_data(data)
181
+ return True
182
+ except (json.JSONDecodeError, OSError, KeyError):
183
+ return False
184
+
185
+ def _fetch_from_api(self) -> bool:
186
+ """Fetch models from models.dev API."""
187
+ try:
188
+ # Add User-Agent header to avoid blocking
189
+ import urllib.request
190
+
191
+ req = urllib.request.Request(self.API_URL, headers={"User-Agent": "TunaCode-CLI/1.0"})
192
+ with urlopen(req, timeout=10) as response: # nosec B310 - Using trusted models.dev API
193
+ data = json.loads(response.read())
194
+
195
+ # Save to cache
196
+ with open(self.cache_file, "w") as f:
197
+ json.dump(data, f, indent=2)
198
+
199
+ self._parse_data(data)
200
+ return True
201
+ except (URLError, json.JSONDecodeError, OSError):
202
+ # Log error but don't fail
203
+ return False
204
+
205
+ def _load_fallback_models(self) -> None:
206
+ """Load hardcoded popular models as fallback."""
207
+ fallback_data = {
208
+ "openai": {
209
+ "name": "OpenAI",
210
+ "env": ["OPENAI_API_KEY"],
211
+ "npm": "@ai-sdk/openai",
212
+ "doc": "https://platform.openai.com/docs",
213
+ "models": {
214
+ "gpt-4": {
215
+ "name": "GPT-4",
216
+ "attachment": True,
217
+ "reasoning": True,
218
+ "tool_call": True,
219
+ "temperature": True,
220
+ "knowledge": "2024-04",
221
+ "cost": {"input": 30.0, "output": 60.0},
222
+ "limit": {"context": 128000, "output": 4096},
223
+ },
224
+ "gpt-4-turbo": {
225
+ "name": "GPT-4 Turbo",
226
+ "attachment": True,
227
+ "reasoning": True,
228
+ "tool_call": True,
229
+ "temperature": True,
230
+ "knowledge": "2024-04",
231
+ "cost": {"input": 10.0, "output": 30.0},
232
+ "limit": {"context": 128000, "output": 4096},
233
+ },
234
+ "gpt-3.5-turbo": {
235
+ "name": "GPT-3.5 Turbo",
236
+ "attachment": False,
237
+ "reasoning": False,
238
+ "tool_call": True,
239
+ "temperature": True,
240
+ "knowledge": "2024-01",
241
+ "cost": {"input": 0.5, "output": 1.5},
242
+ "limit": {"context": 16000, "output": 4096},
243
+ },
244
+ },
245
+ },
246
+ "anthropic": {
247
+ "name": "Anthropic",
248
+ "env": ["ANTHROPIC_API_KEY"],
249
+ "npm": "@ai-sdk/anthropic",
250
+ "doc": "https://docs.anthropic.com",
251
+ "models": {
252
+ "claude-3-opus-20240229": {
253
+ "name": "Claude 3 Opus",
254
+ "attachment": True,
255
+ "reasoning": True,
256
+ "tool_call": True,
257
+ "temperature": True,
258
+ "knowledge": "2024-04",
259
+ "cost": {"input": 15.0, "output": 75.0},
260
+ "limit": {"context": 200000, "output": 4096},
261
+ },
262
+ "claude-3-sonnet-20240229": {
263
+ "name": "Claude 3 Sonnet",
264
+ "attachment": True,
265
+ "reasoning": True,
266
+ "tool_call": True,
267
+ "temperature": True,
268
+ "knowledge": "2024-04",
269
+ "cost": {"input": 3.0, "output": 15.0},
270
+ "limit": {"context": 200000, "output": 4096},
271
+ },
272
+ "claude-3-haiku-20240307": {
273
+ "name": "Claude 3 Haiku",
274
+ "attachment": True,
275
+ "reasoning": False,
276
+ "tool_call": True,
277
+ "temperature": True,
278
+ "knowledge": "2024-04",
279
+ "cost": {"input": 0.25, "output": 1.25},
280
+ "limit": {"context": 200000, "output": 4096},
281
+ },
282
+ },
283
+ },
284
+ "google": {
285
+ "name": "Google",
286
+ "env": ["GOOGLE_API_KEY"],
287
+ "npm": "@ai-sdk/google",
288
+ "doc": "https://ai.google.dev",
289
+ "models": {
290
+ "gemini-1.5-pro": {
291
+ "name": "Gemini 1.5 Pro",
292
+ "attachment": True,
293
+ "reasoning": True,
294
+ "tool_call": True,
295
+ "temperature": True,
296
+ "knowledge": "2024-04",
297
+ "cost": {"input": 3.5, "output": 10.5},
298
+ "limit": {"context": 2000000, "output": 8192},
299
+ },
300
+ "gemini-1.5-flash": {
301
+ "name": "Gemini 1.5 Flash",
302
+ "attachment": True,
303
+ "reasoning": True,
304
+ "tool_call": True,
305
+ "temperature": True,
306
+ "knowledge": "2024-04",
307
+ "cost": {"input": 0.075, "output": 0.3},
308
+ "limit": {"context": 1000000, "output": 8192},
309
+ },
310
+ },
311
+ },
312
+ }
313
+
314
+ self._parse_data(fallback_data)
315
+
316
+ def _parse_data(self, data: Dict[str, Any]) -> None:
317
+ """Parse models data from API response."""
318
+ self.models.clear()
319
+ self.providers.clear()
320
+
321
+ for provider_id, provider_data in data.items():
322
+ # Skip non-provider keys
323
+ if not isinstance(provider_data, dict) or "models" not in provider_data:
324
+ continue
325
+
326
+ # Parse provider info
327
+ provider = ProviderInfo(
328
+ id=provider_id,
329
+ name=provider_data.get("name", provider_id),
330
+ env=provider_data.get("env", []),
331
+ npm=provider_data.get("npm"),
332
+ doc=provider_data.get("doc"),
333
+ )
334
+ self.providers[provider_id] = provider
335
+
336
+ # Parse models
337
+ models_data = provider_data.get("models", {})
338
+ for model_id, model_data in models_data.items():
339
+ if not isinstance(model_data, dict):
340
+ continue
341
+
342
+ # Parse capabilities
343
+ capabilities = ModelCapabilities(
344
+ attachment=bool(model_data.get("attachment", False)),
345
+ reasoning=bool(model_data.get("reasoning", False)),
346
+ tool_call=bool(model_data.get("tool_call", False)),
347
+ temperature=bool(model_data.get("temperature", True)),
348
+ knowledge=model_data.get("knowledge"),
349
+ )
350
+
351
+ # Parse cost
352
+ cost_data = model_data.get("cost", {})
353
+ cost = ModelCost(
354
+ input=(cost_data.get("input") if isinstance(cost_data, dict) else None),
355
+ output=(cost_data.get("output") if isinstance(cost_data, dict) else None),
356
+ cache=(cost_data.get("cache") if isinstance(cost_data, dict) else None),
357
+ )
358
+
359
+ # Parse limits
360
+ limit_data = model_data.get("limit", {})
361
+ limits = ModelLimits(
362
+ context=(limit_data.get("context") if isinstance(limit_data, dict) else None),
363
+ output=(limit_data.get("output") if isinstance(limit_data, dict) else None),
364
+ )
365
+
366
+ # Create model info
367
+ model = ModelInfo(
368
+ id=model_id,
369
+ name=model_data.get("name", model_id),
370
+ provider=provider_id,
371
+ capabilities=capabilities,
372
+ cost=cost,
373
+ limits=limits,
374
+ release_date=model_data.get("release_date"),
375
+ last_updated=model_data.get("last_updated"),
376
+ open_weights=bool(model_data.get("open_weights", False)),
377
+ modalities=(
378
+ model_data.get("modalities", {})
379
+ if isinstance(model_data.get("modalities", {}), dict)
380
+ else {}
381
+ ),
382
+ )
383
+
384
+ # Store with full ID as key
385
+ self.models[model.full_id] = model
386
+
387
+ async def load(self, force_refresh: bool = False) -> bool:
388
+ """Load models data, using cache if available."""
389
+ if self._loaded and not force_refresh:
390
+ return True
391
+
392
+ # Try cache first
393
+ if not force_refresh and self._is_cache_valid():
394
+ if self._load_from_cache():
395
+ self._loaded = True
396
+ return True
397
+
398
+ # Fetch from API
399
+ if self._fetch_from_api():
400
+ self._loaded = True
401
+ return True
402
+
403
+ # Try cache as fallback even if expired
404
+ if self._load_from_cache():
405
+ self._loaded = True
406
+ # Import ui locally to avoid circular imports
407
+ from ..ui import console as ui
408
+
409
+ await ui.warning("Using cached models data (API unavailable)")
410
+ return True
411
+
412
+ # Use fallback models as last resort
413
+ from ..ui import console as ui
414
+
415
+ await ui.warning("models.dev API unavailable, using fallback model list")
416
+ self._load_fallback_models()
417
+ self._loaded = True
418
+ return True
419
+
420
+ def get_model(self, model_id: str) -> Optional[ModelInfo]:
421
+ """Get a specific model by ID."""
422
+ # Try exact match first
423
+ if model_id in self.models:
424
+ return self.models[model_id]
425
+
426
+ # Try without provider prefix
427
+ for full_id, model in self.models.items():
428
+ if model.id == model_id:
429
+ return model
430
+
431
+ return None
432
+
433
+ def validate_model(self, model_id: str) -> bool:
434
+ """Check if a model ID is valid."""
435
+ return self.get_model(model_id) is not None
436
+
437
+ def search_models(
438
+ self, query: str = "", provider: Optional[str] = None, min_score: float = 0.3
439
+ ) -> List[ModelInfo]:
440
+ """Search for models matching query."""
441
+ results = []
442
+
443
+ for model in self.models.values():
444
+ # Filter by provider if specified
445
+ if provider and model.provider != provider:
446
+ continue
447
+
448
+ # Calculate match score
449
+ if query:
450
+ score = model.matches_search(query)
451
+ if score < min_score:
452
+ continue
453
+ results.append((score, model))
454
+ else:
455
+ # No query, include all
456
+ results.append((1.0, model))
457
+
458
+ # Sort by score (descending) and name
459
+ results.sort(key=lambda x: (-x[0], x[1].name))
460
+
461
+ return [model for _, model in results]
462
+
463
+ def get_providers(self) -> List[ProviderInfo]:
464
+ """Get list of all providers."""
465
+ return sorted(self.providers.values(), key=lambda p: p.name)
466
+
467
+ def get_models_by_provider(self, provider: str) -> List[ModelInfo]:
468
+ """Get all models for a specific provider."""
469
+ return [m for m in self.models.values() if m.provider == provider]
470
+
471
+ def _extract_base_model_name(self, model: ModelInfo) -> str:
472
+ """Extract the base model name from a model (e.g., 'gpt-4o' from 'openai:gpt-4o')."""
473
+ model_id = model.id.lower()
474
+
475
+ # Handle common patterns
476
+ base_name = model_id
477
+
478
+ # Remove common suffixes
479
+ suffixes_to_remove = [
480
+ "-latest",
481
+ "-preview",
482
+ "-turbo",
483
+ "-instruct",
484
+ "-chat",
485
+ "-base",
486
+ "-20240229",
487
+ "-20240307",
488
+ "-20240620",
489
+ "-20241022",
490
+ "-20250514",
491
+ "-0613",
492
+ "-0125",
493
+ "-0301",
494
+ "-1106",
495
+ "-2024",
496
+ "-2025",
497
+ ]
498
+
499
+ for suffix in suffixes_to_remove:
500
+ if base_name.endswith(suffix):
501
+ base_name = base_name[: -len(suffix)]
502
+
503
+ # Handle versioned models (e.g., 'claude-3-5-sonnet' -> 'claude-3-sonnet')
504
+ if "claude-3-5" in base_name:
505
+ base_name = base_name.replace("claude-3-5", "claude-3")
506
+ elif "claude-3-7" in base_name:
507
+ base_name = base_name.replace("claude-3-7", "claude-3")
508
+
509
+ # Handle OpenRouter nested paths (e.g., 'openai/gpt-4o' -> 'gpt-4o')
510
+ if "/" in base_name:
511
+ base_name = base_name.split("/")[-1]
512
+
513
+ return base_name
514
+
515
+ def get_model_variants(self, base_model_name: str) -> List[ModelInfo]:
516
+ """Get all variants of a base model across different providers."""
517
+ base_name = base_model_name.lower()
518
+ variants = []
519
+
520
+ for model in self.models.values():
521
+ model_base = self._extract_base_model_name(model)
522
+ if model_base == base_name or base_name in model_base:
523
+ variants.append(model)
524
+
525
+ # Sort by cost (free first, then ascending cost)
526
+ def sort_key(model: ModelInfo) -> tuple:
527
+ cost = model.cost.input or 999
528
+ is_free = cost == 0
529
+ return (not is_free, cost, model.provider, model.id)
530
+
531
+ variants.sort(key=sort_key)
532
+ return variants
533
+
534
+ def find_base_models(self, query: str) -> Dict[str, List[ModelInfo]]:
535
+ """Find base models and group their variants by routing source."""
536
+ query_lower = query.lower()
537
+ base_models: Dict[str, List[ModelInfo]] = {}
538
+
539
+ # Find all models matching the query
540
+ matching_models = []
541
+ for model in self.models.values():
542
+ base_name = self._extract_base_model_name(model)
543
+ if (
544
+ query_lower in model.id.lower()
545
+ or query_lower in model.name.lower()
546
+ or query_lower in base_name
547
+ or query_lower in model.provider.lower()
548
+ ):
549
+ matching_models.append((base_name, model))
550
+
551
+ # Group by base model name
552
+ for base_name, model in matching_models:
553
+ if base_name not in base_models:
554
+ base_models[base_name] = []
555
+ base_models[base_name].append(model)
556
+
557
+ # Sort variants within each base model
558
+ for base_name in base_models:
559
+
560
+ def sort_key(model: ModelInfo) -> tuple:
561
+ cost = model.cost.input or 999
562
+ is_free = cost == 0
563
+ return (not is_free, cost, model.provider, model.id)
564
+
565
+ base_models[base_name].sort(key=sort_key)
566
+
567
+ return base_models
568
+
569
+ def get_popular_base_models(self) -> List[str]:
570
+ """Get list of popular base model names for suggestions."""
571
+ popular_patterns = [
572
+ "gpt-4o",
573
+ "gpt-4",
574
+ "gpt-3.5-turbo",
575
+ "claude-3-opus",
576
+ "claude-3-sonnet",
577
+ "claude-3-haiku",
578
+ "gemini-2",
579
+ "gemini-1.5-pro",
580
+ "gemini-1.5-flash",
581
+ "o1-preview",
582
+ "o1-mini",
583
+ "o3",
584
+ "o3-mini",
585
+ ]
586
+
587
+ available_base_models = []
588
+ for pattern in popular_patterns:
589
+ variants = self.get_model_variants(pattern)
590
+ if variants:
591
+ available_base_models.append(pattern)
592
+
593
+ return available_base_models