sqlsaber 0.20.0__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- sqlsaber/agents/pydantic_ai_agent.py +5 -0
- sqlsaber/cli/display.py +180 -13
- sqlsaber/cli/interactive.py +10 -8
- sqlsaber/cli/streaming.py +98 -55
- sqlsaber/database/connection.py +105 -12
- {sqlsaber-0.20.0.dist-info → sqlsaber-0.22.0.dist-info}/METADATA +1 -1
- {sqlsaber-0.20.0.dist-info → sqlsaber-0.22.0.dist-info}/RECORD +10 -10
- {sqlsaber-0.20.0.dist-info → sqlsaber-0.22.0.dist-info}/WHEEL +0 -0
- {sqlsaber-0.20.0.dist-info → sqlsaber-0.22.0.dist-info}/entry_points.txt +0 -0
- {sqlsaber-0.20.0.dist-info → sqlsaber-0.22.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -8,6 +8,7 @@ import httpx
|
|
|
8
8
|
from pydantic_ai import Agent, RunContext
|
|
9
9
|
from pydantic_ai.models.anthropic import AnthropicModel
|
|
10
10
|
from pydantic_ai.models.google import GoogleModel
|
|
11
|
+
from pydantic_ai.models.openai import OpenAIResponsesModel
|
|
11
12
|
from pydantic_ai.providers.anthropic import AnthropicProvider
|
|
12
13
|
from pydantic_ai.providers.google import GoogleProvider
|
|
13
14
|
|
|
@@ -79,6 +80,10 @@ def build_sqlsaber_agent(
|
|
|
79
80
|
provider_obj = AnthropicProvider(api_key="placeholder", http_client=http_client)
|
|
80
81
|
model_obj = AnthropicModel(model_name_only, provider=provider_obj)
|
|
81
82
|
agent = Agent(model_obj, name="sqlsaber")
|
|
83
|
+
elif provider == "openai":
|
|
84
|
+
# Use OpenAI Responses Model for structured output capabilities
|
|
85
|
+
model_obj = OpenAIResponsesModel(model_name_only)
|
|
86
|
+
agent = Agent(model_obj, name="sqlsaber")
|
|
82
87
|
else:
|
|
83
88
|
agent = Agent(cfg.model_name, name="sqlsaber")
|
|
84
89
|
|
sqlsaber/cli/display.py
CHANGED
|
@@ -1,12 +1,167 @@
|
|
|
1
|
-
"""Display utilities for the CLI interface.
|
|
1
|
+
"""Display utilities for the CLI interface.
|
|
2
|
+
|
|
3
|
+
All rendering occurs on the event loop thread.
|
|
4
|
+
Streaming segments use Live Markdown; transient status and SQL blocks are also
|
|
5
|
+
rendered with Live.
|
|
6
|
+
"""
|
|
2
7
|
|
|
3
8
|
import json
|
|
9
|
+
from typing import Sequence, Type
|
|
4
10
|
|
|
5
|
-
from
|
|
6
|
-
from rich.
|
|
11
|
+
from pydantic_ai.messages import ModelResponsePart, TextPart
|
|
12
|
+
from rich.columns import Columns
|
|
13
|
+
from rich.console import Console, ConsoleOptions, RenderResult
|
|
14
|
+
from rich.live import Live
|
|
15
|
+
from rich.markdown import CodeBlock, Markdown
|
|
7
16
|
from rich.panel import Panel
|
|
17
|
+
from rich.spinner import Spinner
|
|
8
18
|
from rich.syntax import Syntax
|
|
9
19
|
from rich.table import Table
|
|
20
|
+
from rich.text import Text
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class _SimpleCodeBlock(CodeBlock):
|
|
24
|
+
def __rich_console__(
|
|
25
|
+
self, console: Console, options: ConsoleOptions
|
|
26
|
+
) -> RenderResult:
|
|
27
|
+
code = str(self.text).rstrip()
|
|
28
|
+
yield Syntax(
|
|
29
|
+
code,
|
|
30
|
+
self.lexer_name,
|
|
31
|
+
theme=self.theme,
|
|
32
|
+
background_color="default",
|
|
33
|
+
word_wrap=True,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class LiveMarkdownRenderer:
|
|
38
|
+
"""Handles Live markdown rendering with segment separation.
|
|
39
|
+
|
|
40
|
+
Supports different segment kinds: 'assistant', 'thinking', 'sql'.
|
|
41
|
+
Adds visible paragraph breaks between segments and renders code fences
|
|
42
|
+
with nicer formatting.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
_patched_fences = False
|
|
46
|
+
|
|
47
|
+
def __init__(self, console: Console):
|
|
48
|
+
self.console = console
|
|
49
|
+
self._live: Live | None = None
|
|
50
|
+
self._status_live: Live | None = None
|
|
51
|
+
self._buffer: str = ""
|
|
52
|
+
self._current_kind: Type[ModelResponsePart] | None = None
|
|
53
|
+
|
|
54
|
+
def prepare_code_blocks(self) -> None:
|
|
55
|
+
"""Patch rich Markdown fence rendering once for nicer code blocks."""
|
|
56
|
+
if LiveMarkdownRenderer._patched_fences:
|
|
57
|
+
return
|
|
58
|
+
# Guard with class check to avoid re-patching if already applied
|
|
59
|
+
if Markdown.elements.get("fence") is not _SimpleCodeBlock:
|
|
60
|
+
Markdown.elements["fence"] = _SimpleCodeBlock
|
|
61
|
+
LiveMarkdownRenderer._patched_fences = True
|
|
62
|
+
|
|
63
|
+
def ensure_segment(self, kind: Type[ModelResponsePart]) -> None:
|
|
64
|
+
"""
|
|
65
|
+
Ensure a markdown Live segment is active for the given kind.
|
|
66
|
+
|
|
67
|
+
When switching kinds, end the previous segment and add a paragraph break.
|
|
68
|
+
"""
|
|
69
|
+
# If a transient status is showing, clear it first (no paragraph break)
|
|
70
|
+
if self._status_live is not None:
|
|
71
|
+
self.end_status()
|
|
72
|
+
if self._live is not None and self._current_kind == kind:
|
|
73
|
+
return
|
|
74
|
+
if self._live is not None:
|
|
75
|
+
self.end()
|
|
76
|
+
self.paragraph_break()
|
|
77
|
+
|
|
78
|
+
self._start()
|
|
79
|
+
self._current_kind = kind
|
|
80
|
+
|
|
81
|
+
def append(self, text: str | None) -> None:
|
|
82
|
+
"""Append text to the current markdown segment and refresh."""
|
|
83
|
+
if not text:
|
|
84
|
+
return
|
|
85
|
+
if self._live is None:
|
|
86
|
+
# default to assistant if no segment was ensured
|
|
87
|
+
self.ensure_segment(TextPart)
|
|
88
|
+
|
|
89
|
+
self._buffer += text
|
|
90
|
+
self._live.update(Markdown(self._buffer))
|
|
91
|
+
|
|
92
|
+
def end(self) -> None:
|
|
93
|
+
"""Finalize and stop the current Live segment, if any."""
|
|
94
|
+
if self._live is None:
|
|
95
|
+
return
|
|
96
|
+
if self._buffer:
|
|
97
|
+
self._live.update(Markdown(self._buffer))
|
|
98
|
+
self._live.stop()
|
|
99
|
+
self._live = None
|
|
100
|
+
self._buffer = ""
|
|
101
|
+
self._current_kind = None
|
|
102
|
+
|
|
103
|
+
def end_if_active(self) -> None:
|
|
104
|
+
self.end()
|
|
105
|
+
|
|
106
|
+
def paragraph_break(self) -> None:
|
|
107
|
+
self.console.print()
|
|
108
|
+
|
|
109
|
+
def start_sql_block(self, sql: str) -> None:
|
|
110
|
+
"""Render a SQL block using a transient Live markdown segment."""
|
|
111
|
+
if not sql or not isinstance(sql, str) or not sql.strip():
|
|
112
|
+
return
|
|
113
|
+
# Separate from surrounding content
|
|
114
|
+
self.end_if_active()
|
|
115
|
+
self.paragraph_break()
|
|
116
|
+
self._buffer = f"```sql\n{sql}\n```"
|
|
117
|
+
# Use context manager to auto-stop and persist final render
|
|
118
|
+
with Live(
|
|
119
|
+
Markdown(self._buffer),
|
|
120
|
+
console=self.console,
|
|
121
|
+
vertical_overflow="visible",
|
|
122
|
+
refresh_per_second=12,
|
|
123
|
+
):
|
|
124
|
+
pass
|
|
125
|
+
|
|
126
|
+
def start_status(self, message: str = "Crunching data...") -> None:
|
|
127
|
+
"""Show a transient status line with a spinner until streaming starts."""
|
|
128
|
+
if self._status_live is not None:
|
|
129
|
+
# Update existing status text
|
|
130
|
+
self._status_live.update(self._status_renderable(message))
|
|
131
|
+
return
|
|
132
|
+
live = Live(
|
|
133
|
+
self._status_renderable(message),
|
|
134
|
+
console=self.console,
|
|
135
|
+
transient=True, # disappear when stopped
|
|
136
|
+
refresh_per_second=12,
|
|
137
|
+
)
|
|
138
|
+
self._status_live = live
|
|
139
|
+
live.start()
|
|
140
|
+
|
|
141
|
+
def end_status(self) -> None:
|
|
142
|
+
live = self._status_live
|
|
143
|
+
if live is None:
|
|
144
|
+
return
|
|
145
|
+
live.stop()
|
|
146
|
+
self._status_live = None
|
|
147
|
+
|
|
148
|
+
def _status_renderable(self, message: str):
|
|
149
|
+
spinner = Spinner("dots", style="yellow")
|
|
150
|
+
text = Text(f" {message}", style="yellow")
|
|
151
|
+
return Columns([spinner, text], expand=False)
|
|
152
|
+
|
|
153
|
+
def _start(self, initial_markdown: str = "") -> None:
|
|
154
|
+
if self._live is not None:
|
|
155
|
+
self.end()
|
|
156
|
+
self._buffer = initial_markdown or ""
|
|
157
|
+
live = Live(
|
|
158
|
+
Markdown(self._buffer),
|
|
159
|
+
console=self.console,
|
|
160
|
+
vertical_overflow="visible",
|
|
161
|
+
refresh_per_second=12,
|
|
162
|
+
)
|
|
163
|
+
self._live = live
|
|
164
|
+
live.start()
|
|
10
165
|
|
|
11
166
|
|
|
12
167
|
class DisplayManager:
|
|
@@ -14,10 +169,11 @@ class DisplayManager:
|
|
|
14
169
|
|
|
15
170
|
def __init__(self, console: Console):
|
|
16
171
|
self.console = console
|
|
172
|
+
self.live = LiveMarkdownRenderer(console)
|
|
17
173
|
|
|
18
174
|
def _create_table(
|
|
19
175
|
self,
|
|
20
|
-
columns:
|
|
176
|
+
columns: Sequence[str | dict[str, str]],
|
|
21
177
|
header_style: str = "bold blue",
|
|
22
178
|
title: str | None = None,
|
|
23
179
|
) -> Table:
|
|
@@ -34,17 +190,24 @@ class DisplayManager:
|
|
|
34
190
|
|
|
35
191
|
def show_tool_executing(self, tool_name: str, tool_input: dict):
|
|
36
192
|
"""Display tool execution details."""
|
|
37
|
-
|
|
193
|
+
# Normalized leading blank line before tool headers
|
|
194
|
+
self.show_newline()
|
|
38
195
|
if tool_name == "list_tables":
|
|
39
|
-
self.console.print(
|
|
196
|
+
self.console.print(
|
|
197
|
+
"[dim bold]:gear: Discovering available tables[/dim bold]"
|
|
198
|
+
)
|
|
40
199
|
elif tool_name == "introspect_schema":
|
|
41
200
|
pattern = tool_input.get("table_pattern", "all tables")
|
|
42
|
-
self.console.print(
|
|
201
|
+
self.console.print(
|
|
202
|
+
f"[dim bold]:gear: Examining schema for: {pattern}[/dim bold]"
|
|
203
|
+
)
|
|
43
204
|
elif tool_name == "execute_sql":
|
|
205
|
+
# For streaming, we render SQL via LiveMarkdownRenderer; keep Syntax
|
|
206
|
+
# rendering for threads show/resume. Controlled by include_sql flag.
|
|
44
207
|
query = tool_input.get("query", "")
|
|
45
|
-
self.console.print("
|
|
208
|
+
self.console.print("[dim bold]:gear: Executing SQL:[/dim bold]")
|
|
46
209
|
self.show_newline()
|
|
47
|
-
syntax = Syntax(query, "sql")
|
|
210
|
+
syntax = Syntax(query, "sql", background_color="default", word_wrap=True)
|
|
48
211
|
self.console.print(syntax)
|
|
49
212
|
|
|
50
213
|
def show_text_stream(self, text: str):
|
|
@@ -99,10 +262,12 @@ class DisplayManager:
|
|
|
99
262
|
"""Display a newline for spacing."""
|
|
100
263
|
self.console.print()
|
|
101
264
|
|
|
102
|
-
def show_table_list(self, tables_data: str):
|
|
265
|
+
def show_table_list(self, tables_data: str | dict):
|
|
103
266
|
"""Display the results from list_tables tool."""
|
|
104
267
|
try:
|
|
105
|
-
data =
|
|
268
|
+
data = (
|
|
269
|
+
json.loads(tables_data) if isinstance(tables_data, str) else tables_data
|
|
270
|
+
)
|
|
106
271
|
|
|
107
272
|
# Handle error case
|
|
108
273
|
if "error" in data:
|
|
@@ -143,10 +308,12 @@ class DisplayManager:
|
|
|
143
308
|
except Exception as e:
|
|
144
309
|
self.show_error(f"Error displaying table list: {str(e)}")
|
|
145
310
|
|
|
146
|
-
def show_schema_info(self, schema_data: str):
|
|
311
|
+
def show_schema_info(self, schema_data: str | dict):
|
|
147
312
|
"""Display the results from introspect_schema tool."""
|
|
148
313
|
try:
|
|
149
|
-
data =
|
|
314
|
+
data = (
|
|
315
|
+
json.loads(schema_data) if isinstance(schema_data, str) else schema_data
|
|
316
|
+
)
|
|
150
317
|
|
|
151
318
|
# Handle error case
|
|
152
319
|
if "error" in data:
|
sqlsaber/cli/interactive.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
4
|
from pathlib import Path
|
|
5
|
+
from textwrap import dedent
|
|
5
6
|
|
|
6
7
|
import platformdirs
|
|
7
8
|
from prompt_toolkit import PromptSession
|
|
@@ -10,6 +11,7 @@ from prompt_toolkit.patch_stdout import patch_stdout
|
|
|
10
11
|
from prompt_toolkit.styles import Style
|
|
11
12
|
from pydantic_ai import Agent
|
|
12
13
|
from rich.console import Console
|
|
14
|
+
from rich.markdown import Markdown
|
|
13
15
|
from rich.panel import Panel
|
|
14
16
|
|
|
15
17
|
from sqlsaber.cli.completers import (
|
|
@@ -102,14 +104,14 @@ class InteractiveSession:
|
|
|
102
104
|
)
|
|
103
105
|
)
|
|
104
106
|
self.console.print(
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
107
|
+
Markdown(
|
|
108
|
+
dedent("""
|
|
109
|
+
- Use `/` for slash commands
|
|
110
|
+
- Type `@` to get table name completions
|
|
111
|
+
- Start message with `#` to add something to agent's memory
|
|
112
|
+
- Use `Ctrl+C` to interrupt and `Ctrl+D` to exit
|
|
113
|
+
""")
|
|
114
|
+
)
|
|
113
115
|
)
|
|
114
116
|
|
|
115
117
|
self.console.print(
|
sqlsaber/cli/streaming.py
CHANGED
|
@@ -1,7 +1,13 @@
|
|
|
1
|
-
"""Streaming query handling for the CLI (pydantic-ai based).
|
|
1
|
+
"""Streaming query handling for the CLI (pydantic-ai based).
|
|
2
|
+
|
|
3
|
+
This module uses DisplayManager's LiveMarkdownRenderer to stream Markdown
|
|
4
|
+
incrementally as the agent outputs tokens. Tool calls and results are
|
|
5
|
+
rendered via DisplayManager helpers.
|
|
6
|
+
"""
|
|
2
7
|
|
|
3
8
|
import asyncio
|
|
4
9
|
import json
|
|
10
|
+
from functools import singledispatchmethod
|
|
5
11
|
from typing import AsyncIterable
|
|
6
12
|
|
|
7
13
|
from pydantic_ai import Agent, RunContext
|
|
@@ -22,56 +28,98 @@ from sqlsaber.cli.display import DisplayManager
|
|
|
22
28
|
|
|
23
29
|
|
|
24
30
|
class StreamingQueryHandler:
|
|
25
|
-
"""
|
|
31
|
+
"""
|
|
32
|
+
Handles streaming query execution and display using pydantic-ai events.
|
|
33
|
+
|
|
34
|
+
Uses DisplayManager.live to render Markdown incrementally as text streams in.
|
|
35
|
+
"""
|
|
26
36
|
|
|
27
37
|
def __init__(self, console: Console):
|
|
28
38
|
self.console = console
|
|
29
39
|
self.display = DisplayManager(console)
|
|
30
40
|
|
|
31
|
-
self.status = self.console.status(
|
|
32
|
-
"[yellow]Crunching data...[/yellow]", spinner="bouncingBall"
|
|
33
|
-
)
|
|
34
|
-
|
|
35
41
|
async def _event_stream_handler(
|
|
36
42
|
self, ctx: RunContext, event_stream: AsyncIterable[AgentStreamEvent]
|
|
37
43
|
) -> None:
|
|
44
|
+
"""
|
|
45
|
+
Handle pydantic-ai streaming events and update Live Markdown via DisplayManager.
|
|
46
|
+
"""
|
|
47
|
+
|
|
38
48
|
async for event in event_stream:
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
49
|
+
await self.on_event(event, ctx)
|
|
50
|
+
|
|
51
|
+
# --- Event routing via singledispatchmethod ---------------------------------------
|
|
52
|
+
@singledispatchmethod
|
|
53
|
+
async def on_event(
|
|
54
|
+
self, event: AgentStreamEvent, ctx: RunContext
|
|
55
|
+
) -> None: # default
|
|
56
|
+
return
|
|
57
|
+
|
|
58
|
+
@on_event.register
|
|
59
|
+
async def _(self, event: PartStartEvent, ctx: RunContext) -> None:
|
|
60
|
+
if isinstance(event.part, TextPart):
|
|
61
|
+
self.display.live.ensure_segment(TextPart)
|
|
62
|
+
self.display.live.append(event.part.content)
|
|
63
|
+
elif isinstance(event.part, ThinkingPart):
|
|
64
|
+
self.display.live.ensure_segment(ThinkingPart)
|
|
65
|
+
self.display.live.append(event.part.content)
|
|
66
|
+
|
|
67
|
+
@on_event.register
|
|
68
|
+
async def _(self, event: PartDeltaEvent, ctx: RunContext) -> None:
|
|
69
|
+
d = event.delta
|
|
70
|
+
if isinstance(d, TextPartDelta):
|
|
71
|
+
delta = d.content_delta or ""
|
|
72
|
+
if delta:
|
|
73
|
+
self.display.live.ensure_segment(TextPart)
|
|
74
|
+
self.display.live.append(delta)
|
|
75
|
+
elif isinstance(d, ThinkingPartDelta):
|
|
76
|
+
delta = d.content_delta or ""
|
|
77
|
+
if delta:
|
|
78
|
+
self.display.live.ensure_segment(ThinkingPart)
|
|
79
|
+
self.display.live.append(delta)
|
|
80
|
+
|
|
81
|
+
@on_event.register
|
|
82
|
+
async def _(self, event: FunctionToolCallEvent, ctx: RunContext) -> None:
|
|
83
|
+
# Clear any status/markdown Live so tool output sits between
|
|
84
|
+
self.display.live.end_status()
|
|
85
|
+
self.display.live.end_if_active()
|
|
86
|
+
args = event.part.args_as_dict()
|
|
87
|
+
|
|
88
|
+
# Special handling: display SQL via Live as markdown code block
|
|
89
|
+
if event.part.tool_name == "execute_sql":
|
|
90
|
+
query = args.get("query") or ""
|
|
91
|
+
if isinstance(query, str) and query.strip():
|
|
92
|
+
self.display.live.start_sql_block(query)
|
|
93
|
+
else:
|
|
94
|
+
self.display.show_tool_executing(event.part.tool_name, args)
|
|
95
|
+
|
|
96
|
+
@on_event.register
|
|
97
|
+
async def _(self, event: FunctionToolResultEvent, ctx: RunContext) -> None:
|
|
98
|
+
# Route tool result to appropriate display
|
|
99
|
+
tool_name = event.result.tool_name
|
|
100
|
+
content = event.result.content
|
|
101
|
+
if tool_name == "list_tables":
|
|
102
|
+
self.display.show_table_list(content)
|
|
103
|
+
elif tool_name == "introspect_schema":
|
|
104
|
+
self.display.show_schema_info(content)
|
|
105
|
+
elif tool_name == "execute_sql":
|
|
106
|
+
data = {}
|
|
107
|
+
if isinstance(content, str):
|
|
108
|
+
try:
|
|
109
|
+
data = json.loads(content)
|
|
110
|
+
except (json.JSONDecodeError, TypeError) as exc:
|
|
68
111
|
try:
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
self.display.show_query_results(data["results"]) # type: ignore[arg-type]
|
|
72
|
-
except json.JSONDecodeError:
|
|
73
|
-
# If not JSON, ignore here
|
|
112
|
+
self.console.log(f"Malformed execute_sql result: {exc}")
|
|
113
|
+
except Exception:
|
|
74
114
|
pass
|
|
115
|
+
elif isinstance(content, dict):
|
|
116
|
+
data = content
|
|
117
|
+
if isinstance(data, dict) and data.get("success") and data.get("results"):
|
|
118
|
+
self.display.show_query_results(data["results"]) # type: ignore[arg-type]
|
|
119
|
+
# Add a blank line after tool output to separate from next segment
|
|
120
|
+
self.display.show_newline()
|
|
121
|
+
# Show status while agent sends a follow-up request to the model
|
|
122
|
+
self.display.live.start_status("Crunching data...")
|
|
75
123
|
|
|
76
124
|
async def execute_streaming_query(
|
|
77
125
|
self,
|
|
@@ -80,7 +128,8 @@ class StreamingQueryHandler:
|
|
|
80
128
|
cancellation_token: asyncio.Event | None = None,
|
|
81
129
|
message_history: list | None = None,
|
|
82
130
|
):
|
|
83
|
-
|
|
131
|
+
# Prepare nicer code block rendering for Markdown
|
|
132
|
+
self.display.live.prepare_code_blocks()
|
|
84
133
|
try:
|
|
85
134
|
# If Anthropic OAuth, inject SQLsaber instructions before the first user prompt
|
|
86
135
|
prepared_prompt: str | list[str] = user_query
|
|
@@ -104,30 +153,24 @@ class StreamingQueryHandler:
|
|
|
104
153
|
injected = "\n\n".join(parts)
|
|
105
154
|
prepared_prompt = [injected, user_query]
|
|
106
155
|
|
|
156
|
+
# Show a transient status until events start streaming
|
|
157
|
+
self.display.live.start_status("Crunching data...")
|
|
158
|
+
|
|
107
159
|
# Run the agent with our event stream handler
|
|
108
160
|
run = await agent.run(
|
|
109
161
|
prepared_prompt,
|
|
110
162
|
message_history=message_history,
|
|
111
163
|
event_stream_handler=self._event_stream_handler,
|
|
112
164
|
)
|
|
113
|
-
# After the run completes, show the assistant's final text as markdown if available
|
|
114
|
-
try:
|
|
115
|
-
output = run.output
|
|
116
|
-
if isinstance(output, str) and output.strip():
|
|
117
|
-
self.display.show_newline()
|
|
118
|
-
self.display.show_markdown_response(
|
|
119
|
-
[{"type": "text", "text": output}]
|
|
120
|
-
)
|
|
121
|
-
except Exception as e:
|
|
122
|
-
self.display.show_error(str(e))
|
|
123
|
-
self.display.show_newline()
|
|
124
165
|
return run
|
|
125
166
|
except asyncio.CancelledError:
|
|
167
|
+
# Show interruption message outside of Live
|
|
126
168
|
self.display.show_newline()
|
|
127
169
|
self.console.print("[yellow]Query interrupted[/yellow]")
|
|
128
170
|
return None
|
|
129
171
|
finally:
|
|
172
|
+
# End any active status and live markdown segments
|
|
130
173
|
try:
|
|
131
|
-
self.
|
|
132
|
-
|
|
133
|
-
|
|
174
|
+
self.display.live.end_status()
|
|
175
|
+
finally:
|
|
176
|
+
self.display.live.end_if_active()
|
sqlsaber/database/connection.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Database connection management."""
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import ssl
|
|
4
5
|
from abc import ABC, abstractmethod
|
|
5
6
|
from pathlib import Path
|
|
@@ -10,6 +11,17 @@ import aiomysql
|
|
|
10
11
|
import aiosqlite
|
|
11
12
|
import asyncpg
|
|
12
13
|
|
|
14
|
+
# Default query timeout to prevent runaway queries
|
|
15
|
+
DEFAULT_QUERY_TIMEOUT = 30.0 # seconds
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class QueryTimeoutError(RuntimeError):
|
|
19
|
+
"""Exception raised when a query exceeds its timeout."""
|
|
20
|
+
|
|
21
|
+
def __init__(self, seconds: float):
|
|
22
|
+
self.timeout = seconds
|
|
23
|
+
super().__init__(f"Query exceeded timeout of {seconds}s")
|
|
24
|
+
|
|
13
25
|
|
|
14
26
|
class BaseDatabaseConnection(ABC):
|
|
15
27
|
"""Abstract base class for database connections."""
|
|
@@ -29,11 +41,18 @@ class BaseDatabaseConnection(ABC):
|
|
|
29
41
|
pass
|
|
30
42
|
|
|
31
43
|
@abstractmethod
|
|
32
|
-
async def execute_query(
|
|
44
|
+
async def execute_query(
|
|
45
|
+
self, query: str, *args, timeout: float | None = None
|
|
46
|
+
) -> list[dict[str, Any]]:
|
|
33
47
|
"""Execute a query and return results as list of dicts.
|
|
34
48
|
|
|
35
49
|
All queries run in a transaction that is rolled back at the end,
|
|
36
50
|
ensuring no changes are persisted to the database.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
query: SQL query to execute
|
|
54
|
+
*args: Query parameters
|
|
55
|
+
timeout: Query timeout in seconds (overrides default_timeout)
|
|
37
56
|
"""
|
|
38
57
|
pass
|
|
39
58
|
|
|
@@ -111,21 +130,40 @@ class PostgreSQLConnection(BaseDatabaseConnection):
|
|
|
111
130
|
await self._pool.close()
|
|
112
131
|
self._pool = None
|
|
113
132
|
|
|
114
|
-
async def execute_query(
|
|
133
|
+
async def execute_query(
|
|
134
|
+
self, query: str, *args, timeout: float | None = None
|
|
135
|
+
) -> list[dict[str, Any]]:
|
|
115
136
|
"""Execute a query and return results as list of dicts.
|
|
116
137
|
|
|
117
138
|
All queries run in a transaction that is rolled back at the end,
|
|
118
139
|
ensuring no changes are persisted to the database.
|
|
119
140
|
"""
|
|
141
|
+
effective_timeout = timeout or DEFAULT_QUERY_TIMEOUT
|
|
120
142
|
pool = await self.get_pool()
|
|
143
|
+
|
|
121
144
|
async with pool.acquire() as conn:
|
|
122
145
|
# Start a transaction that we'll always rollback
|
|
123
146
|
transaction = conn.transaction()
|
|
124
147
|
await transaction.start()
|
|
125
148
|
|
|
126
149
|
try:
|
|
127
|
-
|
|
150
|
+
# Set server-side timeout if specified
|
|
151
|
+
if effective_timeout:
|
|
152
|
+
await conn.execute(
|
|
153
|
+
f"SET LOCAL statement_timeout = {int(effective_timeout * 1000)}"
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
# Execute query with client-side timeout
|
|
157
|
+
if effective_timeout:
|
|
158
|
+
rows = await asyncio.wait_for(
|
|
159
|
+
conn.fetch(query, *args), timeout=effective_timeout
|
|
160
|
+
)
|
|
161
|
+
else:
|
|
162
|
+
rows = await conn.fetch(query, *args)
|
|
163
|
+
|
|
128
164
|
return [dict(row) for row in rows]
|
|
165
|
+
except asyncio.TimeoutError as exc:
|
|
166
|
+
raise QueryTimeoutError(effective_timeout or 0) from exc
|
|
129
167
|
finally:
|
|
130
168
|
# Always rollback to ensure no changes are committed
|
|
131
169
|
await transaction.rollback()
|
|
@@ -216,21 +254,44 @@ class MySQLConnection(BaseDatabaseConnection):
|
|
|
216
254
|
await self._pool.wait_closed()
|
|
217
255
|
self._pool = None
|
|
218
256
|
|
|
219
|
-
async def execute_query(
|
|
257
|
+
async def execute_query(
|
|
258
|
+
self, query: str, *args, timeout: float | None = None
|
|
259
|
+
) -> list[dict[str, Any]]:
|
|
220
260
|
"""Execute a query and return results as list of dicts.
|
|
221
261
|
|
|
222
262
|
All queries run in a transaction that is rolled back at the end,
|
|
223
263
|
ensuring no changes are persisted to the database.
|
|
224
264
|
"""
|
|
265
|
+
effective_timeout = timeout or DEFAULT_QUERY_TIMEOUT
|
|
225
266
|
pool = await self.get_pool()
|
|
267
|
+
|
|
226
268
|
async with pool.acquire() as conn:
|
|
227
269
|
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
|
228
270
|
# Start transaction
|
|
229
271
|
await conn.begin()
|
|
230
272
|
try:
|
|
231
|
-
|
|
232
|
-
|
|
273
|
+
# Set server-side timeout if specified
|
|
274
|
+
if effective_timeout:
|
|
275
|
+
await cursor.execute(
|
|
276
|
+
f"SET SESSION MAX_EXECUTION_TIME = {int(effective_timeout * 1000)}"
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
# Execute query with client-side timeout
|
|
280
|
+
if effective_timeout:
|
|
281
|
+
await asyncio.wait_for(
|
|
282
|
+
cursor.execute(query, args if args else None),
|
|
283
|
+
timeout=effective_timeout,
|
|
284
|
+
)
|
|
285
|
+
rows = await asyncio.wait_for(
|
|
286
|
+
cursor.fetchall(), timeout=effective_timeout
|
|
287
|
+
)
|
|
288
|
+
else:
|
|
289
|
+
await cursor.execute(query, args if args else None)
|
|
290
|
+
rows = await cursor.fetchall()
|
|
291
|
+
|
|
233
292
|
return [dict(row) for row in rows]
|
|
293
|
+
except asyncio.TimeoutError as exc:
|
|
294
|
+
raise QueryTimeoutError(effective_timeout or 0) from exc
|
|
234
295
|
finally:
|
|
235
296
|
# Always rollback to ensure no changes are committed
|
|
236
297
|
await conn.rollback()
|
|
@@ -252,12 +313,16 @@ class SQLiteConnection(BaseDatabaseConnection):
|
|
|
252
313
|
"""SQLite connections are created per query, no persistent pool to close."""
|
|
253
314
|
pass
|
|
254
315
|
|
|
255
|
-
async def execute_query(
|
|
316
|
+
async def execute_query(
|
|
317
|
+
self, query: str, *args, timeout: float | None = None
|
|
318
|
+
) -> list[dict[str, Any]]:
|
|
256
319
|
"""Execute a query and return results as list of dicts.
|
|
257
320
|
|
|
258
321
|
All queries run in a transaction that is rolled back at the end,
|
|
259
322
|
ensuring no changes are persisted to the database.
|
|
260
323
|
"""
|
|
324
|
+
effective_timeout = timeout or DEFAULT_QUERY_TIMEOUT
|
|
325
|
+
|
|
261
326
|
async with aiosqlite.connect(self.database_path) as conn:
|
|
262
327
|
# Enable row factory for dict-like access
|
|
263
328
|
conn.row_factory = aiosqlite.Row
|
|
@@ -265,9 +330,22 @@ class SQLiteConnection(BaseDatabaseConnection):
|
|
|
265
330
|
# Start transaction
|
|
266
331
|
await conn.execute("BEGIN")
|
|
267
332
|
try:
|
|
268
|
-
|
|
269
|
-
|
|
333
|
+
# Execute query with client-side timeout (SQLite has no server-side timeout)
|
|
334
|
+
if effective_timeout:
|
|
335
|
+
cursor = await asyncio.wait_for(
|
|
336
|
+
conn.execute(query, args if args else ()),
|
|
337
|
+
timeout=effective_timeout,
|
|
338
|
+
)
|
|
339
|
+
rows = await asyncio.wait_for(
|
|
340
|
+
cursor.fetchall(), timeout=effective_timeout
|
|
341
|
+
)
|
|
342
|
+
else:
|
|
343
|
+
cursor = await conn.execute(query, args if args else ())
|
|
344
|
+
rows = await cursor.fetchall()
|
|
345
|
+
|
|
270
346
|
return [dict(row) for row in rows]
|
|
347
|
+
except asyncio.TimeoutError as exc:
|
|
348
|
+
raise QueryTimeoutError(effective_timeout or 0) from exc
|
|
271
349
|
finally:
|
|
272
350
|
# Always rollback to ensure no changes are committed
|
|
273
351
|
await conn.rollback()
|
|
@@ -383,20 +461,35 @@ class CSVConnection(BaseDatabaseConnection):
|
|
|
383
461
|
except Exception as e:
|
|
384
462
|
raise ValueError(f"Error loading CSV file '{self.csv_path}': {str(e)}")
|
|
385
463
|
|
|
386
|
-
async def execute_query(
|
|
464
|
+
async def execute_query(
|
|
465
|
+
self, query: str, *args, timeout: float | None = None
|
|
466
|
+
) -> list[dict[str, Any]]:
|
|
387
467
|
"""Execute a query and return results as list of dicts.
|
|
388
468
|
|
|
389
469
|
All queries run in a transaction that is rolled back at the end,
|
|
390
470
|
ensuring no changes are persisted to the database.
|
|
391
471
|
"""
|
|
472
|
+
effective_timeout = timeout or DEFAULT_QUERY_TIMEOUT
|
|
392
473
|
conn = await self.get_pool()
|
|
393
474
|
|
|
394
475
|
# Start transaction
|
|
395
476
|
await conn.execute("BEGIN")
|
|
396
477
|
try:
|
|
397
|
-
|
|
398
|
-
|
|
478
|
+
# Execute query with client-side timeout (CSV uses in-memory SQLite)
|
|
479
|
+
if effective_timeout:
|
|
480
|
+
cursor = await asyncio.wait_for(
|
|
481
|
+
conn.execute(query, args if args else ()), timeout=effective_timeout
|
|
482
|
+
)
|
|
483
|
+
rows = await asyncio.wait_for(
|
|
484
|
+
cursor.fetchall(), timeout=effective_timeout
|
|
485
|
+
)
|
|
486
|
+
else:
|
|
487
|
+
cursor = await conn.execute(query, args if args else ())
|
|
488
|
+
rows = await cursor.fetchall()
|
|
489
|
+
|
|
399
490
|
return [dict(row) for row in rows]
|
|
491
|
+
except asyncio.TimeoutError as exc:
|
|
492
|
+
raise QueryTimeoutError(effective_timeout or 0) from exc
|
|
400
493
|
finally:
|
|
401
494
|
# Always rollback to ensure no changes are committed
|
|
402
495
|
await conn.rollback()
|
|
@@ -3,17 +3,17 @@ sqlsaber/__main__.py,sha256=RIHxWeWh2QvLfah-2OkhI5IJxojWfy4fXpMnVEJYvxw,78
|
|
|
3
3
|
sqlsaber/agents/__init__.py,sha256=i_MI2eWMQaVzGikKU71FPCmSQxNDKq36Imq1PrYoIPU,130
|
|
4
4
|
sqlsaber/agents/base.py,sha256=7zOZTHKxUuU0uMc-NTaCkkBfDnU3jtwbT8_eP1ZtJ2k,2615
|
|
5
5
|
sqlsaber/agents/mcp.py,sha256=GcJTx7YDYH6aaxIADEIxSgcWAdWakUx395JIzVnf17U,768
|
|
6
|
-
sqlsaber/agents/pydantic_ai_agent.py,sha256=
|
|
6
|
+
sqlsaber/agents/pydantic_ai_agent.py,sha256=6RvG2O7G8P6NN9QaRXUodg5Q26QJ4ShGWoTGYbVQ5K4,7065
|
|
7
7
|
sqlsaber/cli/__init__.py,sha256=qVSLVJLLJYzoC6aj6y9MFrzZvAwc4_OgxU9DlkQnZ4M,86
|
|
8
8
|
sqlsaber/cli/auth.py,sha256=jTsRgbmlGPlASSuIKmdjjwfqtKvjfKd_cTYxX0-QqaQ,7400
|
|
9
9
|
sqlsaber/cli/commands.py,sha256=mjLG9i1bXf0TEroxkIxq5O7Hhjufz3Ad72cyJz7vE1k,8128
|
|
10
10
|
sqlsaber/cli/completers.py,sha256=HsUPjaZweLSeYCWkAcgMl8FylQ1xjWBWYTEL_9F6xfU,6430
|
|
11
11
|
sqlsaber/cli/database.py,sha256=JKtHSN-BFzBa14REf0phFVQB7d67m1M5FFaD8N6DdrY,12966
|
|
12
|
-
sqlsaber/cli/display.py,sha256=
|
|
13
|
-
sqlsaber/cli/interactive.py,sha256=
|
|
12
|
+
sqlsaber/cli/display.py,sha256=bul9Yzw8KFYkof-kDzeajpx2TtG9CjTaUiwWaTv95dQ,14293
|
|
13
|
+
sqlsaber/cli/interactive.py,sha256=7uM4LoXbhPJr8o5yNjICSzL0uxZkp1psWrVq4G9V0OI,13118
|
|
14
14
|
sqlsaber/cli/memory.py,sha256=OufHFJFwV0_GGn7LvKRTJikkWhV1IwNIUDOxFPHXOaQ,7794
|
|
15
15
|
sqlsaber/cli/models.py,sha256=ZewtwGQwhd9b-yxBAPKePolvI1qQG-EkmeWAGMqtWNQ,8986
|
|
16
|
-
sqlsaber/cli/streaming.py,sha256=
|
|
16
|
+
sqlsaber/cli/streaming.py,sha256=BeG7H38-I1n8b9R8XSBV-IqkxDRZhsWFW6sdvtbVi3o,6879
|
|
17
17
|
sqlsaber/cli/threads.py,sha256=XUnLcCUe2wa_85IKdKmryqfiHTQu_IylET2Qo8oy1nk,11324
|
|
18
18
|
sqlsaber/config/__init__.py,sha256=olwC45k8Nc61yK0WmPUk7XHdbsZH9HuUAbwnmKe3IgA,100
|
|
19
19
|
sqlsaber/config/api_keys.py,sha256=RqWQCko1tY7sES7YOlexgBH5Hd5ne_kGXHdBDNqcV2U,3649
|
|
@@ -24,7 +24,7 @@ sqlsaber/config/oauth_tokens.py,sha256=C9z35hyx-PvSAYdC1LNf3rg9_wsEIY56hkEczelba
|
|
|
24
24
|
sqlsaber/config/providers.py,sha256=JFjeJv1K5Q93zWSlWq3hAvgch1TlgoF0qFa0KJROkKY,2957
|
|
25
25
|
sqlsaber/config/settings.py,sha256=vgb_RXaM-7DgbxYDmWNw1cSyMqwys4j3qNCvM4bljwI,5586
|
|
26
26
|
sqlsaber/database/__init__.py,sha256=a_gtKRJnZVO8-fEZI7g3Z8YnGa6Nio-5Y50PgVp07ss,176
|
|
27
|
-
sqlsaber/database/connection.py,sha256=
|
|
27
|
+
sqlsaber/database/connection.py,sha256=1bDPEa6cmdh87gPfhNeBLpOdI0E2_2KlE74q_-4l_jI,18913
|
|
28
28
|
sqlsaber/database/resolver.py,sha256=RPXF5EoKzvQDDLmPGNHYd2uG_oNICH8qvUjBp6iXmNY,3348
|
|
29
29
|
sqlsaber/database/schema.py,sha256=r12qoN3tdtAXdO22EKlauAe7QwOm8lL2vTMM59XEMMY,26594
|
|
30
30
|
sqlsaber/mcp/__init__.py,sha256=COdWq7wauPBp5Ew8tfZItFzbcLDSEkHBJSMhxzy8C9c,112
|
|
@@ -40,8 +40,8 @@ sqlsaber/tools/enums.py,sha256=CH32mL-0k9ZA18911xLpNtsgpV6tB85TktMj6uqGz54,411
|
|
|
40
40
|
sqlsaber/tools/instructions.py,sha256=X-x8maVkkyi16b6Tl0hcAFgjiYceZaSwyWTfmrvx8U8,9024
|
|
41
41
|
sqlsaber/tools/registry.py,sha256=HWOQMsNIdL4XZS6TeNUyrL-5KoSDH6PHsWd3X66o-18,3211
|
|
42
42
|
sqlsaber/tools/sql_tools.py,sha256=hM6tKqW5MDhFUt6MesoqhTUqIpq_5baIIDoN1MjDCXY,9647
|
|
43
|
-
sqlsaber-0.
|
|
44
|
-
sqlsaber-0.
|
|
45
|
-
sqlsaber-0.
|
|
46
|
-
sqlsaber-0.
|
|
47
|
-
sqlsaber-0.
|
|
43
|
+
sqlsaber-0.22.0.dist-info/METADATA,sha256=T9TBoCGfPVrZKM-RnUVROqOKaBaU1KJdYKqh3a8Arr8,6178
|
|
44
|
+
sqlsaber-0.22.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
45
|
+
sqlsaber-0.22.0.dist-info/entry_points.txt,sha256=qEbOB7OffXPFgyJc7qEIJlMEX5RN9xdzLmWZa91zCQQ,162
|
|
46
|
+
sqlsaber-0.22.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
47
|
+
sqlsaber-0.22.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|