sqlsaber 0.24.0__py3-none-any.whl → 0.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlsaber might be problematic. Click here for more details.

@@ -1,7 +1,7 @@
1
1
  """Agents module for SQLSaber."""
2
2
 
3
- from .pydantic_ai_agent import build_sqlsaber_agent
3
+ from .pydantic_ai_agent import SQLSaberAgent
4
4
 
5
5
  __all__ = [
6
- "build_sqlsaber_agent",
6
+ "SQLSaberAgent",
7
7
  ]
sqlsaber/agents/base.py CHANGED
@@ -5,9 +5,10 @@ import json
5
5
  from abc import ABC, abstractmethod
6
6
  from typing import Any, AsyncIterator
7
7
 
8
- from sqlsaber.database.connection import (
8
+ from sqlsaber.database import (
9
9
  BaseDatabaseConnection,
10
10
  CSVConnection,
11
+ DuckDBConnection,
11
12
  MySQLConnection,
12
13
  PostgreSQLConnection,
13
14
  SQLiteConnection,
@@ -51,7 +52,9 @@ class BaseSQLAgent(ABC):
51
52
  elif isinstance(self.db, SQLiteConnection):
52
53
  return "SQLite"
53
54
  elif isinstance(self.db, CSVConnection):
54
- return "SQLite" # we convert csv to in-memory sqlite
55
+ return "DuckDB"
56
+ elif isinstance(self.db, DuckDBConnection):
57
+ return "DuckDB"
55
58
  else:
56
59
  return "database" # Fallback
57
60
 
sqlsaber/agents/mcp.py CHANGED
@@ -3,7 +3,7 @@
3
3
  from typing import AsyncIterator
4
4
 
5
5
  from sqlsaber.agents.base import BaseSQLAgent
6
- from sqlsaber.database.connection import BaseDatabaseConnection
6
+ from sqlsaber.database import BaseDatabaseConnection
7
7
 
8
8
 
9
9
  class MCPSQLAgent(BaseSQLAgent):
@@ -6,17 +6,19 @@ function tools, and streaming event types directly.
6
6
 
7
7
  import httpx
8
8
  from pydantic_ai import Agent, RunContext
9
- from pydantic_ai.models.anthropic import AnthropicModel
10
- from pydantic_ai.models.google import GoogleModel
11
- from pydantic_ai.models.openai import OpenAIResponsesModel
9
+ from pydantic_ai.models.anthropic import AnthropicModel, AnthropicModelSettings
10
+ from pydantic_ai.models.google import GoogleModel, GoogleModelSettings
11
+ from pydantic_ai.models.groq import GroqModelSettings
12
+ from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings
12
13
  from pydantic_ai.providers.anthropic import AnthropicProvider
13
14
  from pydantic_ai.providers.google import GoogleProvider
14
15
 
15
16
  from sqlsaber.config import providers
16
17
  from sqlsaber.config.settings import Config
17
- from sqlsaber.database.connection import (
18
+ from sqlsaber.database import (
18
19
  BaseDatabaseConnection,
19
20
  CSVConnection,
21
+ DuckDBConnection,
20
22
  MySQLConnection,
21
23
  PostgreSQLConnection,
22
24
  SQLiteConnection,
@@ -27,47 +29,119 @@ from sqlsaber.tools.registry import tool_registry
27
29
  from sqlsaber.tools.sql_tools import SQLTool
28
30
 
29
31
 
30
- def build_sqlsaber_agent(
31
- db_connection: BaseDatabaseConnection,
32
- database_name: str | None,
33
- ) -> Agent:
34
- """Create and configure a pydantic-ai Agent for SQLSaber.
35
-
36
- - Registers function tools that delegate to the existing tool registry
37
- - Attaches dynamic system prompt built from InstructionBuilder + MemoryManager
38
- - Ensures SQL tools have the active DB connection
39
- """
40
- # Ensure SQL tools receive the active connection
41
- for tool_name in tool_registry.list_tools(category="sql"):
42
- tool = tool_registry.get_tool(tool_name)
43
- if isinstance(tool, SQLTool):
44
- tool.set_connection(db_connection)
45
-
46
- cfg = Config()
47
- # Ensure provider env var is hydrated from keyring for current provider (Config.validate handles it)
48
- cfg.validate()
49
-
50
- # Build model/agent. For some providers (e.g., google), construct provider model explicitly to
51
- # allow arbitrary model IDs even if not in pydantic-ai's KnownModelName.
52
- model_name_only = (
53
- cfg.model_name.split(":", 1)[1] if ":" in cfg.model_name else cfg.model_name
54
- )
55
-
56
- provider = providers.provider_from_model(cfg.model_name) or ""
57
- if provider == "google":
58
- model_obj = GoogleModel(
59
- model_name_only, provider=GoogleProvider(api_key=cfg.api_key)
32
+ class SQLSaberAgent:
33
+ """Pydantic-AI Agent wrapper for SQLSaber with enhanced state management."""
34
+
35
+ def __init__(
36
+ self,
37
+ db_connection: BaseDatabaseConnection,
38
+ database_name: str | None = None,
39
+ memory_manager: MemoryManager | None = None,
40
+ thinking_enabled: bool | None = None,
41
+ ):
42
+ self.db_connection = db_connection
43
+ self.database_name = database_name
44
+ self.config = Config()
45
+ self.memory_manager = memory_manager or MemoryManager()
46
+ self.instruction_builder = InstructionBuilder(tool_registry)
47
+ self.db_type = self._get_database_type_name()
48
+
49
+ # Thinking configuration (CLI override or config default)
50
+ self.thinking_enabled = (
51
+ thinking_enabled
52
+ if thinking_enabled is not None
53
+ else self.config.thinking_enabled
60
54
  )
61
- agent = Agent(model_obj, name="sqlsaber")
62
- elif provider == "anthropic" and bool(getattr(cfg, "oauth_token", None)):
63
- # Build custom httpx client to inject OAuth headers for Anthropic
55
+
56
+ # Configure SQL tools with the database connection
57
+ self._configure_sql_tools()
58
+
59
+ # Create the pydantic-ai agent
60
+ self.agent = self._build_agent()
61
+
62
+ def _configure_sql_tools(self) -> None:
63
+ """Ensure SQL tools receive the active database connection."""
64
+ for tool_name in tool_registry.list_tools(category="sql"):
65
+ tool = tool_registry.get_tool(tool_name)
66
+ if isinstance(tool, SQLTool):
67
+ tool.set_connection(self.db_connection)
68
+
69
+ def _build_agent(self) -> Agent:
70
+ """Create and configure the pydantic-ai Agent."""
71
+ self.config.validate()
72
+
73
+ model_name_only = (
74
+ self.config.model_name.split(":", 1)[1]
75
+ if ":" in self.config.model_name
76
+ else self.config.model_name
77
+ )
78
+
79
+ provider = providers.provider_from_model(self.config.model_name) or ""
80
+ self.is_oauth = provider == "anthropic" and bool(
81
+ getattr(self.config, "oauth_token", None)
82
+ )
83
+
84
+ agent = self._create_agent_for_provider(provider, model_name_only)
85
+ self._setup_system_prompt(agent)
86
+ self._register_tools(agent)
87
+
88
+ return agent
89
+
90
+ def _create_agent_for_provider(self, provider: str, model_name: str) -> Agent:
91
+ """Create the agent based on the provider type."""
92
+ if provider == "google":
93
+ model_obj = GoogleModel(
94
+ model_name, provider=GoogleProvider(api_key=self.config.api_key)
95
+ )
96
+ if self.thinking_enabled:
97
+ settings = GoogleModelSettings(
98
+ google_thinking_config={"include_thoughts": True}
99
+ )
100
+ return Agent(model_obj, name="sqlsaber", model_settings=settings)
101
+ return Agent(model_obj, name="sqlsaber")
102
+ elif provider == "anthropic" and self.is_oauth:
103
+ return self._create_oauth_anthropic_agent(model_name)
104
+ elif provider == "anthropic":
105
+ if self.thinking_enabled:
106
+ settings = AnthropicModelSettings(
107
+ anthropic_thinking={
108
+ "type": "enabled",
109
+ "budget_tokens": 2048,
110
+ },
111
+ max_tokens=8192,
112
+ )
113
+ return Agent(
114
+ self.config.model_name, name="sqlsaber", model_settings=settings
115
+ )
116
+ return Agent(self.config.model_name, name="sqlsaber")
117
+ elif provider == "openai":
118
+ model_obj = OpenAIResponsesModel(model_name)
119
+ if self.thinking_enabled:
120
+ settings = OpenAIResponsesModelSettings(
121
+ openai_reasoning_effort="medium",
122
+ openai_reasoning_summary="auto",
123
+ )
124
+ return Agent(model_obj, name="sqlsaber", model_settings=settings)
125
+ return Agent(model_obj, name="sqlsaber")
126
+ elif provider == "groq":
127
+ if self.thinking_enabled:
128
+ settings = GroqModelSettings(groq_reasoning_format="parsed")
129
+ return Agent(
130
+ self.config.model_name, name="sqlsaber", model_settings=settings
131
+ )
132
+ return Agent(self.config.model_name, name="sqlsaber")
133
+ else:
134
+ return Agent(self.config.model_name, name="sqlsaber")
135
+
136
+ def _create_oauth_anthropic_agent(self, model_name: str) -> Agent:
137
+ """Create an Anthropic agent with OAuth configuration."""
138
+
64
139
  async def add_oauth_headers(request: httpx.Request) -> None: # type: ignore[override]
65
- # Remove API-key header if present and add OAuth headers
66
140
  if "x-api-key" in request.headers:
67
141
  del request.headers["x-api-key"]
68
142
  request.headers.update(
69
143
  {
70
- "Authorization": f"Bearer {cfg.oauth_token}",
144
+ "Authorization": f"Bearer {self.config.oauth_token}",
71
145
  "anthropic-version": "2023-06-01",
72
146
  "anthropic-beta": "oauth-2025-04-20",
73
147
  "User-Agent": "ClaudeCode/1.0 (Anthropic Claude Code CLI)",
@@ -78,98 +152,99 @@ def build_sqlsaber_agent(
78
152
 
79
153
  http_client = httpx.AsyncClient(event_hooks={"request": [add_oauth_headers]})
80
154
  provider_obj = AnthropicProvider(api_key="placeholder", http_client=http_client)
81
- model_obj = AnthropicModel(model_name_only, provider=provider_obj)
82
- agent = Agent(model_obj, name="sqlsaber")
83
- elif provider == "openai":
84
- # Use OpenAI Responses Model for structured output capabilities
85
- model_obj = OpenAIResponsesModel(model_name_only)
86
- agent = Agent(model_obj, name="sqlsaber")
87
- else:
88
- agent = Agent(cfg.model_name, name="sqlsaber")
89
-
90
- # Memory + dynamic system prompt
91
- memory_manager = MemoryManager()
92
- instruction_builder = InstructionBuilder(tool_registry)
93
-
94
- is_oauth = provider == "anthropic" and bool(getattr(cfg, "oauth_token", None))
95
-
96
- if not is_oauth:
97
-
98
- @agent.system_prompt(dynamic=True)
99
- async def sqlsaber_system_prompt(ctx: RunContext) -> str:
100
- db_type = _get_database_type_name(db_connection)
101
- instructions = instruction_builder.build_instructions(db_type=db_type)
102
-
103
- # Add memory context if available
104
- if database_name:
105
- mem = memory_manager.format_memories_for_prompt(database_name)
106
- else:
107
- mem = ""
155
+ model_obj = AnthropicModel(model_name, provider=provider_obj)
156
+ if self.thinking_enabled:
157
+ settings = AnthropicModelSettings(
158
+ anthropic_thinking={
159
+ "type": "enabled",
160
+ "budget_tokens": 2048,
161
+ },
162
+ max_tokens=8192,
163
+ )
164
+ return Agent(model_obj, name="sqlsaber", model_settings=settings)
165
+ return Agent(model_obj, name="sqlsaber")
166
+
167
+ def _setup_system_prompt(self, agent: Agent) -> None:
168
+ """Set up the dynamic system prompt for the agent."""
169
+ if not self.is_oauth:
108
170
 
109
- parts = [p for p in (instructions, mem) if p and p.strip()]
110
- return "\n\n".join(parts) if parts else ""
111
- else:
112
-
113
- @agent.system_prompt(dynamic=True)
114
- async def sqlsaber_system_prompt(ctx: RunContext) -> str:
115
- # Minimal system prompt in OAuth mode to match Claude Code identity
116
- return "You are Claude Code, Anthropic's official CLI for Claude."
117
-
118
- # Expose helpers and context on agent instance
119
- agent._sqlsaber_memory_manager = memory_manager # type: ignore[attr-defined]
120
- agent._sqlsaber_database_name = database_name # type: ignore[attr-defined]
121
- agent._sqlsaber_instruction_builder = instruction_builder # type: ignore[attr-defined]
122
- agent._sqlsaber_db_type = _get_database_type_name(db_connection) # type: ignore[attr-defined]
123
- agent._sqlsaber_is_oauth = is_oauth # type: ignore[attr-defined]
124
-
125
- # Tool wrappers that invoke the registered tools
126
- @agent.tool(name="list_tables")
127
- async def list_tables(ctx: RunContext) -> str:
128
- """
129
- Get a list of all tables in the database with row counts.
130
- Use this first to discover available tables.
131
- """
132
- tool = tool_registry.get_tool("list_tables")
133
- return await tool.execute()
134
-
135
- @agent.tool(name="introspect_schema")
136
- async def introspect_schema(
137
- ctx: RunContext, table_pattern: str | None = None
138
- ) -> str:
139
- """
140
- Introspect database schema to understand table structures.
141
-
142
- Args:
143
- table_pattern: Optional pattern to filter tables (e.g., 'public.users', 'user%', '%order%')
144
- """
145
- tool = tool_registry.get_tool("introspect_schema")
146
- return await tool.execute(table_pattern=table_pattern)
147
-
148
- @agent.tool(name="execute_sql")
149
- async def execute_sql(ctx: RunContext, query: str, limit: int | None = 100) -> str:
150
- """
151
- Execute a SQL query and return the results.
152
-
153
- Args:
154
- query: SQL query to execute
155
- limit: Maximum number of rows to return (default: 100)
156
- """
157
- tool = tool_registry.get_tool("execute_sql")
158
- return await tool.execute(query=query, limit=limit)
159
-
160
- return agent
161
-
162
-
163
- def _get_database_type_name(db: BaseDatabaseConnection) -> str:
164
- """Get the human-readable database type name (mirrors BaseSQLAgent)."""
165
-
166
- if isinstance(db, PostgreSQLConnection):
167
- return "PostgreSQL"
168
- elif isinstance(db, MySQLConnection):
169
- return "MySQL"
170
- elif isinstance(db, SQLiteConnection):
171
- return "SQLite"
172
- elif isinstance(db, CSVConnection):
173
- return "SQLite"
174
- else:
175
- return "database"
171
+ @agent.system_prompt(dynamic=True)
172
+ async def sqlsaber_system_prompt(ctx: RunContext) -> str:
173
+ instructions = self.instruction_builder.build_instructions(
174
+ db_type=self.db_type
175
+ )
176
+
177
+ # Add memory context if available
178
+ mem = ""
179
+ if self.database_name:
180
+ mem = self.memory_manager.format_memories_for_prompt(
181
+ self.database_name
182
+ )
183
+
184
+ parts = [p for p in (instructions, mem) if p and p.strip()]
185
+ return "\n\n".join(parts) if parts else ""
186
+ else:
187
+
188
+ @agent.system_prompt(dynamic=True)
189
+ async def sqlsaber_system_prompt(ctx: RunContext) -> str:
190
+ return "You are Claude Code, Anthropic's official CLI for Claude."
191
+
192
+ def _register_tools(self, agent: Agent) -> None:
193
+ """Register all the SQL tools with the agent."""
194
+
195
+ @agent.tool(name="list_tables")
196
+ async def list_tables(ctx: RunContext) -> str:
197
+ """
198
+ Get a list of all tables in the database with row counts.
199
+ Use this first to discover available tables.
200
+ """
201
+ tool = tool_registry.get_tool("list_tables")
202
+ return await tool.execute()
203
+
204
+ @agent.tool(name="introspect_schema")
205
+ async def introspect_schema(
206
+ ctx: RunContext, table_pattern: str | None = None
207
+ ) -> str:
208
+ """
209
+ Introspect database schema to understand table structures.
210
+
211
+ Args:
212
+ table_pattern: Optional pattern to filter tables (e.g., 'public.users', 'user%', '%order%')
213
+ """
214
+ tool = tool_registry.get_tool("introspect_schema")
215
+ return await tool.execute(table_pattern=table_pattern)
216
+
217
+ @agent.tool(name="execute_sql")
218
+ async def execute_sql(
219
+ ctx: RunContext, query: str, limit: int | None = 100
220
+ ) -> str:
221
+ """
222
+ Execute a SQL query and return the results.
223
+
224
+ Args:
225
+ query: SQL query to execute
226
+ limit: Maximum number of rows to return (default: 100)
227
+ """
228
+ tool = tool_registry.get_tool("execute_sql")
229
+ return await tool.execute(query=query, limit=limit)
230
+
231
+ def set_thinking(self, enabled: bool) -> None:
232
+ """Update thinking settings and rebuild the agent."""
233
+ self.thinking_enabled = enabled
234
+ # Rebuild agent with new thinking settings
235
+ self.agent = self._build_agent()
236
+
237
+ def _get_database_type_name(self) -> str:
238
+ """Get the human-readable database type name."""
239
+ if isinstance(self.db_connection, PostgreSQLConnection):
240
+ return "PostgreSQL"
241
+ elif isinstance(self.db_connection, MySQLConnection):
242
+ return "MySQL"
243
+ elif isinstance(self.db_connection, SQLiteConnection):
244
+ return "SQLite"
245
+ elif isinstance(self.db_connection, DuckDBConnection):
246
+ return "DuckDB"
247
+ elif isinstance(self.db_connection, CSVConnection):
248
+ return "DuckDB"
249
+ else:
250
+ return "database"
sqlsaber/cli/commands.py CHANGED
@@ -46,7 +46,7 @@ def meta_handler(
46
46
  str | None,
47
47
  cyclopts.Parameter(
48
48
  ["--database", "-d"],
49
- help="Database connection name, file path (CSV/SQLite), or connection string (postgresql://, mysql://) (uses default if not specified)",
49
+ help="Database connection name, file path (CSV/SQLite/DuckDB), or connection string (postgresql://, mysql://, duckdb://) (uses default if not specified)",
50
50
  ),
51
51
  ] = None,
52
52
  ):
@@ -59,8 +59,10 @@ def meta_handler(
59
59
  saber -d mydb "show me users" # Run a query with specific database
60
60
  saber -d data.csv "show me users" # Run a query with ad-hoc CSV file
61
61
  saber -d data.db "show me users" # Run a query with ad-hoc SQLite file
62
+ saber -d data.duckdb "show me users" # Run a query with ad-hoc DuckDB file
62
63
  saber -d "postgresql://user:pass@host:5432/db" "show users" # PostgreSQL connection string
63
64
  saber -d "mysql://user:pass@host:3306/db" "show users" # MySQL connection string
65
+ saber -d "duckdb:///data.duckdb" "show users" # DuckDB connection string
64
66
  echo "show me all users" | saber # Read query from stdin
65
67
  cat query.txt | saber # Read query from file via stdin
66
68
  """
@@ -73,16 +75,17 @@ def query(
73
75
  query_text: Annotated[
74
76
  str | None,
75
77
  cyclopts.Parameter(
76
- help="SQL query in natural language (if not provided, reads from stdin or starts interactive mode)",
78
+ help="Question in natural language (if not provided, reads from stdin or starts interactive mode)",
77
79
  ),
78
80
  ] = None,
79
81
  database: Annotated[
80
82
  str | None,
81
83
  cyclopts.Parameter(
82
84
  ["--database", "-d"],
83
- help="Database connection name, file path (CSV/SQLite), or connection string (postgresql://, mysql://) (uses default if not specified)",
85
+ help="Database connection name, file path (CSV/SQLite/DuckDB), or connection string (postgresql://, mysql://, duckdb://) (uses default if not specified)",
84
86
  ),
85
87
  ] = None,
88
+ thinking: bool = False,
86
89
  ):
87
90
  """Run a query against the database or start interactive mode.
88
91
 
@@ -97,23 +100,21 @@ def query(
97
100
  saber "show me all users" # Run a single query
98
101
  saber -d data.csv "show users" # Run a query with ad-hoc CSV file
99
102
  saber -d data.db "show users" # Run a query with ad-hoc SQLite file
103
+ saber -d data.duckdb "show users" # Run a query with ad-hoc DuckDB file
100
104
  saber -d "postgresql://user:pass@host:5432/db" "show users" # PostgreSQL connection string
101
105
  saber -d "mysql://user:pass@host:3306/db" "show users" # MySQL connection string
106
+ saber -d "duckdb:///data.duckdb" "show users" # DuckDB connection string
102
107
  echo "show me all users" | saber # Read query from stdin
103
108
  """
104
109
 
105
110
  async def run_session():
106
111
  # Import heavy dependencies only when actually running a query
107
112
  # This is only done to speed up startup time
108
- from sqlsaber.agents import build_sqlsaber_agent
113
+ from sqlsaber.agents import SQLSaberAgent
109
114
  from sqlsaber.cli.interactive import InteractiveSession
110
115
  from sqlsaber.cli.streaming import StreamingQueryHandler
111
- from sqlsaber.database.connection import (
112
- CSVConnection,
116
+ from sqlsaber.database import (
113
117
  DatabaseConnection,
114
- MySQLConnection,
115
- PostgreSQLConnection,
116
- SQLiteConnection,
117
118
  )
118
119
  from sqlsaber.database.resolver import DatabaseResolutionError, resolve_database
119
120
  from sqlsaber.threads import ThreadStorage
@@ -142,42 +143,32 @@ def query(
142
143
  raise CLIError(f"Error creating database connection: {e}")
143
144
 
144
145
  # Create pydantic-ai agent instance with database name for memory context
145
- agent = build_sqlsaber_agent(db_conn, db_name)
146
+ sqlsaber_agent = SQLSaberAgent(db_conn, db_name, thinking_enabled=thinking)
146
147
 
147
148
  try:
148
149
  if actual_query:
149
150
  # Single query mode with streaming
150
151
  streaming_handler = StreamingQueryHandler(console)
151
- # Compute DB type for the greeting line
152
- db_type = (
153
- "PostgreSQL"
154
- if isinstance(db_conn, PostgreSQLConnection)
155
- else "MySQL"
156
- if isinstance(db_conn, MySQLConnection)
157
- else "SQLite"
158
- if isinstance(db_conn, (SQLiteConnection, CSVConnection))
159
- else "database"
160
- )
152
+ db_type = sqlsaber_agent.db_type
161
153
  console.print(
162
154
  f"[bold blue]Connected to:[/bold blue] {db_name} ({db_type})\n"
163
155
  )
164
156
  run = await streaming_handler.execute_streaming_query(
165
- actual_query, agent
157
+ actual_query, sqlsaber_agent
166
158
  )
167
159
  # Persist non-interactive run as a thread snapshot so it can be resumed later
168
160
  try:
169
161
  if run is not None:
170
162
  threads = ThreadStorage()
171
- # Extract title and model name
172
- title = actual_query
173
- model_name: str | None = agent.model.model_name
174
163
 
175
164
  thread_id = await threads.save_snapshot(
176
165
  messages_json=run.all_messages_json(),
177
166
  database_name=db_name,
178
167
  )
179
168
  await threads.save_metadata(
180
- thread_id=thread_id, title=title, model_name=model_name
169
+ thread_id=thread_id,
170
+ title=actual_query,
171
+ model_name=sqlsaber_agent.agent.model.model_name,
181
172
  )
182
173
  await threads.end_thread(thread_id)
183
174
  console.print(
@@ -190,7 +181,7 @@ def query(
190
181
  await threads.prune_threads()
191
182
  else:
192
183
  # Interactive mode
193
- session = InteractiveSession(console, agent, db_conn, db_name)
184
+ session = InteractiveSession(console, sqlsaber_agent, db_conn, db_name)
194
185
  await session.run()
195
186
 
196
187
  finally:
@@ -19,6 +19,8 @@ class SlashCommandCompleter(Completer):
19
19
  ("clear", "Clear conversation history"),
20
20
  ("exit", "Exit the interactive session"),
21
21
  ("quit", "Exit the interactive session"),
22
+ ("thinking on", "Enable extended thinking/reasoning"),
23
+ ("thinking off", "Disable extended thinking/reasoning"),
22
24
  ]
23
25
 
24
26
  # Yield completions that match the partial command
sqlsaber/cli/database.py CHANGED
@@ -31,7 +31,7 @@ def add(
31
31
  str,
32
32
  cyclopts.Parameter(
33
33
  ["--type", "-t"],
34
- help="Database type (postgresql, mysql, sqlite)",
34
+ help="Database type (postgresql, mysql, sqlite, duckdb)",
35
35
  ),
36
36
  ] = "postgresql",
37
37
  host: Annotated[
@@ -87,17 +87,17 @@ def add(
87
87
  if not type or type == "postgresql":
88
88
  type = questionary.select(
89
89
  "Database type:",
90
- choices=["postgresql", "mysql", "sqlite"],
90
+ choices=["postgresql", "mysql", "sqlite", "duckdb"],
91
91
  default="postgresql",
92
92
  ).ask()
93
93
 
94
- if type == "sqlite":
95
- # SQLite only needs database path
94
+ if type in {"sqlite", "duckdb"}:
95
+ # SQLite/DuckDB only need database file path
96
96
  database = database or questionary.path("Database file path:").ask()
97
97
  database = str(Path(database).expanduser().resolve())
98
98
  host = "localhost"
99
99
  port = 0
100
- username = "sqlite"
100
+ username = type
101
101
  password = ""
102
102
  else:
103
103
  # PostgreSQL/MySQL need connection details
@@ -182,6 +182,17 @@ def add(
182
182
  port = 0
183
183
  username = "sqlite"
184
184
  password = ""
185
+ elif type == "duckdb":
186
+ if not database:
187
+ console.print(
188
+ "[bold red]Error:[/bold red] Database file path is required for DuckDB"
189
+ )
190
+ sys.exit(1)
191
+ database = str(Path(database).expanduser().resolve())
192
+ host = "localhost"
193
+ port = 0
194
+ username = "duckdb"
195
+ password = ""
185
196
  else:
186
197
  if not all([host, database, username]):
187
198
  console.print(
@@ -264,7 +275,7 @@ def list():
264
275
  if db.ssl_ca or db.ssl_cert:
265
276
  ssl_status += " (certs)"
266
277
  else:
267
- ssl_status = "disabled" if db.type != "sqlite" else "N/A"
278
+ ssl_status = "disabled" if db.type not in {"sqlite", "duckdb"} else "N/A"
268
279
 
269
280
  table.add_row(
270
281
  db.name,
@@ -343,7 +354,7 @@ def test(
343
354
 
344
355
  async def test_connection():
345
356
  # Lazy import to keep CLI startup fast
346
- from sqlsaber.database.connection import DatabaseConnection
357
+ from sqlsaber.database import DatabaseConnection
347
358
 
348
359
  if name:
349
360
  db_config = config_manager.get_database(name)