tunacode-cli 0.0.70__py3-none-any.whl → 0.0.72__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tunacode-cli might be problematic. Click here for more details.

@@ -0,0 +1,93 @@
1
+ """
2
+ Module: tunacode.utils.api_key_validation
3
+
4
+ Utilities for validating API keys are configured for the selected model.
5
+ """
6
+
7
+ from typing import Optional, Tuple
8
+
9
+ from tunacode.types import UserConfig
10
+
11
+
12
+ def get_required_api_key_for_model(model: str) -> Tuple[Optional[str], str]:
13
+ """
14
+ Determine which API key is required for a given model.
15
+
16
+ Args:
17
+ model: Model identifier in format "provider:model-name"
18
+
19
+ Returns:
20
+ Tuple of (api_key_name, provider_name) or (None, "unknown") if no specific key required
21
+ """
22
+ if not model or ":" not in model:
23
+ return None, "unknown"
24
+
25
+ provider = model.split(":")[0].lower()
26
+
27
+ # Map providers to their required API keys
28
+ provider_key_map = {
29
+ "openrouter": ("OPENROUTER_API_KEY", "OpenRouter"),
30
+ "openai": ("OPENAI_API_KEY", "OpenAI"),
31
+ "anthropic": ("ANTHROPIC_API_KEY", "Anthropic"),
32
+ "google": ("GEMINI_API_KEY", "Google"),
33
+ "google-gla": ("GEMINI_API_KEY", "Google"),
34
+ "gemini": ("GEMINI_API_KEY", "Google"),
35
+ }
36
+
37
+ return provider_key_map.get(provider, (None, provider))
38
+
39
+
40
+ def validate_api_key_for_model(model: str, user_config: UserConfig) -> Tuple[bool, Optional[str]]:
41
+ """
42
+ Check if the required API key exists for the given model.
43
+
44
+ Args:
45
+ model: Model identifier in format "provider:model-name"
46
+ user_config: User configuration containing env variables
47
+
48
+ Returns:
49
+ Tuple of (is_valid, error_message)
50
+ """
51
+ api_key_name, provider_name = get_required_api_key_for_model(model)
52
+
53
+ if not api_key_name:
54
+ # No specific API key required (might be custom endpoint)
55
+ return True, None
56
+
57
+ env_config = user_config.get("env", {})
58
+ api_key = env_config.get(api_key_name, "").strip()
59
+
60
+ if not api_key:
61
+ return False, (
62
+ f"No API key found for {provider_name}.\n"
63
+ f"Please run 'tunacode --setup' to configure your API key."
64
+ )
65
+
66
+ return True, None
67
+
68
+
69
+ def get_configured_providers(user_config: UserConfig) -> list[str]:
70
+ """
71
+ Get list of providers that have API keys configured.
72
+
73
+ Args:
74
+ user_config: User configuration containing env variables
75
+
76
+ Returns:
77
+ List of provider names that have API keys set
78
+ """
79
+ env_config = user_config.get("env", {})
80
+ configured = []
81
+
82
+ provider_map = {
83
+ "OPENROUTER_API_KEY": "openrouter",
84
+ "OPENAI_API_KEY": "openai",
85
+ "ANTHROPIC_API_KEY": "anthropic",
86
+ "GEMINI_API_KEY": "google",
87
+ }
88
+
89
+ for key_name, provider in provider_map.items():
90
+ if env_config.get(key_name, "").strip():
91
+ configured.append(provider)
92
+
93
+ return configured
@@ -0,0 +1,563 @@
1
+ """Models.dev integration for model discovery and validation."""
2
+
3
+ import json
4
+ from dataclasses import dataclass, field
5
+ from datetime import datetime, timedelta
6
+ from difflib import SequenceMatcher
7
+ from pathlib import Path
8
+ from typing import Any, Dict, List, Optional
9
+ from urllib.error import URLError
10
+ from urllib.request import urlopen
11
+
12
+
13
+ @dataclass
14
+ class ModelCapabilities:
15
+ """Model capabilities and features."""
16
+
17
+ attachment: bool = False
18
+ reasoning: bool = False
19
+ tool_call: bool = False
20
+ temperature: bool = True
21
+ knowledge: Optional[str] = None
22
+
23
+
24
+ @dataclass
25
+ class ModelCost:
26
+ """Model pricing information."""
27
+
28
+ input: Optional[float] = None
29
+ output: Optional[float] = None
30
+ cache: Optional[float] = None
31
+
32
+ def format_cost(self) -> str:
33
+ """Format cost as a readable string."""
34
+ if self.input is None or self.output is None:
35
+ return "Pricing not available"
36
+ return f"${self.input}/{self.output} per 1M tokens"
37
+
38
+
39
+ @dataclass
40
+ class ModelLimits:
41
+ """Model context and output limits."""
42
+
43
+ context: Optional[int] = None
44
+ output: Optional[int] = None
45
+
46
+ def format_limits(self) -> str:
47
+ """Format limits as a readable string."""
48
+ parts = []
49
+ if self.context:
50
+ parts.append(f"{self.context:,} context")
51
+ if self.output:
52
+ parts.append(f"{self.output:,} output")
53
+ return ", ".join(parts) if parts else "Limits not specified"
54
+
55
+
56
+ @dataclass
57
+ class ModelInfo:
58
+ """Complete model information."""
59
+
60
+ id: str
61
+ name: str
62
+ provider: str
63
+ capabilities: ModelCapabilities = field(default_factory=ModelCapabilities)
64
+ cost: ModelCost = field(default_factory=ModelCost)
65
+ limits: ModelLimits = field(default_factory=ModelLimits)
66
+ release_date: Optional[str] = None
67
+ last_updated: Optional[str] = None
68
+ open_weights: bool = False
69
+ modalities: Dict[str, List[str]] = field(default_factory=dict)
70
+
71
+ @property
72
+ def full_id(self) -> str:
73
+ """Get the full model identifier with provider prefix."""
74
+ return f"{self.provider}:{self.id}"
75
+
76
+ def format_display(self, include_details: bool = True) -> str:
77
+ """Format model for display."""
78
+ display = f"{self.full_id} - {self.name}"
79
+ if include_details:
80
+ details = []
81
+ if self.cost.input is not None:
82
+ details.append(self.cost.format_cost())
83
+ if self.limits.context:
84
+ details.append(f"{self.limits.context // 1000}k context")
85
+ if details:
86
+ display += f" ({', '.join(details)})"
87
+ return display
88
+
89
+ def matches_search(self, query: str) -> float:
90
+ """Calculate match score for search query (0-1)."""
91
+ query_lower = query.lower()
92
+
93
+ # Exact match in ID or name
94
+ if query_lower in self.id.lower():
95
+ return 1.0
96
+ if query_lower in self.name.lower():
97
+ return 0.9
98
+ if query_lower in self.provider.lower():
99
+ return 0.8
100
+
101
+ # Fuzzy match
102
+ best_ratio = 0.0
103
+ for field_value in [self.id, self.name, self.provider]:
104
+ ratio = SequenceMatcher(None, query_lower, field_value.lower()).ratio()
105
+ best_ratio = max(best_ratio, ratio)
106
+
107
+ return best_ratio
108
+
109
+
110
+ @dataclass
111
+ class ProviderInfo:
112
+ """Provider information."""
113
+
114
+ id: str
115
+ name: str
116
+ env: List[str] = field(default_factory=list)
117
+ npm: Optional[str] = None
118
+ doc: Optional[str] = None
119
+
120
+
121
+ class ModelsRegistry:
122
+ """Registry for managing models from models.dev."""
123
+
124
+ API_URL = "https://models.dev/api.json"
125
+ CACHE_FILE = "models_cache.json"
126
+ CACHE_TTL = timedelta(hours=24)
127
+
128
+ def __init__(self, cache_dir: Optional[Path] = None):
129
+ """Initialize the models registry."""
130
+ if cache_dir is None:
131
+ cache_dir = Path.home() / ".tunacode" / "cache"
132
+ self.cache_dir = cache_dir
133
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
134
+ self.cache_file = self.cache_dir / self.CACHE_FILE
135
+
136
+ self.models: Dict[str, ModelInfo] = {}
137
+ self.providers: Dict[str, ProviderInfo] = {}
138
+ self._loaded = False
139
+
140
+ def _is_cache_valid(self) -> bool:
141
+ """Check if cache file exists and is still valid."""
142
+ if not self.cache_file.exists():
143
+ return False
144
+
145
+ # Check cache age
146
+ cache_age = datetime.now() - datetime.fromtimestamp(self.cache_file.stat().st_mtime)
147
+ return cache_age < self.CACHE_TTL
148
+
149
+ def _load_from_cache(self) -> bool:
150
+ """Load models from cache file."""
151
+ try:
152
+ with open(self.cache_file, "r") as f:
153
+ data = json.load(f)
154
+ self._parse_data(data)
155
+ return True
156
+ except (json.JSONDecodeError, OSError, KeyError):
157
+ return False
158
+
159
+ def _fetch_from_api(self) -> bool:
160
+ """Fetch models from models.dev API."""
161
+ try:
162
+ # Add User-Agent header to avoid blocking
163
+ import urllib.request
164
+
165
+ req = urllib.request.Request(self.API_URL, headers={"User-Agent": "TunaCode-CLI/1.0"})
166
+ with urlopen(req, timeout=10) as response: # nosec B310 - Using trusted models.dev API
167
+ data = json.loads(response.read())
168
+
169
+ # Save to cache
170
+ with open(self.cache_file, "w") as f:
171
+ json.dump(data, f, indent=2)
172
+
173
+ self._parse_data(data)
174
+ return True
175
+ except (URLError, json.JSONDecodeError, OSError):
176
+ # Log error but don't fail
177
+ return False
178
+
179
+ def _load_fallback_models(self) -> None:
180
+ """Load hardcoded popular models as fallback."""
181
+ fallback_data = {
182
+ "openai": {
183
+ "name": "OpenAI",
184
+ "env": ["OPENAI_API_KEY"],
185
+ "npm": "@ai-sdk/openai",
186
+ "doc": "https://platform.openai.com/docs",
187
+ "models": {
188
+ "gpt-4": {
189
+ "name": "GPT-4",
190
+ "attachment": True,
191
+ "reasoning": True,
192
+ "tool_call": True,
193
+ "temperature": True,
194
+ "knowledge": "2024-04",
195
+ "cost": {"input": 30.0, "output": 60.0},
196
+ "limit": {"context": 128000, "output": 4096},
197
+ },
198
+ "gpt-4-turbo": {
199
+ "name": "GPT-4 Turbo",
200
+ "attachment": True,
201
+ "reasoning": True,
202
+ "tool_call": True,
203
+ "temperature": True,
204
+ "knowledge": "2024-04",
205
+ "cost": {"input": 10.0, "output": 30.0},
206
+ "limit": {"context": 128000, "output": 4096},
207
+ },
208
+ "gpt-3.5-turbo": {
209
+ "name": "GPT-3.5 Turbo",
210
+ "attachment": False,
211
+ "reasoning": False,
212
+ "tool_call": True,
213
+ "temperature": True,
214
+ "knowledge": "2024-01",
215
+ "cost": {"input": 0.5, "output": 1.5},
216
+ "limit": {"context": 16000, "output": 4096},
217
+ },
218
+ },
219
+ },
220
+ "anthropic": {
221
+ "name": "Anthropic",
222
+ "env": ["ANTHROPIC_API_KEY"],
223
+ "npm": "@ai-sdk/anthropic",
224
+ "doc": "https://docs.anthropic.com",
225
+ "models": {
226
+ "claude-3-opus-20240229": {
227
+ "name": "Claude 3 Opus",
228
+ "attachment": True,
229
+ "reasoning": True,
230
+ "tool_call": True,
231
+ "temperature": True,
232
+ "knowledge": "2024-04",
233
+ "cost": {"input": 15.0, "output": 75.0},
234
+ "limit": {"context": 200000, "output": 4096},
235
+ },
236
+ "claude-3-sonnet-20240229": {
237
+ "name": "Claude 3 Sonnet",
238
+ "attachment": True,
239
+ "reasoning": True,
240
+ "tool_call": True,
241
+ "temperature": True,
242
+ "knowledge": "2024-04",
243
+ "cost": {"input": 3.0, "output": 15.0},
244
+ "limit": {"context": 200000, "output": 4096},
245
+ },
246
+ "claude-3-haiku-20240307": {
247
+ "name": "Claude 3 Haiku",
248
+ "attachment": True,
249
+ "reasoning": False,
250
+ "tool_call": True,
251
+ "temperature": True,
252
+ "knowledge": "2024-04",
253
+ "cost": {"input": 0.25, "output": 1.25},
254
+ "limit": {"context": 200000, "output": 4096},
255
+ },
256
+ },
257
+ },
258
+ "google": {
259
+ "name": "Google",
260
+ "env": ["GOOGLE_API_KEY"],
261
+ "npm": "@ai-sdk/google",
262
+ "doc": "https://ai.google.dev",
263
+ "models": {
264
+ "gemini-1.5-pro": {
265
+ "name": "Gemini 1.5 Pro",
266
+ "attachment": True,
267
+ "reasoning": True,
268
+ "tool_call": True,
269
+ "temperature": True,
270
+ "knowledge": "2024-04",
271
+ "cost": {"input": 3.5, "output": 10.5},
272
+ "limit": {"context": 2000000, "output": 8192},
273
+ },
274
+ "gemini-1.5-flash": {
275
+ "name": "Gemini 1.5 Flash",
276
+ "attachment": True,
277
+ "reasoning": True,
278
+ "tool_call": True,
279
+ "temperature": True,
280
+ "knowledge": "2024-04",
281
+ "cost": {"input": 0.075, "output": 0.3},
282
+ "limit": {"context": 1000000, "output": 8192},
283
+ },
284
+ },
285
+ },
286
+ }
287
+
288
+ self._parse_data(fallback_data)
289
+
290
+ def _parse_data(self, data: Dict[str, Any]) -> None:
291
+ """Parse models data from API response."""
292
+ self.models.clear()
293
+ self.providers.clear()
294
+
295
+ for provider_id, provider_data in data.items():
296
+ # Skip non-provider keys
297
+ if not isinstance(provider_data, dict) or "models" not in provider_data:
298
+ continue
299
+
300
+ # Parse provider info
301
+ provider = ProviderInfo(
302
+ id=provider_id,
303
+ name=provider_data.get("name", provider_id),
304
+ env=provider_data.get("env", []),
305
+ npm=provider_data.get("npm"),
306
+ doc=provider_data.get("doc"),
307
+ )
308
+ self.providers[provider_id] = provider
309
+
310
+ # Parse models
311
+ models_data = provider_data.get("models", {})
312
+ for model_id, model_data in models_data.items():
313
+ if not isinstance(model_data, dict):
314
+ continue
315
+
316
+ # Parse capabilities
317
+ capabilities = ModelCapabilities(
318
+ attachment=model_data.get("attachment", False),
319
+ reasoning=model_data.get("reasoning", False),
320
+ tool_call=model_data.get("tool_call", False),
321
+ temperature=model_data.get("temperature", True),
322
+ knowledge=model_data.get("knowledge"),
323
+ )
324
+
325
+ # Parse cost
326
+ cost_data = model_data.get("cost", {})
327
+ cost = ModelCost(
328
+ input=cost_data.get("input") if isinstance(cost_data, dict) else None,
329
+ output=cost_data.get("output") if isinstance(cost_data, dict) else None,
330
+ cache=cost_data.get("cache") if isinstance(cost_data, dict) else None,
331
+ )
332
+
333
+ # Parse limits
334
+ limit_data = model_data.get("limit", {})
335
+ limits = ModelLimits(
336
+ context=limit_data.get("context") if isinstance(limit_data, dict) else None,
337
+ output=limit_data.get("output") if isinstance(limit_data, dict) else None,
338
+ )
339
+
340
+ # Create model info
341
+ model = ModelInfo(
342
+ id=model_id,
343
+ name=model_data.get("name", model_id),
344
+ provider=provider_id,
345
+ capabilities=capabilities,
346
+ cost=cost,
347
+ limits=limits,
348
+ release_date=model_data.get("release_date"),
349
+ last_updated=model_data.get("last_updated"),
350
+ open_weights=model_data.get("open_weights", False),
351
+ modalities=model_data.get("modalities", {}),
352
+ )
353
+
354
+ # Store with full ID as key
355
+ self.models[model.full_id] = model
356
+
357
+ async def load(self, force_refresh: bool = False) -> bool:
358
+ """Load models data, using cache if available."""
359
+ if self._loaded and not force_refresh:
360
+ return True
361
+
362
+ # Try cache first
363
+ if not force_refresh and self._is_cache_valid():
364
+ if self._load_from_cache():
365
+ self._loaded = True
366
+ return True
367
+
368
+ # Fetch from API
369
+ if self._fetch_from_api():
370
+ self._loaded = True
371
+ return True
372
+
373
+ # Try cache as fallback even if expired
374
+ if self._load_from_cache():
375
+ self._loaded = True
376
+ # Import ui locally to avoid circular imports
377
+ from ..ui import console as ui
378
+
379
+ await ui.warning("Using cached models data (API unavailable)")
380
+ return True
381
+
382
+ # Use fallback models as last resort
383
+ from ..ui import console as ui
384
+
385
+ await ui.warning("models.dev API unavailable, using fallback model list")
386
+ self._load_fallback_models()
387
+ self._loaded = True
388
+ return True
389
+
390
+ def get_model(self, model_id: str) -> Optional[ModelInfo]:
391
+ """Get a specific model by ID."""
392
+ # Try exact match first
393
+ if model_id in self.models:
394
+ return self.models[model_id]
395
+
396
+ # Try without provider prefix
397
+ for full_id, model in self.models.items():
398
+ if model.id == model_id:
399
+ return model
400
+
401
+ return None
402
+
403
+ def validate_model(self, model_id: str) -> bool:
404
+ """Check if a model ID is valid."""
405
+ return self.get_model(model_id) is not None
406
+
407
+ def search_models(
408
+ self, query: str = "", provider: Optional[str] = None, min_score: float = 0.3
409
+ ) -> List[ModelInfo]:
410
+ """Search for models matching query."""
411
+ results = []
412
+
413
+ for model in self.models.values():
414
+ # Filter by provider if specified
415
+ if provider and model.provider != provider:
416
+ continue
417
+
418
+ # Calculate match score
419
+ if query:
420
+ score = model.matches_search(query)
421
+ if score < min_score:
422
+ continue
423
+ results.append((score, model))
424
+ else:
425
+ # No query, include all
426
+ results.append((1.0, model))
427
+
428
+ # Sort by score (descending) and name
429
+ results.sort(key=lambda x: (-x[0], x[1].name))
430
+
431
+ return [model for _, model in results]
432
+
433
+ def get_providers(self) -> List[ProviderInfo]:
434
+ """Get list of all providers."""
435
+ return sorted(self.providers.values(), key=lambda p: p.name)
436
+
437
+ def get_models_by_provider(self, provider: str) -> List[ModelInfo]:
438
+ """Get all models for a specific provider."""
439
+ return [m for m in self.models.values() if m.provider == provider]
440
+
441
+ def _extract_base_model_name(self, model: ModelInfo) -> str:
442
+ """Extract the base model name from a model (e.g., 'gpt-4o' from 'openai:gpt-4o')."""
443
+ model_id = model.id.lower()
444
+
445
+ # Handle common patterns
446
+ base_name = model_id
447
+
448
+ # Remove common suffixes
449
+ suffixes_to_remove = [
450
+ "-latest",
451
+ "-preview",
452
+ "-turbo",
453
+ "-instruct",
454
+ "-chat",
455
+ "-base",
456
+ "-20240229",
457
+ "-20240307",
458
+ "-20240620",
459
+ "-20241022",
460
+ "-20250514",
461
+ "-0613",
462
+ "-0125",
463
+ "-0301",
464
+ "-1106",
465
+ "-2024",
466
+ "-2025",
467
+ ]
468
+
469
+ for suffix in suffixes_to_remove:
470
+ if base_name.endswith(suffix):
471
+ base_name = base_name[: -len(suffix)]
472
+
473
+ # Handle versioned models (e.g., 'claude-3-5-sonnet' -> 'claude-3-sonnet')
474
+ if "claude-3-5" in base_name:
475
+ base_name = base_name.replace("claude-3-5", "claude-3")
476
+ elif "claude-3-7" in base_name:
477
+ base_name = base_name.replace("claude-3-7", "claude-3")
478
+
479
+ # Handle OpenRouter nested paths (e.g., 'openai/gpt-4o' -> 'gpt-4o')
480
+ if "/" in base_name:
481
+ base_name = base_name.split("/")[-1]
482
+
483
+ return base_name
484
+
485
+ def get_model_variants(self, base_model_name: str) -> List[ModelInfo]:
486
+ """Get all variants of a base model across different providers."""
487
+ base_name = base_model_name.lower()
488
+ variants = []
489
+
490
+ for model in self.models.values():
491
+ model_base = self._extract_base_model_name(model)
492
+ if model_base == base_name or base_name in model_base:
493
+ variants.append(model)
494
+
495
+ # Sort by cost (free first, then ascending cost)
496
+ def sort_key(model: ModelInfo) -> tuple:
497
+ cost = model.cost.input or 999
498
+ is_free = cost == 0
499
+ return (not is_free, cost, model.provider, model.id)
500
+
501
+ variants.sort(key=sort_key)
502
+ return variants
503
+
504
+ def find_base_models(self, query: str) -> Dict[str, List[ModelInfo]]:
505
+ """Find base models and group their variants by routing source."""
506
+ query_lower = query.lower()
507
+ base_models: Dict[str, List[ModelInfo]] = {}
508
+
509
+ # Find all models matching the query
510
+ matching_models = []
511
+ for model in self.models.values():
512
+ base_name = self._extract_base_model_name(model)
513
+ if (
514
+ query_lower in model.id.lower()
515
+ or query_lower in model.name.lower()
516
+ or query_lower in base_name
517
+ or query_lower in model.provider.lower()
518
+ ):
519
+ matching_models.append((base_name, model))
520
+
521
+ # Group by base model name
522
+ for base_name, model in matching_models:
523
+ if base_name not in base_models:
524
+ base_models[base_name] = []
525
+ base_models[base_name].append(model)
526
+
527
+ # Sort variants within each base model
528
+ for base_name in base_models:
529
+
530
+ def sort_key(model: ModelInfo) -> tuple:
531
+ cost = model.cost.input or 999
532
+ is_free = cost == 0
533
+ return (not is_free, cost, model.provider, model.id)
534
+
535
+ base_models[base_name].sort(key=sort_key)
536
+
537
+ return base_models
538
+
539
+ def get_popular_base_models(self) -> List[str]:
540
+ """Get list of popular base model names for suggestions."""
541
+ popular_patterns = [
542
+ "gpt-4o",
543
+ "gpt-4",
544
+ "gpt-3.5-turbo",
545
+ "claude-3-opus",
546
+ "claude-3-sonnet",
547
+ "claude-3-haiku",
548
+ "gemini-2",
549
+ "gemini-1.5-pro",
550
+ "gemini-1.5-flash",
551
+ "o1-preview",
552
+ "o1-mini",
553
+ "o3",
554
+ "o3-mini",
555
+ ]
556
+
557
+ available_base_models = []
558
+ for pattern in popular_patterns:
559
+ variants = self.get_model_variants(pattern)
560
+ if variants:
561
+ available_base_models.append(pattern)
562
+
563
+ return available_base_models
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tunacode-cli
3
- Version: 0.0.70
3
+ Version: 0.0.72
4
4
  Summary: Your agentic CLI developer.
5
5
  Project-URL: Homepage, https://tunacode.xyz/
6
6
  Project-URL: Repository, https://github.com/alchemiststudiosDOTai/tunacode
@@ -100,7 +100,7 @@ See the [Hatch Build System Guide](documentation/development/hatch-build-system.
100
100
 
101
101
  ## Configuration
102
102
 
103
- Choose your AI provider and set your API key. For more details, see the [Configuration Section](documentation/user/getting-started.md#2-configuration) in the Getting Started Guide.
103
+ Choose your AI provider and set your API key. For more details, see the [Configuration Section](documentation/user/getting-started.md#2-configuration) in the Getting Started Guide. For local models (LM Studio, Ollama, etc.), see the [Local Models Setup Guide](documentation/configuration/local-models.md).
104
104
 
105
105
  ### Recommended Models
106
106