sqlsaber 0.20.0__py3-none-any.whl → 0.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlsaber might be problematic. Click here for more details.

@@ -8,6 +8,7 @@ import httpx
8
8
  from pydantic_ai import Agent, RunContext
9
9
  from pydantic_ai.models.anthropic import AnthropicModel
10
10
  from pydantic_ai.models.google import GoogleModel
11
+ from pydantic_ai.models.openai import OpenAIResponsesModel
11
12
  from pydantic_ai.providers.anthropic import AnthropicProvider
12
13
  from pydantic_ai.providers.google import GoogleProvider
13
14
 
@@ -79,6 +80,10 @@ def build_sqlsaber_agent(
79
80
  provider_obj = AnthropicProvider(api_key="placeholder", http_client=http_client)
80
81
  model_obj = AnthropicModel(model_name_only, provider=provider_obj)
81
82
  agent = Agent(model_obj, name="sqlsaber")
83
+ elif provider == "openai":
84
+ # Use OpenAI Responses Model for structured output capabilities
85
+ model_obj = OpenAIResponsesModel(model_name_only)
86
+ agent = Agent(model_obj, name="sqlsaber")
82
87
  else:
83
88
  agent = Agent(cfg.model_name, name="sqlsaber")
84
89
 
sqlsaber/cli/display.py CHANGED
@@ -1,12 +1,167 @@
1
- """Display utilities for the CLI interface."""
1
+ """Display utilities for the CLI interface.
2
+
3
+ All rendering occurs on the event loop thread.
4
+ Streaming segments use Live Markdown; transient status and SQL blocks are also
5
+ rendered with Live.
6
+ """
2
7
 
3
8
  import json
9
+ from typing import Sequence, Type
4
10
 
5
- from rich.console import Console
6
- from rich.markdown import Markdown
11
+ from pydantic_ai.messages import ModelResponsePart, TextPart
12
+ from rich.columns import Columns
13
+ from rich.console import Console, ConsoleOptions, RenderResult
14
+ from rich.live import Live
15
+ from rich.markdown import CodeBlock, Markdown
7
16
  from rich.panel import Panel
17
+ from rich.spinner import Spinner
8
18
  from rich.syntax import Syntax
9
19
  from rich.table import Table
20
+ from rich.text import Text
21
+
22
+
23
+ class _SimpleCodeBlock(CodeBlock):
24
+ def __rich_console__(
25
+ self, console: Console, options: ConsoleOptions
26
+ ) -> RenderResult:
27
+ code = str(self.text).rstrip()
28
+ yield Syntax(
29
+ code,
30
+ self.lexer_name,
31
+ theme=self.theme,
32
+ background_color="default",
33
+ word_wrap=True,
34
+ )
35
+
36
+
37
+ class LiveMarkdownRenderer:
38
+ """Handles Live markdown rendering with segment separation.
39
+
40
+ Supports different segment kinds: 'assistant', 'thinking', 'sql'.
41
+ Adds visible paragraph breaks between segments and renders code fences
42
+ with nicer formatting.
43
+ """
44
+
45
+ _patched_fences = False
46
+
47
+ def __init__(self, console: Console):
48
+ self.console = console
49
+ self._live: Live | None = None
50
+ self._status_live: Live | None = None
51
+ self._buffer: str = ""
52
+ self._current_kind: Type[ModelResponsePart] | None = None
53
+
54
+ def prepare_code_blocks(self) -> None:
55
+ """Patch rich Markdown fence rendering once for nicer code blocks."""
56
+ if LiveMarkdownRenderer._patched_fences:
57
+ return
58
+ # Guard with class check to avoid re-patching if already applied
59
+ if Markdown.elements.get("fence") is not _SimpleCodeBlock:
60
+ Markdown.elements["fence"] = _SimpleCodeBlock
61
+ LiveMarkdownRenderer._patched_fences = True
62
+
63
+ def ensure_segment(self, kind: Type[ModelResponsePart]) -> None:
64
+ """
65
+ Ensure a markdown Live segment is active for the given kind.
66
+
67
+ When switching kinds, end the previous segment and add a paragraph break.
68
+ """
69
+ # If a transient status is showing, clear it first (no paragraph break)
70
+ if self._status_live is not None:
71
+ self.end_status()
72
+ if self._live is not None and self._current_kind == kind:
73
+ return
74
+ if self._live is not None:
75
+ self.end()
76
+ self.paragraph_break()
77
+
78
+ self._start()
79
+ self._current_kind = kind
80
+
81
+ def append(self, text: str | None) -> None:
82
+ """Append text to the current markdown segment and refresh."""
83
+ if not text:
84
+ return
85
+ if self._live is None:
86
+ # default to assistant if no segment was ensured
87
+ self.ensure_segment(TextPart)
88
+
89
+ self._buffer += text
90
+ self._live.update(Markdown(self._buffer))
91
+
92
+ def end(self) -> None:
93
+ """Finalize and stop the current Live segment, if any."""
94
+ if self._live is None:
95
+ return
96
+ if self._buffer:
97
+ self._live.update(Markdown(self._buffer))
98
+ self._live.stop()
99
+ self._live = None
100
+ self._buffer = ""
101
+ self._current_kind = None
102
+
103
+ def end_if_active(self) -> None:
104
+ self.end()
105
+
106
+ def paragraph_break(self) -> None:
107
+ self.console.print()
108
+
109
+ def start_sql_block(self, sql: str) -> None:
110
+ """Render a SQL block using a transient Live markdown segment."""
111
+ if not sql or not isinstance(sql, str) or not sql.strip():
112
+ return
113
+ # Separate from surrounding content
114
+ self.end_if_active()
115
+ self.paragraph_break()
116
+ self._buffer = f"```sql\n{sql}\n```"
117
+ # Use context manager to auto-stop and persist final render
118
+ with Live(
119
+ Markdown(self._buffer),
120
+ console=self.console,
121
+ vertical_overflow="visible",
122
+ refresh_per_second=12,
123
+ ):
124
+ pass
125
+
126
+ def start_status(self, message: str = "Crunching data...") -> None:
127
+ """Show a transient status line with a spinner until streaming starts."""
128
+ if self._status_live is not None:
129
+ # Update existing status text
130
+ self._status_live.update(self._status_renderable(message))
131
+ return
132
+ live = Live(
133
+ self._status_renderable(message),
134
+ console=self.console,
135
+ transient=True, # disappear when stopped
136
+ refresh_per_second=12,
137
+ )
138
+ self._status_live = live
139
+ live.start()
140
+
141
+ def end_status(self) -> None:
142
+ live = self._status_live
143
+ if live is None:
144
+ return
145
+ live.stop()
146
+ self._status_live = None
147
+
148
+ def _status_renderable(self, message: str):
149
+ spinner = Spinner("dots", style="yellow")
150
+ text = Text(f" {message}", style="yellow")
151
+ return Columns([spinner, text], expand=False)
152
+
153
+ def _start(self, initial_markdown: str = "") -> None:
154
+ if self._live is not None:
155
+ self.end()
156
+ self._buffer = initial_markdown or ""
157
+ live = Live(
158
+ Markdown(self._buffer),
159
+ console=self.console,
160
+ vertical_overflow="visible",
161
+ refresh_per_second=12,
162
+ )
163
+ self._live = live
164
+ live.start()
10
165
 
11
166
 
12
167
  class DisplayManager:
@@ -14,10 +169,11 @@ class DisplayManager:
14
169
 
15
170
  def __init__(self, console: Console):
16
171
  self.console = console
172
+ self.live = LiveMarkdownRenderer(console)
17
173
 
18
174
  def _create_table(
19
175
  self,
20
- columns: list,
176
+ columns: Sequence[str | dict[str, str]],
21
177
  header_style: str = "bold blue",
22
178
  title: str | None = None,
23
179
  ) -> Table:
@@ -34,17 +190,24 @@ class DisplayManager:
34
190
 
35
191
  def show_tool_executing(self, tool_name: str, tool_input: dict):
36
192
  """Display tool execution details."""
37
- self.console.print(f"\n[yellow]🔧 Using tool: {tool_name}[/yellow]")
193
+ # Normalized leading blank line before tool headers
194
+ self.show_newline()
38
195
  if tool_name == "list_tables":
39
- self.console.print("[dim] → Discovering available tables[/dim]")
196
+ self.console.print(
197
+ "[dim bold]:gear: Discovering available tables[/dim bold]"
198
+ )
40
199
  elif tool_name == "introspect_schema":
41
200
  pattern = tool_input.get("table_pattern", "all tables")
42
- self.console.print(f"[dim] → Examining schema for: {pattern}[/dim]")
201
+ self.console.print(
202
+ f"[dim bold]:gear: Examining schema for: {pattern}[/dim bold]"
203
+ )
43
204
  elif tool_name == "execute_sql":
205
+ # For streaming, we render SQL via LiveMarkdownRenderer; keep Syntax
206
+ # rendering for threads show/resume. Controlled by include_sql flag.
44
207
  query = tool_input.get("query", "")
45
- self.console.print("\n[bold green]Executing SQL:[/bold green]")
208
+ self.console.print("[dim bold]:gear: Executing SQL:[/dim bold]")
46
209
  self.show_newline()
47
- syntax = Syntax(query, "sql")
210
+ syntax = Syntax(query, "sql", background_color="default", word_wrap=True)
48
211
  self.console.print(syntax)
49
212
 
50
213
  def show_text_stream(self, text: str):
@@ -99,10 +262,12 @@ class DisplayManager:
99
262
  """Display a newline for spacing."""
100
263
  self.console.print()
101
264
 
102
- def show_table_list(self, tables_data: str):
265
+ def show_table_list(self, tables_data: str | dict):
103
266
  """Display the results from list_tables tool."""
104
267
  try:
105
- data = json.loads(tables_data)
268
+ data = (
269
+ json.loads(tables_data) if isinstance(tables_data, str) else tables_data
270
+ )
106
271
 
107
272
  # Handle error case
108
273
  if "error" in data:
@@ -143,10 +308,12 @@ class DisplayManager:
143
308
  except Exception as e:
144
309
  self.show_error(f"Error displaying table list: {str(e)}")
145
310
 
146
- def show_schema_info(self, schema_data: str):
311
+ def show_schema_info(self, schema_data: str | dict):
147
312
  """Display the results from introspect_schema tool."""
148
313
  try:
149
- data = json.loads(schema_data)
314
+ data = (
315
+ json.loads(schema_data) if isinstance(schema_data, str) else schema_data
316
+ )
150
317
 
151
318
  # Handle error case
152
319
  if "error" in data:
@@ -2,6 +2,7 @@
2
2
 
3
3
  import asyncio
4
4
  from pathlib import Path
5
+ from textwrap import dedent
5
6
 
6
7
  import platformdirs
7
8
  from prompt_toolkit import PromptSession
@@ -10,6 +11,7 @@ from prompt_toolkit.patch_stdout import patch_stdout
10
11
  from prompt_toolkit.styles import Style
11
12
  from pydantic_ai import Agent
12
13
  from rich.console import Console
14
+ from rich.markdown import Markdown
13
15
  from rich.panel import Panel
14
16
 
15
17
  from sqlsaber.cli.completers import (
@@ -102,14 +104,14 @@ class InteractiveSession:
102
104
  )
103
105
  )
104
106
  self.console.print(
105
- "\n",
106
- "[dim] > Use '/clear' to reset conversation",
107
- "[dim] > Use 'Ctrl+D', '/exit' or '/quit' to leave[/dim]",
108
- "[dim] > Use 'Ctrl+C' to interrupt and return to prompt\n\n",
109
- "[dim] > Start message with '#' to add something to agent's memory for this database",
110
- "[dim] > Type '@' to get table name completions",
111
- "[dim] > Press 'Esc-Enter' or 'Meta-Enter' to submit your question",
112
- sep="\n",
107
+ Markdown(
108
+ dedent("""
109
+ - Use `/` for slash commands
110
+ - Type `@` to get table name completions
111
+ - Start message with `#` to add something to agent's memory
112
+ - Use `Ctrl+C` to interrupt and `Ctrl+D` to exit
113
+ """)
114
+ )
113
115
  )
114
116
 
115
117
  self.console.print(
sqlsaber/cli/streaming.py CHANGED
@@ -1,7 +1,13 @@
1
- """Streaming query handling for the CLI (pydantic-ai based)."""
1
+ """Streaming query handling for the CLI (pydantic-ai based).
2
+
3
+ This module uses DisplayManager's LiveMarkdownRenderer to stream Markdown
4
+ incrementally as the agent outputs tokens. Tool calls and results are
5
+ rendered via DisplayManager helpers.
6
+ """
2
7
 
3
8
  import asyncio
4
9
  import json
10
+ from functools import singledispatchmethod
5
11
  from typing import AsyncIterable
6
12
 
7
13
  from pydantic_ai import Agent, RunContext
@@ -22,56 +28,98 @@ from sqlsaber.cli.display import DisplayManager
22
28
 
23
29
 
24
30
  class StreamingQueryHandler:
25
- """Handles streaming query execution and display using pydantic-ai events."""
31
+ """
32
+ Handles streaming query execution and display using pydantic-ai events.
33
+
34
+ Uses DisplayManager.live to render Markdown incrementally as text streams in.
35
+ """
26
36
 
27
37
  def __init__(self, console: Console):
28
38
  self.console = console
29
39
  self.display = DisplayManager(console)
30
40
 
31
- self.status = self.console.status(
32
- "[yellow]Crunching data...[/yellow]", spinner="bouncingBall"
33
- )
34
-
35
41
  async def _event_stream_handler(
36
42
  self, ctx: RunContext, event_stream: AsyncIterable[AgentStreamEvent]
37
43
  ) -> None:
44
+ """
45
+ Handle pydantic-ai streaming events and update Live Markdown via DisplayManager.
46
+ """
47
+
38
48
  async for event in event_stream:
39
- if isinstance(event, PartStartEvent):
40
- if isinstance(event.part, (TextPart, ThinkingPart)):
41
- self.status.stop()
42
- self.display.show_text_stream(event.part.content)
43
-
44
- elif isinstance(event, PartDeltaEvent):
45
- if isinstance(event.delta, (TextPartDelta, ThinkingPartDelta)):
46
- delta = event.delta.content_delta or ""
47
- if delta:
48
- self.status.stop()
49
- self.display.show_text_stream(delta)
50
-
51
- elif isinstance(event, FunctionToolCallEvent):
52
- # Show tool execution start
53
- self.status.stop()
54
- args = event.part.args_as_dict()
55
- self.display.show_newline()
56
- self.display.show_tool_executing(event.part.tool_name, args)
57
-
58
- elif isinstance(event, FunctionToolResultEvent):
59
- self.status.stop()
60
- # Route tool result to appropriate display
61
- tool_name = event.result.tool_name
62
- content = event.result.content
63
- if tool_name == "list_tables":
64
- self.display.show_table_list(content)
65
- elif tool_name == "introspect_schema":
66
- self.display.show_schema_info(content)
67
- elif tool_name == "execute_sql":
49
+ await self.on_event(event, ctx)
50
+
51
+ # --- Event routing via singledispatchmethod ---------------------------------------
52
+ @singledispatchmethod
53
+ async def on_event(
54
+ self, event: AgentStreamEvent, ctx: RunContext
55
+ ) -> None: # default
56
+ return
57
+
58
+ @on_event.register
59
+ async def _(self, event: PartStartEvent, ctx: RunContext) -> None:
60
+ if isinstance(event.part, TextPart):
61
+ self.display.live.ensure_segment(TextPart)
62
+ self.display.live.append(event.part.content)
63
+ elif isinstance(event.part, ThinkingPart):
64
+ self.display.live.ensure_segment(ThinkingPart)
65
+ self.display.live.append(event.part.content)
66
+
67
+ @on_event.register
68
+ async def _(self, event: PartDeltaEvent, ctx: RunContext) -> None:
69
+ d = event.delta
70
+ if isinstance(d, TextPartDelta):
71
+ delta = d.content_delta or ""
72
+ if delta:
73
+ self.display.live.ensure_segment(TextPart)
74
+ self.display.live.append(delta)
75
+ elif isinstance(d, ThinkingPartDelta):
76
+ delta = d.content_delta or ""
77
+ if delta:
78
+ self.display.live.ensure_segment(ThinkingPart)
79
+ self.display.live.append(delta)
80
+
81
+ @on_event.register
82
+ async def _(self, event: FunctionToolCallEvent, ctx: RunContext) -> None:
83
+ # Clear any status/markdown Live so tool output sits between
84
+ self.display.live.end_status()
85
+ self.display.live.end_if_active()
86
+ args = event.part.args_as_dict()
87
+
88
+ # Special handling: display SQL via Live as markdown code block
89
+ if event.part.tool_name == "execute_sql":
90
+ query = args.get("query") or ""
91
+ if isinstance(query, str) and query.strip():
92
+ self.display.live.start_sql_block(query)
93
+ else:
94
+ self.display.show_tool_executing(event.part.tool_name, args)
95
+
96
+ @on_event.register
97
+ async def _(self, event: FunctionToolResultEvent, ctx: RunContext) -> None:
98
+ # Route tool result to appropriate display
99
+ tool_name = event.result.tool_name
100
+ content = event.result.content
101
+ if tool_name == "list_tables":
102
+ self.display.show_table_list(content)
103
+ elif tool_name == "introspect_schema":
104
+ self.display.show_schema_info(content)
105
+ elif tool_name == "execute_sql":
106
+ data = {}
107
+ if isinstance(content, str):
108
+ try:
109
+ data = json.loads(content)
110
+ except (json.JSONDecodeError, TypeError) as exc:
68
111
  try:
69
- data = json.loads(content)
70
- if data.get("success") and data.get("results"):
71
- self.display.show_query_results(data["results"]) # type: ignore[arg-type]
72
- except json.JSONDecodeError:
73
- # If not JSON, ignore here
112
+ self.console.log(f"Malformed execute_sql result: {exc}")
113
+ except Exception:
74
114
  pass
115
+ elif isinstance(content, dict):
116
+ data = content
117
+ if isinstance(data, dict) and data.get("success") and data.get("results"):
118
+ self.display.show_query_results(data["results"]) # type: ignore[arg-type]
119
+ # Add a blank line after tool output to separate from next segment
120
+ self.display.show_newline()
121
+ # Show status while agent sends a follow-up request to the model
122
+ self.display.live.start_status("Crunching data...")
75
123
 
76
124
  async def execute_streaming_query(
77
125
  self,
@@ -80,7 +128,8 @@ class StreamingQueryHandler:
80
128
  cancellation_token: asyncio.Event | None = None,
81
129
  message_history: list | None = None,
82
130
  ):
83
- self.status.start()
131
+ # Prepare nicer code block rendering for Markdown
132
+ self.display.live.prepare_code_blocks()
84
133
  try:
85
134
  # If Anthropic OAuth, inject SQLsaber instructions before the first user prompt
86
135
  prepared_prompt: str | list[str] = user_query
@@ -104,30 +153,24 @@ class StreamingQueryHandler:
104
153
  injected = "\n\n".join(parts)
105
154
  prepared_prompt = [injected, user_query]
106
155
 
156
+ # Show a transient status until events start streaming
157
+ self.display.live.start_status("Crunching data...")
158
+
107
159
  # Run the agent with our event stream handler
108
160
  run = await agent.run(
109
161
  prepared_prompt,
110
162
  message_history=message_history,
111
163
  event_stream_handler=self._event_stream_handler,
112
164
  )
113
- # After the run completes, show the assistant's final text as markdown if available
114
- try:
115
- output = run.output
116
- if isinstance(output, str) and output.strip():
117
- self.display.show_newline()
118
- self.display.show_markdown_response(
119
- [{"type": "text", "text": output}]
120
- )
121
- except Exception as e:
122
- self.display.show_error(str(e))
123
- self.display.show_newline()
124
165
  return run
125
166
  except asyncio.CancelledError:
167
+ # Show interruption message outside of Live
126
168
  self.display.show_newline()
127
169
  self.console.print("[yellow]Query interrupted[/yellow]")
128
170
  return None
129
171
  finally:
172
+ # End any active status and live markdown segments
130
173
  try:
131
- self.status.stop()
132
- except Exception:
133
- pass
174
+ self.display.live.end_status()
175
+ finally:
176
+ self.display.live.end_if_active()
@@ -1,5 +1,6 @@
1
1
  """Database connection management."""
2
2
 
3
+ import asyncio
3
4
  import ssl
4
5
  from abc import ABC, abstractmethod
5
6
  from pathlib import Path
@@ -10,6 +11,17 @@ import aiomysql
10
11
  import aiosqlite
11
12
  import asyncpg
12
13
 
14
+ # Default query timeout to prevent runaway queries
15
+ DEFAULT_QUERY_TIMEOUT = 30.0 # seconds
16
+
17
+
18
+ class QueryTimeoutError(RuntimeError):
19
+ """Exception raised when a query exceeds its timeout."""
20
+
21
+ def __init__(self, seconds: float):
22
+ self.timeout = seconds
23
+ super().__init__(f"Query exceeded timeout of {seconds}s")
24
+
13
25
 
14
26
  class BaseDatabaseConnection(ABC):
15
27
  """Abstract base class for database connections."""
@@ -29,11 +41,18 @@ class BaseDatabaseConnection(ABC):
29
41
  pass
30
42
 
31
43
  @abstractmethod
32
- async def execute_query(self, query: str, *args) -> list[dict[str, Any]]:
44
+ async def execute_query(
45
+ self, query: str, *args, timeout: float | None = None
46
+ ) -> list[dict[str, Any]]:
33
47
  """Execute a query and return results as list of dicts.
34
48
 
35
49
  All queries run in a transaction that is rolled back at the end,
36
50
  ensuring no changes are persisted to the database.
51
+
52
+ Args:
53
+ query: SQL query to execute
54
+ *args: Query parameters
55
+ timeout: Query timeout in seconds (overrides default_timeout)
37
56
  """
38
57
  pass
39
58
 
@@ -111,21 +130,40 @@ class PostgreSQLConnection(BaseDatabaseConnection):
111
130
  await self._pool.close()
112
131
  self._pool = None
113
132
 
114
- async def execute_query(self, query: str, *args) -> list[dict[str, Any]]:
133
+ async def execute_query(
134
+ self, query: str, *args, timeout: float | None = None
135
+ ) -> list[dict[str, Any]]:
115
136
  """Execute a query and return results as list of dicts.
116
137
 
117
138
  All queries run in a transaction that is rolled back at the end,
118
139
  ensuring no changes are persisted to the database.
119
140
  """
141
+ effective_timeout = timeout or DEFAULT_QUERY_TIMEOUT
120
142
  pool = await self.get_pool()
143
+
121
144
  async with pool.acquire() as conn:
122
145
  # Start a transaction that we'll always rollback
123
146
  transaction = conn.transaction()
124
147
  await transaction.start()
125
148
 
126
149
  try:
127
- rows = await conn.fetch(query, *args)
150
+ # Set server-side timeout if specified
151
+ if effective_timeout:
152
+ await conn.execute(
153
+ f"SET LOCAL statement_timeout = {int(effective_timeout * 1000)}"
154
+ )
155
+
156
+ # Execute query with client-side timeout
157
+ if effective_timeout:
158
+ rows = await asyncio.wait_for(
159
+ conn.fetch(query, *args), timeout=effective_timeout
160
+ )
161
+ else:
162
+ rows = await conn.fetch(query, *args)
163
+
128
164
  return [dict(row) for row in rows]
165
+ except asyncio.TimeoutError as exc:
166
+ raise QueryTimeoutError(effective_timeout or 0) from exc
129
167
  finally:
130
168
  # Always rollback to ensure no changes are committed
131
169
  await transaction.rollback()
@@ -216,21 +254,44 @@ class MySQLConnection(BaseDatabaseConnection):
216
254
  await self._pool.wait_closed()
217
255
  self._pool = None
218
256
 
219
- async def execute_query(self, query: str, *args) -> list[dict[str, Any]]:
257
+ async def execute_query(
258
+ self, query: str, *args, timeout: float | None = None
259
+ ) -> list[dict[str, Any]]:
220
260
  """Execute a query and return results as list of dicts.
221
261
 
222
262
  All queries run in a transaction that is rolled back at the end,
223
263
  ensuring no changes are persisted to the database.
224
264
  """
265
+ effective_timeout = timeout or DEFAULT_QUERY_TIMEOUT
225
266
  pool = await self.get_pool()
267
+
226
268
  async with pool.acquire() as conn:
227
269
  async with conn.cursor(aiomysql.DictCursor) as cursor:
228
270
  # Start transaction
229
271
  await conn.begin()
230
272
  try:
231
- await cursor.execute(query, args if args else None)
232
- rows = await cursor.fetchall()
273
+ # Set server-side timeout if specified
274
+ if effective_timeout:
275
+ await cursor.execute(
276
+ f"SET SESSION MAX_EXECUTION_TIME = {int(effective_timeout * 1000)}"
277
+ )
278
+
279
+ # Execute query with client-side timeout
280
+ if effective_timeout:
281
+ await asyncio.wait_for(
282
+ cursor.execute(query, args if args else None),
283
+ timeout=effective_timeout,
284
+ )
285
+ rows = await asyncio.wait_for(
286
+ cursor.fetchall(), timeout=effective_timeout
287
+ )
288
+ else:
289
+ await cursor.execute(query, args if args else None)
290
+ rows = await cursor.fetchall()
291
+
233
292
  return [dict(row) for row in rows]
293
+ except asyncio.TimeoutError as exc:
294
+ raise QueryTimeoutError(effective_timeout or 0) from exc
234
295
  finally:
235
296
  # Always rollback to ensure no changes are committed
236
297
  await conn.rollback()
@@ -252,12 +313,16 @@ class SQLiteConnection(BaseDatabaseConnection):
252
313
  """SQLite connections are created per query, no persistent pool to close."""
253
314
  pass
254
315
 
255
- async def execute_query(self, query: str, *args) -> list[dict[str, Any]]:
316
+ async def execute_query(
317
+ self, query: str, *args, timeout: float | None = None
318
+ ) -> list[dict[str, Any]]:
256
319
  """Execute a query and return results as list of dicts.
257
320
 
258
321
  All queries run in a transaction that is rolled back at the end,
259
322
  ensuring no changes are persisted to the database.
260
323
  """
324
+ effective_timeout = timeout or DEFAULT_QUERY_TIMEOUT
325
+
261
326
  async with aiosqlite.connect(self.database_path) as conn:
262
327
  # Enable row factory for dict-like access
263
328
  conn.row_factory = aiosqlite.Row
@@ -265,9 +330,22 @@ class SQLiteConnection(BaseDatabaseConnection):
265
330
  # Start transaction
266
331
  await conn.execute("BEGIN")
267
332
  try:
268
- cursor = await conn.execute(query, args if args else ())
269
- rows = await cursor.fetchall()
333
+ # Execute query with client-side timeout (SQLite has no server-side timeout)
334
+ if effective_timeout:
335
+ cursor = await asyncio.wait_for(
336
+ conn.execute(query, args if args else ()),
337
+ timeout=effective_timeout,
338
+ )
339
+ rows = await asyncio.wait_for(
340
+ cursor.fetchall(), timeout=effective_timeout
341
+ )
342
+ else:
343
+ cursor = await conn.execute(query, args if args else ())
344
+ rows = await cursor.fetchall()
345
+
270
346
  return [dict(row) for row in rows]
347
+ except asyncio.TimeoutError as exc:
348
+ raise QueryTimeoutError(effective_timeout or 0) from exc
271
349
  finally:
272
350
  # Always rollback to ensure no changes are committed
273
351
  await conn.rollback()
@@ -383,20 +461,35 @@ class CSVConnection(BaseDatabaseConnection):
383
461
  except Exception as e:
384
462
  raise ValueError(f"Error loading CSV file '{self.csv_path}': {str(e)}")
385
463
 
386
- async def execute_query(self, query: str, *args) -> list[dict[str, Any]]:
464
+ async def execute_query(
465
+ self, query: str, *args, timeout: float | None = None
466
+ ) -> list[dict[str, Any]]:
387
467
  """Execute a query and return results as list of dicts.
388
468
 
389
469
  All queries run in a transaction that is rolled back at the end,
390
470
  ensuring no changes are persisted to the database.
391
471
  """
472
+ effective_timeout = timeout or DEFAULT_QUERY_TIMEOUT
392
473
  conn = await self.get_pool()
393
474
 
394
475
  # Start transaction
395
476
  await conn.execute("BEGIN")
396
477
  try:
397
- cursor = await conn.execute(query, args if args else ())
398
- rows = await cursor.fetchall()
478
+ # Execute query with client-side timeout (CSV uses in-memory SQLite)
479
+ if effective_timeout:
480
+ cursor = await asyncio.wait_for(
481
+ conn.execute(query, args if args else ()), timeout=effective_timeout
482
+ )
483
+ rows = await asyncio.wait_for(
484
+ cursor.fetchall(), timeout=effective_timeout
485
+ )
486
+ else:
487
+ cursor = await conn.execute(query, args if args else ())
488
+ rows = await cursor.fetchall()
489
+
399
490
  return [dict(row) for row in rows]
491
+ except asyncio.TimeoutError as exc:
492
+ raise QueryTimeoutError(effective_timeout or 0) from exc
400
493
  finally:
401
494
  # Always rollback to ensure no changes are committed
402
495
  await conn.rollback()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sqlsaber
3
- Version: 0.20.0
3
+ Version: 0.22.0
4
4
  Summary: SQLsaber - Open-source agentic SQL assistant
5
5
  License-File: LICENSE
6
6
  Requires-Python: >=3.12
@@ -3,17 +3,17 @@ sqlsaber/__main__.py,sha256=RIHxWeWh2QvLfah-2OkhI5IJxojWfy4fXpMnVEJYvxw,78
3
3
  sqlsaber/agents/__init__.py,sha256=i_MI2eWMQaVzGikKU71FPCmSQxNDKq36Imq1PrYoIPU,130
4
4
  sqlsaber/agents/base.py,sha256=7zOZTHKxUuU0uMc-NTaCkkBfDnU3jtwbT8_eP1ZtJ2k,2615
5
5
  sqlsaber/agents/mcp.py,sha256=GcJTx7YDYH6aaxIADEIxSgcWAdWakUx395JIzVnf17U,768
6
- sqlsaber/agents/pydantic_ai_agent.py,sha256=dGdsgyxCZvfK-v-MH8KimKOr-xb2aSfSWY8CMcOUCT8,6795
6
+ sqlsaber/agents/pydantic_ai_agent.py,sha256=6RvG2O7G8P6NN9QaRXUodg5Q26QJ4ShGWoTGYbVQ5K4,7065
7
7
  sqlsaber/cli/__init__.py,sha256=qVSLVJLLJYzoC6aj6y9MFrzZvAwc4_OgxU9DlkQnZ4M,86
8
8
  sqlsaber/cli/auth.py,sha256=jTsRgbmlGPlASSuIKmdjjwfqtKvjfKd_cTYxX0-QqaQ,7400
9
9
  sqlsaber/cli/commands.py,sha256=mjLG9i1bXf0TEroxkIxq5O7Hhjufz3Ad72cyJz7vE1k,8128
10
10
  sqlsaber/cli/completers.py,sha256=HsUPjaZweLSeYCWkAcgMl8FylQ1xjWBWYTEL_9F6xfU,6430
11
11
  sqlsaber/cli/database.py,sha256=JKtHSN-BFzBa14REf0phFVQB7d67m1M5FFaD8N6DdrY,12966
12
- sqlsaber/cli/display.py,sha256=wa7BjTBwXwqLT145Q1AEL0C28pQJTrvDN10mnFMjqsg,8554
13
- sqlsaber/cli/interactive.py,sha256=suTZ-EvbaB21BsFsRc4MkjM89lZ2iJlYH4G1iYjW7PI,13213
12
+ sqlsaber/cli/display.py,sha256=bul9Yzw8KFYkof-kDzeajpx2TtG9CjTaUiwWaTv95dQ,14293
13
+ sqlsaber/cli/interactive.py,sha256=7uM4LoXbhPJr8o5yNjICSzL0uxZkp1psWrVq4G9V0OI,13118
14
14
  sqlsaber/cli/memory.py,sha256=OufHFJFwV0_GGn7LvKRTJikkWhV1IwNIUDOxFPHXOaQ,7794
15
15
  sqlsaber/cli/models.py,sha256=ZewtwGQwhd9b-yxBAPKePolvI1qQG-EkmeWAGMqtWNQ,8986
16
- sqlsaber/cli/streaming.py,sha256=WNqBYYbWtL5CNQkRg5YWhYpWKI8qz7JmqneB2DXTOHY,5259
16
+ sqlsaber/cli/streaming.py,sha256=BeG7H38-I1n8b9R8XSBV-IqkxDRZhsWFW6sdvtbVi3o,6879
17
17
  sqlsaber/cli/threads.py,sha256=XUnLcCUe2wa_85IKdKmryqfiHTQu_IylET2Qo8oy1nk,11324
18
18
  sqlsaber/config/__init__.py,sha256=olwC45k8Nc61yK0WmPUk7XHdbsZH9HuUAbwnmKe3IgA,100
19
19
  sqlsaber/config/api_keys.py,sha256=RqWQCko1tY7sES7YOlexgBH5Hd5ne_kGXHdBDNqcV2U,3649
@@ -24,7 +24,7 @@ sqlsaber/config/oauth_tokens.py,sha256=C9z35hyx-PvSAYdC1LNf3rg9_wsEIY56hkEczelba
24
24
  sqlsaber/config/providers.py,sha256=JFjeJv1K5Q93zWSlWq3hAvgch1TlgoF0qFa0KJROkKY,2957
25
25
  sqlsaber/config/settings.py,sha256=vgb_RXaM-7DgbxYDmWNw1cSyMqwys4j3qNCvM4bljwI,5586
26
26
  sqlsaber/database/__init__.py,sha256=a_gtKRJnZVO8-fEZI7g3Z8YnGa6Nio-5Y50PgVp07ss,176
27
- sqlsaber/database/connection.py,sha256=kwx18bnwr4kyTUfQT0OW-DXzJUNWIQJP54spJBqU_48,15243
27
+ sqlsaber/database/connection.py,sha256=1bDPEa6cmdh87gPfhNeBLpOdI0E2_2KlE74q_-4l_jI,18913
28
28
  sqlsaber/database/resolver.py,sha256=RPXF5EoKzvQDDLmPGNHYd2uG_oNICH8qvUjBp6iXmNY,3348
29
29
  sqlsaber/database/schema.py,sha256=r12qoN3tdtAXdO22EKlauAe7QwOm8lL2vTMM59XEMMY,26594
30
30
  sqlsaber/mcp/__init__.py,sha256=COdWq7wauPBp5Ew8tfZItFzbcLDSEkHBJSMhxzy8C9c,112
@@ -40,8 +40,8 @@ sqlsaber/tools/enums.py,sha256=CH32mL-0k9ZA18911xLpNtsgpV6tB85TktMj6uqGz54,411
40
40
  sqlsaber/tools/instructions.py,sha256=X-x8maVkkyi16b6Tl0hcAFgjiYceZaSwyWTfmrvx8U8,9024
41
41
  sqlsaber/tools/registry.py,sha256=HWOQMsNIdL4XZS6TeNUyrL-5KoSDH6PHsWd3X66o-18,3211
42
42
  sqlsaber/tools/sql_tools.py,sha256=hM6tKqW5MDhFUt6MesoqhTUqIpq_5baIIDoN1MjDCXY,9647
43
- sqlsaber-0.20.0.dist-info/METADATA,sha256=LJgfLWIWvI8ZgeLjZMLakKCmHn4EywL0MS9XeKLy63E,6178
44
- sqlsaber-0.20.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
45
- sqlsaber-0.20.0.dist-info/entry_points.txt,sha256=qEbOB7OffXPFgyJc7qEIJlMEX5RN9xdzLmWZa91zCQQ,162
46
- sqlsaber-0.20.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
47
- sqlsaber-0.20.0.dist-info/RECORD,,
43
+ sqlsaber-0.22.0.dist-info/METADATA,sha256=T9TBoCGfPVrZKM-RnUVROqOKaBaU1KJdYKqh3a8Arr8,6178
44
+ sqlsaber-0.22.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
45
+ sqlsaber-0.22.0.dist-info/entry_points.txt,sha256=qEbOB7OffXPFgyJc7qEIJlMEX5RN9xdzLmWZa91zCQQ,162
46
+ sqlsaber-0.22.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
47
+ sqlsaber-0.22.0.dist-info/RECORD,,