model-forge-llm 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- model_forge_llm-0.2.0.dist-info/METADATA +327 -0
- model_forge_llm-0.2.0.dist-info/RECORD +14 -0
- model_forge_llm-0.2.0.dist-info/WHEEL +5 -0
- model_forge_llm-0.2.0.dist-info/entry_points.txt +2 -0
- model_forge_llm-0.2.0.dist-info/licenses/LICENSE +21 -0
- model_forge_llm-0.2.0.dist-info/top_level.txt +1 -0
- modelforge/__init__.py +7 -0
- modelforge/auth.py +503 -0
- modelforge/cli.py +720 -0
- modelforge/config.py +211 -0
- modelforge/exceptions.py +29 -0
- modelforge/logging_config.py +69 -0
- modelforge/modelsdev.py +364 -0
- modelforge/registry.py +272 -0
modelforge/config.py
ADDED
@@ -0,0 +1,211 @@
|
|
1
|
+
"""Configuration management for ModelForge."""
|
2
|
+
|
3
|
+
import json
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import Any
|
6
|
+
|
7
|
+
from .exceptions import ConfigurationError
|
8
|
+
from .logging_config import get_logger
|
9
|
+
|
10
|
+
logger = get_logger(__name__)
|
11
|
+
|
12
|
+
# Configuration file paths
|
13
|
+
GLOBAL_CONFIG_FILE = Path.home() / ".config" / "model-forge" / "config.json"
|
14
|
+
LOCAL_CONFIG_FILE = Path.cwd() / ".model-forge" / "config.json"
|
15
|
+
|
16
|
+
|
17
|
+
def get_config_path(local: bool = False) -> Path:
|
18
|
+
"""
|
19
|
+
Determines which config file to use based on the local flag.
|
20
|
+
|
21
|
+
Args:
|
22
|
+
local: If True, returns path to local project config
|
23
|
+
|
24
|
+
Returns:
|
25
|
+
Path to the configuration file to use
|
26
|
+
"""
|
27
|
+
if local:
|
28
|
+
return LOCAL_CONFIG_FILE
|
29
|
+
|
30
|
+
if LOCAL_CONFIG_FILE.exists():
|
31
|
+
return LOCAL_CONFIG_FILE
|
32
|
+
|
33
|
+
return GLOBAL_CONFIG_FILE
|
34
|
+
|
35
|
+
|
36
|
+
def get_config() -> tuple[dict[str, Any], Path]:
|
37
|
+
"""
|
38
|
+
Gets the configuration, with local taking precedence over global.
|
39
|
+
|
40
|
+
Returns:
|
41
|
+
Tuple of (config_data, config_path)
|
42
|
+
"""
|
43
|
+
config_path = get_config_path()
|
44
|
+
return get_config_from_path(config_path)
|
45
|
+
|
46
|
+
|
47
|
+
def get_config_from_path(path: Path) -> tuple[dict[str, Any], Path]:
|
48
|
+
"""
|
49
|
+
Load configuration from a specific path.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
path: Path to the configuration file
|
53
|
+
|
54
|
+
Returns:
|
55
|
+
Tuple of (config_data, config_path)
|
56
|
+
|
57
|
+
Raises:
|
58
|
+
ConfigurationError: If the file cannot be read or is invalid JSON
|
59
|
+
"""
|
60
|
+
if not path.exists():
|
61
|
+
return {}, path
|
62
|
+
|
63
|
+
try:
|
64
|
+
with path.open() as f:
|
65
|
+
config_data = json.load(f)
|
66
|
+
logger.debug("Successfully loaded configuration from: %s", path)
|
67
|
+
return config_data, path
|
68
|
+
except json.JSONDecodeError as e:
|
69
|
+
logger.exception("Invalid JSON in configuration file %s", path)
|
70
|
+
raise ConfigurationError from e
|
71
|
+
except OSError as e:
|
72
|
+
logger.exception("Could not read configuration file %s", path)
|
73
|
+
raise ConfigurationError from e
|
74
|
+
|
75
|
+
|
76
|
+
def save_config(config_data: dict[str, Any], local: bool = False) -> None:
|
77
|
+
"""
|
78
|
+
Save configuration data to the appropriate config file.
|
79
|
+
|
80
|
+
Args:
|
81
|
+
config_data: The configuration data to save
|
82
|
+
local: If True, saves to local project config
|
83
|
+
|
84
|
+
Raises:
|
85
|
+
ConfigurationError: If the file cannot be written
|
86
|
+
"""
|
87
|
+
config_path = get_config_path(local=local)
|
88
|
+
config_dir = config_path.parent
|
89
|
+
|
90
|
+
try:
|
91
|
+
config_dir.mkdir(parents=True, exist_ok=True)
|
92
|
+
with config_path.open("w") as f:
|
93
|
+
json.dump(config_data, f, indent=4)
|
94
|
+
logger.debug("Successfully saved configuration to: %s", config_path)
|
95
|
+
except OSError as e:
|
96
|
+
logger.exception("Could not save config file to %s", config_path)
|
97
|
+
raise ConfigurationError from e
|
98
|
+
|
99
|
+
|
100
|
+
def set_current_model(provider: str, model: str, local: bool = False) -> bool:
|
101
|
+
"""
|
102
|
+
Set the current model for the given provider.
|
103
|
+
|
104
|
+
Args:
|
105
|
+
provider: The provider name
|
106
|
+
model: The model alias
|
107
|
+
local: If True, modifies the local configuration
|
108
|
+
|
109
|
+
Returns:
|
110
|
+
True if successful, False otherwise
|
111
|
+
"""
|
112
|
+
# When setting a model, we should read from the specific config file,
|
113
|
+
# not the merged one.
|
114
|
+
target_config_path = get_config_path(local=local)
|
115
|
+
config_data, _ = get_config_from_path(target_config_path)
|
116
|
+
|
117
|
+
# Check if provider and model exist in the configuration
|
118
|
+
providers = config_data.get("providers", {})
|
119
|
+
if provider not in providers:
|
120
|
+
scope = "local" if local else "global"
|
121
|
+
print(f"Error: Provider '{provider}' not found in {scope} configuration.")
|
122
|
+
print("Please add it using 'modelforge config add' first.")
|
123
|
+
return False
|
124
|
+
|
125
|
+
models = providers[provider].get("models", {})
|
126
|
+
if model not in models:
|
127
|
+
scope = "local" if local else "global"
|
128
|
+
print(
|
129
|
+
f"Error: Model '{model}' for provider '{provider}' not found "
|
130
|
+
f"in {scope} configuration."
|
131
|
+
)
|
132
|
+
print("Please add it using 'modelforge config add' first.")
|
133
|
+
return False
|
134
|
+
|
135
|
+
# Set the current model
|
136
|
+
config_data["current_model"] = {"provider": provider, "model": model}
|
137
|
+
save_config(config_data, local=local)
|
138
|
+
scope_msg = "local" if local else "global"
|
139
|
+
success_message = (
|
140
|
+
f"Successfully set '{model}' from provider '{provider}' as the current "
|
141
|
+
f"model in the {scope_msg} config."
|
142
|
+
)
|
143
|
+
logger.info(success_message)
|
144
|
+
print(success_message)
|
145
|
+
return True
|
146
|
+
|
147
|
+
|
148
|
+
def get_current_model() -> dict[str, str] | None:
|
149
|
+
"""
|
150
|
+
Get the currently selected model.
|
151
|
+
|
152
|
+
Returns:
|
153
|
+
Dictionary with 'provider' and 'model' keys, or None if not set
|
154
|
+
"""
|
155
|
+
config_data, _ = get_config()
|
156
|
+
return config_data.get("current_model")
|
157
|
+
|
158
|
+
|
159
|
+
def migrate_old_config() -> None:
|
160
|
+
"""
|
161
|
+
Migrate configuration from the old location to the new global location.
|
162
|
+
|
163
|
+
Old location: ~/.config/model-forge/models.json
|
164
|
+
New location: ~/.config/model-forge/config.json
|
165
|
+
"""
|
166
|
+
# Define old and new paths
|
167
|
+
old_config_file = Path.home() / ".config" / "model-forge" / "models.json"
|
168
|
+
|
169
|
+
if old_config_file.exists():
|
170
|
+
try:
|
171
|
+
# Read old configuration
|
172
|
+
with old_config_file.open() as f:
|
173
|
+
old_config_data = json.load(f)
|
174
|
+
|
175
|
+
# Check if new configuration already exists
|
176
|
+
if not GLOBAL_CONFIG_FILE.exists():
|
177
|
+
# Create new config directory
|
178
|
+
GLOBAL_CONFIG_FILE.parent.mkdir(parents=True, exist_ok=True)
|
179
|
+
|
180
|
+
# Write old data to new location
|
181
|
+
with GLOBAL_CONFIG_FILE.open("w") as f:
|
182
|
+
json.dump(old_config_data, f, indent=4)
|
183
|
+
|
184
|
+
logger.info(
|
185
|
+
"Migrated configuration from %s to %s",
|
186
|
+
old_config_file,
|
187
|
+
GLOBAL_CONFIG_FILE,
|
188
|
+
)
|
189
|
+
print(f"Configuration migrated from {old_config_file}")
|
190
|
+
print(f" to: {GLOBAL_CONFIG_FILE}")
|
191
|
+
print("You can safely delete the old file if migration was successful.")
|
192
|
+
|
193
|
+
else:
|
194
|
+
logger.info(
|
195
|
+
"Old configuration file found, but new global configuration "
|
196
|
+
"already exists"
|
197
|
+
)
|
198
|
+
print(
|
199
|
+
"Old configuration file found, but a new global configuration "
|
200
|
+
"already exists."
|
201
|
+
)
|
202
|
+
print(f" - Old: {old_config_file}")
|
203
|
+
print(f" - New: {GLOBAL_CONFIG_FILE}")
|
204
|
+
print("Please manually review and merge if needed.")
|
205
|
+
|
206
|
+
except Exception as e:
|
207
|
+
logger.exception("Error during configuration migration")
|
208
|
+
print(f"Error during migration: {e}")
|
209
|
+
else:
|
210
|
+
print("No old configuration file found to migrate.")
|
211
|
+
print(f"Looking for: {old_config_file}")
|
modelforge/exceptions.py
ADDED
@@ -0,0 +1,29 @@
|
|
1
|
+
"""Custom exceptions for ModelForge."""
|
2
|
+
|
3
|
+
|
4
|
+
class ModelForgeError(Exception):
|
5
|
+
"""Base exception class for ModelForge."""
|
6
|
+
|
7
|
+
|
8
|
+
class ConfigurationError(ModelForgeError):
|
9
|
+
"""Raised when there's an issue with configuration."""
|
10
|
+
|
11
|
+
|
12
|
+
class AuthenticationError(ModelForgeError):
|
13
|
+
"""Raised when authentication fails."""
|
14
|
+
|
15
|
+
|
16
|
+
class ProviderError(ModelForgeError):
|
17
|
+
"""Raised when there's an issue with a provider."""
|
18
|
+
|
19
|
+
|
20
|
+
class ModelNotFoundError(ModelForgeError):
|
21
|
+
"""Raised when a requested model is not found."""
|
22
|
+
|
23
|
+
|
24
|
+
class TokenExpiredError(AuthenticationError):
|
25
|
+
"""Raised when an authentication token has expired."""
|
26
|
+
|
27
|
+
|
28
|
+
class InvalidProviderError(ProviderError):
|
29
|
+
"""Raised when an invalid provider is specified."""
|
@@ -0,0 +1,69 @@
|
|
1
|
+
"""Logging configuration for ModelForge."""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import sys
|
5
|
+
|
6
|
+
|
7
|
+
def setup_logging(
|
8
|
+
level: str = "INFO",
|
9
|
+
format_string: str | None = None,
|
10
|
+
filename: str | None = None,
|
11
|
+
) -> logging.Logger:
|
12
|
+
"""
|
13
|
+
Set up logging configuration for ModelForge.
|
14
|
+
|
15
|
+
Args:
|
16
|
+
level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
17
|
+
format_string: Custom format string for log messages
|
18
|
+
filename: Optional log file path
|
19
|
+
|
20
|
+
Returns:
|
21
|
+
Configured logger instance
|
22
|
+
"""
|
23
|
+
if format_string is None:
|
24
|
+
format_string = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
25
|
+
|
26
|
+
# Create logger
|
27
|
+
logger = logging.getLogger("modelforge")
|
28
|
+
logger.setLevel(getattr(logging, level.upper()))
|
29
|
+
|
30
|
+
# Clear existing handlers
|
31
|
+
logger.handlers.clear()
|
32
|
+
|
33
|
+
# Create formatter
|
34
|
+
formatter = logging.Formatter(format_string)
|
35
|
+
|
36
|
+
# Console handler
|
37
|
+
console_handler = logging.StreamHandler(sys.stdout)
|
38
|
+
console_handler.setLevel(getattr(logging, level.upper()))
|
39
|
+
console_handler.setFormatter(formatter)
|
40
|
+
logger.addHandler(console_handler)
|
41
|
+
|
42
|
+
# File handler (optional)
|
43
|
+
if filename:
|
44
|
+
file_handler = logging.FileHandler(filename)
|
45
|
+
file_handler.setLevel(getattr(logging, level.upper()))
|
46
|
+
file_handler.setFormatter(formatter)
|
47
|
+
logger.addHandler(file_handler)
|
48
|
+
|
49
|
+
# Prevent propagation to root logger
|
50
|
+
logger.propagate = False
|
51
|
+
|
52
|
+
return logger
|
53
|
+
|
54
|
+
|
55
|
+
def get_logger(name: str) -> logging.Logger:
|
56
|
+
"""
|
57
|
+
Get a logger instance with the specified name.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
name: Name of the logger (typically __name__)
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
Logger instance
|
64
|
+
"""
|
65
|
+
return logging.getLogger(f"modelforge.{name}")
|
66
|
+
|
67
|
+
|
68
|
+
# Default logger setup
|
69
|
+
default_logger = setup_logging()
|
modelforge/modelsdev.py
ADDED
@@ -0,0 +1,364 @@
|
|
1
|
+
"""Models.dev API integration for ModelForge."""
|
2
|
+
|
3
|
+
import json
|
4
|
+
import time
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import Any
|
7
|
+
|
8
|
+
import requests
|
9
|
+
|
10
|
+
from .logging_config import get_logger
|
11
|
+
|
12
|
+
logger = get_logger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
class ModelsDevClient:
|
16
|
+
"""Client for interacting with models.dev API."""
|
17
|
+
|
18
|
+
BASE_URL = "https://models.dev/api/v1"
|
19
|
+
CACHE_DIR = Path.home() / ".cache" / "model-forge" / "modelsdev"
|
20
|
+
|
21
|
+
# Cache TTL in seconds
|
22
|
+
CACHE_TTL = {
|
23
|
+
"providers": 24 * 3600, # 24 hours
|
24
|
+
"models": 24 * 3600, # 24 hours
|
25
|
+
"model_info": 7 * 24 * 3600, # 7 days
|
26
|
+
"provider_config": 7 * 24 * 3600, # 7 days
|
27
|
+
}
|
28
|
+
|
29
|
+
def __init__(self, api_key: str | None = None) -> None:
|
30
|
+
"""Initialize the models.dev client.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
api_key: Optional API key for authenticated requests
|
34
|
+
"""
|
35
|
+
self.api_key = api_key
|
36
|
+
self.session = requests.Session()
|
37
|
+
if api_key:
|
38
|
+
self.session.headers.update({"Authorization": f"Bearer {api_key}"})
|
39
|
+
|
40
|
+
# Ensure cache directory exists
|
41
|
+
self.CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
42
|
+
|
43
|
+
def _get_cache_path(self, endpoint: str, *args: str) -> Path:
|
44
|
+
"""Get cache file path for an endpoint."""
|
45
|
+
filename = f"{endpoint}_{'_'.join(args)}.json" if args else f"{endpoint}.json"
|
46
|
+
return self.CACHE_DIR / filename
|
47
|
+
|
48
|
+
def _is_cache_valid(self, cache_path: Path, ttl: int) -> bool:
|
49
|
+
"""Check if cached data is still valid."""
|
50
|
+
if not cache_path.exists():
|
51
|
+
return False
|
52
|
+
|
53
|
+
try:
|
54
|
+
cache_time = cache_path.stat().st_mtime
|
55
|
+
return time.time() - cache_time < ttl
|
56
|
+
except OSError:
|
57
|
+
return False
|
58
|
+
|
59
|
+
def _load_from_cache(self, cache_path: Path) -> dict[str, Any] | None:
|
60
|
+
"""Load data from cache."""
|
61
|
+
try:
|
62
|
+
with cache_path.open() as f:
|
63
|
+
data = json.load(f)
|
64
|
+
if isinstance(data, dict):
|
65
|
+
return data
|
66
|
+
if isinstance(data, list):
|
67
|
+
return {"data": data}
|
68
|
+
return {"data": data}
|
69
|
+
except OSError:
|
70
|
+
logger.warning("Failed to read cache file: %s", cache_path)
|
71
|
+
return None
|
72
|
+
except json.JSONDecodeError as e:
|
73
|
+
logger.warning("Invalid JSON in cache file %s: %s", cache_path, e)
|
74
|
+
return None
|
75
|
+
|
76
|
+
def _save_to_cache(self, cache_path: Path, data: dict[str, Any]) -> None:
|
77
|
+
"""Save data to cache."""
|
78
|
+
try:
|
79
|
+
with cache_path.open("w") as f:
|
80
|
+
json.dump(data, f, indent=2)
|
81
|
+
except OSError as e:
|
82
|
+
logger.warning("Failed to write cache file %s: %s", cache_path, e)
|
83
|
+
|
84
|
+
def get_providers(self, force_refresh: bool = False) -> list[dict[str, Any]]:
|
85
|
+
"""Get list of supported providers from models.dev.
|
86
|
+
|
87
|
+
Args:
|
88
|
+
force_refresh: Force refresh from API even if cache is valid
|
89
|
+
|
90
|
+
Returns:
|
91
|
+
List of provider information dictionaries
|
92
|
+
"""
|
93
|
+
cache_path = self._get_cache_path("providers")
|
94
|
+
|
95
|
+
if not force_refresh and self._is_cache_valid(
|
96
|
+
cache_path, self.CACHE_TTL["providers"]
|
97
|
+
):
|
98
|
+
cached_data = self._load_from_cache(cache_path)
|
99
|
+
if cached_data and isinstance(cached_data, dict) and "data" in cached_data:
|
100
|
+
data = cached_data["data"]
|
101
|
+
if isinstance(data, list):
|
102
|
+
return data
|
103
|
+
return [data]
|
104
|
+
|
105
|
+
try:
|
106
|
+
response = self.session.get(f"{self.BASE_URL}/providers")
|
107
|
+
response.raise_for_status()
|
108
|
+
providers_data = response.json()
|
109
|
+
providers: list[dict[str, Any]] = (
|
110
|
+
[providers_data]
|
111
|
+
if isinstance(providers_data, dict)
|
112
|
+
else providers_data
|
113
|
+
if isinstance(providers_data, list)
|
114
|
+
else [providers_data]
|
115
|
+
)
|
116
|
+
except requests.exceptions.ConnectionError:
|
117
|
+
logger.exception("Connection error while fetching providers")
|
118
|
+
cached_data = self._load_from_cache(cache_path)
|
119
|
+
if cached_data and isinstance(cached_data, dict) and "data" in cached_data:
|
120
|
+
logger.info("Using stale cached providers data")
|
121
|
+
return cached_data["data"]
|
122
|
+
raise
|
123
|
+
except requests.exceptions.HTTPError:
|
124
|
+
logger.exception("HTTP error while fetching providers")
|
125
|
+
cached_data = self._load_from_cache(cache_path)
|
126
|
+
if cached_data and isinstance(cached_data, dict) and "data" in cached_data:
|
127
|
+
logger.info("Using stale cached providers data")
|
128
|
+
return cached_data["data"]
|
129
|
+
raise
|
130
|
+
except requests.exceptions.Timeout:
|
131
|
+
logger.exception("Timeout while fetching providers")
|
132
|
+
cached_data = self._load_from_cache(cache_path)
|
133
|
+
if cached_data and isinstance(cached_data, dict) and "data" in cached_data:
|
134
|
+
logger.info("Using stale cached providers data")
|
135
|
+
return cached_data["data"]
|
136
|
+
raise
|
137
|
+
else:
|
138
|
+
self._save_to_cache(cache_path, {"data": providers})
|
139
|
+
return providers
|
140
|
+
|
141
|
+
def get_models(
|
142
|
+
self, provider: str | None = None, force_refresh: bool = False
|
143
|
+
) -> list[dict[str, Any]]:
|
144
|
+
"""Get list of models from models.dev.
|
145
|
+
|
146
|
+
Args:
|
147
|
+
provider: Optional provider filter
|
148
|
+
force_refresh: Force refresh from API even if cache is valid
|
149
|
+
|
150
|
+
Returns:
|
151
|
+
List of model information dictionaries
|
152
|
+
"""
|
153
|
+
cache_path = self._get_cache_path("models", provider or "all")
|
154
|
+
|
155
|
+
if not force_refresh and self._is_cache_valid(
|
156
|
+
cache_path, self.CACHE_TTL["models"]
|
157
|
+
):
|
158
|
+
cached_data = self._load_from_cache(cache_path)
|
159
|
+
if cached_data and isinstance(cached_data, dict) and "data" in cached_data:
|
160
|
+
data = cached_data["data"]
|
161
|
+
if isinstance(data, list):
|
162
|
+
return data
|
163
|
+
return [data]
|
164
|
+
|
165
|
+
try:
|
166
|
+
url = f"{self.BASE_URL}/models"
|
167
|
+
params = {}
|
168
|
+
if provider:
|
169
|
+
params["provider"] = provider
|
170
|
+
|
171
|
+
response = self.session.get(url, params=params)
|
172
|
+
response.raise_for_status()
|
173
|
+
models_data = response.json()
|
174
|
+
models: list[dict[str, Any]] = (
|
175
|
+
models_data if isinstance(models_data, list) else [models_data]
|
176
|
+
)
|
177
|
+
except requests.exceptions.ConnectionError:
|
178
|
+
logger.exception("Connection error while fetching models")
|
179
|
+
cached_data = self._load_from_cache(cache_path)
|
180
|
+
if cached_data and isinstance(cached_data, dict) and "data" in cached_data:
|
181
|
+
logger.info("Using stale cached models data")
|
182
|
+
return cached_data["data"]
|
183
|
+
raise
|
184
|
+
except requests.exceptions.HTTPError:
|
185
|
+
logger.exception("HTTP error while fetching models")
|
186
|
+
cached_data = self._load_from_cache(cache_path)
|
187
|
+
if cached_data and isinstance(cached_data, dict) and "data" in cached_data:
|
188
|
+
logger.info("Using stale cached models data")
|
189
|
+
return cached_data["data"]
|
190
|
+
raise
|
191
|
+
except requests.exceptions.Timeout:
|
192
|
+
logger.exception("Timeout while fetching models")
|
193
|
+
cached_data = self._load_from_cache(cache_path)
|
194
|
+
if cached_data and isinstance(cached_data, dict) and "data" in cached_data:
|
195
|
+
logger.info("Using stale cached models data")
|
196
|
+
return cached_data["data"]
|
197
|
+
raise
|
198
|
+
else:
|
199
|
+
self._save_to_cache(cache_path, {"data": models})
|
200
|
+
return models
|
201
|
+
|
202
|
+
def get_model_info(
|
203
|
+
self, provider: str, model: str, force_refresh: bool = False
|
204
|
+
) -> dict[str, Any]:
|
205
|
+
"""Get detailed information about a specific model.
|
206
|
+
|
207
|
+
Args:
|
208
|
+
provider: Provider name
|
209
|
+
model: Model identifier
|
210
|
+
force_refresh: Force refresh from API even if cache is valid
|
211
|
+
|
212
|
+
Returns:
|
213
|
+
Model information dictionary
|
214
|
+
"""
|
215
|
+
cache_path = self._get_cache_path("model_info", provider, model)
|
216
|
+
|
217
|
+
if not force_refresh and self._is_cache_valid(
|
218
|
+
cache_path, self.CACHE_TTL["model_info"]
|
219
|
+
):
|
220
|
+
cached_data = self._load_from_cache(cache_path)
|
221
|
+
if cached_data and isinstance(cached_data, dict):
|
222
|
+
return cached_data
|
223
|
+
|
224
|
+
try:
|
225
|
+
url = f"{self.BASE_URL}/models/{provider}/{model}"
|
226
|
+
response = self.session.get(url)
|
227
|
+
response.raise_for_status()
|
228
|
+
model_info = response.json()
|
229
|
+
if not isinstance(model_info, dict):
|
230
|
+
model_info = {"model": model_info}
|
231
|
+
except requests.exceptions.ConnectionError:
|
232
|
+
logger.exception("Connection error while fetching model info")
|
233
|
+
cached_data = self._load_from_cache(cache_path)
|
234
|
+
if cached_data and isinstance(cached_data, dict):
|
235
|
+
logger.info("Using stale cached model info")
|
236
|
+
return cached_data
|
237
|
+
raise
|
238
|
+
except requests.exceptions.HTTPError:
|
239
|
+
logger.exception("HTTP error while fetching model info")
|
240
|
+
cached_data = self._load_from_cache(cache_path)
|
241
|
+
if cached_data and isinstance(cached_data, dict):
|
242
|
+
logger.info("Using stale cached model info")
|
243
|
+
return cached_data
|
244
|
+
raise
|
245
|
+
except requests.exceptions.Timeout:
|
246
|
+
logger.exception("Timeout while fetching model info")
|
247
|
+
cached_data = self._load_from_cache(cache_path)
|
248
|
+
if cached_data and isinstance(cached_data, dict):
|
249
|
+
logger.info("Using stale cached model info")
|
250
|
+
return cached_data
|
251
|
+
raise
|
252
|
+
else:
|
253
|
+
self._save_to_cache(cache_path, model_info)
|
254
|
+
return model_info
|
255
|
+
|
256
|
+
def get_provider_config(
|
257
|
+
self, provider: str, force_refresh: bool = False
|
258
|
+
) -> dict[str, Any]:
|
259
|
+
"""Get configuration template for a provider.
|
260
|
+
|
261
|
+
Args:
|
262
|
+
provider: Provider name
|
263
|
+
force_refresh: Force refresh from API even if cache is valid
|
264
|
+
|
265
|
+
Returns:
|
266
|
+
Provider configuration template
|
267
|
+
"""
|
268
|
+
cache_path = self._get_cache_path("provider_config", provider)
|
269
|
+
|
270
|
+
if not force_refresh and self._is_cache_valid(
|
271
|
+
cache_path, self.CACHE_TTL["provider_config"]
|
272
|
+
):
|
273
|
+
cached_data = self._load_from_cache(cache_path)
|
274
|
+
if cached_data and isinstance(cached_data, dict):
|
275
|
+
return cached_data
|
276
|
+
|
277
|
+
try:
|
278
|
+
url = f"{self.BASE_URL}/providers/{provider}/config"
|
279
|
+
response = self.session.get(url)
|
280
|
+
response.raise_for_status()
|
281
|
+
provider_config = response.json()
|
282
|
+
if not isinstance(provider_config, dict):
|
283
|
+
provider_config = {"config": provider_config}
|
284
|
+
return provider_config
|
285
|
+
except requests.exceptions.ConnectionError:
|
286
|
+
logger.exception("Connection error while fetching provider config")
|
287
|
+
cached_data = self._load_from_cache(cache_path)
|
288
|
+
if cached_data and isinstance(cached_data, dict):
|
289
|
+
logger.info("Using stale cached provider config")
|
290
|
+
return cached_data
|
291
|
+
raise
|
292
|
+
except requests.exceptions.HTTPError:
|
293
|
+
logger.exception("HTTP error while fetching provider config")
|
294
|
+
cached_data = self._load_from_cache(cache_path)
|
295
|
+
if cached_data and isinstance(cached_data, dict):
|
296
|
+
logger.info("Using stale cached provider config")
|
297
|
+
return cached_data
|
298
|
+
raise
|
299
|
+
except requests.exceptions.Timeout:
|
300
|
+
logger.exception("Timeout while fetching provider config")
|
301
|
+
cached_data = self._load_from_cache(cache_path)
|
302
|
+
if cached_data and isinstance(cached_data, dict):
|
303
|
+
logger.info("Using stale cached provider config")
|
304
|
+
return cached_data
|
305
|
+
raise
|
306
|
+
else:
|
307
|
+
self._save_to_cache(cache_path, provider_config)
|
308
|
+
return provider_config
|
309
|
+
|
310
|
+
def clear_cache(self) -> None:
|
311
|
+
"""Clear all cached data."""
|
312
|
+
try:
|
313
|
+
for cache_file in self.CACHE_DIR.glob("*.json"):
|
314
|
+
cache_file.unlink()
|
315
|
+
logger.info("Cleared models.dev cache")
|
316
|
+
except OSError:
|
317
|
+
logger.exception("Failed to clear cache")
|
318
|
+
|
319
|
+
def search_models(
|
320
|
+
self,
|
321
|
+
query: str,
|
322
|
+
provider: str | None = None,
|
323
|
+
capabilities: list[str] | None = None,
|
324
|
+
max_price: float | None = None,
|
325
|
+
force_refresh: bool = False,
|
326
|
+
) -> list[dict[str, Any]]:
|
327
|
+
"""Search models based on criteria.
|
328
|
+
|
329
|
+
Args:
|
330
|
+
query: Search query string
|
331
|
+
provider: Optional provider filter
|
332
|
+
capabilities: Optional list of required capabilities
|
333
|
+
max_price: Optional maximum price filter
|
334
|
+
force_refresh: Force refresh from API
|
335
|
+
|
336
|
+
Returns:
|
337
|
+
List of matching models
|
338
|
+
"""
|
339
|
+
models = self.get_models(provider=provider, force_refresh=force_refresh)
|
340
|
+
|
341
|
+
results = []
|
342
|
+
for model in models:
|
343
|
+
# Text search in name and description
|
344
|
+
model_text = (
|
345
|
+
f"{model.get('name', '')} {model.get('description', '')}".lower()
|
346
|
+
)
|
347
|
+
if query.lower() not in model_text:
|
348
|
+
continue
|
349
|
+
|
350
|
+
# Capabilities filter
|
351
|
+
if capabilities:
|
352
|
+
model_caps = set(model.get("capabilities", []))
|
353
|
+
if not set(capabilities).issubset(model_caps):
|
354
|
+
continue
|
355
|
+
|
356
|
+
# Price filter
|
357
|
+
if max_price is not None:
|
358
|
+
price = model.get("pricing", {}).get("input_per_1k_tokens")
|
359
|
+
if price is not None and price > max_price:
|
360
|
+
continue
|
361
|
+
|
362
|
+
results.append(model)
|
363
|
+
|
364
|
+
return results
|