sqlsaber 0.25.0__py3-none-any.whl → 0.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlsaber might be problematic. Click here for more details.

@@ -3,13 +3,13 @@
3
3
  import asyncio
4
4
  from pathlib import Path
5
5
  from textwrap import dedent
6
+ from typing import TYPE_CHECKING
6
7
 
7
8
  import platformdirs
8
9
  from prompt_toolkit import PromptSession
9
10
  from prompt_toolkit.history import FileHistory
10
11
  from prompt_toolkit.patch_stdout import patch_stdout
11
12
  from prompt_toolkit.styles import Style
12
- from pydantic_ai import Agent
13
13
  from rich.console import Console
14
14
  from rich.markdown import Markdown
15
15
  from rich.panel import Panel
@@ -21,7 +21,7 @@ from sqlsaber.cli.completers import (
21
21
  )
22
22
  from sqlsaber.cli.display import DisplayManager
23
23
  from sqlsaber.cli.streaming import StreamingQueryHandler
24
- from sqlsaber.database.connection import (
24
+ from sqlsaber.database import (
25
25
  CSVConnection,
26
26
  DuckDBConnection,
27
27
  MySQLConnection,
@@ -31,6 +31,9 @@ from sqlsaber.database.connection import (
31
31
  from sqlsaber.database.schema import SchemaManager
32
32
  from sqlsaber.threads import ThreadStorage
33
33
 
34
+ if TYPE_CHECKING:
35
+ from sqlsaber.agents.pydantic_ai_agent import SQLSaberAgent
36
+
34
37
 
35
38
  def bottom_toolbar():
36
39
  return [
@@ -55,7 +58,7 @@ class InteractiveSession:
55
58
  def __init__(
56
59
  self,
57
60
  console: Console,
58
- agent: Agent,
61
+ sqlsaber_agent: "SQLSaberAgent",
59
62
  db_conn,
60
63
  database_name: str,
61
64
  *,
@@ -63,7 +66,7 @@ class InteractiveSession:
63
66
  initial_history: list | None = None,
64
67
  ):
65
68
  self.console = console
66
- self.agent = agent
69
+ self.sqlsaber_agent = sqlsaber_agent
67
70
  self.db_conn = db_conn
68
71
  self.database_name = database_name
69
72
  self.display = DisplayManager(console)
@@ -176,7 +179,7 @@ class InteractiveSession:
176
179
  query_task = asyncio.create_task(
177
180
  self.streaming_handler.execute_streaming_query(
178
181
  user_query,
179
- self.agent,
182
+ self.sqlsaber_agent,
180
183
  self.cancellation_token,
181
184
  self.message_history,
182
185
  )
@@ -191,11 +194,6 @@ class InteractiveSession:
191
194
  # Use all_messages() so the system prompt and all prior turns are preserved
192
195
  self.message_history = run_result.all_messages()
193
196
 
194
- # Extract title (first user prompt) and model name
195
- if not self._thread_id:
196
- title = user_query
197
- model_name = self.agent.model.model_name
198
-
199
197
  # Persist snapshot to thread storage (create or overwrite)
200
198
  self._thread_id = await self._threads.save_snapshot(
201
199
  messages_json=run_result.all_messages_json(),
@@ -206,8 +204,8 @@ class InteractiveSession:
206
204
  if self.first_message:
207
205
  await self._threads.save_metadata(
208
206
  thread_id=self._thread_id,
209
- title=title,
210
- model_name=model_name,
207
+ title=user_query,
208
+ model_name=self.sqlsaber_agent.agent.model.model_name,
211
209
  )
212
210
  except Exception:
213
211
  pass
@@ -269,6 +267,17 @@ class InteractiveSession:
269
267
  self._thread_id = None
270
268
  continue
271
269
 
270
+ # Thinking commands
271
+ if user_query == "/thinking on":
272
+ self.sqlsaber_agent.set_thinking(enabled=True)
273
+ self.console.print("[green]✓ Thinking enabled[/green]\n")
274
+ continue
275
+
276
+ if user_query == "/thinking off":
277
+ self.sqlsaber_agent.set_thinking(enabled=False)
278
+ self.console.print("[green]✓ Thinking disabled[/green]\n")
279
+ continue
280
+
272
281
  if memory_text := user_query.strip():
273
282
  # Check if query starts with # for memory addition
274
283
  if memory_text.startswith("#"):
@@ -276,9 +285,7 @@ class InteractiveSession:
276
285
  if memory_content:
277
286
  # Add memory via the agent's memory manager
278
287
  try:
279
- mm = getattr(
280
- self.agent, "_sqlsaber_memory_manager", None
281
- )
288
+ mm = self.sqlsaber_agent.memory_manager
282
289
  if mm and self.database_name:
283
290
  memory = mm.add_memory(
284
291
  self.database_name, memory_content
sqlsaber/cli/streaming.py CHANGED
@@ -8,9 +8,9 @@ rendered via DisplayManager helpers.
8
8
  import asyncio
9
9
  import json
10
10
  from functools import singledispatchmethod
11
- from typing import AsyncIterable
11
+ from typing import TYPE_CHECKING, AsyncIterable
12
12
 
13
- from pydantic_ai import Agent, RunContext
13
+ from pydantic_ai import RunContext
14
14
  from pydantic_ai.messages import (
15
15
  AgentStreamEvent,
16
16
  FunctionToolCallEvent,
@@ -26,6 +26,9 @@ from rich.console import Console
26
26
 
27
27
  from sqlsaber.cli.display import DisplayManager
28
28
 
29
+ if TYPE_CHECKING:
30
+ from sqlsaber.agents.pydantic_ai_agent import SQLSaberAgent
31
+
29
32
 
30
33
  class StreamingQueryHandler:
31
34
  """
@@ -130,7 +133,7 @@ class StreamingQueryHandler:
130
133
  async def execute_streaming_query(
131
134
  self,
132
135
  user_query: str,
133
- agent: Agent,
136
+ sqlsaber_agent: "SQLSaberAgent",
134
137
  cancellation_token: asyncio.Event | None = None,
135
138
  message_history: list | None = None,
136
139
  ):
@@ -139,21 +142,16 @@ class StreamingQueryHandler:
139
142
  try:
140
143
  # If Anthropic OAuth, inject SQLsaber instructions before the first user prompt
141
144
  prepared_prompt: str | list[str] = user_query
142
- is_oauth = bool(getattr(agent, "_sqlsaber_is_oauth", False))
143
145
  no_history = not message_history
144
- if is_oauth and no_history:
145
- ib = getattr(agent, "_sqlsaber_instruction_builder", None)
146
- mm = getattr(agent, "_sqlsaber_memory_manager", None)
147
- db_type = getattr(agent, "_sqlsaber_db_type", "database")
148
- db_name = getattr(agent, "_sqlsaber_database_name", None)
149
- instructions = (
150
- ib.build_instructions(db_type=db_type) if ib is not None else ""
151
- )
152
- mem = (
153
- mm.format_memories_for_prompt(db_name)
154
- if (mm is not None and db_name)
155
- else ""
146
+ if sqlsaber_agent.is_oauth and no_history:
147
+ instructions = sqlsaber_agent.instruction_builder.build_instructions(
148
+ db_type=sqlsaber_agent.db_type
156
149
  )
150
+ mem = ""
151
+ if sqlsaber_agent.database_name:
152
+ mem = sqlsaber_agent.memory_manager.format_memories_for_prompt(
153
+ sqlsaber_agent.database_name
154
+ )
157
155
  parts = [p for p in (instructions, mem) if p and str(p).strip()]
158
156
  if parts:
159
157
  injected = "\n\n".join(parts)
@@ -163,7 +161,7 @@ class StreamingQueryHandler:
163
161
  self.display.live.start_status("Crunching data...")
164
162
 
165
163
  # Run the agent with our event stream handler
166
- run = await agent.run(
164
+ run = await sqlsaber_agent.agent.run(
167
165
  prepared_prompt,
168
166
  message_history=message_history,
169
167
  event_stream_handler=self._event_stream_handler,
sqlsaber/cli/threads.py CHANGED
@@ -148,7 +148,9 @@ def _render_transcript(
148
148
  )
149
149
  else:
150
150
  if is_redirected:
151
- console.print(f"**Tool result ({name}):**\n\n{content_str}\n")
151
+ console.print(
152
+ f"**Tool result ({name}):**\n\n{content_str}\n"
153
+ )
152
154
  else:
153
155
  console.print(
154
156
  Panel.fit(
@@ -159,7 +161,9 @@ def _render_transcript(
159
161
  )
160
162
  except Exception:
161
163
  if is_redirected:
162
- console.print(f"**Tool result ({name}):**\n\n{content_str}\n")
164
+ console.print(
165
+ f"**Tool result ({name}):**\n\n{content_str}\n"
166
+ )
163
167
  else:
164
168
  console.print(
165
169
  Panel.fit(
@@ -258,10 +262,10 @@ def resume(
258
262
 
259
263
  async def _run() -> None:
260
264
  # Lazy imports to avoid heavy modules at CLI startup
261
- from sqlsaber.agents import build_sqlsaber_agent
265
+ from sqlsaber.agents import SQLSaberAgent
262
266
  from sqlsaber.cli.interactive import InteractiveSession
263
267
  from sqlsaber.config.database import DatabaseConfigManager
264
- from sqlsaber.database.connection import DatabaseConnection
268
+ from sqlsaber.database import DatabaseConnection
265
269
  from sqlsaber.database.resolver import (
266
270
  DatabaseResolutionError,
267
271
  resolve_database,
@@ -288,7 +292,7 @@ def resume(
288
292
 
289
293
  db_conn = DatabaseConnection(connection_string)
290
294
  try:
291
- agent = build_sqlsaber_agent(db_conn, db_name)
295
+ sqlsaber_agent = SQLSaberAgent(db_conn, db_name)
292
296
  history = await store.get_thread_messages(thread_id)
293
297
  if console.is_terminal:
294
298
  console.print(Panel.fit(f"Thread: {thread.id}", border_style="blue"))
@@ -297,7 +301,7 @@ def resume(
297
301
  _render_transcript(console, history, None)
298
302
  session = InteractiveSession(
299
303
  console=console,
300
- agent=agent,
304
+ sqlsaber_agent=sqlsaber_agent,
301
305
  db_conn=db_conn,
302
306
  database_name=db_name,
303
307
  initial_thread_id=thread_id,
@@ -46,7 +46,10 @@ class ModelConfigManager:
46
46
  def _load_config(self) -> dict[str, Any]:
47
47
  """Load configuration from file."""
48
48
  if not self.config_file.exists():
49
- return {"model": self.DEFAULT_MODEL}
49
+ return {
50
+ "model": self.DEFAULT_MODEL,
51
+ "thinking_enabled": False,
52
+ }
50
53
 
51
54
  try:
52
55
  with open(self.config_file, "r") as f:
@@ -54,9 +57,15 @@ class ModelConfigManager:
54
57
  # Ensure we have a model set
55
58
  if "model" not in config:
56
59
  config["model"] = self.DEFAULT_MODEL
60
+ # Set defaults for thinking if not present
61
+ if "thinking_enabled" not in config:
62
+ config["thinking_enabled"] = False
57
63
  return config
58
64
  except (json.JSONDecodeError, IOError):
59
- return {"model": self.DEFAULT_MODEL}
65
+ return {
66
+ "model": self.DEFAULT_MODEL,
67
+ "thinking_enabled": False,
68
+ }
60
69
 
61
70
  def _save_config(self, config: dict[str, Any]) -> None:
62
71
  """Save configuration to file."""
@@ -76,6 +85,17 @@ class ModelConfigManager:
76
85
  config["model"] = model
77
86
  self._save_config(config)
78
87
 
88
+ def get_thinking_enabled(self) -> bool:
89
+ """Get whether thinking is enabled."""
90
+ config = self._load_config()
91
+ return config.get("thinking_enabled", False)
92
+
93
+ def set_thinking_enabled(self, enabled: bool) -> None:
94
+ """Set whether thinking is enabled."""
95
+ config = self._load_config()
96
+ config["thinking_enabled"] = enabled
97
+ self._save_config(config)
98
+
79
99
 
80
100
  class Config:
81
101
  """Configuration class for SQLSaber."""
@@ -86,6 +106,9 @@ class Config:
86
106
  self.api_key_manager = APIKeyManager()
87
107
  self.auth_config_manager = AuthConfigManager()
88
108
 
109
+ # Thinking configuration
110
+ self.thinking_enabled = self.model_config_manager.get_thinking_enabled()
111
+
89
112
  # Authentication method (API key or Anthropic OAuth)
90
113
  self.auth_method = self.auth_config_manager.get_auth_method()
91
114
 
@@ -1,9 +1,63 @@
1
1
  """Database module for SQLSaber."""
2
2
 
3
- from .connection import DatabaseConnection
3
+ from .base import (
4
+ DEFAULT_QUERY_TIMEOUT,
5
+ BaseDatabaseConnection,
6
+ BaseSchemaIntrospector,
7
+ ColumnInfo,
8
+ ForeignKeyInfo,
9
+ IndexInfo,
10
+ QueryTimeoutError,
11
+ SchemaInfo,
12
+ )
13
+ from .csv import CSVConnection, CSVSchemaIntrospector
14
+ from .duckdb import DuckDBConnection, DuckDBSchemaIntrospector
15
+ from .mysql import MySQLConnection, MySQLSchemaIntrospector
16
+ from .postgresql import PostgreSQLConnection, PostgreSQLSchemaIntrospector
4
17
  from .schema import SchemaManager
18
+ from .sqlite import SQLiteConnection, SQLiteSchemaIntrospector
19
+
20
+
21
+ def DatabaseConnection(connection_string: str) -> BaseDatabaseConnection:
22
+ """Factory function to create appropriate database connection based on connection string."""
23
+ if connection_string.startswith("postgresql://"):
24
+ return PostgreSQLConnection(connection_string)
25
+ elif connection_string.startswith("mysql://"):
26
+ return MySQLConnection(connection_string)
27
+ elif connection_string.startswith("sqlite:///"):
28
+ return SQLiteConnection(connection_string)
29
+ elif connection_string.startswith("duckdb://"):
30
+ return DuckDBConnection(connection_string)
31
+ elif connection_string.startswith("csv:///"):
32
+ return CSVConnection(connection_string)
33
+ else:
34
+ raise ValueError(
35
+ f"Unsupported database type in connection string: {connection_string}"
36
+ )
37
+
5
38
 
6
39
  __all__ = [
40
+ # Base classes and types
41
+ "BaseDatabaseConnection",
42
+ "BaseSchemaIntrospector",
43
+ "ColumnInfo",
44
+ "DEFAULT_QUERY_TIMEOUT",
45
+ "ForeignKeyInfo",
46
+ "IndexInfo",
47
+ "QueryTimeoutError",
48
+ "SchemaInfo",
49
+ # Concrete implementations
50
+ "PostgreSQLConnection",
51
+ "MySQLConnection",
52
+ "SQLiteConnection",
53
+ "DuckDBConnection",
54
+ "CSVConnection",
55
+ "PostgreSQLSchemaIntrospector",
56
+ "MySQLSchemaIntrospector",
57
+ "SQLiteSchemaIntrospector",
58
+ "DuckDBSchemaIntrospector",
59
+ "CSVSchemaIntrospector",
60
+ # Factory function and manager
7
61
  "DatabaseConnection",
8
62
  "SchemaManager",
9
63
  ]
@@ -0,0 +1,124 @@
1
+ """Base classes and type definitions for database connections and schema introspection."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any, TypedDict
5
+
6
+ # Default query timeout to prevent runaway queries
7
+ DEFAULT_QUERY_TIMEOUT = 30.0 # seconds
8
+
9
+
10
+ class QueryTimeoutError(RuntimeError):
11
+ """Exception raised when a query exceeds its timeout."""
12
+
13
+ def __init__(self, seconds: float):
14
+ self.timeout = seconds
15
+ super().__init__(f"Query exceeded timeout of {seconds}s")
16
+
17
+
18
+ class ColumnInfo(TypedDict):
19
+ """Type definition for column information."""
20
+
21
+ data_type: str
22
+ nullable: bool
23
+ default: str | None
24
+ max_length: int | None
25
+ precision: int | None
26
+ scale: int | None
27
+
28
+
29
+ class ForeignKeyInfo(TypedDict):
30
+ """Type definition for foreign key information."""
31
+
32
+ column: str
33
+ references: dict[str, str] # {"table": "schema.table", "column": "column_name"}
34
+
35
+
36
+ class IndexInfo(TypedDict):
37
+ """Type definition for index information."""
38
+
39
+ name: str
40
+ columns: list[str] # ordered
41
+ unique: bool
42
+ type: str | None # btree, gin, FULLTEXT, etc. None if unknown
43
+
44
+
45
+ class SchemaInfo(TypedDict):
46
+ """Type definition for schema information."""
47
+
48
+ schema: str
49
+ name: str
50
+ type: str
51
+ columns: dict[str, ColumnInfo]
52
+ primary_keys: list[str]
53
+ foreign_keys: list[ForeignKeyInfo]
54
+ indexes: list[IndexInfo]
55
+
56
+
57
+ class BaseDatabaseConnection(ABC):
58
+ """Abstract base class for database connections."""
59
+
60
+ def __init__(self, connection_string: str):
61
+ self.connection_string = connection_string
62
+ self._pool = None
63
+
64
+ @abstractmethod
65
+ async def get_pool(self):
66
+ """Get or create connection pool."""
67
+ pass
68
+
69
+ @abstractmethod
70
+ async def close(self):
71
+ """Close the connection pool."""
72
+ pass
73
+
74
+ @abstractmethod
75
+ async def execute_query(
76
+ self, query: str, *args, timeout: float | None = None
77
+ ) -> list[dict[str, Any]]:
78
+ """Execute a query and return results as list of dicts.
79
+
80
+ All queries run in a transaction that is rolled back at the end,
81
+ ensuring no changes are persisted to the database.
82
+
83
+ Args:
84
+ query: SQL query to execute
85
+ *args: Query parameters
86
+ timeout: Query timeout in seconds (overrides default_timeout)
87
+ """
88
+ pass
89
+
90
+
91
+ class BaseSchemaIntrospector(ABC):
92
+ """Abstract base class for database-specific schema introspection."""
93
+
94
+ @abstractmethod
95
+ async def get_tables_info(
96
+ self, connection, table_pattern: str | None = None
97
+ ) -> dict[str, Any]:
98
+ """Get tables information for the specific database type."""
99
+ pass
100
+
101
+ @abstractmethod
102
+ async def get_columns_info(self, connection, tables: list) -> list:
103
+ """Get columns information for the specific database type."""
104
+ pass
105
+
106
+ @abstractmethod
107
+ async def get_foreign_keys_info(self, connection, tables: list) -> list:
108
+ """Get foreign keys information for the specific database type."""
109
+ pass
110
+
111
+ @abstractmethod
112
+ async def get_primary_keys_info(self, connection, tables: list) -> list:
113
+ """Get primary keys information for the specific database type."""
114
+ pass
115
+
116
+ @abstractmethod
117
+ async def get_indexes_info(self, connection, tables: list) -> list:
118
+ """Get indexes information for the specific database type."""
119
+ pass
120
+
121
+ @abstractmethod
122
+ async def list_tables_info(self, connection) -> list[dict[str, Any]]:
123
+ """Get list of tables with basic information."""
124
+ pass
@@ -0,0 +1,133 @@
1
+ """CSV database connection using DuckDB backend."""
2
+
3
+ import asyncio
4
+ from pathlib import Path
5
+ from typing import Any
6
+ from urllib.parse import parse_qs, urlparse
7
+
8
+ import duckdb
9
+
10
+ from .base import DEFAULT_QUERY_TIMEOUT, BaseDatabaseConnection, QueryTimeoutError
11
+ from .duckdb import DuckDBSchemaIntrospector
12
+
13
+
14
+ def _execute_duckdb_transaction(
15
+ conn: duckdb.DuckDBPyConnection, query: str, args: tuple[Any, ...]
16
+ ) -> list[dict[str, Any]]:
17
+ """Run a DuckDB query inside a transaction and return list of dicts."""
18
+ conn.execute("BEGIN TRANSACTION")
19
+ try:
20
+ if args:
21
+ conn.execute(query, args)
22
+ else:
23
+ conn.execute(query)
24
+
25
+ if conn.description is None:
26
+ rows: list[dict[str, Any]] = []
27
+ else:
28
+ columns = [col[0] for col in conn.description]
29
+ data = conn.fetchall()
30
+ rows = [dict(zip(columns, row)) for row in data]
31
+
32
+ conn.execute("ROLLBACK")
33
+ return rows
34
+ except Exception:
35
+ conn.execute("ROLLBACK")
36
+ raise
37
+
38
+
39
+ class CSVConnection(BaseDatabaseConnection):
40
+ """CSV file connection using DuckDB per query."""
41
+
42
+ def __init__(self, connection_string: str):
43
+ super().__init__(connection_string)
44
+
45
+ raw_path = connection_string.replace("csv:///", "", 1)
46
+ self.csv_path = raw_path.split("?", 1)[0]
47
+
48
+ self.delimiter = ","
49
+ self.encoding = "utf-8"
50
+ self.has_header = True
51
+
52
+ parsed = urlparse(connection_string)
53
+ if parsed.query:
54
+ params = parse_qs(parsed.query)
55
+ self.delimiter = params.get("delimiter", [self.delimiter])[0]
56
+ self.encoding = params.get("encoding", [self.encoding])[0]
57
+ self.has_header = params.get("header", ["true"])[0].lower() == "true"
58
+
59
+ self.table_name = Path(self.csv_path).stem or "csv_table"
60
+
61
+ async def get_pool(self):
62
+ """CSV connections do not maintain a pool."""
63
+ return None
64
+
65
+ async def close(self):
66
+ """No persistent resources to close for CSV connections."""
67
+ pass
68
+
69
+ def _quote_identifier(self, identifier: str) -> str:
70
+ escaped = identifier.replace('"', '""')
71
+ return f'"{escaped}"'
72
+
73
+ def _quote_literal(self, value: str) -> str:
74
+ escaped = value.replace("'", "''")
75
+ return f"'{escaped}'"
76
+
77
+ def _normalized_encoding(self) -> str | None:
78
+ encoding = (self.encoding or "").strip()
79
+ if not encoding or encoding.lower() == "utf-8":
80
+ return None
81
+ return encoding.replace("-", "").replace("_", "").upper()
82
+
83
+ def _create_view(self, conn: duckdb.DuckDBPyConnection) -> None:
84
+ header_literal = "TRUE" if self.has_header else "FALSE"
85
+ option_parts = [f"HEADER={header_literal}"]
86
+
87
+ if self.delimiter:
88
+ option_parts.append(f"DELIM={self._quote_literal(self.delimiter)}")
89
+
90
+ encoding = self._normalized_encoding()
91
+ if encoding:
92
+ option_parts.append(f"ENCODING={self._quote_literal(encoding)}")
93
+
94
+ options_sql = ""
95
+ if option_parts:
96
+ options_sql = ", " + ", ".join(option_parts)
97
+
98
+ base_relation_sql = (
99
+ f"read_csv_auto({self._quote_literal(self.csv_path)}{options_sql})"
100
+ )
101
+
102
+ create_view_sql = (
103
+ f"CREATE VIEW {self._quote_identifier(self.table_name)} AS "
104
+ f"SELECT * FROM {base_relation_sql}"
105
+ )
106
+ conn.execute(create_view_sql)
107
+
108
+ async def execute_query(
109
+ self, query: str, *args, timeout: float | None = None
110
+ ) -> list[dict[str, Any]]:
111
+ effective_timeout = timeout or DEFAULT_QUERY_TIMEOUT
112
+ args_tuple = tuple(args) if args else tuple()
113
+
114
+ def _run_query() -> list[dict[str, Any]]:
115
+ conn = duckdb.connect(":memory:")
116
+ try:
117
+ self._create_view(conn)
118
+ return _execute_duckdb_transaction(conn, query, args_tuple)
119
+ finally:
120
+ conn.close()
121
+
122
+ try:
123
+ return await asyncio.wait_for(
124
+ asyncio.to_thread(_run_query), timeout=effective_timeout
125
+ )
126
+ except asyncio.TimeoutError as exc:
127
+ raise QueryTimeoutError(effective_timeout or 0) from exc
128
+
129
+
130
+ class CSVSchemaIntrospector(DuckDBSchemaIntrospector):
131
+ """CSV-specific schema introspection using DuckDB backend."""
132
+
133
+ pass