sqlsaber 0.24.0__py3-none-any.whl → 0.26.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- sqlsaber/agents/__init__.py +2 -2
- sqlsaber/agents/base.py +5 -2
- sqlsaber/agents/mcp.py +1 -1
- sqlsaber/agents/pydantic_ai_agent.py +208 -133
- sqlsaber/cli/commands.py +17 -26
- sqlsaber/cli/completers.py +2 -0
- sqlsaber/cli/database.py +18 -7
- sqlsaber/cli/display.py +29 -9
- sqlsaber/cli/interactive.py +28 -16
- sqlsaber/cli/streaming.py +15 -17
- sqlsaber/cli/threads.py +10 -6
- sqlsaber/config/database.py +3 -1
- sqlsaber/config/settings.py +25 -2
- sqlsaber/database/__init__.py +55 -1
- sqlsaber/database/base.py +124 -0
- sqlsaber/database/csv.py +133 -0
- sqlsaber/database/duckdb.py +313 -0
- sqlsaber/database/mysql.py +345 -0
- sqlsaber/database/postgresql.py +328 -0
- sqlsaber/database/resolver.py +7 -3
- sqlsaber/database/schema.py +69 -742
- sqlsaber/database/sqlite.py +258 -0
- sqlsaber/mcp/mcp.py +1 -1
- sqlsaber/tools/sql_tools.py +1 -1
- {sqlsaber-0.24.0.dist-info → sqlsaber-0.26.0.dist-info}/METADATA +45 -10
- sqlsaber-0.26.0.dist-info/RECORD +52 -0
- sqlsaber/database/connection.py +0 -511
- sqlsaber-0.24.0.dist-info/RECORD +0 -47
- {sqlsaber-0.24.0.dist-info → sqlsaber-0.26.0.dist-info}/WHEEL +0 -0
- {sqlsaber-0.24.0.dist-info → sqlsaber-0.26.0.dist-info}/entry_points.txt +0 -0
- {sqlsaber-0.24.0.dist-info → sqlsaber-0.26.0.dist-info}/licenses/LICENSE +0 -0
sqlsaber/agents/__init__.py
CHANGED
sqlsaber/agents/base.py
CHANGED
|
@@ -5,9 +5,10 @@ import json
|
|
|
5
5
|
from abc import ABC, abstractmethod
|
|
6
6
|
from typing import Any, AsyncIterator
|
|
7
7
|
|
|
8
|
-
from sqlsaber.database
|
|
8
|
+
from sqlsaber.database import (
|
|
9
9
|
BaseDatabaseConnection,
|
|
10
10
|
CSVConnection,
|
|
11
|
+
DuckDBConnection,
|
|
11
12
|
MySQLConnection,
|
|
12
13
|
PostgreSQLConnection,
|
|
13
14
|
SQLiteConnection,
|
|
@@ -51,7 +52,9 @@ class BaseSQLAgent(ABC):
|
|
|
51
52
|
elif isinstance(self.db, SQLiteConnection):
|
|
52
53
|
return "SQLite"
|
|
53
54
|
elif isinstance(self.db, CSVConnection):
|
|
54
|
-
return "
|
|
55
|
+
return "DuckDB"
|
|
56
|
+
elif isinstance(self.db, DuckDBConnection):
|
|
57
|
+
return "DuckDB"
|
|
55
58
|
else:
|
|
56
59
|
return "database" # Fallback
|
|
57
60
|
|
sqlsaber/agents/mcp.py
CHANGED
|
@@ -6,17 +6,19 @@ function tools, and streaming event types directly.
|
|
|
6
6
|
|
|
7
7
|
import httpx
|
|
8
8
|
from pydantic_ai import Agent, RunContext
|
|
9
|
-
from pydantic_ai.models.anthropic import AnthropicModel
|
|
10
|
-
from pydantic_ai.models.google import GoogleModel
|
|
11
|
-
from pydantic_ai.models.
|
|
9
|
+
from pydantic_ai.models.anthropic import AnthropicModel, AnthropicModelSettings
|
|
10
|
+
from pydantic_ai.models.google import GoogleModel, GoogleModelSettings
|
|
11
|
+
from pydantic_ai.models.groq import GroqModelSettings
|
|
12
|
+
from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings
|
|
12
13
|
from pydantic_ai.providers.anthropic import AnthropicProvider
|
|
13
14
|
from pydantic_ai.providers.google import GoogleProvider
|
|
14
15
|
|
|
15
16
|
from sqlsaber.config import providers
|
|
16
17
|
from sqlsaber.config.settings import Config
|
|
17
|
-
from sqlsaber.database
|
|
18
|
+
from sqlsaber.database import (
|
|
18
19
|
BaseDatabaseConnection,
|
|
19
20
|
CSVConnection,
|
|
21
|
+
DuckDBConnection,
|
|
20
22
|
MySQLConnection,
|
|
21
23
|
PostgreSQLConnection,
|
|
22
24
|
SQLiteConnection,
|
|
@@ -27,47 +29,119 @@ from sqlsaber.tools.registry import tool_registry
|
|
|
27
29
|
from sqlsaber.tools.sql_tools import SQLTool
|
|
28
30
|
|
|
29
31
|
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
model_name_only = (
|
|
53
|
-
cfg.model_name.split(":", 1)[1] if ":" in cfg.model_name else cfg.model_name
|
|
54
|
-
)
|
|
55
|
-
|
|
56
|
-
provider = providers.provider_from_model(cfg.model_name) or ""
|
|
57
|
-
if provider == "google":
|
|
58
|
-
model_obj = GoogleModel(
|
|
59
|
-
model_name_only, provider=GoogleProvider(api_key=cfg.api_key)
|
|
32
|
+
class SQLSaberAgent:
|
|
33
|
+
"""Pydantic-AI Agent wrapper for SQLSaber with enhanced state management."""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
db_connection: BaseDatabaseConnection,
|
|
38
|
+
database_name: str | None = None,
|
|
39
|
+
memory_manager: MemoryManager | None = None,
|
|
40
|
+
thinking_enabled: bool | None = None,
|
|
41
|
+
):
|
|
42
|
+
self.db_connection = db_connection
|
|
43
|
+
self.database_name = database_name
|
|
44
|
+
self.config = Config()
|
|
45
|
+
self.memory_manager = memory_manager or MemoryManager()
|
|
46
|
+
self.instruction_builder = InstructionBuilder(tool_registry)
|
|
47
|
+
self.db_type = self._get_database_type_name()
|
|
48
|
+
|
|
49
|
+
# Thinking configuration (CLI override or config default)
|
|
50
|
+
self.thinking_enabled = (
|
|
51
|
+
thinking_enabled
|
|
52
|
+
if thinking_enabled is not None
|
|
53
|
+
else self.config.thinking_enabled
|
|
60
54
|
)
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
55
|
+
|
|
56
|
+
# Configure SQL tools with the database connection
|
|
57
|
+
self._configure_sql_tools()
|
|
58
|
+
|
|
59
|
+
# Create the pydantic-ai agent
|
|
60
|
+
self.agent = self._build_agent()
|
|
61
|
+
|
|
62
|
+
def _configure_sql_tools(self) -> None:
|
|
63
|
+
"""Ensure SQL tools receive the active database connection."""
|
|
64
|
+
for tool_name in tool_registry.list_tools(category="sql"):
|
|
65
|
+
tool = tool_registry.get_tool(tool_name)
|
|
66
|
+
if isinstance(tool, SQLTool):
|
|
67
|
+
tool.set_connection(self.db_connection)
|
|
68
|
+
|
|
69
|
+
def _build_agent(self) -> Agent:
|
|
70
|
+
"""Create and configure the pydantic-ai Agent."""
|
|
71
|
+
self.config.validate()
|
|
72
|
+
|
|
73
|
+
model_name_only = (
|
|
74
|
+
self.config.model_name.split(":", 1)[1]
|
|
75
|
+
if ":" in self.config.model_name
|
|
76
|
+
else self.config.model_name
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
provider = providers.provider_from_model(self.config.model_name) or ""
|
|
80
|
+
self.is_oauth = provider == "anthropic" and bool(
|
|
81
|
+
getattr(self.config, "oauth_token", None)
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
agent = self._create_agent_for_provider(provider, model_name_only)
|
|
85
|
+
self._setup_system_prompt(agent)
|
|
86
|
+
self._register_tools(agent)
|
|
87
|
+
|
|
88
|
+
return agent
|
|
89
|
+
|
|
90
|
+
def _create_agent_for_provider(self, provider: str, model_name: str) -> Agent:
|
|
91
|
+
"""Create the agent based on the provider type."""
|
|
92
|
+
if provider == "google":
|
|
93
|
+
model_obj = GoogleModel(
|
|
94
|
+
model_name, provider=GoogleProvider(api_key=self.config.api_key)
|
|
95
|
+
)
|
|
96
|
+
if self.thinking_enabled:
|
|
97
|
+
settings = GoogleModelSettings(
|
|
98
|
+
google_thinking_config={"include_thoughts": True}
|
|
99
|
+
)
|
|
100
|
+
return Agent(model_obj, name="sqlsaber", model_settings=settings)
|
|
101
|
+
return Agent(model_obj, name="sqlsaber")
|
|
102
|
+
elif provider == "anthropic" and self.is_oauth:
|
|
103
|
+
return self._create_oauth_anthropic_agent(model_name)
|
|
104
|
+
elif provider == "anthropic":
|
|
105
|
+
if self.thinking_enabled:
|
|
106
|
+
settings = AnthropicModelSettings(
|
|
107
|
+
anthropic_thinking={
|
|
108
|
+
"type": "enabled",
|
|
109
|
+
"budget_tokens": 2048,
|
|
110
|
+
},
|
|
111
|
+
max_tokens=8192,
|
|
112
|
+
)
|
|
113
|
+
return Agent(
|
|
114
|
+
self.config.model_name, name="sqlsaber", model_settings=settings
|
|
115
|
+
)
|
|
116
|
+
return Agent(self.config.model_name, name="sqlsaber")
|
|
117
|
+
elif provider == "openai":
|
|
118
|
+
model_obj = OpenAIResponsesModel(model_name)
|
|
119
|
+
if self.thinking_enabled:
|
|
120
|
+
settings = OpenAIResponsesModelSettings(
|
|
121
|
+
openai_reasoning_effort="medium",
|
|
122
|
+
openai_reasoning_summary="auto",
|
|
123
|
+
)
|
|
124
|
+
return Agent(model_obj, name="sqlsaber", model_settings=settings)
|
|
125
|
+
return Agent(model_obj, name="sqlsaber")
|
|
126
|
+
elif provider == "groq":
|
|
127
|
+
if self.thinking_enabled:
|
|
128
|
+
settings = GroqModelSettings(groq_reasoning_format="parsed")
|
|
129
|
+
return Agent(
|
|
130
|
+
self.config.model_name, name="sqlsaber", model_settings=settings
|
|
131
|
+
)
|
|
132
|
+
return Agent(self.config.model_name, name="sqlsaber")
|
|
133
|
+
else:
|
|
134
|
+
return Agent(self.config.model_name, name="sqlsaber")
|
|
135
|
+
|
|
136
|
+
def _create_oauth_anthropic_agent(self, model_name: str) -> Agent:
|
|
137
|
+
"""Create an Anthropic agent with OAuth configuration."""
|
|
138
|
+
|
|
64
139
|
async def add_oauth_headers(request: httpx.Request) -> None: # type: ignore[override]
|
|
65
|
-
# Remove API-key header if present and add OAuth headers
|
|
66
140
|
if "x-api-key" in request.headers:
|
|
67
141
|
del request.headers["x-api-key"]
|
|
68
142
|
request.headers.update(
|
|
69
143
|
{
|
|
70
|
-
"Authorization": f"Bearer {
|
|
144
|
+
"Authorization": f"Bearer {self.config.oauth_token}",
|
|
71
145
|
"anthropic-version": "2023-06-01",
|
|
72
146
|
"anthropic-beta": "oauth-2025-04-20",
|
|
73
147
|
"User-Agent": "ClaudeCode/1.0 (Anthropic Claude Code CLI)",
|
|
@@ -78,98 +152,99 @@ def build_sqlsaber_agent(
|
|
|
78
152
|
|
|
79
153
|
http_client = httpx.AsyncClient(event_hooks={"request": [add_oauth_headers]})
|
|
80
154
|
provider_obj = AnthropicProvider(api_key="placeholder", http_client=http_client)
|
|
81
|
-
model_obj = AnthropicModel(
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
if not is_oauth:
|
|
97
|
-
|
|
98
|
-
@agent.system_prompt(dynamic=True)
|
|
99
|
-
async def sqlsaber_system_prompt(ctx: RunContext) -> str:
|
|
100
|
-
db_type = _get_database_type_name(db_connection)
|
|
101
|
-
instructions = instruction_builder.build_instructions(db_type=db_type)
|
|
102
|
-
|
|
103
|
-
# Add memory context if available
|
|
104
|
-
if database_name:
|
|
105
|
-
mem = memory_manager.format_memories_for_prompt(database_name)
|
|
106
|
-
else:
|
|
107
|
-
mem = ""
|
|
155
|
+
model_obj = AnthropicModel(model_name, provider=provider_obj)
|
|
156
|
+
if self.thinking_enabled:
|
|
157
|
+
settings = AnthropicModelSettings(
|
|
158
|
+
anthropic_thinking={
|
|
159
|
+
"type": "enabled",
|
|
160
|
+
"budget_tokens": 2048,
|
|
161
|
+
},
|
|
162
|
+
max_tokens=8192,
|
|
163
|
+
)
|
|
164
|
+
return Agent(model_obj, name="sqlsaber", model_settings=settings)
|
|
165
|
+
return Agent(model_obj, name="sqlsaber")
|
|
166
|
+
|
|
167
|
+
def _setup_system_prompt(self, agent: Agent) -> None:
|
|
168
|
+
"""Set up the dynamic system prompt for the agent."""
|
|
169
|
+
if not self.is_oauth:
|
|
108
170
|
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
"""
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
171
|
+
@agent.system_prompt(dynamic=True)
|
|
172
|
+
async def sqlsaber_system_prompt(ctx: RunContext) -> str:
|
|
173
|
+
instructions = self.instruction_builder.build_instructions(
|
|
174
|
+
db_type=self.db_type
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# Add memory context if available
|
|
178
|
+
mem = ""
|
|
179
|
+
if self.database_name:
|
|
180
|
+
mem = self.memory_manager.format_memories_for_prompt(
|
|
181
|
+
self.database_name
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
parts = [p for p in (instructions, mem) if p and p.strip()]
|
|
185
|
+
return "\n\n".join(parts) if parts else ""
|
|
186
|
+
else:
|
|
187
|
+
|
|
188
|
+
@agent.system_prompt(dynamic=True)
|
|
189
|
+
async def sqlsaber_system_prompt(ctx: RunContext) -> str:
|
|
190
|
+
return "You are Claude Code, Anthropic's official CLI for Claude."
|
|
191
|
+
|
|
192
|
+
def _register_tools(self, agent: Agent) -> None:
|
|
193
|
+
"""Register all the SQL tools with the agent."""
|
|
194
|
+
|
|
195
|
+
@agent.tool(name="list_tables")
|
|
196
|
+
async def list_tables(ctx: RunContext) -> str:
|
|
197
|
+
"""
|
|
198
|
+
Get a list of all tables in the database with row counts.
|
|
199
|
+
Use this first to discover available tables.
|
|
200
|
+
"""
|
|
201
|
+
tool = tool_registry.get_tool("list_tables")
|
|
202
|
+
return await tool.execute()
|
|
203
|
+
|
|
204
|
+
@agent.tool(name="introspect_schema")
|
|
205
|
+
async def introspect_schema(
|
|
206
|
+
ctx: RunContext, table_pattern: str | None = None
|
|
207
|
+
) -> str:
|
|
208
|
+
"""
|
|
209
|
+
Introspect database schema to understand table structures.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
table_pattern: Optional pattern to filter tables (e.g., 'public.users', 'user%', '%order%')
|
|
213
|
+
"""
|
|
214
|
+
tool = tool_registry.get_tool("introspect_schema")
|
|
215
|
+
return await tool.execute(table_pattern=table_pattern)
|
|
216
|
+
|
|
217
|
+
@agent.tool(name="execute_sql")
|
|
218
|
+
async def execute_sql(
|
|
219
|
+
ctx: RunContext, query: str, limit: int | None = 100
|
|
220
|
+
) -> str:
|
|
221
|
+
"""
|
|
222
|
+
Execute a SQL query and return the results.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
query: SQL query to execute
|
|
226
|
+
limit: Maximum number of rows to return (default: 100)
|
|
227
|
+
"""
|
|
228
|
+
tool = tool_registry.get_tool("execute_sql")
|
|
229
|
+
return await tool.execute(query=query, limit=limit)
|
|
230
|
+
|
|
231
|
+
def set_thinking(self, enabled: bool) -> None:
|
|
232
|
+
"""Update thinking settings and rebuild the agent."""
|
|
233
|
+
self.thinking_enabled = enabled
|
|
234
|
+
# Rebuild agent with new thinking settings
|
|
235
|
+
self.agent = self._build_agent()
|
|
236
|
+
|
|
237
|
+
def _get_database_type_name(self) -> str:
|
|
238
|
+
"""Get the human-readable database type name."""
|
|
239
|
+
if isinstance(self.db_connection, PostgreSQLConnection):
|
|
240
|
+
return "PostgreSQL"
|
|
241
|
+
elif isinstance(self.db_connection, MySQLConnection):
|
|
242
|
+
return "MySQL"
|
|
243
|
+
elif isinstance(self.db_connection, SQLiteConnection):
|
|
244
|
+
return "SQLite"
|
|
245
|
+
elif isinstance(self.db_connection, DuckDBConnection):
|
|
246
|
+
return "DuckDB"
|
|
247
|
+
elif isinstance(self.db_connection, CSVConnection):
|
|
248
|
+
return "DuckDB"
|
|
249
|
+
else:
|
|
250
|
+
return "database"
|
sqlsaber/cli/commands.py
CHANGED
|
@@ -46,7 +46,7 @@ def meta_handler(
|
|
|
46
46
|
str | None,
|
|
47
47
|
cyclopts.Parameter(
|
|
48
48
|
["--database", "-d"],
|
|
49
|
-
help="Database connection name, file path (CSV/SQLite), or connection string (postgresql://, mysql://) (uses default if not specified)",
|
|
49
|
+
help="Database connection name, file path (CSV/SQLite/DuckDB), or connection string (postgresql://, mysql://, duckdb://) (uses default if not specified)",
|
|
50
50
|
),
|
|
51
51
|
] = None,
|
|
52
52
|
):
|
|
@@ -59,8 +59,10 @@ def meta_handler(
|
|
|
59
59
|
saber -d mydb "show me users" # Run a query with specific database
|
|
60
60
|
saber -d data.csv "show me users" # Run a query with ad-hoc CSV file
|
|
61
61
|
saber -d data.db "show me users" # Run a query with ad-hoc SQLite file
|
|
62
|
+
saber -d data.duckdb "show me users" # Run a query with ad-hoc DuckDB file
|
|
62
63
|
saber -d "postgresql://user:pass@host:5432/db" "show users" # PostgreSQL connection string
|
|
63
64
|
saber -d "mysql://user:pass@host:3306/db" "show users" # MySQL connection string
|
|
65
|
+
saber -d "duckdb:///data.duckdb" "show users" # DuckDB connection string
|
|
64
66
|
echo "show me all users" | saber # Read query from stdin
|
|
65
67
|
cat query.txt | saber # Read query from file via stdin
|
|
66
68
|
"""
|
|
@@ -73,16 +75,17 @@ def query(
|
|
|
73
75
|
query_text: Annotated[
|
|
74
76
|
str | None,
|
|
75
77
|
cyclopts.Parameter(
|
|
76
|
-
help="
|
|
78
|
+
help="Question in natural language (if not provided, reads from stdin or starts interactive mode)",
|
|
77
79
|
),
|
|
78
80
|
] = None,
|
|
79
81
|
database: Annotated[
|
|
80
82
|
str | None,
|
|
81
83
|
cyclopts.Parameter(
|
|
82
84
|
["--database", "-d"],
|
|
83
|
-
help="Database connection name, file path (CSV/SQLite), or connection string (postgresql://, mysql://) (uses default if not specified)",
|
|
85
|
+
help="Database connection name, file path (CSV/SQLite/DuckDB), or connection string (postgresql://, mysql://, duckdb://) (uses default if not specified)",
|
|
84
86
|
),
|
|
85
87
|
] = None,
|
|
88
|
+
thinking: bool = False,
|
|
86
89
|
):
|
|
87
90
|
"""Run a query against the database or start interactive mode.
|
|
88
91
|
|
|
@@ -97,23 +100,21 @@ def query(
|
|
|
97
100
|
saber "show me all users" # Run a single query
|
|
98
101
|
saber -d data.csv "show users" # Run a query with ad-hoc CSV file
|
|
99
102
|
saber -d data.db "show users" # Run a query with ad-hoc SQLite file
|
|
103
|
+
saber -d data.duckdb "show users" # Run a query with ad-hoc DuckDB file
|
|
100
104
|
saber -d "postgresql://user:pass@host:5432/db" "show users" # PostgreSQL connection string
|
|
101
105
|
saber -d "mysql://user:pass@host:3306/db" "show users" # MySQL connection string
|
|
106
|
+
saber -d "duckdb:///data.duckdb" "show users" # DuckDB connection string
|
|
102
107
|
echo "show me all users" | saber # Read query from stdin
|
|
103
108
|
"""
|
|
104
109
|
|
|
105
110
|
async def run_session():
|
|
106
111
|
# Import heavy dependencies only when actually running a query
|
|
107
112
|
# This is only done to speed up startup time
|
|
108
|
-
from sqlsaber.agents import
|
|
113
|
+
from sqlsaber.agents import SQLSaberAgent
|
|
109
114
|
from sqlsaber.cli.interactive import InteractiveSession
|
|
110
115
|
from sqlsaber.cli.streaming import StreamingQueryHandler
|
|
111
|
-
from sqlsaber.database
|
|
112
|
-
CSVConnection,
|
|
116
|
+
from sqlsaber.database import (
|
|
113
117
|
DatabaseConnection,
|
|
114
|
-
MySQLConnection,
|
|
115
|
-
PostgreSQLConnection,
|
|
116
|
-
SQLiteConnection,
|
|
117
118
|
)
|
|
118
119
|
from sqlsaber.database.resolver import DatabaseResolutionError, resolve_database
|
|
119
120
|
from sqlsaber.threads import ThreadStorage
|
|
@@ -142,42 +143,32 @@ def query(
|
|
|
142
143
|
raise CLIError(f"Error creating database connection: {e}")
|
|
143
144
|
|
|
144
145
|
# Create pydantic-ai agent instance with database name for memory context
|
|
145
|
-
|
|
146
|
+
sqlsaber_agent = SQLSaberAgent(db_conn, db_name, thinking_enabled=thinking)
|
|
146
147
|
|
|
147
148
|
try:
|
|
148
149
|
if actual_query:
|
|
149
150
|
# Single query mode with streaming
|
|
150
151
|
streaming_handler = StreamingQueryHandler(console)
|
|
151
|
-
|
|
152
|
-
db_type = (
|
|
153
|
-
"PostgreSQL"
|
|
154
|
-
if isinstance(db_conn, PostgreSQLConnection)
|
|
155
|
-
else "MySQL"
|
|
156
|
-
if isinstance(db_conn, MySQLConnection)
|
|
157
|
-
else "SQLite"
|
|
158
|
-
if isinstance(db_conn, (SQLiteConnection, CSVConnection))
|
|
159
|
-
else "database"
|
|
160
|
-
)
|
|
152
|
+
db_type = sqlsaber_agent.db_type
|
|
161
153
|
console.print(
|
|
162
154
|
f"[bold blue]Connected to:[/bold blue] {db_name} ({db_type})\n"
|
|
163
155
|
)
|
|
164
156
|
run = await streaming_handler.execute_streaming_query(
|
|
165
|
-
actual_query,
|
|
157
|
+
actual_query, sqlsaber_agent
|
|
166
158
|
)
|
|
167
159
|
# Persist non-interactive run as a thread snapshot so it can be resumed later
|
|
168
160
|
try:
|
|
169
161
|
if run is not None:
|
|
170
162
|
threads = ThreadStorage()
|
|
171
|
-
# Extract title and model name
|
|
172
|
-
title = actual_query
|
|
173
|
-
model_name: str | None = agent.model.model_name
|
|
174
163
|
|
|
175
164
|
thread_id = await threads.save_snapshot(
|
|
176
165
|
messages_json=run.all_messages_json(),
|
|
177
166
|
database_name=db_name,
|
|
178
167
|
)
|
|
179
168
|
await threads.save_metadata(
|
|
180
|
-
thread_id=thread_id,
|
|
169
|
+
thread_id=thread_id,
|
|
170
|
+
title=actual_query,
|
|
171
|
+
model_name=sqlsaber_agent.agent.model.model_name,
|
|
181
172
|
)
|
|
182
173
|
await threads.end_thread(thread_id)
|
|
183
174
|
console.print(
|
|
@@ -190,7 +181,7 @@ def query(
|
|
|
190
181
|
await threads.prune_threads()
|
|
191
182
|
else:
|
|
192
183
|
# Interactive mode
|
|
193
|
-
session = InteractiveSession(console,
|
|
184
|
+
session = InteractiveSession(console, sqlsaber_agent, db_conn, db_name)
|
|
194
185
|
await session.run()
|
|
195
186
|
|
|
196
187
|
finally:
|
sqlsaber/cli/completers.py
CHANGED
|
@@ -19,6 +19,8 @@ class SlashCommandCompleter(Completer):
|
|
|
19
19
|
("clear", "Clear conversation history"),
|
|
20
20
|
("exit", "Exit the interactive session"),
|
|
21
21
|
("quit", "Exit the interactive session"),
|
|
22
|
+
("thinking on", "Enable extended thinking/reasoning"),
|
|
23
|
+
("thinking off", "Disable extended thinking/reasoning"),
|
|
22
24
|
]
|
|
23
25
|
|
|
24
26
|
# Yield completions that match the partial command
|
sqlsaber/cli/database.py
CHANGED
|
@@ -31,7 +31,7 @@ def add(
|
|
|
31
31
|
str,
|
|
32
32
|
cyclopts.Parameter(
|
|
33
33
|
["--type", "-t"],
|
|
34
|
-
help="Database type (postgresql, mysql, sqlite)",
|
|
34
|
+
help="Database type (postgresql, mysql, sqlite, duckdb)",
|
|
35
35
|
),
|
|
36
36
|
] = "postgresql",
|
|
37
37
|
host: Annotated[
|
|
@@ -87,17 +87,17 @@ def add(
|
|
|
87
87
|
if not type or type == "postgresql":
|
|
88
88
|
type = questionary.select(
|
|
89
89
|
"Database type:",
|
|
90
|
-
choices=["postgresql", "mysql", "sqlite"],
|
|
90
|
+
choices=["postgresql", "mysql", "sqlite", "duckdb"],
|
|
91
91
|
default="postgresql",
|
|
92
92
|
).ask()
|
|
93
93
|
|
|
94
|
-
if type
|
|
95
|
-
# SQLite only
|
|
94
|
+
if type in {"sqlite", "duckdb"}:
|
|
95
|
+
# SQLite/DuckDB only need database file path
|
|
96
96
|
database = database or questionary.path("Database file path:").ask()
|
|
97
97
|
database = str(Path(database).expanduser().resolve())
|
|
98
98
|
host = "localhost"
|
|
99
99
|
port = 0
|
|
100
|
-
username =
|
|
100
|
+
username = type
|
|
101
101
|
password = ""
|
|
102
102
|
else:
|
|
103
103
|
# PostgreSQL/MySQL need connection details
|
|
@@ -182,6 +182,17 @@ def add(
|
|
|
182
182
|
port = 0
|
|
183
183
|
username = "sqlite"
|
|
184
184
|
password = ""
|
|
185
|
+
elif type == "duckdb":
|
|
186
|
+
if not database:
|
|
187
|
+
console.print(
|
|
188
|
+
"[bold red]Error:[/bold red] Database file path is required for DuckDB"
|
|
189
|
+
)
|
|
190
|
+
sys.exit(1)
|
|
191
|
+
database = str(Path(database).expanduser().resolve())
|
|
192
|
+
host = "localhost"
|
|
193
|
+
port = 0
|
|
194
|
+
username = "duckdb"
|
|
195
|
+
password = ""
|
|
185
196
|
else:
|
|
186
197
|
if not all([host, database, username]):
|
|
187
198
|
console.print(
|
|
@@ -264,7 +275,7 @@ def list():
|
|
|
264
275
|
if db.ssl_ca or db.ssl_cert:
|
|
265
276
|
ssl_status += " (certs)"
|
|
266
277
|
else:
|
|
267
|
-
ssl_status = "disabled" if db.type
|
|
278
|
+
ssl_status = "disabled" if db.type not in {"sqlite", "duckdb"} else "N/A"
|
|
268
279
|
|
|
269
280
|
table.add_row(
|
|
270
281
|
db.name,
|
|
@@ -343,7 +354,7 @@ def test(
|
|
|
343
354
|
|
|
344
355
|
async def test_connection():
|
|
345
356
|
# Lazy import to keep CLI startup fast
|
|
346
|
-
from sqlsaber.database
|
|
357
|
+
from sqlsaber.database import DatabaseConnection
|
|
347
358
|
|
|
348
359
|
if name:
|
|
349
360
|
db_config = config_manager.get_database(name)
|