sqlsaber 0.25.0__py3-none-any.whl → 0.27.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- sqlsaber/agents/__init__.py +2 -2
- sqlsaber/agents/base.py +1 -1
- sqlsaber/agents/mcp.py +1 -1
- sqlsaber/agents/pydantic_ai_agent.py +207 -135
- sqlsaber/application/__init__.py +1 -0
- sqlsaber/application/auth_setup.py +164 -0
- sqlsaber/application/db_setup.py +223 -0
- sqlsaber/application/model_selection.py +98 -0
- sqlsaber/application/prompts.py +115 -0
- sqlsaber/cli/auth.py +22 -50
- sqlsaber/cli/commands.py +22 -28
- sqlsaber/cli/completers.py +2 -0
- sqlsaber/cli/database.py +25 -86
- sqlsaber/cli/display.py +29 -9
- sqlsaber/cli/interactive.py +150 -127
- sqlsaber/cli/models.py +18 -28
- sqlsaber/cli/onboarding.py +325 -0
- sqlsaber/cli/streaming.py +15 -17
- sqlsaber/cli/threads.py +10 -6
- sqlsaber/config/api_keys.py +2 -2
- sqlsaber/config/settings.py +25 -2
- sqlsaber/database/__init__.py +55 -1
- sqlsaber/database/base.py +124 -0
- sqlsaber/database/csv.py +133 -0
- sqlsaber/database/duckdb.py +313 -0
- sqlsaber/database/mysql.py +345 -0
- sqlsaber/database/postgresql.py +328 -0
- sqlsaber/database/schema.py +66 -963
- sqlsaber/database/sqlite.py +258 -0
- sqlsaber/mcp/mcp.py +1 -1
- sqlsaber/tools/sql_tools.py +1 -1
- {sqlsaber-0.25.0.dist-info → sqlsaber-0.27.0.dist-info}/METADATA +43 -9
- sqlsaber-0.27.0.dist-info/RECORD +58 -0
- sqlsaber/database/connection.py +0 -535
- sqlsaber-0.25.0.dist-info/RECORD +0 -47
- {sqlsaber-0.25.0.dist-info → sqlsaber-0.27.0.dist-info}/WHEEL +0 -0
- {sqlsaber-0.25.0.dist-info → sqlsaber-0.27.0.dist-info}/entry_points.txt +0 -0
- {sqlsaber-0.25.0.dist-info → sqlsaber-0.27.0.dist-info}/licenses/LICENSE +0 -0
sqlsaber/agents/__init__.py
CHANGED
sqlsaber/agents/base.py
CHANGED
sqlsaber/agents/mcp.py
CHANGED
|
@@ -6,15 +6,16 @@ function tools, and streaming event types directly.
|
|
|
6
6
|
|
|
7
7
|
import httpx
|
|
8
8
|
from pydantic_ai import Agent, RunContext
|
|
9
|
-
from pydantic_ai.models.anthropic import AnthropicModel
|
|
10
|
-
from pydantic_ai.models.google import GoogleModel
|
|
11
|
-
from pydantic_ai.models.
|
|
9
|
+
from pydantic_ai.models.anthropic import AnthropicModel, AnthropicModelSettings
|
|
10
|
+
from pydantic_ai.models.google import GoogleModel, GoogleModelSettings
|
|
11
|
+
from pydantic_ai.models.groq import GroqModelSettings
|
|
12
|
+
from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings
|
|
12
13
|
from pydantic_ai.providers.anthropic import AnthropicProvider
|
|
13
14
|
from pydantic_ai.providers.google import GoogleProvider
|
|
14
15
|
|
|
15
16
|
from sqlsaber.config import providers
|
|
16
17
|
from sqlsaber.config.settings import Config
|
|
17
|
-
from sqlsaber.database
|
|
18
|
+
from sqlsaber.database import (
|
|
18
19
|
BaseDatabaseConnection,
|
|
19
20
|
CSVConnection,
|
|
20
21
|
DuckDBConnection,
|
|
@@ -28,47 +29,119 @@ from sqlsaber.tools.registry import tool_registry
|
|
|
28
29
|
from sqlsaber.tools.sql_tools import SQLTool
|
|
29
30
|
|
|
30
31
|
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
model_name_only = (
|
|
54
|
-
cfg.model_name.split(":", 1)[1] if ":" in cfg.model_name else cfg.model_name
|
|
55
|
-
)
|
|
56
|
-
|
|
57
|
-
provider = providers.provider_from_model(cfg.model_name) or ""
|
|
58
|
-
if provider == "google":
|
|
59
|
-
model_obj = GoogleModel(
|
|
60
|
-
model_name_only, provider=GoogleProvider(api_key=cfg.api_key)
|
|
32
|
+
class SQLSaberAgent:
|
|
33
|
+
"""Pydantic-AI Agent wrapper for SQLSaber with enhanced state management."""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
db_connection: BaseDatabaseConnection,
|
|
38
|
+
database_name: str | None = None,
|
|
39
|
+
memory_manager: MemoryManager | None = None,
|
|
40
|
+
thinking_enabled: bool | None = None,
|
|
41
|
+
):
|
|
42
|
+
self.db_connection = db_connection
|
|
43
|
+
self.database_name = database_name
|
|
44
|
+
self.config = Config()
|
|
45
|
+
self.memory_manager = memory_manager or MemoryManager()
|
|
46
|
+
self.instruction_builder = InstructionBuilder(tool_registry)
|
|
47
|
+
self.db_type = self._get_database_type_name()
|
|
48
|
+
|
|
49
|
+
# Thinking configuration (CLI override or config default)
|
|
50
|
+
self.thinking_enabled = (
|
|
51
|
+
thinking_enabled
|
|
52
|
+
if thinking_enabled is not None
|
|
53
|
+
else self.config.thinking_enabled
|
|
61
54
|
)
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
55
|
+
|
|
56
|
+
# Configure SQL tools with the database connection
|
|
57
|
+
self._configure_sql_tools()
|
|
58
|
+
|
|
59
|
+
# Create the pydantic-ai agent
|
|
60
|
+
self.agent = self._build_agent()
|
|
61
|
+
|
|
62
|
+
def _configure_sql_tools(self) -> None:
|
|
63
|
+
"""Ensure SQL tools receive the active database connection."""
|
|
64
|
+
for tool_name in tool_registry.list_tools(category="sql"):
|
|
65
|
+
tool = tool_registry.get_tool(tool_name)
|
|
66
|
+
if isinstance(tool, SQLTool):
|
|
67
|
+
tool.set_connection(self.db_connection)
|
|
68
|
+
|
|
69
|
+
def _build_agent(self) -> Agent:
|
|
70
|
+
"""Create and configure the pydantic-ai Agent."""
|
|
71
|
+
self.config.validate()
|
|
72
|
+
|
|
73
|
+
model_name_only = (
|
|
74
|
+
self.config.model_name.split(":", 1)[1]
|
|
75
|
+
if ":" in self.config.model_name
|
|
76
|
+
else self.config.model_name
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
provider = providers.provider_from_model(self.config.model_name) or ""
|
|
80
|
+
self.is_oauth = provider == "anthropic" and bool(
|
|
81
|
+
getattr(self.config, "oauth_token", None)
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
agent = self._create_agent_for_provider(provider, model_name_only)
|
|
85
|
+
self._setup_system_prompt(agent)
|
|
86
|
+
self._register_tools(agent)
|
|
87
|
+
|
|
88
|
+
return agent
|
|
89
|
+
|
|
90
|
+
def _create_agent_for_provider(self, provider: str, model_name: str) -> Agent:
|
|
91
|
+
"""Create the agent based on the provider type."""
|
|
92
|
+
if provider == "google":
|
|
93
|
+
model_obj = GoogleModel(
|
|
94
|
+
model_name, provider=GoogleProvider(api_key=self.config.api_key)
|
|
95
|
+
)
|
|
96
|
+
if self.thinking_enabled:
|
|
97
|
+
settings = GoogleModelSettings(
|
|
98
|
+
google_thinking_config={"include_thoughts": True}
|
|
99
|
+
)
|
|
100
|
+
return Agent(model_obj, name="sqlsaber", model_settings=settings)
|
|
101
|
+
return Agent(model_obj, name="sqlsaber")
|
|
102
|
+
elif provider == "anthropic" and self.is_oauth:
|
|
103
|
+
return self._create_oauth_anthropic_agent(model_name)
|
|
104
|
+
elif provider == "anthropic":
|
|
105
|
+
if self.thinking_enabled:
|
|
106
|
+
settings = AnthropicModelSettings(
|
|
107
|
+
anthropic_thinking={
|
|
108
|
+
"type": "enabled",
|
|
109
|
+
"budget_tokens": 2048,
|
|
110
|
+
},
|
|
111
|
+
max_tokens=8192,
|
|
112
|
+
)
|
|
113
|
+
return Agent(
|
|
114
|
+
self.config.model_name, name="sqlsaber", model_settings=settings
|
|
115
|
+
)
|
|
116
|
+
return Agent(self.config.model_name, name="sqlsaber")
|
|
117
|
+
elif provider == "openai":
|
|
118
|
+
model_obj = OpenAIResponsesModel(model_name)
|
|
119
|
+
if self.thinking_enabled:
|
|
120
|
+
settings = OpenAIResponsesModelSettings(
|
|
121
|
+
openai_reasoning_effort="medium",
|
|
122
|
+
openai_reasoning_summary="auto",
|
|
123
|
+
)
|
|
124
|
+
return Agent(model_obj, name="sqlsaber", model_settings=settings)
|
|
125
|
+
return Agent(model_obj, name="sqlsaber")
|
|
126
|
+
elif provider == "groq":
|
|
127
|
+
if self.thinking_enabled:
|
|
128
|
+
settings = GroqModelSettings(groq_reasoning_format="parsed")
|
|
129
|
+
return Agent(
|
|
130
|
+
self.config.model_name, name="sqlsaber", model_settings=settings
|
|
131
|
+
)
|
|
132
|
+
return Agent(self.config.model_name, name="sqlsaber")
|
|
133
|
+
else:
|
|
134
|
+
return Agent(self.config.model_name, name="sqlsaber")
|
|
135
|
+
|
|
136
|
+
def _create_oauth_anthropic_agent(self, model_name: str) -> Agent:
|
|
137
|
+
"""Create an Anthropic agent with OAuth configuration."""
|
|
138
|
+
|
|
65
139
|
async def add_oauth_headers(request: httpx.Request) -> None: # type: ignore[override]
|
|
66
|
-
# Remove API-key header if present and add OAuth headers
|
|
67
140
|
if "x-api-key" in request.headers:
|
|
68
141
|
del request.headers["x-api-key"]
|
|
69
142
|
request.headers.update(
|
|
70
143
|
{
|
|
71
|
-
"Authorization": f"Bearer {
|
|
144
|
+
"Authorization": f"Bearer {self.config.oauth_token}",
|
|
72
145
|
"anthropic-version": "2023-06-01",
|
|
73
146
|
"anthropic-beta": "oauth-2025-04-20",
|
|
74
147
|
"User-Agent": "ClaudeCode/1.0 (Anthropic Claude Code CLI)",
|
|
@@ -79,100 +152,99 @@ def build_sqlsaber_agent(
|
|
|
79
152
|
|
|
80
153
|
http_client = httpx.AsyncClient(event_hooks={"request": [add_oauth_headers]})
|
|
81
154
|
provider_obj = AnthropicProvider(api_key="placeholder", http_client=http_client)
|
|
82
|
-
model_obj = AnthropicModel(
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
if not is_oauth:
|
|
98
|
-
|
|
99
|
-
@agent.system_prompt(dynamic=True)
|
|
100
|
-
async def sqlsaber_system_prompt(ctx: RunContext) -> str:
|
|
101
|
-
db_type = _get_database_type_name(db_connection)
|
|
102
|
-
instructions = instruction_builder.build_instructions(db_type=db_type)
|
|
103
|
-
|
|
104
|
-
# Add memory context if available
|
|
105
|
-
if database_name:
|
|
106
|
-
mem = memory_manager.format_memories_for_prompt(database_name)
|
|
107
|
-
else:
|
|
108
|
-
mem = ""
|
|
155
|
+
model_obj = AnthropicModel(model_name, provider=provider_obj)
|
|
156
|
+
if self.thinking_enabled:
|
|
157
|
+
settings = AnthropicModelSettings(
|
|
158
|
+
anthropic_thinking={
|
|
159
|
+
"type": "enabled",
|
|
160
|
+
"budget_tokens": 2048,
|
|
161
|
+
},
|
|
162
|
+
max_tokens=8192,
|
|
163
|
+
)
|
|
164
|
+
return Agent(model_obj, name="sqlsaber", model_settings=settings)
|
|
165
|
+
return Agent(model_obj, name="sqlsaber")
|
|
166
|
+
|
|
167
|
+
def _setup_system_prompt(self, agent: Agent) -> None:
|
|
168
|
+
"""Set up the dynamic system prompt for the agent."""
|
|
169
|
+
if not self.is_oauth:
|
|
109
170
|
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
"""
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
171
|
+
@agent.system_prompt(dynamic=True)
|
|
172
|
+
async def sqlsaber_system_prompt(ctx: RunContext) -> str:
|
|
173
|
+
instructions = self.instruction_builder.build_instructions(
|
|
174
|
+
db_type=self.db_type
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# Add memory context if available
|
|
178
|
+
mem = ""
|
|
179
|
+
if self.database_name:
|
|
180
|
+
mem = self.memory_manager.format_memories_for_prompt(
|
|
181
|
+
self.database_name
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
parts = [p for p in (instructions, mem) if p and p.strip()]
|
|
185
|
+
return "\n\n".join(parts) if parts else ""
|
|
186
|
+
else:
|
|
187
|
+
|
|
188
|
+
@agent.system_prompt(dynamic=True)
|
|
189
|
+
async def sqlsaber_system_prompt(ctx: RunContext) -> str:
|
|
190
|
+
return "You are Claude Code, Anthropic's official CLI for Claude."
|
|
191
|
+
|
|
192
|
+
def _register_tools(self, agent: Agent) -> None:
|
|
193
|
+
"""Register all the SQL tools with the agent."""
|
|
194
|
+
|
|
195
|
+
@agent.tool(name="list_tables")
|
|
196
|
+
async def list_tables(ctx: RunContext) -> str:
|
|
197
|
+
"""
|
|
198
|
+
Get a list of all tables in the database with row counts.
|
|
199
|
+
Use this first to discover available tables.
|
|
200
|
+
"""
|
|
201
|
+
tool = tool_registry.get_tool("list_tables")
|
|
202
|
+
return await tool.execute()
|
|
203
|
+
|
|
204
|
+
@agent.tool(name="introspect_schema")
|
|
205
|
+
async def introspect_schema(
|
|
206
|
+
ctx: RunContext, table_pattern: str | None = None
|
|
207
|
+
) -> str:
|
|
208
|
+
"""
|
|
209
|
+
Introspect database schema to understand table structures.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
table_pattern: Optional pattern to filter tables (e.g., 'public.users', 'user%', '%order%')
|
|
213
|
+
"""
|
|
214
|
+
tool = tool_registry.get_tool("introspect_schema")
|
|
215
|
+
return await tool.execute(table_pattern=table_pattern)
|
|
216
|
+
|
|
217
|
+
@agent.tool(name="execute_sql")
|
|
218
|
+
async def execute_sql(
|
|
219
|
+
ctx: RunContext, query: str, limit: int | None = 100
|
|
220
|
+
) -> str:
|
|
221
|
+
"""
|
|
222
|
+
Execute a SQL query and return the results.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
query: SQL query to execute
|
|
226
|
+
limit: Maximum number of rows to return (default: 100)
|
|
227
|
+
"""
|
|
228
|
+
tool = tool_registry.get_tool("execute_sql")
|
|
229
|
+
return await tool.execute(query=query, limit=limit)
|
|
230
|
+
|
|
231
|
+
def set_thinking(self, enabled: bool) -> None:
|
|
232
|
+
"""Update thinking settings and rebuild the agent."""
|
|
233
|
+
self.thinking_enabled = enabled
|
|
234
|
+
# Rebuild agent with new thinking settings
|
|
235
|
+
self.agent = self._build_agent()
|
|
236
|
+
|
|
237
|
+
def _get_database_type_name(self) -> str:
|
|
238
|
+
"""Get the human-readable database type name."""
|
|
239
|
+
if isinstance(self.db_connection, PostgreSQLConnection):
|
|
240
|
+
return "PostgreSQL"
|
|
241
|
+
elif isinstance(self.db_connection, MySQLConnection):
|
|
242
|
+
return "MySQL"
|
|
243
|
+
elif isinstance(self.db_connection, SQLiteConnection):
|
|
244
|
+
return "SQLite"
|
|
245
|
+
elif isinstance(self.db_connection, DuckDBConnection):
|
|
246
|
+
return "DuckDB"
|
|
247
|
+
elif isinstance(self.db_connection, CSVConnection):
|
|
248
|
+
return "DuckDB"
|
|
249
|
+
else:
|
|
250
|
+
return "database"
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Application layer for SQLsaber - shared business logic and interactive flows."""
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
"""Shared auth setup logic for onboarding and CLI."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
|
|
5
|
+
from questionary import Choice
|
|
6
|
+
from rich.console import Console
|
|
7
|
+
|
|
8
|
+
from sqlsaber.application.prompts import Prompter
|
|
9
|
+
from sqlsaber.config import providers
|
|
10
|
+
from sqlsaber.config.api_keys import APIKeyManager
|
|
11
|
+
from sqlsaber.config.auth import AuthConfigManager, AuthMethod
|
|
12
|
+
from sqlsaber.config.oauth_flow import AnthropicOAuthFlow
|
|
13
|
+
|
|
14
|
+
console = Console()
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
async def select_provider(prompter: Prompter, default: str = "anthropic") -> str | None:
|
|
18
|
+
"""Interactive provider selection.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
prompter: Prompter instance for interaction
|
|
22
|
+
default: Default provider to select
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
Selected provider name or None if cancelled
|
|
26
|
+
"""
|
|
27
|
+
provider = await prompter.select(
|
|
28
|
+
"Select AI provider:", choices=providers.all_keys(), default=default
|
|
29
|
+
)
|
|
30
|
+
return provider
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
async def configure_oauth_anthropic(
|
|
34
|
+
auth_manager: AuthConfigManager, run_in_thread: bool = False
|
|
35
|
+
) -> bool:
|
|
36
|
+
"""Configure Anthropic OAuth.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
auth_manager: AuthConfigManager instance
|
|
40
|
+
run_in_thread: Whether to run OAuth flow in a separate thread (for onboarding)
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
True if OAuth configured successfully, False otherwise
|
|
44
|
+
"""
|
|
45
|
+
flow = AnthropicOAuthFlow()
|
|
46
|
+
|
|
47
|
+
if run_in_thread:
|
|
48
|
+
# Run in thread to avoid event loop conflicts (onboarding)
|
|
49
|
+
oauth_success = await asyncio.to_thread(flow.authenticate)
|
|
50
|
+
else:
|
|
51
|
+
# Run directly (CLI)
|
|
52
|
+
oauth_success = flow.authenticate()
|
|
53
|
+
|
|
54
|
+
if oauth_success:
|
|
55
|
+
auth_manager.set_auth_method(AuthMethod.CLAUDE_PRO)
|
|
56
|
+
return True
|
|
57
|
+
|
|
58
|
+
return False
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
async def configure_api_key(
|
|
62
|
+
provider: str, api_key_manager: APIKeyManager, auth_manager: AuthConfigManager
|
|
63
|
+
) -> bool:
|
|
64
|
+
"""Configure API key for a provider.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
provider: Provider name
|
|
68
|
+
api_key_manager: APIKeyManager instance
|
|
69
|
+
auth_manager: AuthConfigManager instance
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
True if API key configured successfully, False otherwise
|
|
73
|
+
"""
|
|
74
|
+
# Get API key (cascades env -> keyring -> prompt)
|
|
75
|
+
api_key = api_key_manager.get_api_key(provider)
|
|
76
|
+
|
|
77
|
+
if api_key:
|
|
78
|
+
auth_manager.set_auth_method(AuthMethod.API_KEY)
|
|
79
|
+
return True
|
|
80
|
+
|
|
81
|
+
return False
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
async def setup_auth(
|
|
85
|
+
prompter: Prompter,
|
|
86
|
+
auth_manager: AuthConfigManager,
|
|
87
|
+
api_key_manager: APIKeyManager,
|
|
88
|
+
allow_oauth: bool = True,
|
|
89
|
+
default_provider: str = "anthropic",
|
|
90
|
+
run_oauth_in_thread: bool = False,
|
|
91
|
+
) -> tuple[bool, str | None]:
|
|
92
|
+
"""Interactive authentication setup.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
prompter: Prompter instance for interaction
|
|
96
|
+
auth_manager: AuthConfigManager instance
|
|
97
|
+
api_key_manager: APIKeyManager instance
|
|
98
|
+
allow_oauth: Whether to offer OAuth option for Anthropic
|
|
99
|
+
default_provider: Default provider to select
|
|
100
|
+
run_oauth_in_thread: Whether to run OAuth in thread (for onboarding)
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
Tuple of (success: bool, provider: str | None)
|
|
104
|
+
"""
|
|
105
|
+
# Check if auth is already configured
|
|
106
|
+
if auth_manager.has_auth_configured():
|
|
107
|
+
console.print("[green]✓ Authentication already configured![/green]")
|
|
108
|
+
return True, None
|
|
109
|
+
|
|
110
|
+
# Select provider
|
|
111
|
+
provider = await select_provider(prompter, default=default_provider)
|
|
112
|
+
|
|
113
|
+
if provider is None:
|
|
114
|
+
return False, None
|
|
115
|
+
|
|
116
|
+
# For Anthropic, offer OAuth or API key
|
|
117
|
+
if provider == "anthropic" and allow_oauth:
|
|
118
|
+
method_choice = await prompter.select(
|
|
119
|
+
"Authentication method:",
|
|
120
|
+
choices=[
|
|
121
|
+
Choice("API Key", value=AuthMethod.API_KEY),
|
|
122
|
+
Choice("Claude Pro/Max (OAuth)", value=AuthMethod.CLAUDE_PRO),
|
|
123
|
+
],
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
if method_choice is None:
|
|
127
|
+
return False, None
|
|
128
|
+
|
|
129
|
+
if method_choice == AuthMethod.CLAUDE_PRO:
|
|
130
|
+
console.print()
|
|
131
|
+
oauth_success = await configure_oauth_anthropic(
|
|
132
|
+
auth_manager, run_in_thread=run_oauth_in_thread
|
|
133
|
+
)
|
|
134
|
+
if oauth_success:
|
|
135
|
+
console.print(
|
|
136
|
+
"[green]✓ Anthropic OAuth configured successfully![/green]"
|
|
137
|
+
)
|
|
138
|
+
return True, provider
|
|
139
|
+
else:
|
|
140
|
+
console.print("[red]✗ Anthropic OAuth setup failed.[/red]")
|
|
141
|
+
return False, None
|
|
142
|
+
|
|
143
|
+
# API key flow
|
|
144
|
+
env_var = api_key_manager.get_env_var_name(provider)
|
|
145
|
+
|
|
146
|
+
console.print()
|
|
147
|
+
console.print(f"[dim]To use {provider.title()}, you need an API key.[/dim]")
|
|
148
|
+
console.print(f"[dim]You can set the {env_var} environment variable,[/dim]")
|
|
149
|
+
console.print("[dim]or enter it now to store securely in your OS keychain.[/dim]")
|
|
150
|
+
console.print()
|
|
151
|
+
|
|
152
|
+
# Configure API key
|
|
153
|
+
api_key_configured = await configure_api_key(
|
|
154
|
+
provider, api_key_manager, auth_manager
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
if api_key_configured:
|
|
158
|
+
console.print(
|
|
159
|
+
f"[green]✓ {provider.title()} API key configured successfully![/green]"
|
|
160
|
+
)
|
|
161
|
+
return True, provider
|
|
162
|
+
else:
|
|
163
|
+
console.print("[yellow]No API key provided.[/yellow]")
|
|
164
|
+
return False, None
|