sqlsaber 0.25.0__py3-none-any.whl → 0.26.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- sqlsaber/agents/__init__.py +2 -2
- sqlsaber/agents/base.py +1 -1
- sqlsaber/agents/mcp.py +1 -1
- sqlsaber/agents/pydantic_ai_agent.py +207 -135
- sqlsaber/cli/commands.py +11 -28
- sqlsaber/cli/completers.py +2 -0
- sqlsaber/cli/database.py +1 -1
- sqlsaber/cli/display.py +29 -9
- sqlsaber/cli/interactive.py +22 -15
- sqlsaber/cli/streaming.py +15 -17
- sqlsaber/cli/threads.py +10 -6
- sqlsaber/config/settings.py +25 -2
- sqlsaber/database/__init__.py +55 -1
- sqlsaber/database/base.py +124 -0
- sqlsaber/database/csv.py +133 -0
- sqlsaber/database/duckdb.py +313 -0
- sqlsaber/database/mysql.py +345 -0
- sqlsaber/database/postgresql.py +328 -0
- sqlsaber/database/schema.py +66 -963
- sqlsaber/database/sqlite.py +258 -0
- sqlsaber/mcp/mcp.py +1 -1
- sqlsaber/tools/sql_tools.py +1 -1
- {sqlsaber-0.25.0.dist-info → sqlsaber-0.26.0.dist-info}/METADATA +43 -9
- sqlsaber-0.26.0.dist-info/RECORD +52 -0
- sqlsaber/database/connection.py +0 -535
- sqlsaber-0.25.0.dist-info/RECORD +0 -47
- {sqlsaber-0.25.0.dist-info → sqlsaber-0.26.0.dist-info}/WHEEL +0 -0
- {sqlsaber-0.25.0.dist-info → sqlsaber-0.26.0.dist-info}/entry_points.txt +0 -0
- {sqlsaber-0.25.0.dist-info → sqlsaber-0.26.0.dist-info}/licenses/LICENSE +0 -0
sqlsaber/agents/__init__.py
CHANGED
sqlsaber/agents/base.py
CHANGED
sqlsaber/agents/mcp.py
CHANGED
|
@@ -6,15 +6,16 @@ function tools, and streaming event types directly.
|
|
|
6
6
|
|
|
7
7
|
import httpx
|
|
8
8
|
from pydantic_ai import Agent, RunContext
|
|
9
|
-
from pydantic_ai.models.anthropic import AnthropicModel
|
|
10
|
-
from pydantic_ai.models.google import GoogleModel
|
|
11
|
-
from pydantic_ai.models.
|
|
9
|
+
from pydantic_ai.models.anthropic import AnthropicModel, AnthropicModelSettings
|
|
10
|
+
from pydantic_ai.models.google import GoogleModel, GoogleModelSettings
|
|
11
|
+
from pydantic_ai.models.groq import GroqModelSettings
|
|
12
|
+
from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings
|
|
12
13
|
from pydantic_ai.providers.anthropic import AnthropicProvider
|
|
13
14
|
from pydantic_ai.providers.google import GoogleProvider
|
|
14
15
|
|
|
15
16
|
from sqlsaber.config import providers
|
|
16
17
|
from sqlsaber.config.settings import Config
|
|
17
|
-
from sqlsaber.database
|
|
18
|
+
from sqlsaber.database import (
|
|
18
19
|
BaseDatabaseConnection,
|
|
19
20
|
CSVConnection,
|
|
20
21
|
DuckDBConnection,
|
|
@@ -28,47 +29,119 @@ from sqlsaber.tools.registry import tool_registry
|
|
|
28
29
|
from sqlsaber.tools.sql_tools import SQLTool
|
|
29
30
|
|
|
30
31
|
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
model_name_only = (
|
|
54
|
-
cfg.model_name.split(":", 1)[1] if ":" in cfg.model_name else cfg.model_name
|
|
55
|
-
)
|
|
56
|
-
|
|
57
|
-
provider = providers.provider_from_model(cfg.model_name) or ""
|
|
58
|
-
if provider == "google":
|
|
59
|
-
model_obj = GoogleModel(
|
|
60
|
-
model_name_only, provider=GoogleProvider(api_key=cfg.api_key)
|
|
32
|
+
class SQLSaberAgent:
|
|
33
|
+
"""Pydantic-AI Agent wrapper for SQLSaber with enhanced state management."""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
db_connection: BaseDatabaseConnection,
|
|
38
|
+
database_name: str | None = None,
|
|
39
|
+
memory_manager: MemoryManager | None = None,
|
|
40
|
+
thinking_enabled: bool | None = None,
|
|
41
|
+
):
|
|
42
|
+
self.db_connection = db_connection
|
|
43
|
+
self.database_name = database_name
|
|
44
|
+
self.config = Config()
|
|
45
|
+
self.memory_manager = memory_manager or MemoryManager()
|
|
46
|
+
self.instruction_builder = InstructionBuilder(tool_registry)
|
|
47
|
+
self.db_type = self._get_database_type_name()
|
|
48
|
+
|
|
49
|
+
# Thinking configuration (CLI override or config default)
|
|
50
|
+
self.thinking_enabled = (
|
|
51
|
+
thinking_enabled
|
|
52
|
+
if thinking_enabled is not None
|
|
53
|
+
else self.config.thinking_enabled
|
|
61
54
|
)
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
55
|
+
|
|
56
|
+
# Configure SQL tools with the database connection
|
|
57
|
+
self._configure_sql_tools()
|
|
58
|
+
|
|
59
|
+
# Create the pydantic-ai agent
|
|
60
|
+
self.agent = self._build_agent()
|
|
61
|
+
|
|
62
|
+
def _configure_sql_tools(self) -> None:
|
|
63
|
+
"""Ensure SQL tools receive the active database connection."""
|
|
64
|
+
for tool_name in tool_registry.list_tools(category="sql"):
|
|
65
|
+
tool = tool_registry.get_tool(tool_name)
|
|
66
|
+
if isinstance(tool, SQLTool):
|
|
67
|
+
tool.set_connection(self.db_connection)
|
|
68
|
+
|
|
69
|
+
def _build_agent(self) -> Agent:
|
|
70
|
+
"""Create and configure the pydantic-ai Agent."""
|
|
71
|
+
self.config.validate()
|
|
72
|
+
|
|
73
|
+
model_name_only = (
|
|
74
|
+
self.config.model_name.split(":", 1)[1]
|
|
75
|
+
if ":" in self.config.model_name
|
|
76
|
+
else self.config.model_name
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
provider = providers.provider_from_model(self.config.model_name) or ""
|
|
80
|
+
self.is_oauth = provider == "anthropic" and bool(
|
|
81
|
+
getattr(self.config, "oauth_token", None)
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
agent = self._create_agent_for_provider(provider, model_name_only)
|
|
85
|
+
self._setup_system_prompt(agent)
|
|
86
|
+
self._register_tools(agent)
|
|
87
|
+
|
|
88
|
+
return agent
|
|
89
|
+
|
|
90
|
+
def _create_agent_for_provider(self, provider: str, model_name: str) -> Agent:
|
|
91
|
+
"""Create the agent based on the provider type."""
|
|
92
|
+
if provider == "google":
|
|
93
|
+
model_obj = GoogleModel(
|
|
94
|
+
model_name, provider=GoogleProvider(api_key=self.config.api_key)
|
|
95
|
+
)
|
|
96
|
+
if self.thinking_enabled:
|
|
97
|
+
settings = GoogleModelSettings(
|
|
98
|
+
google_thinking_config={"include_thoughts": True}
|
|
99
|
+
)
|
|
100
|
+
return Agent(model_obj, name="sqlsaber", model_settings=settings)
|
|
101
|
+
return Agent(model_obj, name="sqlsaber")
|
|
102
|
+
elif provider == "anthropic" and self.is_oauth:
|
|
103
|
+
return self._create_oauth_anthropic_agent(model_name)
|
|
104
|
+
elif provider == "anthropic":
|
|
105
|
+
if self.thinking_enabled:
|
|
106
|
+
settings = AnthropicModelSettings(
|
|
107
|
+
anthropic_thinking={
|
|
108
|
+
"type": "enabled",
|
|
109
|
+
"budget_tokens": 2048,
|
|
110
|
+
},
|
|
111
|
+
max_tokens=8192,
|
|
112
|
+
)
|
|
113
|
+
return Agent(
|
|
114
|
+
self.config.model_name, name="sqlsaber", model_settings=settings
|
|
115
|
+
)
|
|
116
|
+
return Agent(self.config.model_name, name="sqlsaber")
|
|
117
|
+
elif provider == "openai":
|
|
118
|
+
model_obj = OpenAIResponsesModel(model_name)
|
|
119
|
+
if self.thinking_enabled:
|
|
120
|
+
settings = OpenAIResponsesModelSettings(
|
|
121
|
+
openai_reasoning_effort="medium",
|
|
122
|
+
openai_reasoning_summary="auto",
|
|
123
|
+
)
|
|
124
|
+
return Agent(model_obj, name="sqlsaber", model_settings=settings)
|
|
125
|
+
return Agent(model_obj, name="sqlsaber")
|
|
126
|
+
elif provider == "groq":
|
|
127
|
+
if self.thinking_enabled:
|
|
128
|
+
settings = GroqModelSettings(groq_reasoning_format="parsed")
|
|
129
|
+
return Agent(
|
|
130
|
+
self.config.model_name, name="sqlsaber", model_settings=settings
|
|
131
|
+
)
|
|
132
|
+
return Agent(self.config.model_name, name="sqlsaber")
|
|
133
|
+
else:
|
|
134
|
+
return Agent(self.config.model_name, name="sqlsaber")
|
|
135
|
+
|
|
136
|
+
def _create_oauth_anthropic_agent(self, model_name: str) -> Agent:
|
|
137
|
+
"""Create an Anthropic agent with OAuth configuration."""
|
|
138
|
+
|
|
65
139
|
async def add_oauth_headers(request: httpx.Request) -> None: # type: ignore[override]
|
|
66
|
-
# Remove API-key header if present and add OAuth headers
|
|
67
140
|
if "x-api-key" in request.headers:
|
|
68
141
|
del request.headers["x-api-key"]
|
|
69
142
|
request.headers.update(
|
|
70
143
|
{
|
|
71
|
-
"Authorization": f"Bearer {
|
|
144
|
+
"Authorization": f"Bearer {self.config.oauth_token}",
|
|
72
145
|
"anthropic-version": "2023-06-01",
|
|
73
146
|
"anthropic-beta": "oauth-2025-04-20",
|
|
74
147
|
"User-Agent": "ClaudeCode/1.0 (Anthropic Claude Code CLI)",
|
|
@@ -79,100 +152,99 @@ def build_sqlsaber_agent(
|
|
|
79
152
|
|
|
80
153
|
http_client = httpx.AsyncClient(event_hooks={"request": [add_oauth_headers]})
|
|
81
154
|
provider_obj = AnthropicProvider(api_key="placeholder", http_client=http_client)
|
|
82
|
-
model_obj = AnthropicModel(
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
if not is_oauth:
|
|
98
|
-
|
|
99
|
-
@agent.system_prompt(dynamic=True)
|
|
100
|
-
async def sqlsaber_system_prompt(ctx: RunContext) -> str:
|
|
101
|
-
db_type = _get_database_type_name(db_connection)
|
|
102
|
-
instructions = instruction_builder.build_instructions(db_type=db_type)
|
|
103
|
-
|
|
104
|
-
# Add memory context if available
|
|
105
|
-
if database_name:
|
|
106
|
-
mem = memory_manager.format_memories_for_prompt(database_name)
|
|
107
|
-
else:
|
|
108
|
-
mem = ""
|
|
155
|
+
model_obj = AnthropicModel(model_name, provider=provider_obj)
|
|
156
|
+
if self.thinking_enabled:
|
|
157
|
+
settings = AnthropicModelSettings(
|
|
158
|
+
anthropic_thinking={
|
|
159
|
+
"type": "enabled",
|
|
160
|
+
"budget_tokens": 2048,
|
|
161
|
+
},
|
|
162
|
+
max_tokens=8192,
|
|
163
|
+
)
|
|
164
|
+
return Agent(model_obj, name="sqlsaber", model_settings=settings)
|
|
165
|
+
return Agent(model_obj, name="sqlsaber")
|
|
166
|
+
|
|
167
|
+
def _setup_system_prompt(self, agent: Agent) -> None:
|
|
168
|
+
"""Set up the dynamic system prompt for the agent."""
|
|
169
|
+
if not self.is_oauth:
|
|
109
170
|
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
"""
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
171
|
+
@agent.system_prompt(dynamic=True)
|
|
172
|
+
async def sqlsaber_system_prompt(ctx: RunContext) -> str:
|
|
173
|
+
instructions = self.instruction_builder.build_instructions(
|
|
174
|
+
db_type=self.db_type
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# Add memory context if available
|
|
178
|
+
mem = ""
|
|
179
|
+
if self.database_name:
|
|
180
|
+
mem = self.memory_manager.format_memories_for_prompt(
|
|
181
|
+
self.database_name
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
parts = [p for p in (instructions, mem) if p and p.strip()]
|
|
185
|
+
return "\n\n".join(parts) if parts else ""
|
|
186
|
+
else:
|
|
187
|
+
|
|
188
|
+
@agent.system_prompt(dynamic=True)
|
|
189
|
+
async def sqlsaber_system_prompt(ctx: RunContext) -> str:
|
|
190
|
+
return "You are Claude Code, Anthropic's official CLI for Claude."
|
|
191
|
+
|
|
192
|
+
def _register_tools(self, agent: Agent) -> None:
|
|
193
|
+
"""Register all the SQL tools with the agent."""
|
|
194
|
+
|
|
195
|
+
@agent.tool(name="list_tables")
|
|
196
|
+
async def list_tables(ctx: RunContext) -> str:
|
|
197
|
+
"""
|
|
198
|
+
Get a list of all tables in the database with row counts.
|
|
199
|
+
Use this first to discover available tables.
|
|
200
|
+
"""
|
|
201
|
+
tool = tool_registry.get_tool("list_tables")
|
|
202
|
+
return await tool.execute()
|
|
203
|
+
|
|
204
|
+
@agent.tool(name="introspect_schema")
|
|
205
|
+
async def introspect_schema(
|
|
206
|
+
ctx: RunContext, table_pattern: str | None = None
|
|
207
|
+
) -> str:
|
|
208
|
+
"""
|
|
209
|
+
Introspect database schema to understand table structures.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
table_pattern: Optional pattern to filter tables (e.g., 'public.users', 'user%', '%order%')
|
|
213
|
+
"""
|
|
214
|
+
tool = tool_registry.get_tool("introspect_schema")
|
|
215
|
+
return await tool.execute(table_pattern=table_pattern)
|
|
216
|
+
|
|
217
|
+
@agent.tool(name="execute_sql")
|
|
218
|
+
async def execute_sql(
|
|
219
|
+
ctx: RunContext, query: str, limit: int | None = 100
|
|
220
|
+
) -> str:
|
|
221
|
+
"""
|
|
222
|
+
Execute a SQL query and return the results.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
query: SQL query to execute
|
|
226
|
+
limit: Maximum number of rows to return (default: 100)
|
|
227
|
+
"""
|
|
228
|
+
tool = tool_registry.get_tool("execute_sql")
|
|
229
|
+
return await tool.execute(query=query, limit=limit)
|
|
230
|
+
|
|
231
|
+
def set_thinking(self, enabled: bool) -> None:
|
|
232
|
+
"""Update thinking settings and rebuild the agent."""
|
|
233
|
+
self.thinking_enabled = enabled
|
|
234
|
+
# Rebuild agent with new thinking settings
|
|
235
|
+
self.agent = self._build_agent()
|
|
236
|
+
|
|
237
|
+
def _get_database_type_name(self) -> str:
|
|
238
|
+
"""Get the human-readable database type name."""
|
|
239
|
+
if isinstance(self.db_connection, PostgreSQLConnection):
|
|
240
|
+
return "PostgreSQL"
|
|
241
|
+
elif isinstance(self.db_connection, MySQLConnection):
|
|
242
|
+
return "MySQL"
|
|
243
|
+
elif isinstance(self.db_connection, SQLiteConnection):
|
|
244
|
+
return "SQLite"
|
|
245
|
+
elif isinstance(self.db_connection, DuckDBConnection):
|
|
246
|
+
return "DuckDB"
|
|
247
|
+
elif isinstance(self.db_connection, CSVConnection):
|
|
248
|
+
return "DuckDB"
|
|
249
|
+
else:
|
|
250
|
+
return "database"
|
sqlsaber/cli/commands.py
CHANGED
|
@@ -75,7 +75,7 @@ def query(
|
|
|
75
75
|
query_text: Annotated[
|
|
76
76
|
str | None,
|
|
77
77
|
cyclopts.Parameter(
|
|
78
|
-
help="
|
|
78
|
+
help="Question in natural language (if not provided, reads from stdin or starts interactive mode)",
|
|
79
79
|
),
|
|
80
80
|
] = None,
|
|
81
81
|
database: Annotated[
|
|
@@ -85,6 +85,7 @@ def query(
|
|
|
85
85
|
help="Database connection name, file path (CSV/SQLite/DuckDB), or connection string (postgresql://, mysql://, duckdb://) (uses default if not specified)",
|
|
86
86
|
),
|
|
87
87
|
] = None,
|
|
88
|
+
thinking: bool = False,
|
|
88
89
|
):
|
|
89
90
|
"""Run a query against the database or start interactive mode.
|
|
90
91
|
|
|
@@ -109,16 +110,11 @@ def query(
|
|
|
109
110
|
async def run_session():
|
|
110
111
|
# Import heavy dependencies only when actually running a query
|
|
111
112
|
# This is only done to speed up startup time
|
|
112
|
-
from sqlsaber.agents import
|
|
113
|
+
from sqlsaber.agents import SQLSaberAgent
|
|
113
114
|
from sqlsaber.cli.interactive import InteractiveSession
|
|
114
115
|
from sqlsaber.cli.streaming import StreamingQueryHandler
|
|
115
|
-
from sqlsaber.database
|
|
116
|
-
CSVConnection,
|
|
116
|
+
from sqlsaber.database import (
|
|
117
117
|
DatabaseConnection,
|
|
118
|
-
DuckDBConnection,
|
|
119
|
-
MySQLConnection,
|
|
120
|
-
PostgreSQLConnection,
|
|
121
|
-
SQLiteConnection,
|
|
122
118
|
)
|
|
123
119
|
from sqlsaber.database.resolver import DatabaseResolutionError, resolve_database
|
|
124
120
|
from sqlsaber.threads import ThreadStorage
|
|
@@ -147,45 +143,32 @@ def query(
|
|
|
147
143
|
raise CLIError(f"Error creating database connection: {e}")
|
|
148
144
|
|
|
149
145
|
# Create pydantic-ai agent instance with database name for memory context
|
|
150
|
-
|
|
146
|
+
sqlsaber_agent = SQLSaberAgent(db_conn, db_name, thinking_enabled=thinking)
|
|
151
147
|
|
|
152
148
|
try:
|
|
153
149
|
if actual_query:
|
|
154
150
|
# Single query mode with streaming
|
|
155
151
|
streaming_handler = StreamingQueryHandler(console)
|
|
156
|
-
|
|
157
|
-
if isinstance(db_conn, PostgreSQLConnection):
|
|
158
|
-
db_type = "PostgreSQL"
|
|
159
|
-
elif isinstance(db_conn, MySQLConnection):
|
|
160
|
-
db_type = "MySQL"
|
|
161
|
-
elif isinstance(db_conn, DuckDBConnection):
|
|
162
|
-
db_type = "DuckDB"
|
|
163
|
-
elif isinstance(db_conn, SQLiteConnection):
|
|
164
|
-
db_type = "SQLite"
|
|
165
|
-
elif isinstance(db_conn, CSVConnection):
|
|
166
|
-
db_type = "DuckDB"
|
|
167
|
-
else:
|
|
168
|
-
db_type = "database"
|
|
152
|
+
db_type = sqlsaber_agent.db_type
|
|
169
153
|
console.print(
|
|
170
154
|
f"[bold blue]Connected to:[/bold blue] {db_name} ({db_type})\n"
|
|
171
155
|
)
|
|
172
156
|
run = await streaming_handler.execute_streaming_query(
|
|
173
|
-
actual_query,
|
|
157
|
+
actual_query, sqlsaber_agent
|
|
174
158
|
)
|
|
175
159
|
# Persist non-interactive run as a thread snapshot so it can be resumed later
|
|
176
160
|
try:
|
|
177
161
|
if run is not None:
|
|
178
162
|
threads = ThreadStorage()
|
|
179
|
-
# Extract title and model name
|
|
180
|
-
title = actual_query
|
|
181
|
-
model_name: str | None = agent.model.model_name
|
|
182
163
|
|
|
183
164
|
thread_id = await threads.save_snapshot(
|
|
184
165
|
messages_json=run.all_messages_json(),
|
|
185
166
|
database_name=db_name,
|
|
186
167
|
)
|
|
187
168
|
await threads.save_metadata(
|
|
188
|
-
thread_id=thread_id,
|
|
169
|
+
thread_id=thread_id,
|
|
170
|
+
title=actual_query,
|
|
171
|
+
model_name=sqlsaber_agent.agent.model.model_name,
|
|
189
172
|
)
|
|
190
173
|
await threads.end_thread(thread_id)
|
|
191
174
|
console.print(
|
|
@@ -198,7 +181,7 @@ def query(
|
|
|
198
181
|
await threads.prune_threads()
|
|
199
182
|
else:
|
|
200
183
|
# Interactive mode
|
|
201
|
-
session = InteractiveSession(console,
|
|
184
|
+
session = InteractiveSession(console, sqlsaber_agent, db_conn, db_name)
|
|
202
185
|
await session.run()
|
|
203
186
|
|
|
204
187
|
finally:
|
sqlsaber/cli/completers.py
CHANGED
|
@@ -19,6 +19,8 @@ class SlashCommandCompleter(Completer):
|
|
|
19
19
|
("clear", "Clear conversation history"),
|
|
20
20
|
("exit", "Exit the interactive session"),
|
|
21
21
|
("quit", "Exit the interactive session"),
|
|
22
|
+
("thinking on", "Enable extended thinking/reasoning"),
|
|
23
|
+
("thinking off", "Disable extended thinking/reasoning"),
|
|
22
24
|
]
|
|
23
25
|
|
|
24
26
|
# Yield completions that match the partial command
|
sqlsaber/cli/database.py
CHANGED
|
@@ -354,7 +354,7 @@ def test(
|
|
|
354
354
|
|
|
355
355
|
async def test_connection():
|
|
356
356
|
# Lazy import to keep CLI startup fast
|
|
357
|
-
from sqlsaber.database
|
|
357
|
+
from sqlsaber.database import DatabaseConnection
|
|
358
358
|
|
|
359
359
|
if name:
|
|
360
360
|
db_config = config_manager.get_database(name)
|
sqlsaber/cli/display.py
CHANGED
|
@@ -8,7 +8,7 @@ rendered with Live.
|
|
|
8
8
|
import json
|
|
9
9
|
from typing import Sequence, Type
|
|
10
10
|
|
|
11
|
-
from pydantic_ai.messages import ModelResponsePart, TextPart
|
|
11
|
+
from pydantic_ai.messages import ModelResponsePart, TextPart, ThinkingPart
|
|
12
12
|
from rich.columns import Columns
|
|
13
13
|
from rich.console import Console, ConsoleOptions, RenderResult
|
|
14
14
|
from rich.live import Live
|
|
@@ -75,7 +75,7 @@ class LiveMarkdownRenderer:
|
|
|
75
75
|
self.end()
|
|
76
76
|
self.paragraph_break()
|
|
77
77
|
|
|
78
|
-
self._start()
|
|
78
|
+
self._start(kind)
|
|
79
79
|
self._current_kind = kind
|
|
80
80
|
|
|
81
81
|
def append(self, text: str | None) -> None:
|
|
@@ -87,7 +87,13 @@ class LiveMarkdownRenderer:
|
|
|
87
87
|
self.ensure_segment(TextPart)
|
|
88
88
|
|
|
89
89
|
self._buffer += text
|
|
90
|
-
|
|
90
|
+
|
|
91
|
+
# Apply dim styling for thinking segments
|
|
92
|
+
if self._current_kind == ThinkingPart:
|
|
93
|
+
content = Markdown(self._buffer, style="dim")
|
|
94
|
+
self._live.update(content)
|
|
95
|
+
else:
|
|
96
|
+
self._live.update(Markdown(self._buffer))
|
|
91
97
|
|
|
92
98
|
def end(self) -> None:
|
|
93
99
|
"""Finalize and stop the current Live segment, if any."""
|
|
@@ -95,13 +101,17 @@ class LiveMarkdownRenderer:
|
|
|
95
101
|
return
|
|
96
102
|
# Persist the *final* render exactly once, then shut Live down.
|
|
97
103
|
buf = self._buffer
|
|
104
|
+
kind = self._current_kind
|
|
98
105
|
self._live.stop()
|
|
99
106
|
self._live = None
|
|
100
107
|
self._buffer = ""
|
|
101
108
|
self._current_kind = None
|
|
102
109
|
# Print the complete markdown to scroll-back for permanent reference
|
|
103
110
|
if buf:
|
|
104
|
-
|
|
111
|
+
if kind == ThinkingPart:
|
|
112
|
+
self.console.print(Text(buf, style="dim"))
|
|
113
|
+
else:
|
|
114
|
+
self.console.print(Markdown(buf))
|
|
105
115
|
|
|
106
116
|
def end_if_active(self) -> None:
|
|
107
117
|
self.end()
|
|
@@ -153,10 +163,20 @@ class LiveMarkdownRenderer:
|
|
|
153
163
|
text = Text(f" {message}", style="yellow")
|
|
154
164
|
return Columns([spinner, text], expand=False)
|
|
155
165
|
|
|
156
|
-
def _start(
|
|
166
|
+
def _start(
|
|
167
|
+
self, kind: Type[ModelResponsePart] | None = None, initial_markdown: str = ""
|
|
168
|
+
) -> None:
|
|
157
169
|
if self._live is not None:
|
|
158
170
|
self.end()
|
|
159
171
|
self._buffer = initial_markdown or ""
|
|
172
|
+
|
|
173
|
+
# Add visual styling for thinking segments
|
|
174
|
+
if kind == ThinkingPart:
|
|
175
|
+
if self.console.is_terminal:
|
|
176
|
+
self.console.print("[dim]💭 Thinking...[/dim]")
|
|
177
|
+
else:
|
|
178
|
+
self.console.print("*Thinking...*\n")
|
|
179
|
+
|
|
160
180
|
# NOTE: Use transient=True so the live widget disappears on exit,
|
|
161
181
|
# giving a clean transition to the final printed result.
|
|
162
182
|
live = Live(
|
|
@@ -219,7 +239,9 @@ class DisplayManager:
|
|
|
219
239
|
if self.console.is_terminal:
|
|
220
240
|
self.console.print("[dim bold]:gear: Executing SQL:[/dim bold]")
|
|
221
241
|
self.show_newline()
|
|
222
|
-
syntax = Syntax(
|
|
242
|
+
syntax = Syntax(
|
|
243
|
+
query, "sql", background_color="default", word_wrap=True
|
|
244
|
+
)
|
|
223
245
|
self.console.print(syntax)
|
|
224
246
|
else:
|
|
225
247
|
self.console.print("**Executing SQL:**\n")
|
|
@@ -271,9 +293,7 @@ class DisplayManager:
|
|
|
271
293
|
f"[yellow]... and {len(results) - 20} more rows[/yellow]"
|
|
272
294
|
)
|
|
273
295
|
else:
|
|
274
|
-
self.console.print(
|
|
275
|
-
f"*... and {len(results) - 20} more rows*\n"
|
|
276
|
-
)
|
|
296
|
+
self.console.print(f"*... and {len(results) - 20} more rows*\n")
|
|
277
297
|
|
|
278
298
|
def show_error(self, error_message: str):
|
|
279
299
|
"""Display error message."""
|