model-forge-llm 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
modelforge/auth.py ADDED
@@ -0,0 +1,503 @@
1
+ """Authentication strategies for ModelForge providers."""
2
+
3
+ import getpass
4
+ import time
5
+ import webbrowser
6
+ from abc import ABC, abstractmethod
7
+ from datetime import UTC, datetime, timedelta
8
+ from typing import Any
9
+
10
+ import requests
11
+
12
+ from .config import get_config, save_config
13
+ from .exceptions import AuthenticationError, ConfigurationError
14
+ from .logging_config import get_logger
15
+
16
+ logger = get_logger(__name__)
17
+
18
+
19
+ class AuthStrategy(ABC):
20
+ """Abstract base class for authentication strategies."""
21
+
22
+ def __init__(self, provider_name: str) -> None:
23
+ """Initialize the authentication strategy.
24
+ Args:
25
+ provider_name: The name of the provider this strategy is for
26
+ """
27
+ self.provider_name = provider_name
28
+
29
+ @abstractmethod
30
+ def authenticate(self) -> dict[str, Any] | None:
31
+ """Perform authentication and return credentials."""
32
+
33
+ @abstractmethod
34
+ def get_credentials(self) -> dict[str, Any] | None:
35
+ """Retrieve stored credentials."""
36
+
37
+ @abstractmethod
38
+ def clear_credentials(self) -> None:
39
+ """Clear any stored credentials for the provider."""
40
+
41
+ def _get_auth_data(self) -> dict[str, Any]:
42
+ """Get authentication data from config file."""
43
+ config_data, _ = get_config()
44
+ providers = config_data.get("providers", {})
45
+ provider_data = providers.get(self.provider_name, {})
46
+ return dict(provider_data.get("auth_data", {}))
47
+
48
+ def _save_auth_data(self, auth_data: dict[str, Any]) -> None:
49
+ """Save authentication data to config file."""
50
+ config_data, _ = get_config()
51
+
52
+ # Ensure providers section exists
53
+ if "providers" not in config_data:
54
+ config_data["providers"] = {}
55
+
56
+ # Ensure provider section exists
57
+ if self.provider_name not in config_data["providers"]:
58
+ config_data["providers"][self.provider_name] = {}
59
+
60
+ # Store auth data
61
+ config_data["providers"][self.provider_name]["auth_data"] = auth_data
62
+
63
+ # Save config
64
+ save_config(config_data)
65
+ logger.info("Successfully saved auth data for %s", self.provider_name)
66
+
67
+ def _clear_auth_data(self) -> None:
68
+ """Clear authentication data from config file."""
69
+ config_data, _ = get_config()
70
+ providers = config_data.get("providers", {})
71
+
72
+ if self.provider_name in providers:
73
+ providers[self.provider_name].pop("auth_data", None)
74
+ save_config(config_data)
75
+ logger.info("Cleared stored auth data for %s", self.provider_name)
76
+
77
+
78
+ class ApiKeyAuth(AuthStrategy):
79
+ """API key authentication strategy."""
80
+
81
+ def authenticate(self) -> dict[str, Any] | None:
82
+ """Prompt for API key and store it in config."""
83
+ api_key = getpass.getpass(f"Enter API key for {self.provider_name}: ")
84
+ if api_key:
85
+ auth_data = {"api_key": api_key}
86
+ self._save_auth_data(auth_data)
87
+ logger.info("API key stored for %s", self.provider_name)
88
+ return auth_data
89
+ logger.warning("No API key provided for %s", self.provider_name)
90
+ return None
91
+
92
+ def store_api_key(self, api_key: str) -> None:
93
+ """Store API key for the provider without prompting."""
94
+ auth_data = {"api_key": api_key}
95
+ self._save_auth_data(auth_data)
96
+ logger.info("API key stored for %s", self.provider_name)
97
+
98
+ def get_credentials(self) -> dict[str, Any] | None:
99
+ """Retrieve stored API key from config."""
100
+ try:
101
+ auth_data = self._get_auth_data()
102
+ except Exception:
103
+ logger.exception("Failed to retrieve API key for %s", self.provider_name)
104
+ return None
105
+ else:
106
+ if auth_data and "api_key" in auth_data:
107
+ logger.debug("Retrieved API key for %s", self.provider_name)
108
+ return auth_data
109
+ logger.warning("No stored API key found for %s", self.provider_name)
110
+ return None
111
+
112
+ def clear_credentials(self) -> None:
113
+ """Clear stored API key from config."""
114
+ try:
115
+ self._clear_auth_data()
116
+ except Exception:
117
+ logger.exception(
118
+ "An unexpected error occurred while clearing API key for %s",
119
+ self.provider_name,
120
+ )
121
+
122
+
123
+ class DeviceFlowAuth(AuthStrategy):
124
+ """OAuth device flow authentication strategy."""
125
+
126
+ def __init__(
127
+ self,
128
+ provider_name: str,
129
+ client_id: str,
130
+ device_code_url: str,
131
+ token_url: str,
132
+ scope: str,
133
+ ) -> None:
134
+ """Initialize device flow authentication.
135
+ Args:
136
+ provider_name: The name of the provider
137
+ client_id: OAuth client ID
138
+ device_code_url: URL to request device code
139
+ token_url: URL to exchange device code for token
140
+ scope: OAuth scope
141
+ """
142
+ super().__init__(provider_name)
143
+ self.client_id = client_id
144
+ self.device_code_url = device_code_url
145
+ self.token_url = token_url
146
+ self.scope = scope
147
+
148
+ def authenticate(self) -> dict[str, Any] | None:
149
+ """Perform device flow authentication."""
150
+ logger.info("Starting device flow authentication for %s", self.provider_name)
151
+
152
+ # Step 1: Request device code
153
+ try:
154
+ device_code_data = self._request_device_code()
155
+ except requests.exceptions.RequestException as e:
156
+ logger.exception(
157
+ "Network error requesting device code from %s",
158
+ self.provider_name,
159
+ )
160
+ raise AuthenticationError from e
161
+
162
+ logger.info("Device code obtained for %s", self.provider_name)
163
+
164
+ # Step 2: Show user instructions
165
+ print("\n--- Device Authentication ---")
166
+ print(
167
+ f"Please open the following URL in your browser: "
168
+ f"{device_code_data['verification_uri']}"
169
+ )
170
+ print(f"And enter this code: {device_code_data['user_code']}")
171
+ print("Waiting for authentication...")
172
+
173
+ # Try to open browser automatically
174
+ try:
175
+ webbrowser.open(device_code_data["verification_uri"])
176
+ print("Browser opened automatically. If not, use the URL above.")
177
+ except Exception:
178
+ print("Please open the URL manually in your browser.")
179
+
180
+ # Step 3: Poll for token
181
+ return self._poll_for_token(device_code_data)
182
+
183
+ def _request_device_code(self) -> dict[str, Any]:
184
+ """Request device code from the provider."""
185
+ headers = {"Accept": "application/json"}
186
+ data = {
187
+ "client_id": self.client_id,
188
+ "scope": self.scope,
189
+ }
190
+
191
+ try:
192
+ response = requests.post(
193
+ self.device_code_url, data=data, headers=headers, timeout=30
194
+ )
195
+ response.raise_for_status()
196
+ return dict(response.json())
197
+ except requests.exceptions.JSONDecodeError as e:
198
+ logger.exception(
199
+ "Invalid response from %s device code endpoint",
200
+ self.provider_name,
201
+ )
202
+ raise AuthenticationError from e
203
+ except requests.exceptions.HTTPError as e:
204
+ # Check if this is a recoverable error
205
+ try:
206
+ error_info = response.json()
207
+ logger.exception(
208
+ "HTTP error requesting device code from %s: %s",
209
+ self.provider_name,
210
+ error_info.get("error"),
211
+ )
212
+ raise AuthenticationError from e
213
+ except requests.exceptions.JSONDecodeError:
214
+ logger.exception(
215
+ "HTTP error while polling for token from %s", self.provider_name
216
+ )
217
+ raise AuthenticationError from e
218
+ except requests.exceptions.RequestException as e:
219
+ logger.exception(
220
+ "Network error while polling for token from %s",
221
+ self.provider_name,
222
+ )
223
+ raise AuthenticationError from e
224
+
225
+ def _poll_for_token(
226
+ self, device_code_data: dict[str, Any]
227
+ ) -> dict[str, Any] | None:
228
+ """Poll for access token after device code is obtained."""
229
+ logger.info("Polling for access token for %s", self.provider_name)
230
+
231
+ while True:
232
+ time.sleep(device_code_data.get("interval", 5))
233
+ token_payload = {
234
+ "client_id": self.client_id,
235
+ "device_code": device_code_data["device_code"],
236
+ "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
237
+ }
238
+ headers = {"Accept": "application/json"}
239
+
240
+ try:
241
+ token_response = requests.post(
242
+ self.token_url, data=token_payload, headers=headers, timeout=30
243
+ )
244
+ token_data = token_response.json()
245
+ token_response.raise_for_status()
246
+ except requests.exceptions.JSONDecodeError as e:
247
+ logger.exception(
248
+ "Invalid JSON response while polling for token from %s",
249
+ self.provider_name,
250
+ )
251
+ raise AuthenticationError from e
252
+ except requests.exceptions.HTTPError as e:
253
+ # Check if this is a recoverable error
254
+ try:
255
+ error_info = token_response.json()
256
+ logger.exception(
257
+ "HTTP error while polling for token from %s: %s",
258
+ self.provider_name,
259
+ error_info.get("error"),
260
+ )
261
+ error_code = error_info.get("error")
262
+ if error_code == "authorization_pending":
263
+ # This is expected, continue polling
264
+ continue
265
+ if error_code == "slow_down":
266
+ # Increase interval and continue
267
+ new_interval = device_code_data.get("interval", 5) + 5
268
+ device_code_data["interval"] = new_interval
269
+ logger.info(
270
+ "Slowing down polling to %s seconds for %s",
271
+ new_interval,
272
+ self.provider_name,
273
+ )
274
+ continue
275
+ if error_code in ("expired_token", "access_denied"):
276
+ # Unrecoverable error, stop polling
277
+ logger.exception(
278
+ "Unrecoverable error from %s: %s",
279
+ self.provider_name,
280
+ error_code,
281
+ )
282
+ msg = f"Authentication failed: {error_code}"
283
+ raise AuthenticationError(msg) from e
284
+ except (requests.exceptions.JSONDecodeError, KeyError):
285
+ logger.exception(
286
+ "Unexpected error format from %s", self.provider_name
287
+ )
288
+ raise AuthenticationError from e
289
+ except requests.exceptions.RequestException as e:
290
+ logger.exception(
291
+ "Network error while polling for token from %s",
292
+ self.provider_name,
293
+ )
294
+ raise AuthenticationError from e
295
+
296
+ if "access_token" in token_data:
297
+ logger.info(
298
+ "Successfully obtained access token for %s", self.provider_name
299
+ )
300
+ self._save_token_info(token_data)
301
+ return dict(token_data)
302
+
303
+ def _save_token_info(self, token_data: dict[str, Any]) -> None:
304
+ """Save token information to config file."""
305
+ # Calculate expiry time and add to token_data
306
+ if "expires_in" in token_data:
307
+ expires_in = token_data["expires_in"]
308
+ expires_at = datetime.now(UTC) + timedelta(seconds=expires_in)
309
+ token_data["expires_at"] = expires_at.isoformat()
310
+
311
+ try:
312
+ self._save_auth_data(token_data)
313
+ except Exception:
314
+ logger.exception("Failed to save token for %s", self.provider_name)
315
+ msg = "Could not save token information."
316
+ raise ConfigurationError(msg) from None
317
+
318
+ def get_credentials(self) -> dict[str, Any] | None:
319
+ """Retrieve stored token info. If expired, try to refresh."""
320
+ token_info = self.get_token_info()
321
+ if not token_info:
322
+ logger.debug("No token info found for %s.", self.provider_name)
323
+ return None
324
+
325
+ # Check for expiry, with a 60-second buffer
326
+ expires_at_str = token_info.get("expires_at")
327
+ if not expires_at_str:
328
+ logger.warning(
329
+ "Token for %s has no expiration info. Assuming it's valid.",
330
+ self.provider_name,
331
+ )
332
+ return token_info
333
+
334
+ expires_at = datetime.fromisoformat(expires_at_str)
335
+ if datetime.now(UTC) >= (expires_at - timedelta(seconds=60)):
336
+ logger.info(
337
+ "Access token for %s is expired or nearing expiry. Attempting refresh.",
338
+ self.provider_name,
339
+ )
340
+ return self._refresh_token()
341
+
342
+ logger.debug("Access token for %s is still valid.", self.provider_name)
343
+ return token_info
344
+
345
+ def _refresh_token(self) -> dict[str, Any] | None:
346
+ """Use a refresh token to get a new access token."""
347
+ token_info = self.get_token_info()
348
+ if not token_info or "refresh_token" not in token_info:
349
+ logger.warning(
350
+ "No refresh token found for %s. Cannot refresh.", self.provider_name
351
+ )
352
+ return None
353
+
354
+ logger.info("Attempting to refresh token for %s", self.provider_name)
355
+ payload = {
356
+ "client_id": self.client_id,
357
+ "refresh_token": token_info["refresh_token"],
358
+ "grant_type": "refresh_token",
359
+ }
360
+ headers = {"Accept": "application/json"}
361
+ try:
362
+ response = requests.post(
363
+ self.token_url, data=payload, headers=headers, timeout=30
364
+ )
365
+ response.raise_for_status()
366
+ new_token_data = response.json()
367
+ except requests.exceptions.RequestException:
368
+ logger.exception(
369
+ "Failed to refresh token for %s. Re-authentication will be required.",
370
+ self.provider_name,
371
+ )
372
+ self.clear_credentials()
373
+ return None
374
+ else:
375
+ if "refresh_token" not in new_token_data:
376
+ new_token_data["refresh_token"] = token_info["refresh_token"]
377
+
378
+ self._save_token_info(new_token_data)
379
+ logger.info("Successfully refreshed token for %s", self.provider_name)
380
+ return dict(new_token_data)
381
+
382
+ def get_token_info(self) -> dict[str, Any] | None:
383
+ """Retrieve token information from config file."""
384
+ try:
385
+ auth_data = self._get_auth_data()
386
+ except Exception:
387
+ logger.exception("Could not retrieve token for %s", self.provider_name)
388
+ return None
389
+ else:
390
+ return auth_data if auth_data else None
391
+
392
+ def clear_credentials(self) -> None:
393
+ """Clear stored token from config file."""
394
+ try:
395
+ self._clear_auth_data()
396
+ except Exception:
397
+ logger.exception(
398
+ "An unexpected error occurred while clearing token for %s",
399
+ self.provider_name,
400
+ )
401
+
402
+
403
+ def get_auth_strategy(
404
+ provider_name: str,
405
+ provider_data: dict[str, Any],
406
+ model_alias: str | None = None, # noqa: ARG001
407
+ ) -> AuthStrategy:
408
+ """
409
+ Factory function to get the correct authentication strategy for a provider.
410
+
411
+ Args:
412
+ provider_name: The name of the provider.
413
+ provider_data: The configuration data for the provider.
414
+ model_alias: The model alias (currently unused but for future use).
415
+
416
+ Returns:
417
+ An instance of an AuthStrategy subclass.
418
+
419
+ Raises:
420
+ ConfigurationError: If the provider is not found or misconfigured.
421
+ """
422
+ if not provider_data:
423
+ msg = f"Provider '{provider_name}' not found in configuration."
424
+ raise ConfigurationError(msg)
425
+
426
+ strategy_name = provider_data.get("auth_strategy")
427
+ if not strategy_name:
428
+ return NoAuth(provider_name)
429
+
430
+ if strategy_name == "api_key":
431
+ return ApiKeyAuth(provider_name)
432
+
433
+ if strategy_name == "device_flow":
434
+ auth_details = provider_data.get("auth_details")
435
+ if not auth_details:
436
+ raise ConfigurationError(
437
+ f"Provider '{provider_name}' is missing required device flow settings."
438
+ )
439
+
440
+ return DeviceFlowAuth(
441
+ provider_name,
442
+ auth_details["client_id"],
443
+ auth_details["device_code_url"],
444
+ auth_details["token_url"],
445
+ auth_details["scope"],
446
+ )
447
+
448
+ raise ConfigurationError(
449
+ f"Unknown auth strategy '{strategy_name}' for provider '{provider_name}'."
450
+ )
451
+
452
+
453
+ def get_credentials(
454
+ provider_name: str,
455
+ model_alias: str,
456
+ provider_data: dict[str, Any],
457
+ verbose: bool = False,
458
+ ) -> dict[str, Any] | None:
459
+ """
460
+ Get credentials for a given provider and model.
461
+
462
+ This function will first try to retrieve stored credentials. If they are
463
+ not available or invalid, it will trigger the authentication process.
464
+ """
465
+ if verbose:
466
+ logger.setLevel("DEBUG")
467
+
468
+ try:
469
+ strategy = get_auth_strategy(provider_name, provider_data, model_alias)
470
+ creds = strategy.get_credentials()
471
+ if creds:
472
+ logger.info("Successfully retrieved credentials for %s", provider_name)
473
+ return creds
474
+
475
+ logger.info(
476
+ "No valid credentials found for %s. Initiating authentication.",
477
+ provider_name,
478
+ )
479
+ return strategy.authenticate()
480
+
481
+ except (ConfigurationError, AuthenticationError):
482
+ logger.exception("Authentication failed for %s", provider_name)
483
+ return None
484
+ except Exception:
485
+ logger.exception(
486
+ "An unexpected error occurred during authentication for %s", provider_name
487
+ )
488
+ return None
489
+
490
+
491
+ class NoAuth(AuthStrategy):
492
+ """Dummy authentication for providers that don't need it."""
493
+
494
+ def authenticate(self) -> dict[str, Any] | None:
495
+ """No authentication needed."""
496
+ return {}
497
+
498
+ def get_credentials(self) -> dict[str, Any] | None:
499
+ """No credentials to retrieve."""
500
+ return {}
501
+
502
+ def clear_credentials(self) -> None:
503
+ """No credentials to clear."""