sqlsaber 0.14.0__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlsaber might be problematic. Click here for more details.

Files changed (38) hide show
  1. sqlsaber/agents/__init__.py +2 -4
  2. sqlsaber/agents/base.py +18 -221
  3. sqlsaber/agents/mcp.py +2 -2
  4. sqlsaber/agents/pydantic_ai_agent.py +170 -0
  5. sqlsaber/cli/auth.py +146 -79
  6. sqlsaber/cli/commands.py +22 -7
  7. sqlsaber/cli/database.py +1 -1
  8. sqlsaber/cli/interactive.py +65 -30
  9. sqlsaber/cli/models.py +58 -29
  10. sqlsaber/cli/streaming.py +114 -77
  11. sqlsaber/config/api_keys.py +9 -11
  12. sqlsaber/config/providers.py +116 -0
  13. sqlsaber/config/settings.py +50 -30
  14. sqlsaber/database/connection.py +3 -3
  15. sqlsaber/mcp/mcp.py +43 -51
  16. sqlsaber/models/__init__.py +0 -3
  17. sqlsaber/tools/__init__.py +25 -0
  18. sqlsaber/tools/base.py +85 -0
  19. sqlsaber/tools/enums.py +21 -0
  20. sqlsaber/tools/instructions.py +251 -0
  21. sqlsaber/tools/registry.py +130 -0
  22. sqlsaber/tools/sql_tools.py +275 -0
  23. sqlsaber/tools/visualization_tools.py +144 -0
  24. {sqlsaber-0.14.0.dist-info → sqlsaber-0.16.0.dist-info}/METADATA +20 -39
  25. sqlsaber-0.16.0.dist-info/RECORD +51 -0
  26. sqlsaber/agents/anthropic.py +0 -579
  27. sqlsaber/agents/streaming.py +0 -16
  28. sqlsaber/clients/__init__.py +0 -6
  29. sqlsaber/clients/anthropic.py +0 -285
  30. sqlsaber/clients/base.py +0 -31
  31. sqlsaber/clients/exceptions.py +0 -117
  32. sqlsaber/clients/models.py +0 -282
  33. sqlsaber/clients/streaming.py +0 -257
  34. sqlsaber/models/events.py +0 -28
  35. sqlsaber-0.14.0.dist-info/RECORD +0 -51
  36. {sqlsaber-0.14.0.dist-info → sqlsaber-0.16.0.dist-info}/WHEEL +0 -0
  37. {sqlsaber-0.14.0.dist-info → sqlsaber-0.16.0.dist-info}/entry_points.txt +0 -0
  38. {sqlsaber-0.14.0.dist-info → sqlsaber-0.16.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,9 +1,7 @@
1
1
  """Agents module for SQLSaber."""
2
2
 
3
- from .anthropic import AnthropicSQLAgent
4
- from .base import BaseSQLAgent
3
+ from .pydantic_ai_agent import build_sqlsaber_agent
5
4
 
6
5
  __all__ = [
7
- "BaseSQLAgent",
8
- "AnthropicSQLAgent",
6
+ "build_sqlsaber_agent",
9
7
  ]
sqlsaber/agents/base.py CHANGED
@@ -5,8 +5,6 @@ import json
5
5
  from abc import ABC, abstractmethod
6
6
  from typing import Any, AsyncIterator
7
7
 
8
- from uniplot import histogram, plot
9
-
10
8
  from sqlsaber.conversation.manager import ConversationManager
11
9
  from sqlsaber.database.connection import (
12
10
  BaseDatabaseConnection,
@@ -16,7 +14,7 @@ from sqlsaber.database.connection import (
16
14
  SQLiteConnection,
17
15
  )
18
16
  from sqlsaber.database.schema import SchemaManager
19
- from sqlsaber.models.events import StreamEvent
17
+ from sqlsaber.tools import SQLTool, tool_registry
20
18
 
21
19
 
22
20
  class BaseSQLAgent(ABC):
@@ -32,13 +30,16 @@ class BaseSQLAgent(ABC):
32
30
  self._conversation_id: str | None = None
33
31
  self._msg_index: int = 0
34
32
 
33
+ # Initialize SQL tools with database connection
34
+ self._init_tools()
35
+
35
36
  @abstractmethod
36
37
  async def query_stream(
37
38
  self,
38
39
  user_query: str,
39
40
  use_history: bool = True,
40
41
  cancellation_token: asyncio.Event | None = None,
41
- ) -> AsyncIterator[StreamEvent]:
42
+ ) -> AsyncIterator:
42
43
  """Process a user query and stream responses.
43
44
 
44
45
  Args:
@@ -69,232 +70,28 @@ class BaseSQLAgent(ABC):
69
70
  else:
70
71
  return "database" # Fallback
71
72
 
72
- async def introspect_schema(self, table_pattern: str | None = None) -> str:
73
- """Introspect database schema to understand table structures."""
74
- try:
75
- # Pass table_pattern to get_schema_info for efficient filtering at DB level
76
- schema_info = await self.schema_manager.get_schema_info(table_pattern)
77
-
78
- # Format the schema information
79
- formatted_info = {}
80
- for table_name, table_info in schema_info.items():
81
- formatted_info[table_name] = {
82
- "columns": {
83
- col_name: {
84
- "type": col_info["data_type"],
85
- "nullable": col_info["nullable"],
86
- "default": col_info["default"],
87
- }
88
- for col_name, col_info in table_info["columns"].items()
89
- },
90
- "primary_keys": table_info["primary_keys"],
91
- "foreign_keys": [
92
- f"{fk['column']} -> {fk['references']['table']}.{fk['references']['column']}"
93
- for fk in table_info["foreign_keys"]
94
- ],
95
- }
96
-
97
- return json.dumps(formatted_info)
98
- except Exception as e:
99
- return json.dumps({"error": f"Error introspecting schema: {str(e)}"})
100
-
101
- async def list_tables(self) -> str:
102
- """List all tables in the database with basic information."""
103
- try:
104
- tables_info = await self.schema_manager.list_tables()
105
- return json.dumps(tables_info)
106
- except Exception as e:
107
- return json.dumps({"error": f"Error listing tables: {str(e)}"})
108
-
109
- async def execute_sql(self, query: str, limit: int | None = None) -> str:
110
- """Execute a SQL query against the database."""
111
- try:
112
- # Security check - only allow SELECT queries unless write is enabled
113
- write_error = self._validate_write_operation(query)
114
- if write_error:
115
- return json.dumps(
116
- {
117
- "error": write_error,
118
- }
119
- )
120
-
121
- # Add LIMIT if not present and it's a SELECT query
122
- query = self._add_limit_to_query(query, limit)
123
-
124
- # Execute the query (wrapped in a transaction for safety)
125
- results = await self.db.execute_query(query)
126
-
127
- # Format results
128
- actual_limit = limit if limit is not None else len(results)
129
-
130
- return json.dumps(
131
- {
132
- "success": True,
133
- "row_count": len(results),
134
- "results": results[:actual_limit], # Extra safety for limit
135
- "truncated": len(results) > actual_limit,
136
- }
137
- )
138
-
139
- except Exception as e:
140
- error_msg = str(e)
141
-
142
- # Provide helpful error messages
143
- suggestions = []
144
- if "column" in error_msg.lower() and "does not exist" in error_msg.lower():
145
- suggestions.append(
146
- "Check column names using the schema introspection tool"
147
- )
148
- elif "table" in error_msg.lower() and "does not exist" in error_msg.lower():
149
- suggestions.append(
150
- "Check table names using the schema introspection tool"
151
- )
152
- elif "syntax error" in error_msg.lower():
153
- suggestions.append(
154
- "Review SQL syntax, especially JOIN conditions and WHERE clauses"
155
- )
156
-
157
- return json.dumps({"error": error_msg, "suggestions": suggestions})
73
+ def _init_tools(self) -> None:
74
+ """Initialize SQL tools with database connection."""
75
+ # Get all SQL tools and set their database connection
76
+ for tool_name in tool_registry.list_tools(category="sql"):
77
+ tool = tool_registry.get_tool(tool_name)
78
+ if isinstance(tool, SQLTool):
79
+ tool.set_connection(self.db)
158
80
 
159
81
  async def process_tool_call(
160
82
  self, tool_name: str, tool_input: dict[str, Any]
161
83
  ) -> str:
162
84
  """Process a tool call and return the result."""
163
- if tool_name == "list_tables":
164
- return await self.list_tables()
165
- elif tool_name == "introspect_schema":
166
- return await self.introspect_schema(tool_input.get("table_pattern"))
167
- elif tool_name == "execute_sql":
168
- return await self.execute_sql(
169
- tool_input["query"], tool_input.get("limit", 100)
170
- )
171
- elif tool_name == "plot_data":
172
- return await self.plot_data(
173
- y_values=tool_input["y_values"],
174
- x_values=tool_input.get("x_values"),
175
- plot_type=tool_input.get("plot_type", "line"),
176
- title=tool_input.get("title"),
177
- x_label=tool_input.get("x_label"),
178
- y_label=tool_input.get("y_label"),
179
- )
180
- else:
181
- return json.dumps({"error": f"Unknown tool: {tool_name}"})
182
-
183
- def _validate_write_operation(self, query: str) -> str | None:
184
- """Validate if a write operation is allowed.
185
-
186
- Returns:
187
- None if operation is allowed, error message if not allowed.
188
- """
189
- query_upper = query.strip().upper()
190
-
191
- # Check for write operations
192
- write_keywords = [
193
- "INSERT",
194
- "UPDATE",
195
- "DELETE",
196
- "DROP",
197
- "CREATE",
198
- "ALTER",
199
- "TRUNCATE",
200
- ]
201
- is_write_query = any(query_upper.startswith(kw) for kw in write_keywords)
202
-
203
- if is_write_query:
204
- return (
205
- "Write operations are not allowed. Only SELECT queries are permitted."
206
- )
207
-
208
- return None
209
-
210
- def _add_limit_to_query(self, query: str, limit: int = 100) -> str:
211
- """Add LIMIT clause to SELECT queries if not present."""
212
- query_upper = query.strip().upper()
213
- if query_upper.startswith("SELECT") and "LIMIT" not in query_upper:
214
- return f"{query.rstrip(';')} LIMIT {limit};"
215
- return query
216
-
217
- async def plot_data(
218
- self,
219
- y_values: list[float],
220
- x_values: list[float] | None = None,
221
- plot_type: str = "line",
222
- title: str | None = None,
223
- x_label: str | None = None,
224
- y_label: str | None = None,
225
- ) -> str:
226
- """Create a terminal plot using uniplot.
227
-
228
- Args:
229
- y_values: Y-axis data points
230
- x_values: X-axis data points (optional)
231
- plot_type: Type of plot - "line", "scatter", or "histogram"
232
- title: Plot title
233
- x_label: X-axis label
234
- y_label: Y-axis label
235
-
236
- Returns:
237
- JSON string with success status and plot details
238
- """
239
85
  try:
240
- # Validate inputs
241
- if not y_values:
242
- return json.dumps({"error": "No data provided for plotting"})
243
-
244
- # Convert to floats if needed
245
- try:
246
- y_values = [float(v) if v is not None else None for v in y_values]
247
- if x_values:
248
- x_values = [float(v) if v is not None else None for v in x_values]
249
- except (ValueError, TypeError) as e:
250
- return json.dumps({"error": f"Invalid data format: {str(e)}"})
251
-
252
- # Create the plot
253
- if plot_type == "histogram":
254
- # For histogram, we only need y_values
255
- histogram(
256
- y_values,
257
- title=title,
258
- bins=min(20, len(set(y_values))), # Adaptive bin count
259
- )
260
- plot_info = {
261
- "type": "histogram",
262
- "data_points": len(y_values),
263
- "title": title or "Histogram",
264
- }
265
- elif plot_type in ["line", "scatter"]:
266
- # For line/scatter plots
267
- plot_kwargs = {
268
- "ys": y_values,
269
- "title": title,
270
- "lines": plot_type == "line",
271
- }
272
-
273
- if x_values:
274
- plot_kwargs["xs"] = x_values
275
- if x_label:
276
- plot_kwargs["x_unit"] = x_label
277
- if y_label:
278
- plot_kwargs["y_unit"] = y_label
279
-
280
- plot(**plot_kwargs)
281
-
282
- plot_info = {
283
- "type": plot_type,
284
- "data_points": len(y_values),
285
- "title": title or f"{plot_type.capitalize()} Plot",
286
- "has_x_values": x_values is not None,
287
- }
288
- else:
289
- return json.dumps({"error": f"Unsupported plot type: {plot_type}"})
290
-
86
+ tool = tool_registry.get_tool(tool_name)
87
+ return await tool.execute(**tool_input)
88
+ except KeyError:
89
+ return json.dumps({"error": f"Unknown tool: {tool_name}"})
90
+ except Exception as e:
291
91
  return json.dumps(
292
- {"success": True, "plot_rendered": True, "plot_info": plot_info}
92
+ {"error": f"Error executing tool '{tool_name}': {str(e)}"}
293
93
  )
294
94
 
295
- except Exception as e:
296
- return json.dumps({"error": f"Error creating plot: {str(e)}"})
297
-
298
95
  # Conversation persistence helpers
299
96
 
300
97
  async def _ensure_conversation(self) -> None:
sqlsaber/agents/mcp.py CHANGED
@@ -1,9 +1,9 @@
1
1
  """Generic SQL agent implementation for MCP tools."""
2
2
 
3
3
  from typing import AsyncIterator
4
+
4
5
  from sqlsaber.agents.base import BaseSQLAgent
5
6
  from sqlsaber.database.connection import BaseDatabaseConnection
6
- from sqlsaber.models.events import StreamEvent
7
7
 
8
8
 
9
9
  class MCPSQLAgent(BaseSQLAgent):
@@ -14,7 +14,7 @@ class MCPSQLAgent(BaseSQLAgent):
14
14
 
15
15
  async def query_stream(
16
16
  self, user_query: str, use_history: bool = True
17
- ) -> AsyncIterator[StreamEvent]:
17
+ ) -> AsyncIterator:
18
18
  """Not implemented for generic agent as it's only used for tool operations."""
19
19
  raise NotImplementedError(
20
20
  "MCPSQLAgent does not support query streaming. Use specific agent implementations for conversation."
@@ -0,0 +1,170 @@
1
+ """Pydantic-AI Agent for SQLSaber.
2
+
3
+ This replaces the custom AnthropicSQLAgent and uses pydantic-ai's Agent,
4
+ function tools, and streaming event types directly.
5
+ """
6
+
7
+ import httpx
8
+ from pydantic_ai import Agent, RunContext
9
+ from pydantic_ai.models.anthropic import AnthropicModel
10
+ from pydantic_ai.models.google import GoogleModel
11
+ from pydantic_ai.providers.anthropic import AnthropicProvider
12
+ from pydantic_ai.providers.google import GoogleProvider
13
+
14
+ from sqlsaber.config import providers
15
+ from sqlsaber.config.settings import Config
16
+ from sqlsaber.database.connection import (
17
+ BaseDatabaseConnection,
18
+ CSVConnection,
19
+ MySQLConnection,
20
+ PostgreSQLConnection,
21
+ SQLiteConnection,
22
+ )
23
+ from sqlsaber.memory.manager import MemoryManager
24
+ from sqlsaber.tools.instructions import InstructionBuilder
25
+ from sqlsaber.tools.registry import tool_registry
26
+ from sqlsaber.tools.sql_tools import SQLTool
27
+
28
+
29
+ def build_sqlsaber_agent(
30
+ db_connection: BaseDatabaseConnection,
31
+ database_name: str | None,
32
+ ) -> Agent:
33
+ """Create and configure a pydantic-ai Agent for SQLSaber.
34
+
35
+ - Registers function tools that delegate to the existing tool registry
36
+ - Attaches dynamic system prompt built from InstructionBuilder + MemoryManager
37
+ - Ensures SQL tools have the active DB connection
38
+ """
39
+ # Ensure SQL tools receive the active connection
40
+ for tool_name in tool_registry.list_tools(category="sql"):
41
+ tool = tool_registry.get_tool(tool_name)
42
+ if isinstance(tool, SQLTool):
43
+ tool.set_connection(db_connection)
44
+
45
+ cfg = Config()
46
+ # Ensure provider env var is hydrated from keyring for current provider (Config.validate handles it)
47
+ cfg.validate()
48
+
49
+ # Build model/agent. For some providers (e.g., google), construct provider model explicitly to
50
+ # allow arbitrary model IDs even if not in pydantic-ai's KnownModelName.
51
+ model_name_only = (
52
+ cfg.model_name.split(":", 1)[1] if ":" in cfg.model_name else cfg.model_name
53
+ )
54
+
55
+ provider = providers.provider_from_model(cfg.model_name) or ""
56
+ if provider == "google":
57
+ model_obj = GoogleModel(
58
+ model_name_only, provider=GoogleProvider(api_key=cfg.api_key)
59
+ )
60
+ agent = Agent(model_obj, name="sqlsaber")
61
+ elif provider == "anthropic" and bool(getattr(cfg, "oauth_token", None)):
62
+ # Build custom httpx client to inject OAuth headers for Anthropic
63
+ async def add_oauth_headers(request: httpx.Request) -> None: # type: ignore[override]
64
+ # Remove API-key header if present and add OAuth headers
65
+ if "x-api-key" in request.headers:
66
+ del request.headers["x-api-key"]
67
+ request.headers.update(
68
+ {
69
+ "Authorization": f"Bearer {cfg.oauth_token}",
70
+ "anthropic-version": "2023-06-01",
71
+ "anthropic-beta": "oauth-2025-04-20",
72
+ "User-Agent": "ClaudeCode/1.0 (Anthropic Claude Code CLI)",
73
+ "X-Client-Name": "claude-code",
74
+ "X-Client-Version": "1.0.0",
75
+ }
76
+ )
77
+
78
+ http_client = httpx.AsyncClient(event_hooks={"request": [add_oauth_headers]})
79
+ provider_obj = AnthropicProvider(api_key="placeholder", http_client=http_client)
80
+ model_obj = AnthropicModel(model_name_only, provider=provider_obj)
81
+ agent = Agent(model_obj, name="sqlsaber")
82
+ else:
83
+ agent = Agent(cfg.model_name, name="sqlsaber")
84
+
85
+ # Memory + dynamic system prompt
86
+ memory_manager = MemoryManager()
87
+ instruction_builder = InstructionBuilder(tool_registry)
88
+
89
+ is_oauth = provider == "anthropic" and bool(getattr(cfg, "oauth_token", None))
90
+
91
+ if not is_oauth:
92
+
93
+ @agent.system_prompt(dynamic=True)
94
+ async def sqlsaber_system_prompt(ctx: RunContext) -> str:
95
+ db_type = _get_database_type_name(db_connection)
96
+ instructions = instruction_builder.build_instructions(db_type=db_type)
97
+
98
+ # Add memory context if available
99
+ if database_name:
100
+ mem = memory_manager.format_memories_for_prompt(database_name)
101
+ else:
102
+ mem = ""
103
+
104
+ parts = [p for p in (instructions, mem) if p and p.strip()]
105
+ return "\n\n".join(parts) if parts else ""
106
+ else:
107
+
108
+ @agent.system_prompt(dynamic=True)
109
+ async def sqlsaber_system_prompt(ctx: RunContext) -> str:
110
+ # Minimal system prompt in OAuth mode to match Claude Code identity
111
+ return "You are Claude Code, Anthropic's official CLI for Claude."
112
+
113
+ # Expose helpers and context on agent instance
114
+ agent._sqlsaber_memory_manager = memory_manager # type: ignore[attr-defined]
115
+ agent._sqlsaber_database_name = database_name # type: ignore[attr-defined]
116
+ agent._sqlsaber_instruction_builder = instruction_builder # type: ignore[attr-defined]
117
+ agent._sqlsaber_db_type = _get_database_type_name(db_connection) # type: ignore[attr-defined]
118
+ agent._sqlsaber_is_oauth = is_oauth # type: ignore[attr-defined]
119
+
120
+ # Tool wrappers that invoke the registered tools
121
+ @agent.tool(name="list_tables")
122
+ async def list_tables(ctx: RunContext) -> str:
123
+ """
124
+ Get a list of all tables in the database with row counts.
125
+ Use this first to discover available tables.
126
+ """
127
+ tool = tool_registry.get_tool("list_tables")
128
+ return await tool.execute()
129
+
130
+ @agent.tool(name="introspect_schema")
131
+ async def introspect_schema(
132
+ ctx: RunContext, table_pattern: str | None = None
133
+ ) -> str:
134
+ """
135
+ Introspect database schema to understand table structures.
136
+
137
+ Args:
138
+ table_pattern: Optional pattern to filter tables (e.g., 'public.users', 'user%', '%order%')
139
+ """
140
+ tool = tool_registry.get_tool("introspect_schema")
141
+ return await tool.execute(table_pattern=table_pattern)
142
+
143
+ @agent.tool(name="execute_sql")
144
+ async def execute_sql(ctx: RunContext, query: str, limit: int | None = 100) -> str:
145
+ """
146
+ Execute a SQL query and return the results.
147
+
148
+ Args:
149
+ query: SQL query to execute
150
+ limit: Maximum number of rows to return (default: 100)
151
+ """
152
+ tool = tool_registry.get_tool("execute_sql")
153
+ return await tool.execute(query=query, limit=limit)
154
+
155
+ return agent
156
+
157
+
158
+ def _get_database_type_name(db: BaseDatabaseConnection) -> str:
159
+ """Get the human-readable database type name (mirrors BaseSQLAgent)."""
160
+
161
+ if isinstance(db, PostgreSQLConnection):
162
+ return "PostgreSQL"
163
+ elif isinstance(db, MySQLConnection):
164
+ return "MySQL"
165
+ elif isinstance(db, SQLiteConnection):
166
+ return "SQLite"
167
+ elif isinstance(db, CSVConnection):
168
+ return "SQLite"
169
+ else:
170
+ return "database"