sqlsaber 0.25.0__py3-none-any.whl → 0.27.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlsaber might be problematic. Click here for more details.

Files changed (38) hide show
  1. sqlsaber/agents/__init__.py +2 -2
  2. sqlsaber/agents/base.py +1 -1
  3. sqlsaber/agents/mcp.py +1 -1
  4. sqlsaber/agents/pydantic_ai_agent.py +207 -135
  5. sqlsaber/application/__init__.py +1 -0
  6. sqlsaber/application/auth_setup.py +164 -0
  7. sqlsaber/application/db_setup.py +223 -0
  8. sqlsaber/application/model_selection.py +98 -0
  9. sqlsaber/application/prompts.py +115 -0
  10. sqlsaber/cli/auth.py +22 -50
  11. sqlsaber/cli/commands.py +22 -28
  12. sqlsaber/cli/completers.py +2 -0
  13. sqlsaber/cli/database.py +25 -86
  14. sqlsaber/cli/display.py +29 -9
  15. sqlsaber/cli/interactive.py +150 -127
  16. sqlsaber/cli/models.py +18 -28
  17. sqlsaber/cli/onboarding.py +325 -0
  18. sqlsaber/cli/streaming.py +15 -17
  19. sqlsaber/cli/threads.py +10 -6
  20. sqlsaber/config/api_keys.py +2 -2
  21. sqlsaber/config/settings.py +25 -2
  22. sqlsaber/database/__init__.py +55 -1
  23. sqlsaber/database/base.py +124 -0
  24. sqlsaber/database/csv.py +133 -0
  25. sqlsaber/database/duckdb.py +313 -0
  26. sqlsaber/database/mysql.py +345 -0
  27. sqlsaber/database/postgresql.py +328 -0
  28. sqlsaber/database/schema.py +66 -963
  29. sqlsaber/database/sqlite.py +258 -0
  30. sqlsaber/mcp/mcp.py +1 -1
  31. sqlsaber/tools/sql_tools.py +1 -1
  32. {sqlsaber-0.25.0.dist-info → sqlsaber-0.27.0.dist-info}/METADATA +43 -9
  33. sqlsaber-0.27.0.dist-info/RECORD +58 -0
  34. sqlsaber/database/connection.py +0 -535
  35. sqlsaber-0.25.0.dist-info/RECORD +0 -47
  36. {sqlsaber-0.25.0.dist-info → sqlsaber-0.27.0.dist-info}/WHEEL +0 -0
  37. {sqlsaber-0.25.0.dist-info → sqlsaber-0.27.0.dist-info}/entry_points.txt +0 -0
  38. {sqlsaber-0.25.0.dist-info → sqlsaber-0.27.0.dist-info}/licenses/LICENSE +0 -0
sqlsaber/cli/models.py CHANGED
@@ -25,11 +25,20 @@ models_app = cyclopts.App(
25
25
  class ModelManager:
26
26
  """Manages AI model configuration and fetching."""
27
27
 
28
- DEFAULT_MODEL = "anthropic:claude-sonnet-4-20250514"
28
+ DEFAULT_MODEL = "anthropic:claude-sonnet-4-5-20250929"
29
29
  MODELS_API_URL = "https://models.dev/api.json"
30
30
  # Providers come from central registry
31
31
  SUPPORTED_PROVIDERS = providers.all_keys()
32
32
 
33
+ RECOMMENDED_MODELS = {
34
+ "anthropic": "claude-sonnet-4-5-20250929",
35
+ "openai": "gpt-5",
36
+ "google": "gemini-2.5-pro",
37
+ "groq": "llama-3-3-70b-versatile",
38
+ "mistral": "mistral-large-latest",
39
+ "cohere": "command-r-plus",
40
+ }
41
+
33
42
  async def fetch_available_models(
34
43
  self, providers: list[str] | None = None
35
44
  ) -> list[dict]:
@@ -180,39 +189,20 @@ def set():
180
189
  """Set the AI model to use."""
181
190
 
182
191
  async def interactive_set():
192
+ from sqlsaber.application.model_selection import choose_model, fetch_models
193
+ from sqlsaber.application.prompts import AsyncPrompter
194
+
183
195
  console.print("[blue]Fetching available models...[/blue]")
184
- models = await model_manager.fetch_available_models()
196
+ models = await fetch_models(model_manager)
185
197
 
186
198
  if not models:
187
199
  console.print("[red]Failed to fetch models. Cannot set model.[/red]")
188
200
  sys.exit(1)
189
201
 
190
- # Create choices for questionary
191
- choices = []
192
- for model in models:
193
- # Format: "[provider] ID - Name (Description)"
194
- prov = model.get("provider", "?")
195
- choice_text = f"[{prov}] {model['id']} - {model['name']}"
196
- if model["description"]:
197
- choice_text += f" ({model['description'][:50]}{'...' if len(model['description']) > 50 else ''})"
198
-
199
- choices.append({"name": choice_text, "value": model["id"]})
200
-
201
- # Get current model to set as default
202
- current_model = model_manager.get_current_model()
203
- default_index = 0
204
- for i, choice in enumerate(choices):
205
- if choice["value"] == current_model:
206
- default_index = i
207
- break
208
-
209
- selected_model = await questionary.select(
210
- "Select a model:",
211
- choices=choices,
212
- use_search_filter=True,
213
- use_jk_keys=False, # Disable j/k keys when using search filter
214
- default=choices[default_index] if choices else None,
215
- ).ask_async()
202
+ prompter = AsyncPrompter()
203
+ selected_model = await choose_model(
204
+ prompter, models, restrict_provider=None, use_search_filter=True
205
+ )
216
206
 
217
207
  if selected_model:
218
208
  if model_manager.set_model(selected_model):
@@ -0,0 +1,325 @@
1
+ """Interactive onboarding flow for first-time SQLSaber users."""
2
+
3
+ import sys
4
+
5
+ from rich.console import Console
6
+ from rich.panel import Panel
7
+
8
+ from sqlsaber.cli.models import ModelManager
9
+ from sqlsaber.config.api_keys import APIKeyManager
10
+ from sqlsaber.config.auth import AuthConfigManager
11
+ from sqlsaber.config.database import DatabaseConfigManager
12
+
13
+ console = Console()
14
+
15
+
16
+ def needs_onboarding(database_arg: str | None = None) -> bool:
17
+ """Check if user needs onboarding.
18
+
19
+ Onboarding is needed if:
20
+ - No database is configured AND no database connection string provided via CLI
21
+ """
22
+ # If user provided a database argument, skip onboarding
23
+ if database_arg:
24
+ return False
25
+
26
+ # Check if databases are configured
27
+ db_manager = DatabaseConfigManager()
28
+ has_db = db_manager.has_databases()
29
+
30
+ return not has_db
31
+
32
+
33
+ def welcome_screen() -> None:
34
+ """Display welcome screen to new users."""
35
+ banner = """
36
+ ███████ ██████ ██ ███████ █████ ██████ ███████ ██████
37
+ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
38
+ ███████ ██ ██ ██ ███████ ███████ ██████ █████ ██████
39
+ ██ ██ ▄▄ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
40
+ ███████ ██████ ███████ ███████ ██ ██ ██████ ███████ ██ ██
41
+ ▀▀
42
+ """
43
+
44
+ console.print(Panel.fit(banner, style="bold blue"))
45
+ console.print()
46
+
47
+ welcome_message = """
48
+ [bold]Welcome to SQLsaber! 🎉[/bold]
49
+
50
+ SQLsaber is an agentic SQL assistant that lets you query your database using natural language.
51
+
52
+ Let's get you set up in just a few steps.
53
+ """
54
+
55
+ console.print(Panel(welcome_message.strip(), border_style="blue", padding=(1, 2)))
56
+ console.print()
57
+
58
+
59
+ async def setup_database_guided() -> str | None:
60
+ """Guide user through database setup.
61
+
62
+ Returns the name of the configured database or None if cancelled.
63
+ """
64
+ from sqlsaber.application.db_setup import (
65
+ build_config,
66
+ collect_db_input,
67
+ save_database,
68
+ test_connection,
69
+ )
70
+ from sqlsaber.application.prompts import AsyncPrompter
71
+
72
+ console.print("━" * 80, style="dim")
73
+ console.print("[bold cyan]Step 1 of 2: Database Connection[/bold cyan]")
74
+ console.print("━" * 80, style="dim")
75
+ console.print()
76
+
77
+ try:
78
+ # Ask for connection name
79
+ prompter = AsyncPrompter()
80
+ name = await prompter.text(
81
+ "What would you like to name this connection?",
82
+ default="mydb",
83
+ validate=lambda x: bool(x.strip()) or "Name cannot be empty",
84
+ )
85
+
86
+ if name is None:
87
+ return None
88
+
89
+ name = name.strip()
90
+
91
+ # Check if name already exists
92
+ db_manager = DatabaseConfigManager()
93
+ if db_manager.get_database(name):
94
+ console.print(
95
+ f"[yellow]Database connection '{name}' already exists.[/yellow]"
96
+ )
97
+ return name
98
+
99
+ # Collect database input (simplified - no SSL in onboarding)
100
+ db_input = await collect_db_input(
101
+ prompter=prompter, name=name, db_type="postgresql", include_ssl=False
102
+ )
103
+
104
+ if db_input is None:
105
+ return None
106
+
107
+ # Build config
108
+ db_config = build_config(db_input)
109
+
110
+ # Test the connection
111
+ console.print(f"[dim]Testing connection to '{name}'...[/dim]")
112
+ connection_success = await test_connection(db_config, db_input.password)
113
+
114
+ if not connection_success:
115
+ retry = await prompter.confirm(
116
+ "Would you like to try again with different settings?", default=True
117
+ )
118
+ if retry:
119
+ return await setup_database_guided()
120
+ else:
121
+ console.print(
122
+ "[yellow]You can add a database later using 'saber db add'[/yellow]"
123
+ )
124
+ return None
125
+
126
+ # Save the configuration
127
+ try:
128
+ save_database(db_manager, db_config, db_input.password)
129
+ console.print(f"[green]✓ Connection to '{name}' successful![/green]")
130
+ console.print()
131
+ return name
132
+ except Exception as e:
133
+ console.print(f"[bold red]Error saving database:[/bold red] {e}")
134
+ return None
135
+
136
+ except KeyboardInterrupt:
137
+ console.print("\n[yellow]Setup cancelled.[/yellow]")
138
+ return None
139
+ except Exception as e:
140
+ console.print(f"[bold red]Unexpected error:[/bold red] {e}")
141
+ return None
142
+
143
+
144
+ async def select_model_for_provider(provider: str) -> str | None:
145
+ """Fetch and let user select a model for the given provider.
146
+
147
+ Returns the selected model ID or None if cancelled/failed.
148
+ """
149
+ from sqlsaber.application.model_selection import choose_model, fetch_models
150
+ from sqlsaber.application.prompts import AsyncPrompter
151
+
152
+ try:
153
+ console.print()
154
+ console.print(f"[dim]Fetching available {provider.title()} models...[/dim]")
155
+
156
+ model_manager = ModelManager()
157
+ models = await fetch_models(model_manager, providers=[provider])
158
+
159
+ if not models:
160
+ console.print(
161
+ f"[yellow]Could not fetch models for {provider}. Using default.[/yellow]"
162
+ )
163
+ # Use provider-specific default or fallback to Anthropic
164
+ default_model_id = ModelManager.RECOMMENDED_MODELS.get(
165
+ provider, ModelManager.DEFAULT_MODEL
166
+ )
167
+ # Format it properly if we have a recommended model for this provider
168
+ if provider in ModelManager.RECOMMENDED_MODELS:
169
+ return f"{provider}:{ModelManager.RECOMMENDED_MODELS[provider]}"
170
+ return default_model_id
171
+
172
+ prompter = AsyncPrompter()
173
+ console.print()
174
+ selected_model = await choose_model(
175
+ prompter, models, restrict_provider=provider, use_search_filter=True
176
+ )
177
+
178
+ return selected_model
179
+
180
+ except KeyboardInterrupt:
181
+ console.print("\n[yellow]Model selection cancelled.[/yellow]")
182
+ return None
183
+ except Exception as e:
184
+ console.print(f"[yellow]Error selecting model: {e}. Using default.[/yellow]")
185
+ # Fallback to provider default
186
+ if provider in ModelManager.RECOMMENDED_MODELS:
187
+ return f"{provider}:{ModelManager.RECOMMENDED_MODELS[provider]}"
188
+ return ModelManager.DEFAULT_MODEL
189
+
190
+
191
+ async def setup_auth_guided() -> tuple[bool, str | None]:
192
+ """Guide user through auth setup.
193
+
194
+ Returns tuple of (success: bool, selected_model: str | None).
195
+ """
196
+ from sqlsaber.application.auth_setup import setup_auth
197
+ from sqlsaber.application.prompts import AsyncPrompter
198
+
199
+ console.print("━" * 80, style="dim")
200
+ console.print("[bold cyan]Step 2 of 2: Authentication[/bold cyan]")
201
+ console.print("━" * 80, style="dim")
202
+ console.print()
203
+
204
+ try:
205
+ # Run auth setup
206
+ prompter = AsyncPrompter()
207
+ auth_manager = AuthConfigManager()
208
+ api_key_manager = APIKeyManager()
209
+
210
+ success, provider = await setup_auth(
211
+ prompter=prompter,
212
+ auth_manager=auth_manager,
213
+ api_key_manager=api_key_manager,
214
+ allow_oauth=True,
215
+ default_provider="anthropic",
216
+ run_oauth_in_thread=True,
217
+ )
218
+
219
+ if not success:
220
+ console.print(
221
+ "[yellow]You can set it up later using 'saber auth setup'[/yellow]"
222
+ )
223
+ console.print()
224
+ return False, None
225
+
226
+ # If auth configured but we don't know the provider (already configured case)
227
+ if provider is None:
228
+ console.print()
229
+ return True, None
230
+
231
+ # Select model for this provider
232
+ selected_model = await select_model_for_provider(provider)
233
+ if selected_model:
234
+ model_manager = ModelManager()
235
+ model_manager.set_model(selected_model)
236
+ console.print(f"[green]✓ Model set to: {selected_model}[/green]")
237
+ console.print()
238
+ return True, selected_model
239
+
240
+ except KeyboardInterrupt:
241
+ console.print("\n[yellow]Setup cancelled.[/yellow]")
242
+ console.print()
243
+ return False, None
244
+ except Exception as e:
245
+ console.print(f"[bold red]Unexpected error:[/bold red] {e}")
246
+ console.print()
247
+ return False, None
248
+
249
+
250
+ def success_screen(
251
+ database_name: str | None, auth_configured: bool, model_name: str | None = None
252
+ ) -> None:
253
+ """Display success screen after onboarding."""
254
+ console.print("━" * 80, style="dim")
255
+ console.print("[bold green]You're all set! 🚀[/bold green]")
256
+ console.print("━" * 80, style="dim")
257
+ console.print()
258
+
259
+ if database_name and auth_configured:
260
+ console.print(
261
+ f"[green]✓ Database '{database_name}' connected and ready to use[/green]"
262
+ )
263
+ console.print("[green]✓ Authentication configured[/green]")
264
+ if model_name:
265
+ console.print(f"[green]✓ Model: {model_name}[/green]")
266
+ elif database_name:
267
+ console.print(
268
+ f"[green]✓ Database '{database_name}' connected and ready to use[/green]"
269
+ )
270
+ console.print(
271
+ "[yellow]⚠ AI authentication not configured - you'll be prompted when needed[/yellow]"
272
+ )
273
+ elif auth_configured:
274
+ console.print("[green]✓ AI authentication configured[/green]")
275
+ if model_name:
276
+ console.print(f"[green]✓ Model: {model_name}[/green]")
277
+ console.print(
278
+ "[yellow]⚠ No database configured - you'll need to provide one via -d flag[/yellow]"
279
+ )
280
+
281
+ console.print()
282
+ console.print("[dim]Starting interactive session...[/dim]")
283
+ console.print()
284
+
285
+
286
+ async def run_onboarding() -> bool:
287
+ """Run the complete onboarding flow.
288
+
289
+ Returns True if onboarding completed successfully (at least database configured),
290
+ False if user cancelled or onboarding failed.
291
+ """
292
+ try:
293
+ # Welcome screen
294
+ welcome_screen()
295
+
296
+ # Database setup
297
+ database_name = await setup_database_guided()
298
+
299
+ # If user cancelled database setup, exit
300
+ if database_name is None:
301
+ console.print("[yellow]Database setup is required to continue.[/yellow]")
302
+ console.print(
303
+ "[dim]You can also provide a connection string using: saber -d <connection-string>[/dim]"
304
+ )
305
+ return False
306
+
307
+ # Auth setup
308
+ auth_configured, model_name = await setup_auth_guided()
309
+
310
+ # Show success screen
311
+ success_screen(database_name, auth_configured, model_name)
312
+
313
+ return True
314
+
315
+ except KeyboardInterrupt:
316
+ console.print("\n[yellow]Onboarding cancelled.[/yellow]")
317
+ console.print(
318
+ "[dim]You can run setup commands manually:[/dim]\n"
319
+ "[dim] - saber db add <name> # Add database connection[/dim]\n"
320
+ "[dim] - saber auth setup # Configure authentication[/dim]"
321
+ )
322
+ sys.exit(0)
323
+ except Exception as e:
324
+ console.print(f"[bold red]Onboarding failed:[/bold red] {e}")
325
+ return False
sqlsaber/cli/streaming.py CHANGED
@@ -8,9 +8,9 @@ rendered via DisplayManager helpers.
8
8
  import asyncio
9
9
  import json
10
10
  from functools import singledispatchmethod
11
- from typing import AsyncIterable
11
+ from typing import TYPE_CHECKING, AsyncIterable
12
12
 
13
- from pydantic_ai import Agent, RunContext
13
+ from pydantic_ai import RunContext
14
14
  from pydantic_ai.messages import (
15
15
  AgentStreamEvent,
16
16
  FunctionToolCallEvent,
@@ -26,6 +26,9 @@ from rich.console import Console
26
26
 
27
27
  from sqlsaber.cli.display import DisplayManager
28
28
 
29
+ if TYPE_CHECKING:
30
+ from sqlsaber.agents.pydantic_ai_agent import SQLSaberAgent
31
+
29
32
 
30
33
  class StreamingQueryHandler:
31
34
  """
@@ -130,7 +133,7 @@ class StreamingQueryHandler:
130
133
  async def execute_streaming_query(
131
134
  self,
132
135
  user_query: str,
133
- agent: Agent,
136
+ sqlsaber_agent: "SQLSaberAgent",
134
137
  cancellation_token: asyncio.Event | None = None,
135
138
  message_history: list | None = None,
136
139
  ):
@@ -139,21 +142,16 @@ class StreamingQueryHandler:
139
142
  try:
140
143
  # If Anthropic OAuth, inject SQLsaber instructions before the first user prompt
141
144
  prepared_prompt: str | list[str] = user_query
142
- is_oauth = bool(getattr(agent, "_sqlsaber_is_oauth", False))
143
145
  no_history = not message_history
144
- if is_oauth and no_history:
145
- ib = getattr(agent, "_sqlsaber_instruction_builder", None)
146
- mm = getattr(agent, "_sqlsaber_memory_manager", None)
147
- db_type = getattr(agent, "_sqlsaber_db_type", "database")
148
- db_name = getattr(agent, "_sqlsaber_database_name", None)
149
- instructions = (
150
- ib.build_instructions(db_type=db_type) if ib is not None else ""
151
- )
152
- mem = (
153
- mm.format_memories_for_prompt(db_name)
154
- if (mm is not None and db_name)
155
- else ""
146
+ if sqlsaber_agent.is_oauth and no_history:
147
+ instructions = sqlsaber_agent.instruction_builder.build_instructions(
148
+ db_type=sqlsaber_agent.db_type
156
149
  )
150
+ mem = ""
151
+ if sqlsaber_agent.database_name:
152
+ mem = sqlsaber_agent.memory_manager.format_memories_for_prompt(
153
+ sqlsaber_agent.database_name
154
+ )
157
155
  parts = [p for p in (instructions, mem) if p and str(p).strip()]
158
156
  if parts:
159
157
  injected = "\n\n".join(parts)
@@ -163,7 +161,7 @@ class StreamingQueryHandler:
163
161
  self.display.live.start_status("Crunching data...")
164
162
 
165
163
  # Run the agent with our event stream handler
166
- run = await agent.run(
164
+ run = await sqlsaber_agent.agent.run(
167
165
  prepared_prompt,
168
166
  message_history=message_history,
169
167
  event_stream_handler=self._event_stream_handler,
sqlsaber/cli/threads.py CHANGED
@@ -148,7 +148,9 @@ def _render_transcript(
148
148
  )
149
149
  else:
150
150
  if is_redirected:
151
- console.print(f"**Tool result ({name}):**\n\n{content_str}\n")
151
+ console.print(
152
+ f"**Tool result ({name}):**\n\n{content_str}\n"
153
+ )
152
154
  else:
153
155
  console.print(
154
156
  Panel.fit(
@@ -159,7 +161,9 @@ def _render_transcript(
159
161
  )
160
162
  except Exception:
161
163
  if is_redirected:
162
- console.print(f"**Tool result ({name}):**\n\n{content_str}\n")
164
+ console.print(
165
+ f"**Tool result ({name}):**\n\n{content_str}\n"
166
+ )
163
167
  else:
164
168
  console.print(
165
169
  Panel.fit(
@@ -258,10 +262,10 @@ def resume(
258
262
 
259
263
  async def _run() -> None:
260
264
  # Lazy imports to avoid heavy modules at CLI startup
261
- from sqlsaber.agents import build_sqlsaber_agent
265
+ from sqlsaber.agents import SQLSaberAgent
262
266
  from sqlsaber.cli.interactive import InteractiveSession
263
267
  from sqlsaber.config.database import DatabaseConfigManager
264
- from sqlsaber.database.connection import DatabaseConnection
268
+ from sqlsaber.database import DatabaseConnection
265
269
  from sqlsaber.database.resolver import (
266
270
  DatabaseResolutionError,
267
271
  resolve_database,
@@ -288,7 +292,7 @@ def resume(
288
292
 
289
293
  db_conn = DatabaseConnection(connection_string)
290
294
  try:
291
- agent = build_sqlsaber_agent(db_conn, db_name)
295
+ sqlsaber_agent = SQLSaberAgent(db_conn, db_name)
292
296
  history = await store.get_thread_messages(thread_id)
293
297
  if console.is_terminal:
294
298
  console.print(Panel.fit(f"Thread: {thread.id}", border_style="blue"))
@@ -297,7 +301,7 @@ def resume(
297
301
  _render_transcript(console, history, None)
298
302
  session = InteractiveSession(
299
303
  console=console,
300
- agent=agent,
304
+ sqlsaber_agent=sqlsaber_agent,
301
305
  db_conn=db_conn,
302
306
  database_name=db_name,
303
307
  initial_thread_id=thread_id,
@@ -19,7 +19,7 @@ class APIKeyManager:
19
19
 
20
20
  def get_api_key(self, provider: str) -> str | None:
21
21
  """Get API key for the specified provider using cascading logic."""
22
- env_var_name = self._get_env_var_name(provider)
22
+ env_var_name = self.get_env_var_name(provider)
23
23
  service_name = self._get_service_name(provider)
24
24
 
25
25
  # 1. Check environment variable first
@@ -41,7 +41,7 @@ class APIKeyManager:
41
41
  # 3. Prompt user for API key
42
42
  return self._prompt_and_store_key(provider, env_var_name, service_name)
43
43
 
44
- def _get_env_var_name(self, provider: str) -> str:
44
+ def get_env_var_name(self, provider: str) -> str:
45
45
  """Get the expected environment variable name for a provider."""
46
46
  # Normalize aliases to canonical provider keys
47
47
  key = providers.canonical(provider) or provider
@@ -46,7 +46,10 @@ class ModelConfigManager:
46
46
  def _load_config(self) -> dict[str, Any]:
47
47
  """Load configuration from file."""
48
48
  if not self.config_file.exists():
49
- return {"model": self.DEFAULT_MODEL}
49
+ return {
50
+ "model": self.DEFAULT_MODEL,
51
+ "thinking_enabled": False,
52
+ }
50
53
 
51
54
  try:
52
55
  with open(self.config_file, "r") as f:
@@ -54,9 +57,15 @@ class ModelConfigManager:
54
57
  # Ensure we have a model set
55
58
  if "model" not in config:
56
59
  config["model"] = self.DEFAULT_MODEL
60
+ # Set defaults for thinking if not present
61
+ if "thinking_enabled" not in config:
62
+ config["thinking_enabled"] = False
57
63
  return config
58
64
  except (json.JSONDecodeError, IOError):
59
- return {"model": self.DEFAULT_MODEL}
65
+ return {
66
+ "model": self.DEFAULT_MODEL,
67
+ "thinking_enabled": False,
68
+ }
60
69
 
61
70
  def _save_config(self, config: dict[str, Any]) -> None:
62
71
  """Save configuration to file."""
@@ -76,6 +85,17 @@ class ModelConfigManager:
76
85
  config["model"] = model
77
86
  self._save_config(config)
78
87
 
88
+ def get_thinking_enabled(self) -> bool:
89
+ """Get whether thinking is enabled."""
90
+ config = self._load_config()
91
+ return config.get("thinking_enabled", False)
92
+
93
+ def set_thinking_enabled(self, enabled: bool) -> None:
94
+ """Set whether thinking is enabled."""
95
+ config = self._load_config()
96
+ config["thinking_enabled"] = enabled
97
+ self._save_config(config)
98
+
79
99
 
80
100
  class Config:
81
101
  """Configuration class for SQLSaber."""
@@ -86,6 +106,9 @@ class Config:
86
106
  self.api_key_manager = APIKeyManager()
87
107
  self.auth_config_manager = AuthConfigManager()
88
108
 
109
+ # Thinking configuration
110
+ self.thinking_enabled = self.model_config_manager.get_thinking_enabled()
111
+
89
112
  # Authentication method (API key or Anthropic OAuth)
90
113
  self.auth_method = self.auth_config_manager.get_auth_method()
91
114
 
@@ -1,9 +1,63 @@
1
1
  """Database module for SQLSaber."""
2
2
 
3
- from .connection import DatabaseConnection
3
+ from .base import (
4
+ DEFAULT_QUERY_TIMEOUT,
5
+ BaseDatabaseConnection,
6
+ BaseSchemaIntrospector,
7
+ ColumnInfo,
8
+ ForeignKeyInfo,
9
+ IndexInfo,
10
+ QueryTimeoutError,
11
+ SchemaInfo,
12
+ )
13
+ from .csv import CSVConnection, CSVSchemaIntrospector
14
+ from .duckdb import DuckDBConnection, DuckDBSchemaIntrospector
15
+ from .mysql import MySQLConnection, MySQLSchemaIntrospector
16
+ from .postgresql import PostgreSQLConnection, PostgreSQLSchemaIntrospector
4
17
  from .schema import SchemaManager
18
+ from .sqlite import SQLiteConnection, SQLiteSchemaIntrospector
19
+
20
+
21
+ def DatabaseConnection(connection_string: str) -> BaseDatabaseConnection:
22
+ """Factory function to create appropriate database connection based on connection string."""
23
+ if connection_string.startswith("postgresql://"):
24
+ return PostgreSQLConnection(connection_string)
25
+ elif connection_string.startswith("mysql://"):
26
+ return MySQLConnection(connection_string)
27
+ elif connection_string.startswith("sqlite:///"):
28
+ return SQLiteConnection(connection_string)
29
+ elif connection_string.startswith("duckdb://"):
30
+ return DuckDBConnection(connection_string)
31
+ elif connection_string.startswith("csv:///"):
32
+ return CSVConnection(connection_string)
33
+ else:
34
+ raise ValueError(
35
+ f"Unsupported database type in connection string: {connection_string}"
36
+ )
37
+
5
38
 
6
39
  __all__ = [
40
+ # Base classes and types
41
+ "BaseDatabaseConnection",
42
+ "BaseSchemaIntrospector",
43
+ "ColumnInfo",
44
+ "DEFAULT_QUERY_TIMEOUT",
45
+ "ForeignKeyInfo",
46
+ "IndexInfo",
47
+ "QueryTimeoutError",
48
+ "SchemaInfo",
49
+ # Concrete implementations
50
+ "PostgreSQLConnection",
51
+ "MySQLConnection",
52
+ "SQLiteConnection",
53
+ "DuckDBConnection",
54
+ "CSVConnection",
55
+ "PostgreSQLSchemaIntrospector",
56
+ "MySQLSchemaIntrospector",
57
+ "SQLiteSchemaIntrospector",
58
+ "DuckDBSchemaIntrospector",
59
+ "CSVSchemaIntrospector",
60
+ # Factory function and manager
7
61
  "DatabaseConnection",
8
62
  "SchemaManager",
9
63
  ]