model-forge-llm 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- model_forge_llm-0.2.0.dist-info/METADATA +327 -0
- model_forge_llm-0.2.0.dist-info/RECORD +14 -0
- model_forge_llm-0.2.0.dist-info/WHEEL +5 -0
- model_forge_llm-0.2.0.dist-info/entry_points.txt +2 -0
- model_forge_llm-0.2.0.dist-info/licenses/LICENSE +21 -0
- model_forge_llm-0.2.0.dist-info/top_level.txt +1 -0
- modelforge/__init__.py +7 -0
- modelforge/auth.py +503 -0
- modelforge/cli.py +720 -0
- modelforge/config.py +211 -0
- modelforge/exceptions.py +29 -0
- modelforge/logging_config.py +69 -0
- modelforge/modelsdev.py +364 -0
- modelforge/registry.py +272 -0
modelforge/auth.py
ADDED
@@ -0,0 +1,503 @@
|
|
1
|
+
"""Authentication strategies for ModelForge providers."""
|
2
|
+
|
3
|
+
import getpass
|
4
|
+
import time
|
5
|
+
import webbrowser
|
6
|
+
from abc import ABC, abstractmethod
|
7
|
+
from datetime import UTC, datetime, timedelta
|
8
|
+
from typing import Any
|
9
|
+
|
10
|
+
import requests
|
11
|
+
|
12
|
+
from .config import get_config, save_config
|
13
|
+
from .exceptions import AuthenticationError, ConfigurationError
|
14
|
+
from .logging_config import get_logger
|
15
|
+
|
16
|
+
logger = get_logger(__name__)
|
17
|
+
|
18
|
+
|
19
|
+
class AuthStrategy(ABC):
|
20
|
+
"""Abstract base class for authentication strategies."""
|
21
|
+
|
22
|
+
def __init__(self, provider_name: str) -> None:
|
23
|
+
"""Initialize the authentication strategy.
|
24
|
+
Args:
|
25
|
+
provider_name: The name of the provider this strategy is for
|
26
|
+
"""
|
27
|
+
self.provider_name = provider_name
|
28
|
+
|
29
|
+
@abstractmethod
|
30
|
+
def authenticate(self) -> dict[str, Any] | None:
|
31
|
+
"""Perform authentication and return credentials."""
|
32
|
+
|
33
|
+
@abstractmethod
|
34
|
+
def get_credentials(self) -> dict[str, Any] | None:
|
35
|
+
"""Retrieve stored credentials."""
|
36
|
+
|
37
|
+
@abstractmethod
|
38
|
+
def clear_credentials(self) -> None:
|
39
|
+
"""Clear any stored credentials for the provider."""
|
40
|
+
|
41
|
+
def _get_auth_data(self) -> dict[str, Any]:
|
42
|
+
"""Get authentication data from config file."""
|
43
|
+
config_data, _ = get_config()
|
44
|
+
providers = config_data.get("providers", {})
|
45
|
+
provider_data = providers.get(self.provider_name, {})
|
46
|
+
return dict(provider_data.get("auth_data", {}))
|
47
|
+
|
48
|
+
def _save_auth_data(self, auth_data: dict[str, Any]) -> None:
|
49
|
+
"""Save authentication data to config file."""
|
50
|
+
config_data, _ = get_config()
|
51
|
+
|
52
|
+
# Ensure providers section exists
|
53
|
+
if "providers" not in config_data:
|
54
|
+
config_data["providers"] = {}
|
55
|
+
|
56
|
+
# Ensure provider section exists
|
57
|
+
if self.provider_name not in config_data["providers"]:
|
58
|
+
config_data["providers"][self.provider_name] = {}
|
59
|
+
|
60
|
+
# Store auth data
|
61
|
+
config_data["providers"][self.provider_name]["auth_data"] = auth_data
|
62
|
+
|
63
|
+
# Save config
|
64
|
+
save_config(config_data)
|
65
|
+
logger.info("Successfully saved auth data for %s", self.provider_name)
|
66
|
+
|
67
|
+
def _clear_auth_data(self) -> None:
|
68
|
+
"""Clear authentication data from config file."""
|
69
|
+
config_data, _ = get_config()
|
70
|
+
providers = config_data.get("providers", {})
|
71
|
+
|
72
|
+
if self.provider_name in providers:
|
73
|
+
providers[self.provider_name].pop("auth_data", None)
|
74
|
+
save_config(config_data)
|
75
|
+
logger.info("Cleared stored auth data for %s", self.provider_name)
|
76
|
+
|
77
|
+
|
78
|
+
class ApiKeyAuth(AuthStrategy):
|
79
|
+
"""API key authentication strategy."""
|
80
|
+
|
81
|
+
def authenticate(self) -> dict[str, Any] | None:
|
82
|
+
"""Prompt for API key and store it in config."""
|
83
|
+
api_key = getpass.getpass(f"Enter API key for {self.provider_name}: ")
|
84
|
+
if api_key:
|
85
|
+
auth_data = {"api_key": api_key}
|
86
|
+
self._save_auth_data(auth_data)
|
87
|
+
logger.info("API key stored for %s", self.provider_name)
|
88
|
+
return auth_data
|
89
|
+
logger.warning("No API key provided for %s", self.provider_name)
|
90
|
+
return None
|
91
|
+
|
92
|
+
def store_api_key(self, api_key: str) -> None:
|
93
|
+
"""Store API key for the provider without prompting."""
|
94
|
+
auth_data = {"api_key": api_key}
|
95
|
+
self._save_auth_data(auth_data)
|
96
|
+
logger.info("API key stored for %s", self.provider_name)
|
97
|
+
|
98
|
+
def get_credentials(self) -> dict[str, Any] | None:
|
99
|
+
"""Retrieve stored API key from config."""
|
100
|
+
try:
|
101
|
+
auth_data = self._get_auth_data()
|
102
|
+
except Exception:
|
103
|
+
logger.exception("Failed to retrieve API key for %s", self.provider_name)
|
104
|
+
return None
|
105
|
+
else:
|
106
|
+
if auth_data and "api_key" in auth_data:
|
107
|
+
logger.debug("Retrieved API key for %s", self.provider_name)
|
108
|
+
return auth_data
|
109
|
+
logger.warning("No stored API key found for %s", self.provider_name)
|
110
|
+
return None
|
111
|
+
|
112
|
+
def clear_credentials(self) -> None:
|
113
|
+
"""Clear stored API key from config."""
|
114
|
+
try:
|
115
|
+
self._clear_auth_data()
|
116
|
+
except Exception:
|
117
|
+
logger.exception(
|
118
|
+
"An unexpected error occurred while clearing API key for %s",
|
119
|
+
self.provider_name,
|
120
|
+
)
|
121
|
+
|
122
|
+
|
123
|
+
class DeviceFlowAuth(AuthStrategy):
|
124
|
+
"""OAuth device flow authentication strategy."""
|
125
|
+
|
126
|
+
def __init__(
|
127
|
+
self,
|
128
|
+
provider_name: str,
|
129
|
+
client_id: str,
|
130
|
+
device_code_url: str,
|
131
|
+
token_url: str,
|
132
|
+
scope: str,
|
133
|
+
) -> None:
|
134
|
+
"""Initialize device flow authentication.
|
135
|
+
Args:
|
136
|
+
provider_name: The name of the provider
|
137
|
+
client_id: OAuth client ID
|
138
|
+
device_code_url: URL to request device code
|
139
|
+
token_url: URL to exchange device code for token
|
140
|
+
scope: OAuth scope
|
141
|
+
"""
|
142
|
+
super().__init__(provider_name)
|
143
|
+
self.client_id = client_id
|
144
|
+
self.device_code_url = device_code_url
|
145
|
+
self.token_url = token_url
|
146
|
+
self.scope = scope
|
147
|
+
|
148
|
+
def authenticate(self) -> dict[str, Any] | None:
|
149
|
+
"""Perform device flow authentication."""
|
150
|
+
logger.info("Starting device flow authentication for %s", self.provider_name)
|
151
|
+
|
152
|
+
# Step 1: Request device code
|
153
|
+
try:
|
154
|
+
device_code_data = self._request_device_code()
|
155
|
+
except requests.exceptions.RequestException as e:
|
156
|
+
logger.exception(
|
157
|
+
"Network error requesting device code from %s",
|
158
|
+
self.provider_name,
|
159
|
+
)
|
160
|
+
raise AuthenticationError from e
|
161
|
+
|
162
|
+
logger.info("Device code obtained for %s", self.provider_name)
|
163
|
+
|
164
|
+
# Step 2: Show user instructions
|
165
|
+
print("\n--- Device Authentication ---")
|
166
|
+
print(
|
167
|
+
f"Please open the following URL in your browser: "
|
168
|
+
f"{device_code_data['verification_uri']}"
|
169
|
+
)
|
170
|
+
print(f"And enter this code: {device_code_data['user_code']}")
|
171
|
+
print("Waiting for authentication...")
|
172
|
+
|
173
|
+
# Try to open browser automatically
|
174
|
+
try:
|
175
|
+
webbrowser.open(device_code_data["verification_uri"])
|
176
|
+
print("Browser opened automatically. If not, use the URL above.")
|
177
|
+
except Exception:
|
178
|
+
print("Please open the URL manually in your browser.")
|
179
|
+
|
180
|
+
# Step 3: Poll for token
|
181
|
+
return self._poll_for_token(device_code_data)
|
182
|
+
|
183
|
+
def _request_device_code(self) -> dict[str, Any]:
|
184
|
+
"""Request device code from the provider."""
|
185
|
+
headers = {"Accept": "application/json"}
|
186
|
+
data = {
|
187
|
+
"client_id": self.client_id,
|
188
|
+
"scope": self.scope,
|
189
|
+
}
|
190
|
+
|
191
|
+
try:
|
192
|
+
response = requests.post(
|
193
|
+
self.device_code_url, data=data, headers=headers, timeout=30
|
194
|
+
)
|
195
|
+
response.raise_for_status()
|
196
|
+
return dict(response.json())
|
197
|
+
except requests.exceptions.JSONDecodeError as e:
|
198
|
+
logger.exception(
|
199
|
+
"Invalid response from %s device code endpoint",
|
200
|
+
self.provider_name,
|
201
|
+
)
|
202
|
+
raise AuthenticationError from e
|
203
|
+
except requests.exceptions.HTTPError as e:
|
204
|
+
# Check if this is a recoverable error
|
205
|
+
try:
|
206
|
+
error_info = response.json()
|
207
|
+
logger.exception(
|
208
|
+
"HTTP error requesting device code from %s: %s",
|
209
|
+
self.provider_name,
|
210
|
+
error_info.get("error"),
|
211
|
+
)
|
212
|
+
raise AuthenticationError from e
|
213
|
+
except requests.exceptions.JSONDecodeError:
|
214
|
+
logger.exception(
|
215
|
+
"HTTP error while polling for token from %s", self.provider_name
|
216
|
+
)
|
217
|
+
raise AuthenticationError from e
|
218
|
+
except requests.exceptions.RequestException as e:
|
219
|
+
logger.exception(
|
220
|
+
"Network error while polling for token from %s",
|
221
|
+
self.provider_name,
|
222
|
+
)
|
223
|
+
raise AuthenticationError from e
|
224
|
+
|
225
|
+
def _poll_for_token(
|
226
|
+
self, device_code_data: dict[str, Any]
|
227
|
+
) -> dict[str, Any] | None:
|
228
|
+
"""Poll for access token after device code is obtained."""
|
229
|
+
logger.info("Polling for access token for %s", self.provider_name)
|
230
|
+
|
231
|
+
while True:
|
232
|
+
time.sleep(device_code_data.get("interval", 5))
|
233
|
+
token_payload = {
|
234
|
+
"client_id": self.client_id,
|
235
|
+
"device_code": device_code_data["device_code"],
|
236
|
+
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
|
237
|
+
}
|
238
|
+
headers = {"Accept": "application/json"}
|
239
|
+
|
240
|
+
try:
|
241
|
+
token_response = requests.post(
|
242
|
+
self.token_url, data=token_payload, headers=headers, timeout=30
|
243
|
+
)
|
244
|
+
token_data = token_response.json()
|
245
|
+
token_response.raise_for_status()
|
246
|
+
except requests.exceptions.JSONDecodeError as e:
|
247
|
+
logger.exception(
|
248
|
+
"Invalid JSON response while polling for token from %s",
|
249
|
+
self.provider_name,
|
250
|
+
)
|
251
|
+
raise AuthenticationError from e
|
252
|
+
except requests.exceptions.HTTPError as e:
|
253
|
+
# Check if this is a recoverable error
|
254
|
+
try:
|
255
|
+
error_info = token_response.json()
|
256
|
+
logger.exception(
|
257
|
+
"HTTP error while polling for token from %s: %s",
|
258
|
+
self.provider_name,
|
259
|
+
error_info.get("error"),
|
260
|
+
)
|
261
|
+
error_code = error_info.get("error")
|
262
|
+
if error_code == "authorization_pending":
|
263
|
+
# This is expected, continue polling
|
264
|
+
continue
|
265
|
+
if error_code == "slow_down":
|
266
|
+
# Increase interval and continue
|
267
|
+
new_interval = device_code_data.get("interval", 5) + 5
|
268
|
+
device_code_data["interval"] = new_interval
|
269
|
+
logger.info(
|
270
|
+
"Slowing down polling to %s seconds for %s",
|
271
|
+
new_interval,
|
272
|
+
self.provider_name,
|
273
|
+
)
|
274
|
+
continue
|
275
|
+
if error_code in ("expired_token", "access_denied"):
|
276
|
+
# Unrecoverable error, stop polling
|
277
|
+
logger.exception(
|
278
|
+
"Unrecoverable error from %s: %s",
|
279
|
+
self.provider_name,
|
280
|
+
error_code,
|
281
|
+
)
|
282
|
+
msg = f"Authentication failed: {error_code}"
|
283
|
+
raise AuthenticationError(msg) from e
|
284
|
+
except (requests.exceptions.JSONDecodeError, KeyError):
|
285
|
+
logger.exception(
|
286
|
+
"Unexpected error format from %s", self.provider_name
|
287
|
+
)
|
288
|
+
raise AuthenticationError from e
|
289
|
+
except requests.exceptions.RequestException as e:
|
290
|
+
logger.exception(
|
291
|
+
"Network error while polling for token from %s",
|
292
|
+
self.provider_name,
|
293
|
+
)
|
294
|
+
raise AuthenticationError from e
|
295
|
+
|
296
|
+
if "access_token" in token_data:
|
297
|
+
logger.info(
|
298
|
+
"Successfully obtained access token for %s", self.provider_name
|
299
|
+
)
|
300
|
+
self._save_token_info(token_data)
|
301
|
+
return dict(token_data)
|
302
|
+
|
303
|
+
def _save_token_info(self, token_data: dict[str, Any]) -> None:
|
304
|
+
"""Save token information to config file."""
|
305
|
+
# Calculate expiry time and add to token_data
|
306
|
+
if "expires_in" in token_data:
|
307
|
+
expires_in = token_data["expires_in"]
|
308
|
+
expires_at = datetime.now(UTC) + timedelta(seconds=expires_in)
|
309
|
+
token_data["expires_at"] = expires_at.isoformat()
|
310
|
+
|
311
|
+
try:
|
312
|
+
self._save_auth_data(token_data)
|
313
|
+
except Exception:
|
314
|
+
logger.exception("Failed to save token for %s", self.provider_name)
|
315
|
+
msg = "Could not save token information."
|
316
|
+
raise ConfigurationError(msg) from None
|
317
|
+
|
318
|
+
def get_credentials(self) -> dict[str, Any] | None:
|
319
|
+
"""Retrieve stored token info. If expired, try to refresh."""
|
320
|
+
token_info = self.get_token_info()
|
321
|
+
if not token_info:
|
322
|
+
logger.debug("No token info found for %s.", self.provider_name)
|
323
|
+
return None
|
324
|
+
|
325
|
+
# Check for expiry, with a 60-second buffer
|
326
|
+
expires_at_str = token_info.get("expires_at")
|
327
|
+
if not expires_at_str:
|
328
|
+
logger.warning(
|
329
|
+
"Token for %s has no expiration info. Assuming it's valid.",
|
330
|
+
self.provider_name,
|
331
|
+
)
|
332
|
+
return token_info
|
333
|
+
|
334
|
+
expires_at = datetime.fromisoformat(expires_at_str)
|
335
|
+
if datetime.now(UTC) >= (expires_at - timedelta(seconds=60)):
|
336
|
+
logger.info(
|
337
|
+
"Access token for %s is expired or nearing expiry. Attempting refresh.",
|
338
|
+
self.provider_name,
|
339
|
+
)
|
340
|
+
return self._refresh_token()
|
341
|
+
|
342
|
+
logger.debug("Access token for %s is still valid.", self.provider_name)
|
343
|
+
return token_info
|
344
|
+
|
345
|
+
def _refresh_token(self) -> dict[str, Any] | None:
|
346
|
+
"""Use a refresh token to get a new access token."""
|
347
|
+
token_info = self.get_token_info()
|
348
|
+
if not token_info or "refresh_token" not in token_info:
|
349
|
+
logger.warning(
|
350
|
+
"No refresh token found for %s. Cannot refresh.", self.provider_name
|
351
|
+
)
|
352
|
+
return None
|
353
|
+
|
354
|
+
logger.info("Attempting to refresh token for %s", self.provider_name)
|
355
|
+
payload = {
|
356
|
+
"client_id": self.client_id,
|
357
|
+
"refresh_token": token_info["refresh_token"],
|
358
|
+
"grant_type": "refresh_token",
|
359
|
+
}
|
360
|
+
headers = {"Accept": "application/json"}
|
361
|
+
try:
|
362
|
+
response = requests.post(
|
363
|
+
self.token_url, data=payload, headers=headers, timeout=30
|
364
|
+
)
|
365
|
+
response.raise_for_status()
|
366
|
+
new_token_data = response.json()
|
367
|
+
except requests.exceptions.RequestException:
|
368
|
+
logger.exception(
|
369
|
+
"Failed to refresh token for %s. Re-authentication will be required.",
|
370
|
+
self.provider_name,
|
371
|
+
)
|
372
|
+
self.clear_credentials()
|
373
|
+
return None
|
374
|
+
else:
|
375
|
+
if "refresh_token" not in new_token_data:
|
376
|
+
new_token_data["refresh_token"] = token_info["refresh_token"]
|
377
|
+
|
378
|
+
self._save_token_info(new_token_data)
|
379
|
+
logger.info("Successfully refreshed token for %s", self.provider_name)
|
380
|
+
return dict(new_token_data)
|
381
|
+
|
382
|
+
def get_token_info(self) -> dict[str, Any] | None:
|
383
|
+
"""Retrieve token information from config file."""
|
384
|
+
try:
|
385
|
+
auth_data = self._get_auth_data()
|
386
|
+
except Exception:
|
387
|
+
logger.exception("Could not retrieve token for %s", self.provider_name)
|
388
|
+
return None
|
389
|
+
else:
|
390
|
+
return auth_data if auth_data else None
|
391
|
+
|
392
|
+
def clear_credentials(self) -> None:
|
393
|
+
"""Clear stored token from config file."""
|
394
|
+
try:
|
395
|
+
self._clear_auth_data()
|
396
|
+
except Exception:
|
397
|
+
logger.exception(
|
398
|
+
"An unexpected error occurred while clearing token for %s",
|
399
|
+
self.provider_name,
|
400
|
+
)
|
401
|
+
|
402
|
+
|
403
|
+
def get_auth_strategy(
|
404
|
+
provider_name: str,
|
405
|
+
provider_data: dict[str, Any],
|
406
|
+
model_alias: str | None = None, # noqa: ARG001
|
407
|
+
) -> AuthStrategy:
|
408
|
+
"""
|
409
|
+
Factory function to get the correct authentication strategy for a provider.
|
410
|
+
|
411
|
+
Args:
|
412
|
+
provider_name: The name of the provider.
|
413
|
+
provider_data: The configuration data for the provider.
|
414
|
+
model_alias: The model alias (currently unused but for future use).
|
415
|
+
|
416
|
+
Returns:
|
417
|
+
An instance of an AuthStrategy subclass.
|
418
|
+
|
419
|
+
Raises:
|
420
|
+
ConfigurationError: If the provider is not found or misconfigured.
|
421
|
+
"""
|
422
|
+
if not provider_data:
|
423
|
+
msg = f"Provider '{provider_name}' not found in configuration."
|
424
|
+
raise ConfigurationError(msg)
|
425
|
+
|
426
|
+
strategy_name = provider_data.get("auth_strategy")
|
427
|
+
if not strategy_name:
|
428
|
+
return NoAuth(provider_name)
|
429
|
+
|
430
|
+
if strategy_name == "api_key":
|
431
|
+
return ApiKeyAuth(provider_name)
|
432
|
+
|
433
|
+
if strategy_name == "device_flow":
|
434
|
+
auth_details = provider_data.get("auth_details")
|
435
|
+
if not auth_details:
|
436
|
+
raise ConfigurationError(
|
437
|
+
f"Provider '{provider_name}' is missing required device flow settings."
|
438
|
+
)
|
439
|
+
|
440
|
+
return DeviceFlowAuth(
|
441
|
+
provider_name,
|
442
|
+
auth_details["client_id"],
|
443
|
+
auth_details["device_code_url"],
|
444
|
+
auth_details["token_url"],
|
445
|
+
auth_details["scope"],
|
446
|
+
)
|
447
|
+
|
448
|
+
raise ConfigurationError(
|
449
|
+
f"Unknown auth strategy '{strategy_name}' for provider '{provider_name}'."
|
450
|
+
)
|
451
|
+
|
452
|
+
|
453
|
+
def get_credentials(
|
454
|
+
provider_name: str,
|
455
|
+
model_alias: str,
|
456
|
+
provider_data: dict[str, Any],
|
457
|
+
verbose: bool = False,
|
458
|
+
) -> dict[str, Any] | None:
|
459
|
+
"""
|
460
|
+
Get credentials for a given provider and model.
|
461
|
+
|
462
|
+
This function will first try to retrieve stored credentials. If they are
|
463
|
+
not available or invalid, it will trigger the authentication process.
|
464
|
+
"""
|
465
|
+
if verbose:
|
466
|
+
logger.setLevel("DEBUG")
|
467
|
+
|
468
|
+
try:
|
469
|
+
strategy = get_auth_strategy(provider_name, provider_data, model_alias)
|
470
|
+
creds = strategy.get_credentials()
|
471
|
+
if creds:
|
472
|
+
logger.info("Successfully retrieved credentials for %s", provider_name)
|
473
|
+
return creds
|
474
|
+
|
475
|
+
logger.info(
|
476
|
+
"No valid credentials found for %s. Initiating authentication.",
|
477
|
+
provider_name,
|
478
|
+
)
|
479
|
+
return strategy.authenticate()
|
480
|
+
|
481
|
+
except (ConfigurationError, AuthenticationError):
|
482
|
+
logger.exception("Authentication failed for %s", provider_name)
|
483
|
+
return None
|
484
|
+
except Exception:
|
485
|
+
logger.exception(
|
486
|
+
"An unexpected error occurred during authentication for %s", provider_name
|
487
|
+
)
|
488
|
+
return None
|
489
|
+
|
490
|
+
|
491
|
+
class NoAuth(AuthStrategy):
|
492
|
+
"""Dummy authentication for providers that don't need it."""
|
493
|
+
|
494
|
+
def authenticate(self) -> dict[str, Any] | None:
|
495
|
+
"""No authentication needed."""
|
496
|
+
return {}
|
497
|
+
|
498
|
+
def get_credentials(self) -> dict[str, Any] | None:
|
499
|
+
"""No credentials to retrieve."""
|
500
|
+
return {}
|
501
|
+
|
502
|
+
def clear_credentials(self) -> None:
|
503
|
+
"""No credentials to clear."""
|