model-forge-llm 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
modelforge/config.py ADDED
@@ -0,0 +1,211 @@
1
+ """Configuration management for ModelForge."""
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ from .exceptions import ConfigurationError
8
+ from .logging_config import get_logger
9
+
10
+ logger = get_logger(__name__)
11
+
12
+ # Configuration file paths
13
+ GLOBAL_CONFIG_FILE = Path.home() / ".config" / "model-forge" / "config.json"
14
+ LOCAL_CONFIG_FILE = Path.cwd() / ".model-forge" / "config.json"
15
+
16
+
17
+ def get_config_path(local: bool = False) -> Path:
18
+ """
19
+ Determines which config file to use based on the local flag.
20
+
21
+ Args:
22
+ local: If True, returns path to local project config
23
+
24
+ Returns:
25
+ Path to the configuration file to use
26
+ """
27
+ if local:
28
+ return LOCAL_CONFIG_FILE
29
+
30
+ if LOCAL_CONFIG_FILE.exists():
31
+ return LOCAL_CONFIG_FILE
32
+
33
+ return GLOBAL_CONFIG_FILE
34
+
35
+
36
+ def get_config() -> tuple[dict[str, Any], Path]:
37
+ """
38
+ Gets the configuration, with local taking precedence over global.
39
+
40
+ Returns:
41
+ Tuple of (config_data, config_path)
42
+ """
43
+ config_path = get_config_path()
44
+ return get_config_from_path(config_path)
45
+
46
+
47
+ def get_config_from_path(path: Path) -> tuple[dict[str, Any], Path]:
48
+ """
49
+ Load configuration from a specific path.
50
+
51
+ Args:
52
+ path: Path to the configuration file
53
+
54
+ Returns:
55
+ Tuple of (config_data, config_path)
56
+
57
+ Raises:
58
+ ConfigurationError: If the file cannot be read or is invalid JSON
59
+ """
60
+ if not path.exists():
61
+ return {}, path
62
+
63
+ try:
64
+ with path.open() as f:
65
+ config_data = json.load(f)
66
+ logger.debug("Successfully loaded configuration from: %s", path)
67
+ return config_data, path
68
+ except json.JSONDecodeError as e:
69
+ logger.exception("Invalid JSON in configuration file %s", path)
70
+ raise ConfigurationError from e
71
+ except OSError as e:
72
+ logger.exception("Could not read configuration file %s", path)
73
+ raise ConfigurationError from e
74
+
75
+
76
+ def save_config(config_data: dict[str, Any], local: bool = False) -> None:
77
+ """
78
+ Save configuration data to the appropriate config file.
79
+
80
+ Args:
81
+ config_data: The configuration data to save
82
+ local: If True, saves to local project config
83
+
84
+ Raises:
85
+ ConfigurationError: If the file cannot be written
86
+ """
87
+ config_path = get_config_path(local=local)
88
+ config_dir = config_path.parent
89
+
90
+ try:
91
+ config_dir.mkdir(parents=True, exist_ok=True)
92
+ with config_path.open("w") as f:
93
+ json.dump(config_data, f, indent=4)
94
+ logger.debug("Successfully saved configuration to: %s", config_path)
95
+ except OSError as e:
96
+ logger.exception("Could not save config file to %s", config_path)
97
+ raise ConfigurationError from e
98
+
99
+
100
+ def set_current_model(provider: str, model: str, local: bool = False) -> bool:
101
+ """
102
+ Set the current model for the given provider.
103
+
104
+ Args:
105
+ provider: The provider name
106
+ model: The model alias
107
+ local: If True, modifies the local configuration
108
+
109
+ Returns:
110
+ True if successful, False otherwise
111
+ """
112
+ # When setting a model, we should read from the specific config file,
113
+ # not the merged one.
114
+ target_config_path = get_config_path(local=local)
115
+ config_data, _ = get_config_from_path(target_config_path)
116
+
117
+ # Check if provider and model exist in the configuration
118
+ providers = config_data.get("providers", {})
119
+ if provider not in providers:
120
+ scope = "local" if local else "global"
121
+ print(f"Error: Provider '{provider}' not found in {scope} configuration.")
122
+ print("Please add it using 'modelforge config add' first.")
123
+ return False
124
+
125
+ models = providers[provider].get("models", {})
126
+ if model not in models:
127
+ scope = "local" if local else "global"
128
+ print(
129
+ f"Error: Model '{model}' for provider '{provider}' not found "
130
+ f"in {scope} configuration."
131
+ )
132
+ print("Please add it using 'modelforge config add' first.")
133
+ return False
134
+
135
+ # Set the current model
136
+ config_data["current_model"] = {"provider": provider, "model": model}
137
+ save_config(config_data, local=local)
138
+ scope_msg = "local" if local else "global"
139
+ success_message = (
140
+ f"Successfully set '{model}' from provider '{provider}' as the current "
141
+ f"model in the {scope_msg} config."
142
+ )
143
+ logger.info(success_message)
144
+ print(success_message)
145
+ return True
146
+
147
+
148
+ def get_current_model() -> dict[str, str] | None:
149
+ """
150
+ Get the currently selected model.
151
+
152
+ Returns:
153
+ Dictionary with 'provider' and 'model' keys, or None if not set
154
+ """
155
+ config_data, _ = get_config()
156
+ return config_data.get("current_model")
157
+
158
+
159
+ def migrate_old_config() -> None:
160
+ """
161
+ Migrate configuration from the old location to the new global location.
162
+
163
+ Old location: ~/.config/model-forge/models.json
164
+ New location: ~/.config/model-forge/config.json
165
+ """
166
+ # Define old and new paths
167
+ old_config_file = Path.home() / ".config" / "model-forge" / "models.json"
168
+
169
+ if old_config_file.exists():
170
+ try:
171
+ # Read old configuration
172
+ with old_config_file.open() as f:
173
+ old_config_data = json.load(f)
174
+
175
+ # Check if new configuration already exists
176
+ if not GLOBAL_CONFIG_FILE.exists():
177
+ # Create new config directory
178
+ GLOBAL_CONFIG_FILE.parent.mkdir(parents=True, exist_ok=True)
179
+
180
+ # Write old data to new location
181
+ with GLOBAL_CONFIG_FILE.open("w") as f:
182
+ json.dump(old_config_data, f, indent=4)
183
+
184
+ logger.info(
185
+ "Migrated configuration from %s to %s",
186
+ old_config_file,
187
+ GLOBAL_CONFIG_FILE,
188
+ )
189
+ print(f"Configuration migrated from {old_config_file}")
190
+ print(f" to: {GLOBAL_CONFIG_FILE}")
191
+ print("You can safely delete the old file if migration was successful.")
192
+
193
+ else:
194
+ logger.info(
195
+ "Old configuration file found, but new global configuration "
196
+ "already exists"
197
+ )
198
+ print(
199
+ "Old configuration file found, but a new global configuration "
200
+ "already exists."
201
+ )
202
+ print(f" - Old: {old_config_file}")
203
+ print(f" - New: {GLOBAL_CONFIG_FILE}")
204
+ print("Please manually review and merge if needed.")
205
+
206
+ except Exception as e:
207
+ logger.exception("Error during configuration migration")
208
+ print(f"Error during migration: {e}")
209
+ else:
210
+ print("No old configuration file found to migrate.")
211
+ print(f"Looking for: {old_config_file}")
@@ -0,0 +1,29 @@
1
+ """Custom exceptions for ModelForge."""
2
+
3
+
4
+ class ModelForgeError(Exception):
5
+ """Base exception class for ModelForge."""
6
+
7
+
8
+ class ConfigurationError(ModelForgeError):
9
+ """Raised when there's an issue with configuration."""
10
+
11
+
12
+ class AuthenticationError(ModelForgeError):
13
+ """Raised when authentication fails."""
14
+
15
+
16
+ class ProviderError(ModelForgeError):
17
+ """Raised when there's an issue with a provider."""
18
+
19
+
20
+ class ModelNotFoundError(ModelForgeError):
21
+ """Raised when a requested model is not found."""
22
+
23
+
24
+ class TokenExpiredError(AuthenticationError):
25
+ """Raised when an authentication token has expired."""
26
+
27
+
28
+ class InvalidProviderError(ProviderError):
29
+ """Raised when an invalid provider is specified."""
@@ -0,0 +1,69 @@
1
+ """Logging configuration for ModelForge."""
2
+
3
+ import logging
4
+ import sys
5
+
6
+
7
+ def setup_logging(
8
+ level: str = "INFO",
9
+ format_string: str | None = None,
10
+ filename: str | None = None,
11
+ ) -> logging.Logger:
12
+ """
13
+ Set up logging configuration for ModelForge.
14
+
15
+ Args:
16
+ level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
17
+ format_string: Custom format string for log messages
18
+ filename: Optional log file path
19
+
20
+ Returns:
21
+ Configured logger instance
22
+ """
23
+ if format_string is None:
24
+ format_string = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
25
+
26
+ # Create logger
27
+ logger = logging.getLogger("modelforge")
28
+ logger.setLevel(getattr(logging, level.upper()))
29
+
30
+ # Clear existing handlers
31
+ logger.handlers.clear()
32
+
33
+ # Create formatter
34
+ formatter = logging.Formatter(format_string)
35
+
36
+ # Console handler
37
+ console_handler = logging.StreamHandler(sys.stdout)
38
+ console_handler.setLevel(getattr(logging, level.upper()))
39
+ console_handler.setFormatter(formatter)
40
+ logger.addHandler(console_handler)
41
+
42
+ # File handler (optional)
43
+ if filename:
44
+ file_handler = logging.FileHandler(filename)
45
+ file_handler.setLevel(getattr(logging, level.upper()))
46
+ file_handler.setFormatter(formatter)
47
+ logger.addHandler(file_handler)
48
+
49
+ # Prevent propagation to root logger
50
+ logger.propagate = False
51
+
52
+ return logger
53
+
54
+
55
+ def get_logger(name: str) -> logging.Logger:
56
+ """
57
+ Get a logger instance with the specified name.
58
+
59
+ Args:
60
+ name: Name of the logger (typically __name__)
61
+
62
+ Returns:
63
+ Logger instance
64
+ """
65
+ return logging.getLogger(f"modelforge.{name}")
66
+
67
+
68
+ # Default logger setup
69
+ default_logger = setup_logging()
@@ -0,0 +1,364 @@
1
+ """Models.dev API integration for ModelForge."""
2
+
3
+ import json
4
+ import time
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import requests
9
+
10
+ from .logging_config import get_logger
11
+
12
+ logger = get_logger(__name__)
13
+
14
+
15
+ class ModelsDevClient:
16
+ """Client for interacting with models.dev API."""
17
+
18
+ BASE_URL = "https://models.dev/api/v1"
19
+ CACHE_DIR = Path.home() / ".cache" / "model-forge" / "modelsdev"
20
+
21
+ # Cache TTL in seconds
22
+ CACHE_TTL = {
23
+ "providers": 24 * 3600, # 24 hours
24
+ "models": 24 * 3600, # 24 hours
25
+ "model_info": 7 * 24 * 3600, # 7 days
26
+ "provider_config": 7 * 24 * 3600, # 7 days
27
+ }
28
+
29
+ def __init__(self, api_key: str | None = None) -> None:
30
+ """Initialize the models.dev client.
31
+
32
+ Args:
33
+ api_key: Optional API key for authenticated requests
34
+ """
35
+ self.api_key = api_key
36
+ self.session = requests.Session()
37
+ if api_key:
38
+ self.session.headers.update({"Authorization": f"Bearer {api_key}"})
39
+
40
+ # Ensure cache directory exists
41
+ self.CACHE_DIR.mkdir(parents=True, exist_ok=True)
42
+
43
+ def _get_cache_path(self, endpoint: str, *args: str) -> Path:
44
+ """Get cache file path for an endpoint."""
45
+ filename = f"{endpoint}_{'_'.join(args)}.json" if args else f"{endpoint}.json"
46
+ return self.CACHE_DIR / filename
47
+
48
+ def _is_cache_valid(self, cache_path: Path, ttl: int) -> bool:
49
+ """Check if cached data is still valid."""
50
+ if not cache_path.exists():
51
+ return False
52
+
53
+ try:
54
+ cache_time = cache_path.stat().st_mtime
55
+ return time.time() - cache_time < ttl
56
+ except OSError:
57
+ return False
58
+
59
+ def _load_from_cache(self, cache_path: Path) -> dict[str, Any] | None:
60
+ """Load data from cache."""
61
+ try:
62
+ with cache_path.open() as f:
63
+ data = json.load(f)
64
+ if isinstance(data, dict):
65
+ return data
66
+ if isinstance(data, list):
67
+ return {"data": data}
68
+ return {"data": data}
69
+ except OSError:
70
+ logger.warning("Failed to read cache file: %s", cache_path)
71
+ return None
72
+ except json.JSONDecodeError as e:
73
+ logger.warning("Invalid JSON in cache file %s: %s", cache_path, e)
74
+ return None
75
+
76
+ def _save_to_cache(self, cache_path: Path, data: dict[str, Any]) -> None:
77
+ """Save data to cache."""
78
+ try:
79
+ with cache_path.open("w") as f:
80
+ json.dump(data, f, indent=2)
81
+ except OSError as e:
82
+ logger.warning("Failed to write cache file %s: %s", cache_path, e)
83
+
84
+ def get_providers(self, force_refresh: bool = False) -> list[dict[str, Any]]:
85
+ """Get list of supported providers from models.dev.
86
+
87
+ Args:
88
+ force_refresh: Force refresh from API even if cache is valid
89
+
90
+ Returns:
91
+ List of provider information dictionaries
92
+ """
93
+ cache_path = self._get_cache_path("providers")
94
+
95
+ if not force_refresh and self._is_cache_valid(
96
+ cache_path, self.CACHE_TTL["providers"]
97
+ ):
98
+ cached_data = self._load_from_cache(cache_path)
99
+ if cached_data and isinstance(cached_data, dict) and "data" in cached_data:
100
+ data = cached_data["data"]
101
+ if isinstance(data, list):
102
+ return data
103
+ return [data]
104
+
105
+ try:
106
+ response = self.session.get(f"{self.BASE_URL}/providers")
107
+ response.raise_for_status()
108
+ providers_data = response.json()
109
+ providers: list[dict[str, Any]] = (
110
+ [providers_data]
111
+ if isinstance(providers_data, dict)
112
+ else providers_data
113
+ if isinstance(providers_data, list)
114
+ else [providers_data]
115
+ )
116
+ except requests.exceptions.ConnectionError:
117
+ logger.exception("Connection error while fetching providers")
118
+ cached_data = self._load_from_cache(cache_path)
119
+ if cached_data and isinstance(cached_data, dict) and "data" in cached_data:
120
+ logger.info("Using stale cached providers data")
121
+ return cached_data["data"]
122
+ raise
123
+ except requests.exceptions.HTTPError:
124
+ logger.exception("HTTP error while fetching providers")
125
+ cached_data = self._load_from_cache(cache_path)
126
+ if cached_data and isinstance(cached_data, dict) and "data" in cached_data:
127
+ logger.info("Using stale cached providers data")
128
+ return cached_data["data"]
129
+ raise
130
+ except requests.exceptions.Timeout:
131
+ logger.exception("Timeout while fetching providers")
132
+ cached_data = self._load_from_cache(cache_path)
133
+ if cached_data and isinstance(cached_data, dict) and "data" in cached_data:
134
+ logger.info("Using stale cached providers data")
135
+ return cached_data["data"]
136
+ raise
137
+ else:
138
+ self._save_to_cache(cache_path, {"data": providers})
139
+ return providers
140
+
141
+ def get_models(
142
+ self, provider: str | None = None, force_refresh: bool = False
143
+ ) -> list[dict[str, Any]]:
144
+ """Get list of models from models.dev.
145
+
146
+ Args:
147
+ provider: Optional provider filter
148
+ force_refresh: Force refresh from API even if cache is valid
149
+
150
+ Returns:
151
+ List of model information dictionaries
152
+ """
153
+ cache_path = self._get_cache_path("models", provider or "all")
154
+
155
+ if not force_refresh and self._is_cache_valid(
156
+ cache_path, self.CACHE_TTL["models"]
157
+ ):
158
+ cached_data = self._load_from_cache(cache_path)
159
+ if cached_data and isinstance(cached_data, dict) and "data" in cached_data:
160
+ data = cached_data["data"]
161
+ if isinstance(data, list):
162
+ return data
163
+ return [data]
164
+
165
+ try:
166
+ url = f"{self.BASE_URL}/models"
167
+ params = {}
168
+ if provider:
169
+ params["provider"] = provider
170
+
171
+ response = self.session.get(url, params=params)
172
+ response.raise_for_status()
173
+ models_data = response.json()
174
+ models: list[dict[str, Any]] = (
175
+ models_data if isinstance(models_data, list) else [models_data]
176
+ )
177
+ except requests.exceptions.ConnectionError:
178
+ logger.exception("Connection error while fetching models")
179
+ cached_data = self._load_from_cache(cache_path)
180
+ if cached_data and isinstance(cached_data, dict) and "data" in cached_data:
181
+ logger.info("Using stale cached models data")
182
+ return cached_data["data"]
183
+ raise
184
+ except requests.exceptions.HTTPError:
185
+ logger.exception("HTTP error while fetching models")
186
+ cached_data = self._load_from_cache(cache_path)
187
+ if cached_data and isinstance(cached_data, dict) and "data" in cached_data:
188
+ logger.info("Using stale cached models data")
189
+ return cached_data["data"]
190
+ raise
191
+ except requests.exceptions.Timeout:
192
+ logger.exception("Timeout while fetching models")
193
+ cached_data = self._load_from_cache(cache_path)
194
+ if cached_data and isinstance(cached_data, dict) and "data" in cached_data:
195
+ logger.info("Using stale cached models data")
196
+ return cached_data["data"]
197
+ raise
198
+ else:
199
+ self._save_to_cache(cache_path, {"data": models})
200
+ return models
201
+
202
+ def get_model_info(
203
+ self, provider: str, model: str, force_refresh: bool = False
204
+ ) -> dict[str, Any]:
205
+ """Get detailed information about a specific model.
206
+
207
+ Args:
208
+ provider: Provider name
209
+ model: Model identifier
210
+ force_refresh: Force refresh from API even if cache is valid
211
+
212
+ Returns:
213
+ Model information dictionary
214
+ """
215
+ cache_path = self._get_cache_path("model_info", provider, model)
216
+
217
+ if not force_refresh and self._is_cache_valid(
218
+ cache_path, self.CACHE_TTL["model_info"]
219
+ ):
220
+ cached_data = self._load_from_cache(cache_path)
221
+ if cached_data and isinstance(cached_data, dict):
222
+ return cached_data
223
+
224
+ try:
225
+ url = f"{self.BASE_URL}/models/{provider}/{model}"
226
+ response = self.session.get(url)
227
+ response.raise_for_status()
228
+ model_info = response.json()
229
+ if not isinstance(model_info, dict):
230
+ model_info = {"model": model_info}
231
+ except requests.exceptions.ConnectionError:
232
+ logger.exception("Connection error while fetching model info")
233
+ cached_data = self._load_from_cache(cache_path)
234
+ if cached_data and isinstance(cached_data, dict):
235
+ logger.info("Using stale cached model info")
236
+ return cached_data
237
+ raise
238
+ except requests.exceptions.HTTPError:
239
+ logger.exception("HTTP error while fetching model info")
240
+ cached_data = self._load_from_cache(cache_path)
241
+ if cached_data and isinstance(cached_data, dict):
242
+ logger.info("Using stale cached model info")
243
+ return cached_data
244
+ raise
245
+ except requests.exceptions.Timeout:
246
+ logger.exception("Timeout while fetching model info")
247
+ cached_data = self._load_from_cache(cache_path)
248
+ if cached_data and isinstance(cached_data, dict):
249
+ logger.info("Using stale cached model info")
250
+ return cached_data
251
+ raise
252
+ else:
253
+ self._save_to_cache(cache_path, model_info)
254
+ return model_info
255
+
256
+ def get_provider_config(
257
+ self, provider: str, force_refresh: bool = False
258
+ ) -> dict[str, Any]:
259
+ """Get configuration template for a provider.
260
+
261
+ Args:
262
+ provider: Provider name
263
+ force_refresh: Force refresh from API even if cache is valid
264
+
265
+ Returns:
266
+ Provider configuration template
267
+ """
268
+ cache_path = self._get_cache_path("provider_config", provider)
269
+
270
+ if not force_refresh and self._is_cache_valid(
271
+ cache_path, self.CACHE_TTL["provider_config"]
272
+ ):
273
+ cached_data = self._load_from_cache(cache_path)
274
+ if cached_data and isinstance(cached_data, dict):
275
+ return cached_data
276
+
277
+ try:
278
+ url = f"{self.BASE_URL}/providers/{provider}/config"
279
+ response = self.session.get(url)
280
+ response.raise_for_status()
281
+ provider_config = response.json()
282
+ if not isinstance(provider_config, dict):
283
+ provider_config = {"config": provider_config}
284
+ return provider_config
285
+ except requests.exceptions.ConnectionError:
286
+ logger.exception("Connection error while fetching provider config")
287
+ cached_data = self._load_from_cache(cache_path)
288
+ if cached_data and isinstance(cached_data, dict):
289
+ logger.info("Using stale cached provider config")
290
+ return cached_data
291
+ raise
292
+ except requests.exceptions.HTTPError:
293
+ logger.exception("HTTP error while fetching provider config")
294
+ cached_data = self._load_from_cache(cache_path)
295
+ if cached_data and isinstance(cached_data, dict):
296
+ logger.info("Using stale cached provider config")
297
+ return cached_data
298
+ raise
299
+ except requests.exceptions.Timeout:
300
+ logger.exception("Timeout while fetching provider config")
301
+ cached_data = self._load_from_cache(cache_path)
302
+ if cached_data and isinstance(cached_data, dict):
303
+ logger.info("Using stale cached provider config")
304
+ return cached_data
305
+ raise
306
+ else:
307
+ self._save_to_cache(cache_path, provider_config)
308
+ return provider_config
309
+
310
+ def clear_cache(self) -> None:
311
+ """Clear all cached data."""
312
+ try:
313
+ for cache_file in self.CACHE_DIR.glob("*.json"):
314
+ cache_file.unlink()
315
+ logger.info("Cleared models.dev cache")
316
+ except OSError:
317
+ logger.exception("Failed to clear cache")
318
+
319
+ def search_models(
320
+ self,
321
+ query: str,
322
+ provider: str | None = None,
323
+ capabilities: list[str] | None = None,
324
+ max_price: float | None = None,
325
+ force_refresh: bool = False,
326
+ ) -> list[dict[str, Any]]:
327
+ """Search models based on criteria.
328
+
329
+ Args:
330
+ query: Search query string
331
+ provider: Optional provider filter
332
+ capabilities: Optional list of required capabilities
333
+ max_price: Optional maximum price filter
334
+ force_refresh: Force refresh from API
335
+
336
+ Returns:
337
+ List of matching models
338
+ """
339
+ models = self.get_models(provider=provider, force_refresh=force_refresh)
340
+
341
+ results = []
342
+ for model in models:
343
+ # Text search in name and description
344
+ model_text = (
345
+ f"{model.get('name', '')} {model.get('description', '')}".lower()
346
+ )
347
+ if query.lower() not in model_text:
348
+ continue
349
+
350
+ # Capabilities filter
351
+ if capabilities:
352
+ model_caps = set(model.get("capabilities", []))
353
+ if not set(capabilities).issubset(model_caps):
354
+ continue
355
+
356
+ # Price filter
357
+ if max_price is not None:
358
+ price = model.get("pricing", {}).get("input_per_1k_tokens")
359
+ if price is not None and price > max_price:
360
+ continue
361
+
362
+ results.append(model)
363
+
364
+ return results