sqlsaber 0.25.0__py3-none-any.whl → 0.27.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlsaber might be problematic. Click here for more details.

Files changed (38) hide show
  1. sqlsaber/agents/__init__.py +2 -2
  2. sqlsaber/agents/base.py +1 -1
  3. sqlsaber/agents/mcp.py +1 -1
  4. sqlsaber/agents/pydantic_ai_agent.py +207 -135
  5. sqlsaber/application/__init__.py +1 -0
  6. sqlsaber/application/auth_setup.py +164 -0
  7. sqlsaber/application/db_setup.py +223 -0
  8. sqlsaber/application/model_selection.py +98 -0
  9. sqlsaber/application/prompts.py +115 -0
  10. sqlsaber/cli/auth.py +22 -50
  11. sqlsaber/cli/commands.py +22 -28
  12. sqlsaber/cli/completers.py +2 -0
  13. sqlsaber/cli/database.py +25 -86
  14. sqlsaber/cli/display.py +29 -9
  15. sqlsaber/cli/interactive.py +150 -127
  16. sqlsaber/cli/models.py +18 -28
  17. sqlsaber/cli/onboarding.py +325 -0
  18. sqlsaber/cli/streaming.py +15 -17
  19. sqlsaber/cli/threads.py +10 -6
  20. sqlsaber/config/api_keys.py +2 -2
  21. sqlsaber/config/settings.py +25 -2
  22. sqlsaber/database/__init__.py +55 -1
  23. sqlsaber/database/base.py +124 -0
  24. sqlsaber/database/csv.py +133 -0
  25. sqlsaber/database/duckdb.py +313 -0
  26. sqlsaber/database/mysql.py +345 -0
  27. sqlsaber/database/postgresql.py +328 -0
  28. sqlsaber/database/schema.py +66 -963
  29. sqlsaber/database/sqlite.py +258 -0
  30. sqlsaber/mcp/mcp.py +1 -1
  31. sqlsaber/tools/sql_tools.py +1 -1
  32. {sqlsaber-0.25.0.dist-info → sqlsaber-0.27.0.dist-info}/METADATA +43 -9
  33. sqlsaber-0.27.0.dist-info/RECORD +58 -0
  34. sqlsaber/database/connection.py +0 -535
  35. sqlsaber-0.25.0.dist-info/RECORD +0 -47
  36. {sqlsaber-0.25.0.dist-info → sqlsaber-0.27.0.dist-info}/WHEEL +0 -0
  37. {sqlsaber-0.25.0.dist-info → sqlsaber-0.27.0.dist-info}/entry_points.txt +0 -0
  38. {sqlsaber-0.25.0.dist-info → sqlsaber-0.27.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,223 @@
1
+ """Shared database setup logic for onboarding and CLI."""
2
+
3
+ import getpass
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+
7
+ from rich.console import Console
8
+
9
+ from sqlsaber.application.prompts import Prompter
10
+ from sqlsaber.config.database import DatabaseConfig, DatabaseConfigManager
11
+
12
+ console = Console()
13
+
14
+
15
+ @dataclass
16
+ class DatabaseInput:
17
+ """Input data for database configuration."""
18
+
19
+ name: str
20
+ type: str
21
+ host: str
22
+ port: int
23
+ database: str
24
+ username: str
25
+ password: str | None
26
+ ssl_mode: str | None = None
27
+ ssl_ca: str | None = None
28
+ ssl_cert: str | None = None
29
+ ssl_key: str | None = None
30
+
31
+
32
+ async def collect_db_input(
33
+ prompter: Prompter,
34
+ name: str,
35
+ db_type: str = "postgresql",
36
+ include_ssl: bool = True,
37
+ ) -> DatabaseInput | None:
38
+ """Collect database connection details interactively.
39
+
40
+ Args:
41
+ prompter: Prompter instance for interaction
42
+ name: Database connection name
43
+ db_type: Initial database type (can be changed via prompt)
44
+ include_ssl: Whether to prompt for SSL configuration
45
+
46
+ Returns:
47
+ DatabaseInput with collected values or None if cancelled
48
+ """
49
+ # Ask for database type
50
+ db_type = await prompter.select(
51
+ "Database type:",
52
+ choices=["postgresql", "mysql", "sqlite", "duckdb"],
53
+ default=db_type,
54
+ )
55
+
56
+ if db_type is None:
57
+ return None
58
+
59
+ # Handle file-based databases
60
+ if db_type in {"sqlite", "duckdb"}:
61
+ database_path = await prompter.path(
62
+ f"{db_type.upper()} file path:", only_directories=False
63
+ )
64
+
65
+ if database_path is None:
66
+ return None
67
+
68
+ database = str(Path(database_path).expanduser().resolve())
69
+ host = "localhost"
70
+ port = 0
71
+ username = db_type
72
+ password = ""
73
+ ssl_mode = None
74
+ ssl_ca = None
75
+ ssl_cert = None
76
+ ssl_key = None
77
+
78
+ else:
79
+ # PostgreSQL/MySQL need connection details
80
+ host = await prompter.text("Host:", default="localhost")
81
+ if host is None:
82
+ return None
83
+
84
+ default_port = 5432 if db_type == "postgresql" else 3306
85
+ port_str = await prompter.text("Port:", default=str(default_port))
86
+ if port_str is None:
87
+ return None
88
+
89
+ try:
90
+ port = int(port_str)
91
+ except ValueError:
92
+ console.print("[red]Invalid port number. Using default.[/red]")
93
+ port = default_port
94
+
95
+ database = await prompter.text("Database name:")
96
+ if database is None:
97
+ return None
98
+
99
+ username = await prompter.text("Username:")
100
+ if username is None:
101
+ return None
102
+
103
+ password = getpass.getpass("Password (stored in your OS keychain): ")
104
+
105
+ ssl_mode = None
106
+ ssl_ca = None
107
+ ssl_cert = None
108
+ ssl_key = None
109
+
110
+ # Ask for SSL configuration if enabled
111
+ if include_ssl:
112
+ configure_ssl = await prompter.confirm(
113
+ "Configure SSL/TLS settings?", default=False
114
+ )
115
+ if configure_ssl:
116
+ if db_type == "postgresql":
117
+ ssl_mode = await prompter.select(
118
+ "SSL mode for PostgreSQL:",
119
+ choices=[
120
+ "disable",
121
+ "allow",
122
+ "prefer",
123
+ "require",
124
+ "verify-ca",
125
+ "verify-full",
126
+ ],
127
+ default="prefer",
128
+ )
129
+ elif db_type == "mysql":
130
+ ssl_mode = await prompter.select(
131
+ "SSL mode for MySQL:",
132
+ choices=[
133
+ "DISABLED",
134
+ "PREFERRED",
135
+ "REQUIRED",
136
+ "VERIFY_CA",
137
+ "VERIFY_IDENTITY",
138
+ ],
139
+ default="PREFERRED",
140
+ )
141
+
142
+ if ssl_mode and ssl_mode not in ["disable", "DISABLED"]:
143
+ specify_certs = await prompter.confirm(
144
+ "Specify SSL certificate files?", default=False
145
+ )
146
+ if specify_certs:
147
+ ssl_ca = await prompter.path("SSL CA certificate file:")
148
+ specify_client = await prompter.confirm(
149
+ "Specify client certificate?", default=False
150
+ )
151
+ if specify_client:
152
+ ssl_cert = await prompter.path(
153
+ "SSL client certificate file:"
154
+ )
155
+ ssl_key = await prompter.path(
156
+ "SSL client private key file:"
157
+ )
158
+
159
+ return DatabaseInput(
160
+ name=name,
161
+ type=db_type,
162
+ host=host,
163
+ port=port,
164
+ database=database,
165
+ username=username,
166
+ password=password,
167
+ ssl_mode=ssl_mode,
168
+ ssl_ca=ssl_ca,
169
+ ssl_cert=ssl_cert,
170
+ ssl_key=ssl_key,
171
+ )
172
+
173
+
174
+ def build_config(db_input: DatabaseInput) -> DatabaseConfig:
175
+ """Build DatabaseConfig from DatabaseInput."""
176
+ return DatabaseConfig(
177
+ name=db_input.name,
178
+ type=db_input.type,
179
+ host=db_input.host,
180
+ port=db_input.port,
181
+ database=db_input.database,
182
+ username=db_input.username,
183
+ ssl_mode=db_input.ssl_mode,
184
+ ssl_ca=db_input.ssl_ca,
185
+ ssl_cert=db_input.ssl_cert,
186
+ ssl_key=db_input.ssl_key,
187
+ )
188
+
189
+
190
+ async def test_connection(config: DatabaseConfig, password: str | None) -> bool:
191
+ """Test database connection.
192
+
193
+ Args:
194
+ config: DatabaseConfig to test
195
+ password: Password for connection (not stored in config yet)
196
+
197
+ Returns:
198
+ True if connection successful, False otherwise
199
+ """
200
+ from sqlsaber.database import DatabaseConnection
201
+
202
+ try:
203
+ connection_string = config.to_connection_string()
204
+ db_conn = DatabaseConnection(connection_string)
205
+ await db_conn.execute_query("SELECT 1 as test")
206
+ await db_conn.close()
207
+ return True
208
+ except Exception as e:
209
+ console.print(f"[bold red]Connection failed:[/bold red] {e}", style="red")
210
+ return False
211
+
212
+
213
+ def save_database(
214
+ config_manager: DatabaseConfigManager, config: DatabaseConfig, password: str | None
215
+ ) -> None:
216
+ """Save database configuration.
217
+
218
+ Args:
219
+ config_manager: DatabaseConfigManager instance
220
+ config: DatabaseConfig to save
221
+ password: Password to store in keyring (if provided)
222
+ """
223
+ config_manager.add_database(config, password if password else None)
@@ -0,0 +1,98 @@
1
+ """Shared model selection logic for onboarding and CLI."""
2
+
3
+ from questionary import Choice
4
+ from rich.console import Console
5
+
6
+ from sqlsaber.application.prompts import Prompter
7
+ from sqlsaber.cli.models import ModelManager
8
+
9
+ console = Console()
10
+
11
+
12
+ async def fetch_models(
13
+ model_manager: ModelManager, providers: list[str] | None = None
14
+ ) -> list[dict]:
15
+ """Fetch available models from models.dev API."""
16
+ return await model_manager.fetch_available_models(providers=providers)
17
+
18
+
19
+ async def choose_model(
20
+ prompter: Prompter,
21
+ models: list[dict],
22
+ restrict_provider: str | None = None,
23
+ use_search_filter: bool = True,
24
+ ) -> str | None:
25
+ """Interactive model selection with recommended models prioritized.
26
+
27
+ Args:
28
+ prompter: Prompter instance for interaction
29
+ models: List of model dicts from fetch_models
30
+ restrict_provider: If set, only show models from this provider and use provider-specific recommendation
31
+ use_search_filter: Enable search filter for large lists
32
+
33
+ Returns:
34
+ Selected model ID (provider:model_id) or None if cancelled
35
+ """
36
+ if not models:
37
+ console.print("[yellow]No models available[/yellow]")
38
+ return None
39
+
40
+ # Filter by provider if restricted
41
+ if restrict_provider:
42
+ models = [m for m in models if m.get("provider") == restrict_provider]
43
+ if not models:
44
+ console.print(
45
+ f"[yellow]No models available for {restrict_provider}[/yellow]"
46
+ )
47
+ return None
48
+
49
+ # Get recommended model for the provider
50
+ recommended_id = None
51
+ if restrict_provider and restrict_provider in ModelManager.RECOMMENDED_MODELS:
52
+ recommended_id = ModelManager.RECOMMENDED_MODELS[restrict_provider]
53
+
54
+ # Build choices
55
+ choices = []
56
+ recommended_index = 0
57
+
58
+ for i, model in enumerate(models):
59
+ model_id_without_provider = model["id"].split(":", 1)[1]
60
+ is_recommended = recommended_id == model_id_without_provider
61
+
62
+ choice_text = model["name"]
63
+ if is_recommended:
64
+ choice_text += " (Recommended)"
65
+ recommended_index = i
66
+ elif model["description"]:
67
+ desc_short = model["description"][:40]
68
+ choice_text += (
69
+ f" ({desc_short}...)"
70
+ if len(model["description"]) > 40
71
+ else f" ({desc_short})"
72
+ )
73
+
74
+ choices.append(Choice(choice_text, value=model["id"]))
75
+
76
+ # Move recommended model to top if it exists
77
+ if recommended_index > 0:
78
+ choices.insert(0, choices.pop(recommended_index))
79
+
80
+ # Prompt user
81
+ selected_model = await prompter.select(
82
+ "Select a model:",
83
+ choices=choices,
84
+ use_search_filter=use_search_filter,
85
+ )
86
+
87
+ if selected_model:
88
+ return selected_model
89
+
90
+ # User cancelled, return recommended or first available
91
+ if recommended_id and restrict_provider:
92
+ return f"{restrict_provider}:{recommended_id}"
93
+ return models[0]["id"] if models else None
94
+
95
+
96
+ def set_model(model_manager: ModelManager, model_id: str) -> bool:
97
+ """Set the current model."""
98
+ return model_manager.set_model(model_id)
@@ -0,0 +1,115 @@
1
+ """Prompter abstraction for sync/async questionary interactions."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any, Callable
5
+
6
+ import questionary
7
+ from questionary import Choice
8
+
9
+
10
+ class Prompter(ABC):
11
+ """Abstract base class for interactive prompting."""
12
+
13
+ @abstractmethod
14
+ async def text(
15
+ self,
16
+ message: str,
17
+ default: str = "",
18
+ validate: Callable[[str], bool | str] | None = None,
19
+ ) -> str | None:
20
+ """Prompt for text input."""
21
+ pass
22
+
23
+ @abstractmethod
24
+ async def select(
25
+ self,
26
+ message: str,
27
+ choices: list[str] | list[Choice] | list[dict],
28
+ default: Any = None,
29
+ use_search_filter: bool = False,
30
+ use_jk_keys: bool = True,
31
+ ) -> Any:
32
+ """Prompt for selection from choices."""
33
+ pass
34
+
35
+ @abstractmethod
36
+ async def confirm(self, message: str, default: bool = False) -> bool | None:
37
+ """Prompt for yes/no confirmation."""
38
+ pass
39
+
40
+ @abstractmethod
41
+ async def path(self, message: str, only_directories: bool = False) -> str | None:
42
+ """Prompt for file/directory path."""
43
+ pass
44
+
45
+
46
+ class AsyncPrompter(Prompter):
47
+ """Async prompter using questionary.ask_async() for onboarding."""
48
+
49
+ async def text(
50
+ self,
51
+ message: str,
52
+ default: str = "",
53
+ validate: Callable[[str], bool | str] | None = None,
54
+ ) -> str | None:
55
+ return await questionary.text(
56
+ message, default=default, validate=validate
57
+ ).ask_async()
58
+
59
+ async def select(
60
+ self,
61
+ message: str,
62
+ choices: list[str] | list[Choice] | list[dict],
63
+ default: Any = None,
64
+ use_search_filter: bool = True,
65
+ use_jk_keys: bool = False,
66
+ ) -> Any:
67
+ return await questionary.select(
68
+ message,
69
+ choices=choices,
70
+ default=default,
71
+ use_search_filter=use_search_filter,
72
+ use_jk_keys=use_jk_keys,
73
+ ).ask_async()
74
+
75
+ async def confirm(self, message: str, default: bool = False) -> bool | None:
76
+ return await questionary.confirm(message, default=default).ask_async()
77
+
78
+ async def path(self, message: str, only_directories: bool = False) -> str | None:
79
+ return await questionary.path(
80
+ message, only_directories=only_directories
81
+ ).ask_async()
82
+
83
+
84
+ class SyncPrompter(Prompter):
85
+ """Sync prompter using questionary.ask() for CLI commands."""
86
+
87
+ async def text(
88
+ self,
89
+ message: str,
90
+ default: str = "",
91
+ validate: Callable[[str], bool | str] | None = None,
92
+ ) -> str | None:
93
+ return questionary.text(message, default=default, validate=validate).ask()
94
+
95
+ async def select(
96
+ self,
97
+ message: str,
98
+ choices: list[str] | list[Choice] | list[dict],
99
+ default: Any = None,
100
+ use_search_filter: bool = True,
101
+ use_jk_keys: bool = False,
102
+ ) -> Any:
103
+ return questionary.select(
104
+ message,
105
+ choices=choices,
106
+ default=default,
107
+ use_search_filter=use_search_filter,
108
+ use_jk_keys=use_jk_keys,
109
+ ).ask()
110
+
111
+ async def confirm(self, message: str, default: bool = False) -> bool | None:
112
+ return questionary.confirm(message, default=default).ask()
113
+
114
+ async def path(self, message: str, only_directories: bool = False) -> str | None:
115
+ return questionary.path(message, only_directories=only_directories).ask()
sqlsaber/cli/auth.py CHANGED
@@ -10,7 +10,6 @@ from rich.console import Console
10
10
  from sqlsaber.config import providers
11
11
  from sqlsaber.config.api_keys import APIKeyManager
12
12
  from sqlsaber.config.auth import AuthConfigManager, AuthMethod
13
- from sqlsaber.config.oauth_flow import AnthropicOAuthFlow
14
13
  from sqlsaber.config.oauth_tokens import OAuthTokenManager
15
14
 
16
15
  # Global instances for CLI commands
@@ -27,60 +26,33 @@ auth_app = cyclopts.App(
27
26
  @auth_app.command
28
27
  def setup():
29
28
  """Configure authentication for SQLsaber (API keys and Anthropic OAuth)."""
30
- console.print("\n[bold]SQLsaber Authentication Setup[/bold]\n")
29
+ import asyncio
31
30
 
32
- provider = questionary.select(
33
- "Select provider to configure:",
34
- choices=providers.all_keys(),
35
- ).ask()
31
+ from sqlsaber.application.auth_setup import setup_auth
32
+ from sqlsaber.application.prompts import AsyncPrompter
36
33
 
37
- if provider is None:
38
- console.print("[yellow]Setup cancelled.[/yellow]")
39
- return
34
+ console.print("\n[bold]SQLsaber Authentication Setup[/bold]\n")
40
35
 
41
- if provider == "anthropic":
42
- # Let user choose API key or OAuth
43
- method_choice = questionary.select(
44
- "Select Anthropic authentication method:",
45
- choices=[
46
- {"name": "API key", "value": AuthMethod.API_KEY},
47
- {"name": "Claude Pro/Max (OAuth)", "value": AuthMethod.CLAUDE_PRO},
48
- ],
49
- ).ask()
50
-
51
- if method_choice == AuthMethod.CLAUDE_PRO:
52
- flow = AnthropicOAuthFlow()
53
- if flow.authenticate():
54
- config_manager.set_auth_method(AuthMethod.CLAUDE_PRO)
55
- console.print(
56
- "\n[bold green]✓ Anthropic OAuth configured successfully![/bold green]"
57
- )
58
- else:
59
- console.print("\n[red]✗ Anthropic OAuth setup failed.[/red]")
60
- console.print(
61
- "You can change this anytime by running [cyan]saber auth setup[/cyan] again."
62
- )
63
- return
64
-
65
- # API key flow (all providers + Anthropic when selected above)
66
- api_key_manager = APIKeyManager()
67
- env_var = api_key_manager._get_env_var_name(provider)
68
- console.print("\nTo configure your API key, you can either:")
69
- console.print(f"• Set the {env_var} environment variable")
70
- console.print("• Let SQLsaber prompt you for the key when needed (stored securely)")
71
-
72
- # Fetch/store key (cascades env -> keyring -> prompt)
73
- api_key = api_key_manager.get_api_key(provider)
74
- if api_key:
75
- config_manager.set_auth_method(AuthMethod.API_KEY)
76
- console.print(
77
- f"\n[bold green]✓ {provider.title()} API key configured successfully![/bold green]"
36
+ async def run_setup():
37
+ prompter = AsyncPrompter()
38
+ api_key_manager = APIKeyManager()
39
+ success, provider = await setup_auth(
40
+ prompter=prompter,
41
+ auth_manager=config_manager,
42
+ api_key_manager=api_key_manager,
43
+ allow_oauth=True,
44
+ default_provider="anthropic",
45
+ run_oauth_in_thread=False,
78
46
  )
79
- else:
80
- console.print("\n[yellow]No API key configured.[/yellow]")
47
+ return success, provider
48
+
49
+ success, _ = asyncio.run(run_setup())
50
+
51
+ if not success:
52
+ console.print("\n[yellow]No authentication configured.[/yellow]")
81
53
 
82
54
  console.print(
83
- "You can change this anytime by running [cyan]saber auth setup[/cyan] again."
55
+ "\nYou can change this anytime by running [cyan]saber auth setup[/cyan] again."
84
56
  )
85
57
 
86
58
 
@@ -109,7 +81,7 @@ def status():
109
81
  # Include OAuth status
110
82
  if OAuthTokenManager().has_oauth_token("anthropic"):
111
83
  console.print("> anthropic (oauth): [green]configured[/green]")
112
- env_var = api_key_manager._get_env_var_name(provider)
84
+ env_var = api_key_manager.get_env_var_name(provider)
113
85
  service = api_key_manager._get_service_name(provider)
114
86
  from_env = bool(os.getenv(env_var))
115
87
  from_keyring = bool(keyring.get_password(service, provider))
sqlsaber/cli/commands.py CHANGED
@@ -11,6 +11,7 @@ from sqlsaber.cli.auth import create_auth_app
11
11
  from sqlsaber.cli.database import create_db_app
12
12
  from sqlsaber.cli.memory import create_memory_app
13
13
  from sqlsaber.cli.models import create_models_app
14
+ from sqlsaber.cli.onboarding import needs_onboarding, run_onboarding
14
15
  from sqlsaber.cli.threads import create_threads_app
15
16
 
16
17
  # Lazy imports - only import what's needed for CLI parsing
@@ -75,7 +76,7 @@ def query(
75
76
  query_text: Annotated[
76
77
  str | None,
77
78
  cyclopts.Parameter(
78
- help="SQL query in natural language (if not provided, reads from stdin or starts interactive mode)",
79
+ help="Question in natural language (if not provided, reads from stdin or starts interactive mode)",
79
80
  ),
80
81
  ] = None,
81
82
  database: Annotated[
@@ -85,6 +86,7 @@ def query(
85
86
  help="Database connection name, file path (CSV/SQLite/DuckDB), or connection string (postgresql://, mysql://, duckdb://) (uses default if not specified)",
86
87
  ),
87
88
  ] = None,
89
+ thinking: bool = False,
88
90
  ):
89
91
  """Run a query against the database or start interactive mode.
90
92
 
@@ -109,16 +111,11 @@ def query(
109
111
  async def run_session():
110
112
  # Import heavy dependencies only when actually running a query
111
113
  # This is only done to speed up startup time
112
- from sqlsaber.agents import build_sqlsaber_agent
114
+ from sqlsaber.agents import SQLSaberAgent
113
115
  from sqlsaber.cli.interactive import InteractiveSession
114
116
  from sqlsaber.cli.streaming import StreamingQueryHandler
115
- from sqlsaber.database.connection import (
116
- CSVConnection,
117
+ from sqlsaber.database import (
117
118
  DatabaseConnection,
118
- DuckDBConnection,
119
- MySQLConnection,
120
- PostgreSQLConnection,
121
- SQLiteConnection,
122
119
  )
123
120
  from sqlsaber.database.resolver import DatabaseResolutionError, resolve_database
124
121
  from sqlsaber.threads import ThreadStorage
@@ -132,6 +129,16 @@ def query(
132
129
  # If stdin was empty, fall back to interactive mode
133
130
  actual_query = None
134
131
 
132
+ # Check if onboarding is needed (only for interactive mode or when no database is configured)
133
+ if needs_onboarding(database):
134
+ # Run onboarding flow
135
+ onboarding_success = await run_onboarding()
136
+ if not onboarding_success:
137
+ # User cancelled or onboarding failed
138
+ raise CLIError(
139
+ "Setup incomplete. Please configure your database and try again."
140
+ )
141
+
135
142
  # Resolve database from CLI input
136
143
  try:
137
144
  resolved = resolve_database(database, config_manager)
@@ -147,45 +154,32 @@ def query(
147
154
  raise CLIError(f"Error creating database connection: {e}")
148
155
 
149
156
  # Create pydantic-ai agent instance with database name for memory context
150
- agent = build_sqlsaber_agent(db_conn, db_name)
157
+ sqlsaber_agent = SQLSaberAgent(db_conn, db_name, thinking_enabled=thinking)
151
158
 
152
159
  try:
153
160
  if actual_query:
154
161
  # Single query mode with streaming
155
162
  streaming_handler = StreamingQueryHandler(console)
156
- # Compute DB type for the greeting line
157
- if isinstance(db_conn, PostgreSQLConnection):
158
- db_type = "PostgreSQL"
159
- elif isinstance(db_conn, MySQLConnection):
160
- db_type = "MySQL"
161
- elif isinstance(db_conn, DuckDBConnection):
162
- db_type = "DuckDB"
163
- elif isinstance(db_conn, SQLiteConnection):
164
- db_type = "SQLite"
165
- elif isinstance(db_conn, CSVConnection):
166
- db_type = "DuckDB"
167
- else:
168
- db_type = "database"
163
+ db_type = sqlsaber_agent.db_type
169
164
  console.print(
170
165
  f"[bold blue]Connected to:[/bold blue] {db_name} ({db_type})\n"
171
166
  )
172
167
  run = await streaming_handler.execute_streaming_query(
173
- actual_query, agent
168
+ actual_query, sqlsaber_agent
174
169
  )
175
170
  # Persist non-interactive run as a thread snapshot so it can be resumed later
176
171
  try:
177
172
  if run is not None:
178
173
  threads = ThreadStorage()
179
- # Extract title and model name
180
- title = actual_query
181
- model_name: str | None = agent.model.model_name
182
174
 
183
175
  thread_id = await threads.save_snapshot(
184
176
  messages_json=run.all_messages_json(),
185
177
  database_name=db_name,
186
178
  )
187
179
  await threads.save_metadata(
188
- thread_id=thread_id, title=title, model_name=model_name
180
+ thread_id=thread_id,
181
+ title=actual_query,
182
+ model_name=sqlsaber_agent.agent.model.model_name,
189
183
  )
190
184
  await threads.end_thread(thread_id)
191
185
  console.print(
@@ -198,7 +192,7 @@ def query(
198
192
  await threads.prune_threads()
199
193
  else:
200
194
  # Interactive mode
201
- session = InteractiveSession(console, agent, db_conn, db_name)
195
+ session = InteractiveSession(console, sqlsaber_agent, db_conn, db_name)
202
196
  await session.run()
203
197
 
204
198
  finally:
@@ -19,6 +19,8 @@ class SlashCommandCompleter(Completer):
19
19
  ("clear", "Clear conversation history"),
20
20
  ("exit", "Exit the interactive session"),
21
21
  ("quit", "Exit the interactive session"),
22
+ ("thinking on", "Enable extended thinking/reasoning"),
23
+ ("thinking off", "Disable extended thinking/reasoning"),
22
24
  ]
23
25
 
24
26
  # Yield completions that match the partial command