sqlsaber 0.7.0__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlsaber might be problematic. Click here for more details.

@@ -6,7 +6,7 @@ import platform
6
6
  import stat
7
7
  from dataclasses import dataclass
8
8
  from pathlib import Path
9
- from typing import Any, Dict, List, Optional
9
+ from typing import Any
10
10
  from urllib.parse import quote_plus
11
11
 
12
12
  import keyring
@@ -19,16 +19,16 @@ class DatabaseConfig:
19
19
 
20
20
  name: str
21
21
  type: str # postgresql, mysql, sqlite, csv
22
- host: Optional[str]
23
- port: Optional[int]
22
+ host: str | None
23
+ port: int | None
24
24
  database: str
25
- username: Optional[str]
26
- password: Optional[str] = None
27
- ssl_mode: Optional[str] = None
28
- ssl_ca: Optional[str] = None
29
- ssl_cert: Optional[str] = None
30
- ssl_key: Optional[str] = None
31
- schema: Optional[str] = None
25
+ username: str | None
26
+ password: str | None = None
27
+ ssl_mode: str | None = None
28
+ ssl_ca: str | None = None
29
+ ssl_cert: str | None = None
30
+ ssl_key: str | None = None
31
+ schema: str | None = None
32
32
 
33
33
  def to_connection_string(self) -> str:
34
34
  """Convert config to database connection string."""
@@ -115,7 +115,7 @@ class DatabaseConfig:
115
115
  else:
116
116
  raise ValueError(f"Unsupported database type: {self.type}")
117
117
 
118
- def _get_password_from_keyring(self) -> Optional[str]:
118
+ def _get_password_from_keyring(self) -> str | None:
119
119
  """Get password from OS keyring."""
120
120
  try:
121
121
  return keyring.get_password("sqlsaber", f"{self.name}_{self.username}")
@@ -133,7 +133,7 @@ class DatabaseConfig:
133
133
  except Exception:
134
134
  pass
135
135
 
136
- def to_dict(self) -> Dict[str, Any]:
136
+ def to_dict(self) -> dict[str, Any]:
137
137
  """Convert to dictionary for JSON serialization."""
138
138
  return {
139
139
  "name": self.name,
@@ -150,7 +150,7 @@ class DatabaseConfig:
150
150
  }
151
151
 
152
152
  @classmethod
153
- def from_dict(cls, data: Dict[str, Any]) -> "DatabaseConfig":
153
+ def from_dict(cls, data: dict[str, Any]) -> "DatabaseConfig":
154
154
  """Create from dictionary."""
155
155
  return cls(
156
156
  name=data["name"],
@@ -202,7 +202,7 @@ class DatabaseConfigManager:
202
202
  # The directory/file creation should still work
203
203
  pass
204
204
 
205
- def _load_config(self) -> Dict[str, Any]:
205
+ def _load_config(self) -> dict[str, Any]:
206
206
  """Load configuration from file."""
207
207
  if not self.config_file.exists():
208
208
  return {"default": None, "connections": {}}
@@ -213,7 +213,7 @@ class DatabaseConfigManager:
213
213
  except (json.JSONDecodeError, IOError):
214
214
  return {"default": None, "connections": {}}
215
215
 
216
- def _save_config(self, config: Dict[str, Any]) -> None:
216
+ def _save_config(self, config: dict[str, Any]) -> None:
217
217
  """Save configuration to file."""
218
218
  with open(self.config_file, "w") as f:
219
219
  json.dump(config, f, indent=2)
@@ -222,7 +222,7 @@ class DatabaseConfigManager:
222
222
  self._set_secure_permissions(self.config_file, is_directory=False)
223
223
 
224
224
  def add_database(
225
- self, db_config: DatabaseConfig, password: Optional[str] = None
225
+ self, db_config: DatabaseConfig, password: str | None = None
226
226
  ) -> None:
227
227
  """Add a database configuration."""
228
228
  config = self._load_config()
@@ -244,7 +244,7 @@ class DatabaseConfigManager:
244
244
 
245
245
  self._save_config(config)
246
246
 
247
- def get_database(self, name: str) -> Optional[DatabaseConfig]:
247
+ def get_database(self, name: str) -> DatabaseConfig | None:
248
248
  """Get a database configuration by name."""
249
249
  config = self._load_config()
250
250
 
@@ -253,7 +253,7 @@ class DatabaseConfigManager:
253
253
 
254
254
  return DatabaseConfig.from_dict(config["connections"][name])
255
255
 
256
- def get_default_database(self) -> Optional[DatabaseConfig]:
256
+ def get_default_database(self) -> DatabaseConfig | None:
257
257
  """Get the default database configuration."""
258
258
  config = self._load_config()
259
259
 
@@ -263,7 +263,7 @@ class DatabaseConfigManager:
263
263
 
264
264
  return self.get_database(default_name)
265
265
 
266
- def list_databases(self) -> List[DatabaseConfig]:
266
+ def list_databases(self) -> list[DatabaseConfig]:
267
267
  """List all database configurations."""
268
268
  config = self._load_config()
269
269
 
@@ -313,7 +313,7 @@ class DatabaseConfigManager:
313
313
  config = self._load_config()
314
314
  return len(config["connections"]) > 0
315
315
 
316
- def get_default_name(self) -> Optional[str]:
316
+ def get_default_name(self) -> str | None:
317
317
  """Get the name of the default database."""
318
318
  config = self._load_config()
319
319
  return config.get("default")
@@ -0,0 +1,274 @@
1
+ """Synchronous OAuth flow management for Anthropic Claude Pro authentication."""
2
+
3
+ import base64
4
+ import hashlib
5
+ import logging
6
+ import secrets
7
+ import urllib.parse
8
+ import webbrowser
9
+ from datetime import datetime, timezone
10
+
11
+ import httpx
12
+ import questionary
13
+ from rich.console import Console
14
+ from rich.progress import Progress, SpinnerColumn, TextColumn
15
+
16
+ from .oauth_tokens import OAuthToken, OAuthTokenManager
17
+
18
+ console = Console()
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ CLIENT_ID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
23
+
24
+
25
+ class AnthropicOAuthFlow:
26
+ """Handles the complete OAuth flow for Anthropic Claude Pro authentication."""
27
+
28
+ def __init__(self):
29
+ self.client_id = CLIENT_ID
30
+ self.token_manager = OAuthTokenManager()
31
+
32
+ def _generate_pkce(self) -> tuple[str, str]:
33
+ """Generate PKCE code verifier and challenge."""
34
+ verifier = (
35
+ base64.urlsafe_b64encode(secrets.token_bytes(32))
36
+ .decode("utf-8")
37
+ .rstrip("=")
38
+ )
39
+ challenge = (
40
+ base64.urlsafe_b64encode(hashlib.sha256(verifier.encode("utf-8")).digest())
41
+ .decode("utf-8")
42
+ .rstrip("=")
43
+ )
44
+ return verifier, challenge
45
+
46
+ def _create_authorization_url(self) -> tuple[str, str]:
47
+ """Create OAuth authorization URL with PKCE."""
48
+ verifier, challenge = self._generate_pkce()
49
+
50
+ params = {
51
+ "code": "true",
52
+ "client_id": self.client_id,
53
+ "response_type": "code",
54
+ "redirect_uri": "https://console.anthropic.com/oauth/code/callback",
55
+ "scope": "org:create_api_key user:profile user:inference",
56
+ "code_challenge": challenge,
57
+ "code_challenge_method": "S256",
58
+ "state": verifier,
59
+ }
60
+
61
+ url = "https://claude.ai/oauth/authorize?" + urllib.parse.urlencode(params)
62
+ return url, verifier
63
+
64
+ def _exchange_code_for_tokens(self, code: str, verifier: str) -> dict[str, str]:
65
+ """Exchange authorization code for access and refresh tokens."""
66
+ # Handle the code format (may have # separator for state)
67
+ code_parts = code.split("#")
68
+ auth_code = code_parts[0]
69
+ state = code_parts[1] if len(code_parts) > 1 else verifier
70
+
71
+ data = {
72
+ "code": auth_code,
73
+ "state": state,
74
+ "grant_type": "authorization_code",
75
+ "client_id": self.client_id,
76
+ "redirect_uri": "https://console.anthropic.com/oauth/code/callback",
77
+ "code_verifier": verifier,
78
+ }
79
+
80
+ with httpx.Client() as client:
81
+ response = client.post(
82
+ "https://console.anthropic.com/v1/oauth/token",
83
+ headers={"Content-Type": "application/json"},
84
+ json=data,
85
+ )
86
+
87
+ if not response.is_success:
88
+ error_msg = (
89
+ f"Token exchange failed: {response.status_code} {response.text}"
90
+ )
91
+ logger.error(error_msg)
92
+ raise Exception(error_msg)
93
+
94
+ return response.json()
95
+
96
+ def refresh_access_token(self, refresh_token: str) -> dict[str, str]:
97
+ """Refresh access token using refresh token."""
98
+ data = {
99
+ "grant_type": "refresh_token",
100
+ "refresh_token": refresh_token,
101
+ "client_id": self.client_id,
102
+ }
103
+
104
+ with httpx.Client() as client:
105
+ response = client.post(
106
+ "https://console.anthropic.com/v1/oauth/token",
107
+ headers={"Content-Type": "application/json"},
108
+ json=data,
109
+ )
110
+
111
+ if not response.is_success:
112
+ error_msg = (
113
+ f"Token refresh failed: {response.status_code} {response.text}"
114
+ )
115
+ logger.error(error_msg)
116
+ raise Exception(error_msg)
117
+
118
+ return response.json()
119
+
120
+ def authenticate(self) -> bool:
121
+ """Complete OAuth authentication flow."""
122
+ console.print(
123
+ "\n[bold blue]Claude Pro/Max Subscription Authentication[/bold blue]"
124
+ )
125
+ console.print(
126
+ "This will open your web browser to authenticate with your Claude subscription.\n"
127
+ )
128
+
129
+ # Check if user wants to proceed
130
+ if not questionary.confirm(
131
+ "Continue with browser-based authentication?", default=True
132
+ ).ask():
133
+ console.print("[yellow]Authentication cancelled.[/yellow]")
134
+ return False
135
+
136
+ try:
137
+ # Step 1: Create authorization URL
138
+ with Progress(
139
+ SpinnerColumn(),
140
+ TextColumn("[progress.description]{task.description}"),
141
+ console=console,
142
+ ) as progress:
143
+ task = progress.add_task("Preparing authentication...", total=None)
144
+
145
+ auth_url, verifier = self._create_authorization_url()
146
+ progress.update(task, description="Opening browser...")
147
+
148
+ # Open browser for user authorization
149
+ webbrowser.open(auth_url)
150
+
151
+ console.print("\n[green]✓[/green] Browser opened for authentication")
152
+ console.print(
153
+ "[dim]If your browser didn't open automatically, visit this URL:[/dim]"
154
+ )
155
+ console.print(f"[dim]{auth_url}[/dim]\n")
156
+
157
+ # Get authorization code from user
158
+ console.print("After authorizing, you'll be redirected to a callback URL.")
159
+ console.print(
160
+ "Copy the 'code' that shows up on your screen and paste it here."
161
+ )
162
+
163
+ auth_code = questionary.text(
164
+ "Enter the authorization code:",
165
+ validate=lambda x: len(x.strip()) > 0
166
+ or "Authorization code is required",
167
+ ).ask()
168
+
169
+ if not auth_code:
170
+ console.print("[yellow]Authentication cancelled.[/yellow]")
171
+ return False
172
+
173
+ # Step 2: Exchange code for tokens
174
+ with Progress(
175
+ SpinnerColumn(),
176
+ TextColumn("[progress.description]{task.description}"),
177
+ console=console,
178
+ ) as progress:
179
+ task = progress.add_task("Exchanging code for tokens...", total=None)
180
+
181
+ tokens = self._exchange_code_for_tokens(auth_code.strip(), verifier)
182
+
183
+ # Calculate expiration time if provided
184
+ expires_at = None
185
+ if "expires_in" in tokens:
186
+ expires_in = int(tokens["expires_in"])
187
+ expires_dt = datetime.now(timezone.utc).timestamp() + expires_in
188
+ expires_at = datetime.fromtimestamp(
189
+ expires_dt, timezone.utc
190
+ ).isoformat()
191
+
192
+ # Store tokens
193
+ oauth_token = OAuthToken(
194
+ access_token=tokens["access_token"],
195
+ refresh_token=tokens["refresh_token"],
196
+ expires_at=expires_at,
197
+ )
198
+
199
+ if self.token_manager.store_oauth_token("anthropic", oauth_token):
200
+ console.print(
201
+ "\n[bold green]✓ Authentication successful![/bold green]"
202
+ )
203
+ console.print(
204
+ "Your Claude Pro/Max subscription is now configured for SQLSaber."
205
+ )
206
+ return True
207
+ else:
208
+ console.print("[red]✗ Failed to store authentication tokens.[/red]")
209
+ return False
210
+
211
+ except KeyboardInterrupt:
212
+ console.print("\n[yellow]Authentication cancelled by user.[/yellow]")
213
+ return False
214
+ except Exception as e:
215
+ logger.error(f"OAuth authentication failed: {e}")
216
+ console.print(f"[red]✗ Authentication failed: {str(e)}[/red]")
217
+ return False
218
+
219
+ def refresh_token_if_needed(self) -> OAuthToken | None:
220
+ """Refresh OAuth token if it's expired or expiring soon."""
221
+ current_token = self.token_manager.get_oauth_token("anthropic")
222
+ if not current_token:
223
+ return None
224
+
225
+ # If token is not expired and not expiring soon, return it as-is
226
+ if not current_token.is_expired() and not current_token.expires_soon():
227
+ return current_token
228
+
229
+ # Attempt to refresh
230
+ try:
231
+ console.print("Refreshing OAuth token...", style="dim")
232
+ new_tokens = self.refresh_access_token(current_token.refresh_token)
233
+
234
+ # Calculate new expiration time
235
+ expires_at = None
236
+ if "expires_in" in new_tokens:
237
+ expires_in = int(new_tokens["expires_in"])
238
+ expires_dt = datetime.now(timezone.utc).timestamp() + expires_in
239
+ expires_at = datetime.fromtimestamp(
240
+ expires_dt, timezone.utc
241
+ ).isoformat()
242
+
243
+ # Create new token object
244
+ refreshed_token = OAuthToken(
245
+ access_token=new_tokens["access_token"],
246
+ refresh_token=new_tokens.get(
247
+ "refresh_token", current_token.refresh_token
248
+ ),
249
+ expires_at=expires_at,
250
+ )
251
+
252
+ # Store the refreshed token
253
+ if self.token_manager.store_oauth_token("anthropic", refreshed_token):
254
+ console.print("OAuth token refreshed successfully", style="green")
255
+ return refreshed_token
256
+ else:
257
+ console.print("Failed to store refreshed token", style="yellow")
258
+ return current_token
259
+
260
+ except Exception as e:
261
+ logger.warning(f"Token refresh failed: {e}")
262
+ console.print(
263
+ "Token refresh failed. You may need to re-authenticate.", style="yellow"
264
+ )
265
+ return current_token
266
+
267
+ def remove_authentication(self) -> bool:
268
+ """Remove stored OAuth authentication."""
269
+ return self.token_manager.remove_oauth_token("anthropic")
270
+
271
+ def has_valid_authentication(self) -> bool:
272
+ """Check if valid OAuth authentication exists."""
273
+ token = self.token_manager.get_oauth_token("anthropic")
274
+ return token is not None and not token.is_expired()
@@ -0,0 +1,175 @@
1
+ """OAuth token management for SQLSaber."""
2
+
3
+ import json
4
+ import logging
5
+ from datetime import datetime, timedelta, timezone
6
+ from typing import Any
7
+
8
+ import keyring
9
+ from rich.console import Console
10
+
11
+ console = Console()
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class OAuthToken:
16
+ """Represents an OAuth token with metadata."""
17
+
18
+ def __init__(
19
+ self,
20
+ access_token: str,
21
+ refresh_token: str,
22
+ expires_at: str | None = None,
23
+ token_type: str = "Bearer",
24
+ ):
25
+ self.access_token = access_token
26
+ self.refresh_token = refresh_token
27
+ self.expires_at = expires_at
28
+ self.token_type = token_type
29
+
30
+ @classmethod
31
+ def from_dict(cls, data: dict[str, Any]) -> "OAuthToken":
32
+ """Create token from dictionary."""
33
+ return cls(
34
+ access_token=data["access_token"],
35
+ refresh_token=data["refresh_token"],
36
+ expires_at=data.get("expires_at"),
37
+ token_type=data.get("token_type", "Bearer"),
38
+ )
39
+
40
+ def to_dict(self) -> dict[str, Any]:
41
+ """Convert token to dictionary."""
42
+ return {
43
+ "access_token": self.access_token,
44
+ "refresh_token": self.refresh_token,
45
+ "expires_at": self.expires_at,
46
+ "token_type": self.token_type,
47
+ }
48
+
49
+ def is_expired(self) -> bool:
50
+ """Check if the token is expired."""
51
+ if not self.expires_at:
52
+ return False
53
+
54
+ try:
55
+ expires_dt = datetime.fromisoformat(self.expires_at.replace("Z", "+00:00"))
56
+ return datetime.now(timezone.utc) >= expires_dt
57
+ except (ValueError, AttributeError):
58
+ # If we can't parse the expiration, assume expired for safety
59
+ return True
60
+
61
+ def expires_soon(self, buffer_seconds: int = 300) -> bool:
62
+ """Check if token expires within buffer_seconds."""
63
+ if not self.expires_at:
64
+ return False
65
+
66
+ try:
67
+ expires_dt = datetime.fromisoformat(self.expires_at.replace("Z", "+00:00"))
68
+
69
+ return (
70
+ datetime.now(timezone.utc) + timedelta(seconds=buffer_seconds)
71
+ ) >= expires_dt
72
+ except (ValueError, AttributeError):
73
+ return True
74
+
75
+
76
+ class OAuthTokenManager:
77
+ """Manages OAuth tokens with secure storage and refresh logic."""
78
+
79
+ def __init__(self):
80
+ self.service_prefix = "sqlsaber"
81
+
82
+ def get_oauth_token(self, provider: str) -> OAuthToken | None:
83
+ """Get OAuth token for the specified provider."""
84
+ service_name = self._get_service_name(provider)
85
+
86
+ try:
87
+ token_data = keyring.get_password(service_name, provider)
88
+ if not token_data:
89
+ return None
90
+
91
+ # Parse the stored JSON
92
+ data = json.loads(token_data)
93
+ token = OAuthToken.from_dict(data)
94
+
95
+ # Check if token is expired
96
+ if token.is_expired():
97
+ console.print(
98
+ f"OAuth token for {provider} has expired and needs refresh",
99
+ style="dim yellow",
100
+ )
101
+ return token # Return anyway for refresh attempt
102
+
103
+ if token.expires_soon():
104
+ console.print(
105
+ f"OAuth token for {provider} expires soon, consider refreshing",
106
+ style="dim yellow",
107
+ )
108
+
109
+ return token
110
+
111
+ except Exception as e:
112
+ logger.warning(f"Failed to retrieve OAuth token for {provider}: {e}")
113
+ return None
114
+
115
+ def store_oauth_token(self, provider: str, token: OAuthToken) -> bool:
116
+ """Store OAuth token securely."""
117
+ service_name = self._get_service_name(provider)
118
+
119
+ try:
120
+ token_data = json.dumps(token.to_dict())
121
+ keyring.set_password(service_name, provider, token_data)
122
+ console.print(f"OAuth token for {provider} stored securely", style="green")
123
+ return True
124
+ except Exception as e:
125
+ logger.error(f"Failed to store OAuth token for {provider}: {e}")
126
+ console.print(
127
+ f"Warning: Could not store OAuth token in keyring: {e}",
128
+ style="yellow",
129
+ )
130
+ return False
131
+
132
+ def update_oauth_token(
133
+ self, provider: str, access_token: str, expires_at: str | None = None
134
+ ) -> bool:
135
+ """Update only the access token (keep refresh token)."""
136
+ existing_token = self.get_oauth_token(provider)
137
+ if not existing_token:
138
+ console.print(
139
+ f"No existing OAuth token found for {provider}", style="yellow"
140
+ )
141
+ return False
142
+
143
+ # Update the access token while preserving refresh token
144
+ updated_token = OAuthToken(
145
+ access_token=access_token,
146
+ refresh_token=existing_token.refresh_token,
147
+ expires_at=expires_at,
148
+ token_type=existing_token.token_type,
149
+ )
150
+
151
+ return self.store_oauth_token(provider, updated_token)
152
+
153
+ def remove_oauth_token(self, provider: str) -> bool:
154
+ """Remove OAuth token from storage."""
155
+ service_name = self._get_service_name(provider)
156
+
157
+ try:
158
+ keyring.delete_password(service_name, provider)
159
+ console.print(f"OAuth token for {provider} removed", style="green")
160
+ return True
161
+ except keyring.errors.PasswordDeleteError:
162
+ # Token doesn't exist
163
+ return True
164
+ except Exception as e:
165
+ logger.error(f"Failed to remove OAuth token for {provider}: {e}")
166
+ console.print(f"Warning: Could not remove OAuth token: {e}", style="yellow")
167
+ return False
168
+
169
+ def has_oauth_token(self, provider: str) -> bool:
170
+ """Check if OAuth token exists for provider."""
171
+ return self.get_oauth_token(provider) is not None
172
+
173
+ def _get_service_name(self, provider: str) -> str:
174
+ """Get the keyring service name for OAuth tokens."""
175
+ return f"{self.service_prefix}-{provider}-oauth"
@@ -5,11 +5,13 @@ import os
5
5
  import platform
6
6
  import stat
7
7
  from pathlib import Path
8
- from typing import Any, Dict, Optional
8
+ from typing import Any
9
9
 
10
10
  import platformdirs
11
11
 
12
12
  from sqlsaber.config.api_keys import APIKeyManager
13
+ from sqlsaber.config.auth import AuthConfigManager, AuthMethod
14
+ from sqlsaber.config.oauth_flow import AnthropicOAuthFlow
13
15
 
14
16
 
15
17
  class ModelConfigManager:
@@ -40,7 +42,7 @@ class ModelConfigManager:
40
42
  except (OSError, PermissionError):
41
43
  pass
42
44
 
43
- def _load_config(self) -> Dict[str, Any]:
45
+ def _load_config(self) -> dict[str, Any]:
44
46
  """Load configuration from file."""
45
47
  if not self.config_file.exists():
46
48
  return {"model": self.DEFAULT_MODEL}
@@ -55,7 +57,7 @@ class ModelConfigManager:
55
57
  except (json.JSONDecodeError, IOError):
56
58
  return {"model": self.DEFAULT_MODEL}
57
59
 
58
- def _save_config(self, config: Dict[str, Any]) -> None:
60
+ def _save_config(self, config: dict[str, Any]) -> None:
59
61
  """Save configuration to file."""
60
62
  with open(self.config_file, "w") as f:
61
63
  json.dump(config, f, indent=2)
@@ -81,35 +83,48 @@ class Config:
81
83
  self.model_config_manager = ModelConfigManager()
82
84
  self.model_name = self.model_config_manager.get_model()
83
85
  self.api_key_manager = APIKeyManager()
84
- self.api_key = self._get_api_key()
86
+ self.auth_config_manager = AuthConfigManager()
87
+ self.oauth_flow = AnthropicOAuthFlow()
88
+
89
+ # Get authentication credentials based on configured method
90
+ self.auth_method = self.auth_config_manager.get_auth_method()
91
+ self.api_key = None
92
+ self.oauth_token = None
93
+
94
+ if self.auth_method == AuthMethod.CLAUDE_PRO:
95
+ # Try to get OAuth token and refresh if needed
96
+ try:
97
+ token = self.oauth_flow.refresh_token_if_needed()
98
+ if token:
99
+ self.oauth_token = token.access_token
100
+ except Exception:
101
+ # OAuth token unavailable, will need to re-authenticate
102
+ pass
103
+ else:
104
+ # Use API key authentication (default or explicitly configured)
105
+ self.api_key = self._get_api_key()
85
106
 
86
- def _get_api_key(self) -> Optional[str]:
107
+ def _get_api_key(self) -> str | None:
87
108
  """Get API key for the model provider using cascading logic."""
88
109
  model = self.model_name
89
-
90
- if model.startswith("openai:"):
91
- return self.api_key_manager.get_api_key("openai")
92
- elif model.startswith("anthropic:"):
110
+ if model.startswith("anthropic:"):
93
111
  return self.api_key_manager.get_api_key("anthropic")
94
- else:
95
- # For other providers, use generic key
96
- return self.api_key_manager.get_api_key("generic")
97
112
 
98
113
  def set_model(self, model: str) -> None:
99
114
  """Set the model and update configuration."""
100
115
  self.model_config_manager.set_model(model)
101
116
  self.model_name = model
102
- # Update API key for new model
103
- self.api_key = self._get_api_key()
104
117
 
105
118
  def validate(self):
106
119
  """Validate that necessary configuration is present."""
120
+ # 1. Claude-Pro flow → require OAuth token only
121
+ if self.auth_method == AuthMethod.CLAUDE_PRO:
122
+ if not self.oauth_token:
123
+ raise ValueError(
124
+ "OAuth token not available. Run 'saber auth setup' to authenticate with Claude Pro."
125
+ )
126
+ return # OAuth path satisfied – nothing more to check
127
+
128
+ # 2. Default / API-key flow → require API key
107
129
  if not self.api_key:
108
- model = self.model_name
109
- provider = "generic"
110
- if model.startswith("openai:"):
111
- provider = "OpenAI"
112
- elif model.startswith("anthropic:"):
113
- provider = "Anthropic"
114
-
115
- raise ValueError(f"{provider} API key not found.")
130
+ raise ValueError("Anthropic API key not found.")