sqlsaber 0.7.0__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- sqlsaber/agents/anthropic.py +283 -176
- sqlsaber/agents/base.py +11 -11
- sqlsaber/agents/streaming.py +3 -3
- sqlsaber/cli/auth.py +142 -0
- sqlsaber/cli/commands.py +9 -4
- sqlsaber/cli/completers.py +3 -5
- sqlsaber/cli/database.py +9 -10
- sqlsaber/cli/display.py +5 -7
- sqlsaber/cli/interactive.py +2 -3
- sqlsaber/cli/memory.py +7 -9
- sqlsaber/cli/models.py +1 -2
- sqlsaber/cli/streaming.py +5 -31
- sqlsaber/clients/__init__.py +6 -0
- sqlsaber/clients/anthropic.py +285 -0
- sqlsaber/clients/base.py +31 -0
- sqlsaber/clients/exceptions.py +117 -0
- sqlsaber/clients/models.py +282 -0
- sqlsaber/clients/streaming.py +257 -0
- sqlsaber/config/api_keys.py +2 -3
- sqlsaber/config/auth.py +86 -0
- sqlsaber/config/database.py +20 -20
- sqlsaber/config/oauth_flow.py +274 -0
- sqlsaber/config/oauth_tokens.py +175 -0
- sqlsaber/config/settings.py +37 -22
- sqlsaber/database/connection.py +9 -9
- sqlsaber/database/schema.py +25 -25
- sqlsaber/mcp/mcp.py +3 -4
- sqlsaber/memory/manager.py +3 -5
- sqlsaber/memory/storage.py +7 -8
- sqlsaber/models/events.py +4 -4
- sqlsaber/models/types.py +10 -10
- {sqlsaber-0.7.0.dist-info → sqlsaber-0.8.1.dist-info}/METADATA +1 -1
- sqlsaber-0.8.1.dist-info/RECORD +46 -0
- sqlsaber-0.7.0.dist-info/RECORD +0 -36
- {sqlsaber-0.7.0.dist-info → sqlsaber-0.8.1.dist-info}/WHEEL +0 -0
- {sqlsaber-0.7.0.dist-info → sqlsaber-0.8.1.dist-info}/entry_points.txt +0 -0
- {sqlsaber-0.7.0.dist-info → sqlsaber-0.8.1.dist-info}/licenses/LICENSE +0 -0
sqlsaber/config/database.py
CHANGED
|
@@ -6,7 +6,7 @@ import platform
|
|
|
6
6
|
import stat
|
|
7
7
|
from dataclasses import dataclass
|
|
8
8
|
from pathlib import Path
|
|
9
|
-
from typing import Any
|
|
9
|
+
from typing import Any
|
|
10
10
|
from urllib.parse import quote_plus
|
|
11
11
|
|
|
12
12
|
import keyring
|
|
@@ -19,16 +19,16 @@ class DatabaseConfig:
|
|
|
19
19
|
|
|
20
20
|
name: str
|
|
21
21
|
type: str # postgresql, mysql, sqlite, csv
|
|
22
|
-
host:
|
|
23
|
-
port:
|
|
22
|
+
host: str | None
|
|
23
|
+
port: int | None
|
|
24
24
|
database: str
|
|
25
|
-
username:
|
|
26
|
-
password:
|
|
27
|
-
ssl_mode:
|
|
28
|
-
ssl_ca:
|
|
29
|
-
ssl_cert:
|
|
30
|
-
ssl_key:
|
|
31
|
-
schema:
|
|
25
|
+
username: str | None
|
|
26
|
+
password: str | None = None
|
|
27
|
+
ssl_mode: str | None = None
|
|
28
|
+
ssl_ca: str | None = None
|
|
29
|
+
ssl_cert: str | None = None
|
|
30
|
+
ssl_key: str | None = None
|
|
31
|
+
schema: str | None = None
|
|
32
32
|
|
|
33
33
|
def to_connection_string(self) -> str:
|
|
34
34
|
"""Convert config to database connection string."""
|
|
@@ -115,7 +115,7 @@ class DatabaseConfig:
|
|
|
115
115
|
else:
|
|
116
116
|
raise ValueError(f"Unsupported database type: {self.type}")
|
|
117
117
|
|
|
118
|
-
def _get_password_from_keyring(self) ->
|
|
118
|
+
def _get_password_from_keyring(self) -> str | None:
|
|
119
119
|
"""Get password from OS keyring."""
|
|
120
120
|
try:
|
|
121
121
|
return keyring.get_password("sqlsaber", f"{self.name}_{self.username}")
|
|
@@ -133,7 +133,7 @@ class DatabaseConfig:
|
|
|
133
133
|
except Exception:
|
|
134
134
|
pass
|
|
135
135
|
|
|
136
|
-
def to_dict(self) ->
|
|
136
|
+
def to_dict(self) -> dict[str, Any]:
|
|
137
137
|
"""Convert to dictionary for JSON serialization."""
|
|
138
138
|
return {
|
|
139
139
|
"name": self.name,
|
|
@@ -150,7 +150,7 @@ class DatabaseConfig:
|
|
|
150
150
|
}
|
|
151
151
|
|
|
152
152
|
@classmethod
|
|
153
|
-
def from_dict(cls, data:
|
|
153
|
+
def from_dict(cls, data: dict[str, Any]) -> "DatabaseConfig":
|
|
154
154
|
"""Create from dictionary."""
|
|
155
155
|
return cls(
|
|
156
156
|
name=data["name"],
|
|
@@ -202,7 +202,7 @@ class DatabaseConfigManager:
|
|
|
202
202
|
# The directory/file creation should still work
|
|
203
203
|
pass
|
|
204
204
|
|
|
205
|
-
def _load_config(self) ->
|
|
205
|
+
def _load_config(self) -> dict[str, Any]:
|
|
206
206
|
"""Load configuration from file."""
|
|
207
207
|
if not self.config_file.exists():
|
|
208
208
|
return {"default": None, "connections": {}}
|
|
@@ -213,7 +213,7 @@ class DatabaseConfigManager:
|
|
|
213
213
|
except (json.JSONDecodeError, IOError):
|
|
214
214
|
return {"default": None, "connections": {}}
|
|
215
215
|
|
|
216
|
-
def _save_config(self, config:
|
|
216
|
+
def _save_config(self, config: dict[str, Any]) -> None:
|
|
217
217
|
"""Save configuration to file."""
|
|
218
218
|
with open(self.config_file, "w") as f:
|
|
219
219
|
json.dump(config, f, indent=2)
|
|
@@ -222,7 +222,7 @@ class DatabaseConfigManager:
|
|
|
222
222
|
self._set_secure_permissions(self.config_file, is_directory=False)
|
|
223
223
|
|
|
224
224
|
def add_database(
|
|
225
|
-
self, db_config: DatabaseConfig, password:
|
|
225
|
+
self, db_config: DatabaseConfig, password: str | None = None
|
|
226
226
|
) -> None:
|
|
227
227
|
"""Add a database configuration."""
|
|
228
228
|
config = self._load_config()
|
|
@@ -244,7 +244,7 @@ class DatabaseConfigManager:
|
|
|
244
244
|
|
|
245
245
|
self._save_config(config)
|
|
246
246
|
|
|
247
|
-
def get_database(self, name: str) ->
|
|
247
|
+
def get_database(self, name: str) -> DatabaseConfig | None:
|
|
248
248
|
"""Get a database configuration by name."""
|
|
249
249
|
config = self._load_config()
|
|
250
250
|
|
|
@@ -253,7 +253,7 @@ class DatabaseConfigManager:
|
|
|
253
253
|
|
|
254
254
|
return DatabaseConfig.from_dict(config["connections"][name])
|
|
255
255
|
|
|
256
|
-
def get_default_database(self) ->
|
|
256
|
+
def get_default_database(self) -> DatabaseConfig | None:
|
|
257
257
|
"""Get the default database configuration."""
|
|
258
258
|
config = self._load_config()
|
|
259
259
|
|
|
@@ -263,7 +263,7 @@ class DatabaseConfigManager:
|
|
|
263
263
|
|
|
264
264
|
return self.get_database(default_name)
|
|
265
265
|
|
|
266
|
-
def list_databases(self) ->
|
|
266
|
+
def list_databases(self) -> list[DatabaseConfig]:
|
|
267
267
|
"""List all database configurations."""
|
|
268
268
|
config = self._load_config()
|
|
269
269
|
|
|
@@ -313,7 +313,7 @@ class DatabaseConfigManager:
|
|
|
313
313
|
config = self._load_config()
|
|
314
314
|
return len(config["connections"]) > 0
|
|
315
315
|
|
|
316
|
-
def get_default_name(self) ->
|
|
316
|
+
def get_default_name(self) -> str | None:
|
|
317
317
|
"""Get the name of the default database."""
|
|
318
318
|
config = self._load_config()
|
|
319
319
|
return config.get("default")
|
|
@@ -0,0 +1,274 @@
|
|
|
1
|
+
"""Synchronous OAuth flow management for Anthropic Claude Pro authentication."""
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
import hashlib
|
|
5
|
+
import logging
|
|
6
|
+
import secrets
|
|
7
|
+
import urllib.parse
|
|
8
|
+
import webbrowser
|
|
9
|
+
from datetime import datetime, timezone
|
|
10
|
+
|
|
11
|
+
import httpx
|
|
12
|
+
import questionary
|
|
13
|
+
from rich.console import Console
|
|
14
|
+
from rich.progress import Progress, SpinnerColumn, TextColumn
|
|
15
|
+
|
|
16
|
+
from .oauth_tokens import OAuthToken, OAuthTokenManager
|
|
17
|
+
|
|
18
|
+
console = Console()
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
CLIENT_ID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class AnthropicOAuthFlow:
|
|
26
|
+
"""Handles the complete OAuth flow for Anthropic Claude Pro authentication."""
|
|
27
|
+
|
|
28
|
+
def __init__(self):
|
|
29
|
+
self.client_id = CLIENT_ID
|
|
30
|
+
self.token_manager = OAuthTokenManager()
|
|
31
|
+
|
|
32
|
+
def _generate_pkce(self) -> tuple[str, str]:
|
|
33
|
+
"""Generate PKCE code verifier and challenge."""
|
|
34
|
+
verifier = (
|
|
35
|
+
base64.urlsafe_b64encode(secrets.token_bytes(32))
|
|
36
|
+
.decode("utf-8")
|
|
37
|
+
.rstrip("=")
|
|
38
|
+
)
|
|
39
|
+
challenge = (
|
|
40
|
+
base64.urlsafe_b64encode(hashlib.sha256(verifier.encode("utf-8")).digest())
|
|
41
|
+
.decode("utf-8")
|
|
42
|
+
.rstrip("=")
|
|
43
|
+
)
|
|
44
|
+
return verifier, challenge
|
|
45
|
+
|
|
46
|
+
def _create_authorization_url(self) -> tuple[str, str]:
|
|
47
|
+
"""Create OAuth authorization URL with PKCE."""
|
|
48
|
+
verifier, challenge = self._generate_pkce()
|
|
49
|
+
|
|
50
|
+
params = {
|
|
51
|
+
"code": "true",
|
|
52
|
+
"client_id": self.client_id,
|
|
53
|
+
"response_type": "code",
|
|
54
|
+
"redirect_uri": "https://console.anthropic.com/oauth/code/callback",
|
|
55
|
+
"scope": "org:create_api_key user:profile user:inference",
|
|
56
|
+
"code_challenge": challenge,
|
|
57
|
+
"code_challenge_method": "S256",
|
|
58
|
+
"state": verifier,
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
url = "https://claude.ai/oauth/authorize?" + urllib.parse.urlencode(params)
|
|
62
|
+
return url, verifier
|
|
63
|
+
|
|
64
|
+
def _exchange_code_for_tokens(self, code: str, verifier: str) -> dict[str, str]:
|
|
65
|
+
"""Exchange authorization code for access and refresh tokens."""
|
|
66
|
+
# Handle the code format (may have # separator for state)
|
|
67
|
+
code_parts = code.split("#")
|
|
68
|
+
auth_code = code_parts[0]
|
|
69
|
+
state = code_parts[1] if len(code_parts) > 1 else verifier
|
|
70
|
+
|
|
71
|
+
data = {
|
|
72
|
+
"code": auth_code,
|
|
73
|
+
"state": state,
|
|
74
|
+
"grant_type": "authorization_code",
|
|
75
|
+
"client_id": self.client_id,
|
|
76
|
+
"redirect_uri": "https://console.anthropic.com/oauth/code/callback",
|
|
77
|
+
"code_verifier": verifier,
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
with httpx.Client() as client:
|
|
81
|
+
response = client.post(
|
|
82
|
+
"https://console.anthropic.com/v1/oauth/token",
|
|
83
|
+
headers={"Content-Type": "application/json"},
|
|
84
|
+
json=data,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
if not response.is_success:
|
|
88
|
+
error_msg = (
|
|
89
|
+
f"Token exchange failed: {response.status_code} {response.text}"
|
|
90
|
+
)
|
|
91
|
+
logger.error(error_msg)
|
|
92
|
+
raise Exception(error_msg)
|
|
93
|
+
|
|
94
|
+
return response.json()
|
|
95
|
+
|
|
96
|
+
def refresh_access_token(self, refresh_token: str) -> dict[str, str]:
|
|
97
|
+
"""Refresh access token using refresh token."""
|
|
98
|
+
data = {
|
|
99
|
+
"grant_type": "refresh_token",
|
|
100
|
+
"refresh_token": refresh_token,
|
|
101
|
+
"client_id": self.client_id,
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
with httpx.Client() as client:
|
|
105
|
+
response = client.post(
|
|
106
|
+
"https://console.anthropic.com/v1/oauth/token",
|
|
107
|
+
headers={"Content-Type": "application/json"},
|
|
108
|
+
json=data,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
if not response.is_success:
|
|
112
|
+
error_msg = (
|
|
113
|
+
f"Token refresh failed: {response.status_code} {response.text}"
|
|
114
|
+
)
|
|
115
|
+
logger.error(error_msg)
|
|
116
|
+
raise Exception(error_msg)
|
|
117
|
+
|
|
118
|
+
return response.json()
|
|
119
|
+
|
|
120
|
+
def authenticate(self) -> bool:
|
|
121
|
+
"""Complete OAuth authentication flow."""
|
|
122
|
+
console.print(
|
|
123
|
+
"\n[bold blue]Claude Pro/Max Subscription Authentication[/bold blue]"
|
|
124
|
+
)
|
|
125
|
+
console.print(
|
|
126
|
+
"This will open your web browser to authenticate with your Claude subscription.\n"
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
# Check if user wants to proceed
|
|
130
|
+
if not questionary.confirm(
|
|
131
|
+
"Continue with browser-based authentication?", default=True
|
|
132
|
+
).ask():
|
|
133
|
+
console.print("[yellow]Authentication cancelled.[/yellow]")
|
|
134
|
+
return False
|
|
135
|
+
|
|
136
|
+
try:
|
|
137
|
+
# Step 1: Create authorization URL
|
|
138
|
+
with Progress(
|
|
139
|
+
SpinnerColumn(),
|
|
140
|
+
TextColumn("[progress.description]{task.description}"),
|
|
141
|
+
console=console,
|
|
142
|
+
) as progress:
|
|
143
|
+
task = progress.add_task("Preparing authentication...", total=None)
|
|
144
|
+
|
|
145
|
+
auth_url, verifier = self._create_authorization_url()
|
|
146
|
+
progress.update(task, description="Opening browser...")
|
|
147
|
+
|
|
148
|
+
# Open browser for user authorization
|
|
149
|
+
webbrowser.open(auth_url)
|
|
150
|
+
|
|
151
|
+
console.print("\n[green]✓[/green] Browser opened for authentication")
|
|
152
|
+
console.print(
|
|
153
|
+
"[dim]If your browser didn't open automatically, visit this URL:[/dim]"
|
|
154
|
+
)
|
|
155
|
+
console.print(f"[dim]{auth_url}[/dim]\n")
|
|
156
|
+
|
|
157
|
+
# Get authorization code from user
|
|
158
|
+
console.print("After authorizing, you'll be redirected to a callback URL.")
|
|
159
|
+
console.print(
|
|
160
|
+
"Copy the 'code' that shows up on your screen and paste it here."
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
auth_code = questionary.text(
|
|
164
|
+
"Enter the authorization code:",
|
|
165
|
+
validate=lambda x: len(x.strip()) > 0
|
|
166
|
+
or "Authorization code is required",
|
|
167
|
+
).ask()
|
|
168
|
+
|
|
169
|
+
if not auth_code:
|
|
170
|
+
console.print("[yellow]Authentication cancelled.[/yellow]")
|
|
171
|
+
return False
|
|
172
|
+
|
|
173
|
+
# Step 2: Exchange code for tokens
|
|
174
|
+
with Progress(
|
|
175
|
+
SpinnerColumn(),
|
|
176
|
+
TextColumn("[progress.description]{task.description}"),
|
|
177
|
+
console=console,
|
|
178
|
+
) as progress:
|
|
179
|
+
task = progress.add_task("Exchanging code for tokens...", total=None)
|
|
180
|
+
|
|
181
|
+
tokens = self._exchange_code_for_tokens(auth_code.strip(), verifier)
|
|
182
|
+
|
|
183
|
+
# Calculate expiration time if provided
|
|
184
|
+
expires_at = None
|
|
185
|
+
if "expires_in" in tokens:
|
|
186
|
+
expires_in = int(tokens["expires_in"])
|
|
187
|
+
expires_dt = datetime.now(timezone.utc).timestamp() + expires_in
|
|
188
|
+
expires_at = datetime.fromtimestamp(
|
|
189
|
+
expires_dt, timezone.utc
|
|
190
|
+
).isoformat()
|
|
191
|
+
|
|
192
|
+
# Store tokens
|
|
193
|
+
oauth_token = OAuthToken(
|
|
194
|
+
access_token=tokens["access_token"],
|
|
195
|
+
refresh_token=tokens["refresh_token"],
|
|
196
|
+
expires_at=expires_at,
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
if self.token_manager.store_oauth_token("anthropic", oauth_token):
|
|
200
|
+
console.print(
|
|
201
|
+
"\n[bold green]✓ Authentication successful![/bold green]"
|
|
202
|
+
)
|
|
203
|
+
console.print(
|
|
204
|
+
"Your Claude Pro/Max subscription is now configured for SQLSaber."
|
|
205
|
+
)
|
|
206
|
+
return True
|
|
207
|
+
else:
|
|
208
|
+
console.print("[red]✗ Failed to store authentication tokens.[/red]")
|
|
209
|
+
return False
|
|
210
|
+
|
|
211
|
+
except KeyboardInterrupt:
|
|
212
|
+
console.print("\n[yellow]Authentication cancelled by user.[/yellow]")
|
|
213
|
+
return False
|
|
214
|
+
except Exception as e:
|
|
215
|
+
logger.error(f"OAuth authentication failed: {e}")
|
|
216
|
+
console.print(f"[red]✗ Authentication failed: {str(e)}[/red]")
|
|
217
|
+
return False
|
|
218
|
+
|
|
219
|
+
def refresh_token_if_needed(self) -> OAuthToken | None:
|
|
220
|
+
"""Refresh OAuth token if it's expired or expiring soon."""
|
|
221
|
+
current_token = self.token_manager.get_oauth_token("anthropic")
|
|
222
|
+
if not current_token:
|
|
223
|
+
return None
|
|
224
|
+
|
|
225
|
+
# If token is not expired and not expiring soon, return it as-is
|
|
226
|
+
if not current_token.is_expired() and not current_token.expires_soon():
|
|
227
|
+
return current_token
|
|
228
|
+
|
|
229
|
+
# Attempt to refresh
|
|
230
|
+
try:
|
|
231
|
+
console.print("Refreshing OAuth token...", style="dim")
|
|
232
|
+
new_tokens = self.refresh_access_token(current_token.refresh_token)
|
|
233
|
+
|
|
234
|
+
# Calculate new expiration time
|
|
235
|
+
expires_at = None
|
|
236
|
+
if "expires_in" in new_tokens:
|
|
237
|
+
expires_in = int(new_tokens["expires_in"])
|
|
238
|
+
expires_dt = datetime.now(timezone.utc).timestamp() + expires_in
|
|
239
|
+
expires_at = datetime.fromtimestamp(
|
|
240
|
+
expires_dt, timezone.utc
|
|
241
|
+
).isoformat()
|
|
242
|
+
|
|
243
|
+
# Create new token object
|
|
244
|
+
refreshed_token = OAuthToken(
|
|
245
|
+
access_token=new_tokens["access_token"],
|
|
246
|
+
refresh_token=new_tokens.get(
|
|
247
|
+
"refresh_token", current_token.refresh_token
|
|
248
|
+
),
|
|
249
|
+
expires_at=expires_at,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
# Store the refreshed token
|
|
253
|
+
if self.token_manager.store_oauth_token("anthropic", refreshed_token):
|
|
254
|
+
console.print("OAuth token refreshed successfully", style="green")
|
|
255
|
+
return refreshed_token
|
|
256
|
+
else:
|
|
257
|
+
console.print("Failed to store refreshed token", style="yellow")
|
|
258
|
+
return current_token
|
|
259
|
+
|
|
260
|
+
except Exception as e:
|
|
261
|
+
logger.warning(f"Token refresh failed: {e}")
|
|
262
|
+
console.print(
|
|
263
|
+
"Token refresh failed. You may need to re-authenticate.", style="yellow"
|
|
264
|
+
)
|
|
265
|
+
return current_token
|
|
266
|
+
|
|
267
|
+
def remove_authentication(self) -> bool:
|
|
268
|
+
"""Remove stored OAuth authentication."""
|
|
269
|
+
return self.token_manager.remove_oauth_token("anthropic")
|
|
270
|
+
|
|
271
|
+
def has_valid_authentication(self) -> bool:
|
|
272
|
+
"""Check if valid OAuth authentication exists."""
|
|
273
|
+
token = self.token_manager.get_oauth_token("anthropic")
|
|
274
|
+
return token is not None and not token.is_expired()
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
"""OAuth token management for SQLSaber."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
from datetime import datetime, timedelta, timezone
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import keyring
|
|
9
|
+
from rich.console import Console
|
|
10
|
+
|
|
11
|
+
console = Console()
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class OAuthToken:
|
|
16
|
+
"""Represents an OAuth token with metadata."""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
access_token: str,
|
|
21
|
+
refresh_token: str,
|
|
22
|
+
expires_at: str | None = None,
|
|
23
|
+
token_type: str = "Bearer",
|
|
24
|
+
):
|
|
25
|
+
self.access_token = access_token
|
|
26
|
+
self.refresh_token = refresh_token
|
|
27
|
+
self.expires_at = expires_at
|
|
28
|
+
self.token_type = token_type
|
|
29
|
+
|
|
30
|
+
@classmethod
|
|
31
|
+
def from_dict(cls, data: dict[str, Any]) -> "OAuthToken":
|
|
32
|
+
"""Create token from dictionary."""
|
|
33
|
+
return cls(
|
|
34
|
+
access_token=data["access_token"],
|
|
35
|
+
refresh_token=data["refresh_token"],
|
|
36
|
+
expires_at=data.get("expires_at"),
|
|
37
|
+
token_type=data.get("token_type", "Bearer"),
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
def to_dict(self) -> dict[str, Any]:
|
|
41
|
+
"""Convert token to dictionary."""
|
|
42
|
+
return {
|
|
43
|
+
"access_token": self.access_token,
|
|
44
|
+
"refresh_token": self.refresh_token,
|
|
45
|
+
"expires_at": self.expires_at,
|
|
46
|
+
"token_type": self.token_type,
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
def is_expired(self) -> bool:
|
|
50
|
+
"""Check if the token is expired."""
|
|
51
|
+
if not self.expires_at:
|
|
52
|
+
return False
|
|
53
|
+
|
|
54
|
+
try:
|
|
55
|
+
expires_dt = datetime.fromisoformat(self.expires_at.replace("Z", "+00:00"))
|
|
56
|
+
return datetime.now(timezone.utc) >= expires_dt
|
|
57
|
+
except (ValueError, AttributeError):
|
|
58
|
+
# If we can't parse the expiration, assume expired for safety
|
|
59
|
+
return True
|
|
60
|
+
|
|
61
|
+
def expires_soon(self, buffer_seconds: int = 300) -> bool:
|
|
62
|
+
"""Check if token expires within buffer_seconds."""
|
|
63
|
+
if not self.expires_at:
|
|
64
|
+
return False
|
|
65
|
+
|
|
66
|
+
try:
|
|
67
|
+
expires_dt = datetime.fromisoformat(self.expires_at.replace("Z", "+00:00"))
|
|
68
|
+
|
|
69
|
+
return (
|
|
70
|
+
datetime.now(timezone.utc) + timedelta(seconds=buffer_seconds)
|
|
71
|
+
) >= expires_dt
|
|
72
|
+
except (ValueError, AttributeError):
|
|
73
|
+
return True
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class OAuthTokenManager:
|
|
77
|
+
"""Manages OAuth tokens with secure storage and refresh logic."""
|
|
78
|
+
|
|
79
|
+
def __init__(self):
|
|
80
|
+
self.service_prefix = "sqlsaber"
|
|
81
|
+
|
|
82
|
+
def get_oauth_token(self, provider: str) -> OAuthToken | None:
|
|
83
|
+
"""Get OAuth token for the specified provider."""
|
|
84
|
+
service_name = self._get_service_name(provider)
|
|
85
|
+
|
|
86
|
+
try:
|
|
87
|
+
token_data = keyring.get_password(service_name, provider)
|
|
88
|
+
if not token_data:
|
|
89
|
+
return None
|
|
90
|
+
|
|
91
|
+
# Parse the stored JSON
|
|
92
|
+
data = json.loads(token_data)
|
|
93
|
+
token = OAuthToken.from_dict(data)
|
|
94
|
+
|
|
95
|
+
# Check if token is expired
|
|
96
|
+
if token.is_expired():
|
|
97
|
+
console.print(
|
|
98
|
+
f"OAuth token for {provider} has expired and needs refresh",
|
|
99
|
+
style="dim yellow",
|
|
100
|
+
)
|
|
101
|
+
return token # Return anyway for refresh attempt
|
|
102
|
+
|
|
103
|
+
if token.expires_soon():
|
|
104
|
+
console.print(
|
|
105
|
+
f"OAuth token for {provider} expires soon, consider refreshing",
|
|
106
|
+
style="dim yellow",
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
return token
|
|
110
|
+
|
|
111
|
+
except Exception as e:
|
|
112
|
+
logger.warning(f"Failed to retrieve OAuth token for {provider}: {e}")
|
|
113
|
+
return None
|
|
114
|
+
|
|
115
|
+
def store_oauth_token(self, provider: str, token: OAuthToken) -> bool:
|
|
116
|
+
"""Store OAuth token securely."""
|
|
117
|
+
service_name = self._get_service_name(provider)
|
|
118
|
+
|
|
119
|
+
try:
|
|
120
|
+
token_data = json.dumps(token.to_dict())
|
|
121
|
+
keyring.set_password(service_name, provider, token_data)
|
|
122
|
+
console.print(f"OAuth token for {provider} stored securely", style="green")
|
|
123
|
+
return True
|
|
124
|
+
except Exception as e:
|
|
125
|
+
logger.error(f"Failed to store OAuth token for {provider}: {e}")
|
|
126
|
+
console.print(
|
|
127
|
+
f"Warning: Could not store OAuth token in keyring: {e}",
|
|
128
|
+
style="yellow",
|
|
129
|
+
)
|
|
130
|
+
return False
|
|
131
|
+
|
|
132
|
+
def update_oauth_token(
|
|
133
|
+
self, provider: str, access_token: str, expires_at: str | None = None
|
|
134
|
+
) -> bool:
|
|
135
|
+
"""Update only the access token (keep refresh token)."""
|
|
136
|
+
existing_token = self.get_oauth_token(provider)
|
|
137
|
+
if not existing_token:
|
|
138
|
+
console.print(
|
|
139
|
+
f"No existing OAuth token found for {provider}", style="yellow"
|
|
140
|
+
)
|
|
141
|
+
return False
|
|
142
|
+
|
|
143
|
+
# Update the access token while preserving refresh token
|
|
144
|
+
updated_token = OAuthToken(
|
|
145
|
+
access_token=access_token,
|
|
146
|
+
refresh_token=existing_token.refresh_token,
|
|
147
|
+
expires_at=expires_at,
|
|
148
|
+
token_type=existing_token.token_type,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
return self.store_oauth_token(provider, updated_token)
|
|
152
|
+
|
|
153
|
+
def remove_oauth_token(self, provider: str) -> bool:
|
|
154
|
+
"""Remove OAuth token from storage."""
|
|
155
|
+
service_name = self._get_service_name(provider)
|
|
156
|
+
|
|
157
|
+
try:
|
|
158
|
+
keyring.delete_password(service_name, provider)
|
|
159
|
+
console.print(f"OAuth token for {provider} removed", style="green")
|
|
160
|
+
return True
|
|
161
|
+
except keyring.errors.PasswordDeleteError:
|
|
162
|
+
# Token doesn't exist
|
|
163
|
+
return True
|
|
164
|
+
except Exception as e:
|
|
165
|
+
logger.error(f"Failed to remove OAuth token for {provider}: {e}")
|
|
166
|
+
console.print(f"Warning: Could not remove OAuth token: {e}", style="yellow")
|
|
167
|
+
return False
|
|
168
|
+
|
|
169
|
+
def has_oauth_token(self, provider: str) -> bool:
|
|
170
|
+
"""Check if OAuth token exists for provider."""
|
|
171
|
+
return self.get_oauth_token(provider) is not None
|
|
172
|
+
|
|
173
|
+
def _get_service_name(self, provider: str) -> str:
|
|
174
|
+
"""Get the keyring service name for OAuth tokens."""
|
|
175
|
+
return f"{self.service_prefix}-{provider}-oauth"
|
sqlsaber/config/settings.py
CHANGED
|
@@ -5,11 +5,13 @@ import os
|
|
|
5
5
|
import platform
|
|
6
6
|
import stat
|
|
7
7
|
from pathlib import Path
|
|
8
|
-
from typing import Any
|
|
8
|
+
from typing import Any
|
|
9
9
|
|
|
10
10
|
import platformdirs
|
|
11
11
|
|
|
12
12
|
from sqlsaber.config.api_keys import APIKeyManager
|
|
13
|
+
from sqlsaber.config.auth import AuthConfigManager, AuthMethod
|
|
14
|
+
from sqlsaber.config.oauth_flow import AnthropicOAuthFlow
|
|
13
15
|
|
|
14
16
|
|
|
15
17
|
class ModelConfigManager:
|
|
@@ -40,7 +42,7 @@ class ModelConfigManager:
|
|
|
40
42
|
except (OSError, PermissionError):
|
|
41
43
|
pass
|
|
42
44
|
|
|
43
|
-
def _load_config(self) ->
|
|
45
|
+
def _load_config(self) -> dict[str, Any]:
|
|
44
46
|
"""Load configuration from file."""
|
|
45
47
|
if not self.config_file.exists():
|
|
46
48
|
return {"model": self.DEFAULT_MODEL}
|
|
@@ -55,7 +57,7 @@ class ModelConfigManager:
|
|
|
55
57
|
except (json.JSONDecodeError, IOError):
|
|
56
58
|
return {"model": self.DEFAULT_MODEL}
|
|
57
59
|
|
|
58
|
-
def _save_config(self, config:
|
|
60
|
+
def _save_config(self, config: dict[str, Any]) -> None:
|
|
59
61
|
"""Save configuration to file."""
|
|
60
62
|
with open(self.config_file, "w") as f:
|
|
61
63
|
json.dump(config, f, indent=2)
|
|
@@ -81,35 +83,48 @@ class Config:
|
|
|
81
83
|
self.model_config_manager = ModelConfigManager()
|
|
82
84
|
self.model_name = self.model_config_manager.get_model()
|
|
83
85
|
self.api_key_manager = APIKeyManager()
|
|
84
|
-
self.
|
|
86
|
+
self.auth_config_manager = AuthConfigManager()
|
|
87
|
+
self.oauth_flow = AnthropicOAuthFlow()
|
|
88
|
+
|
|
89
|
+
# Get authentication credentials based on configured method
|
|
90
|
+
self.auth_method = self.auth_config_manager.get_auth_method()
|
|
91
|
+
self.api_key = None
|
|
92
|
+
self.oauth_token = None
|
|
93
|
+
|
|
94
|
+
if self.auth_method == AuthMethod.CLAUDE_PRO:
|
|
95
|
+
# Try to get OAuth token and refresh if needed
|
|
96
|
+
try:
|
|
97
|
+
token = self.oauth_flow.refresh_token_if_needed()
|
|
98
|
+
if token:
|
|
99
|
+
self.oauth_token = token.access_token
|
|
100
|
+
except Exception:
|
|
101
|
+
# OAuth token unavailable, will need to re-authenticate
|
|
102
|
+
pass
|
|
103
|
+
else:
|
|
104
|
+
# Use API key authentication (default or explicitly configured)
|
|
105
|
+
self.api_key = self._get_api_key()
|
|
85
106
|
|
|
86
|
-
def _get_api_key(self) ->
|
|
107
|
+
def _get_api_key(self) -> str | None:
|
|
87
108
|
"""Get API key for the model provider using cascading logic."""
|
|
88
109
|
model = self.model_name
|
|
89
|
-
|
|
90
|
-
if model.startswith("openai:"):
|
|
91
|
-
return self.api_key_manager.get_api_key("openai")
|
|
92
|
-
elif model.startswith("anthropic:"):
|
|
110
|
+
if model.startswith("anthropic:"):
|
|
93
111
|
return self.api_key_manager.get_api_key("anthropic")
|
|
94
|
-
else:
|
|
95
|
-
# For other providers, use generic key
|
|
96
|
-
return self.api_key_manager.get_api_key("generic")
|
|
97
112
|
|
|
98
113
|
def set_model(self, model: str) -> None:
|
|
99
114
|
"""Set the model and update configuration."""
|
|
100
115
|
self.model_config_manager.set_model(model)
|
|
101
116
|
self.model_name = model
|
|
102
|
-
# Update API key for new model
|
|
103
|
-
self.api_key = self._get_api_key()
|
|
104
117
|
|
|
105
118
|
def validate(self):
|
|
106
119
|
"""Validate that necessary configuration is present."""
|
|
120
|
+
# 1. Claude-Pro flow → require OAuth token only
|
|
121
|
+
if self.auth_method == AuthMethod.CLAUDE_PRO:
|
|
122
|
+
if not self.oauth_token:
|
|
123
|
+
raise ValueError(
|
|
124
|
+
"OAuth token not available. Run 'saber auth setup' to authenticate with Claude Pro."
|
|
125
|
+
)
|
|
126
|
+
return # OAuth path satisfied – nothing more to check
|
|
127
|
+
|
|
128
|
+
# 2. Default / API-key flow → require API key
|
|
107
129
|
if not self.api_key:
|
|
108
|
-
|
|
109
|
-
provider = "generic"
|
|
110
|
-
if model.startswith("openai:"):
|
|
111
|
-
provider = "OpenAI"
|
|
112
|
-
elif model.startswith("anthropic:"):
|
|
113
|
-
provider = "Anthropic"
|
|
114
|
-
|
|
115
|
-
raise ValueError(f"{provider} API key not found.")
|
|
130
|
+
raise ValueError("Anthropic API key not found.")
|