sqlsaber 0.13.0__tar.gz → 0.15.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlsaber might be problematic. Click here for more details.

Files changed (91) hide show
  1. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/.github/workflows/claude-code-review.yml +12 -13
  2. sqlsaber-0.15.0/.github/workflows/test.yml +33 -0
  3. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/CHANGELOG.md +35 -0
  4. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/PKG-INFO +1 -1
  5. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/pyproject.toml +1 -1
  6. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/agents/anthropic.py +63 -123
  7. sqlsaber-0.15.0/src/sqlsaber/agents/base.py +187 -0
  8. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/cli/interactive.py +6 -2
  9. sqlsaber-0.15.0/src/sqlsaber/conversation/__init__.py +12 -0
  10. sqlsaber-0.15.0/src/sqlsaber/conversation/manager.py +224 -0
  11. sqlsaber-0.15.0/src/sqlsaber/conversation/models.py +120 -0
  12. sqlsaber-0.15.0/src/sqlsaber/conversation/storage.py +362 -0
  13. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/database/schema.py +2 -51
  14. sqlsaber-0.15.0/src/sqlsaber/mcp/mcp.py +129 -0
  15. sqlsaber-0.15.0/src/sqlsaber/tools/__init__.py +25 -0
  16. sqlsaber-0.15.0/src/sqlsaber/tools/base.py +83 -0
  17. sqlsaber-0.15.0/src/sqlsaber/tools/enums.py +21 -0
  18. sqlsaber-0.15.0/src/sqlsaber/tools/instructions.py +251 -0
  19. sqlsaber-0.15.0/src/sqlsaber/tools/registry.py +130 -0
  20. sqlsaber-0.15.0/src/sqlsaber/tools/sql_tools.py +275 -0
  21. sqlsaber-0.15.0/src/sqlsaber/tools/visualization_tools.py +144 -0
  22. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/tests/test_cli/test_commands.py +23 -23
  23. sqlsaber-0.15.0/tests/test_conversation_storage.py +136 -0
  24. sqlsaber-0.15.0/tests/test_tools/__init__.py +1 -0
  25. sqlsaber-0.15.0/tests/test_tools/test_base.py +63 -0
  26. sqlsaber-0.15.0/tests/test_tools/test_instructions.py +255 -0
  27. sqlsaber-0.15.0/tests/test_tools/test_registry.py +189 -0
  28. sqlsaber-0.15.0/tests/test_tools/test_sql_tools.py +218 -0
  29. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/uv.lock +1 -1
  30. sqlsaber-0.13.0/src/sqlsaber/agents/base.py +0 -286
  31. sqlsaber-0.13.0/src/sqlsaber/mcp/mcp.py +0 -137
  32. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/.github/workflows/claude.yml +0 -0
  33. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/.github/workflows/publish.yml +0 -0
  34. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/.gitignore +0 -0
  35. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/.python-version +0 -0
  36. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/AGENT.md +0 -0
  37. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/CLAUDE.md +0 -0
  38. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/LICENSE +0 -0
  39. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/README.md +0 -0
  40. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/pytest.ini +0 -0
  41. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/sqlsaber.svg +0 -0
  42. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/__init__.py +0 -0
  43. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/__main__.py +0 -0
  44. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/agents/__init__.py +0 -0
  45. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/agents/mcp.py +0 -0
  46. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/agents/streaming.py +0 -0
  47. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/cli/__init__.py +0 -0
  48. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/cli/auth.py +0 -0
  49. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/cli/commands.py +0 -0
  50. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/cli/completers.py +0 -0
  51. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/cli/database.py +0 -0
  52. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/cli/display.py +0 -0
  53. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/cli/memory.py +0 -0
  54. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/cli/models.py +0 -0
  55. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/cli/streaming.py +0 -0
  56. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/clients/__init__.py +0 -0
  57. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/clients/anthropic.py +0 -0
  58. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/clients/base.py +0 -0
  59. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/clients/exceptions.py +0 -0
  60. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/clients/models.py +0 -0
  61. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/clients/streaming.py +0 -0
  62. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/config/__init__.py +0 -0
  63. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/config/api_keys.py +0 -0
  64. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/config/auth.py +0 -0
  65. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/config/database.py +0 -0
  66. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/config/oauth_flow.py +0 -0
  67. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/config/oauth_tokens.py +0 -0
  68. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/config/settings.py +0 -0
  69. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/database/__init__.py +0 -0
  70. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/database/connection.py +0 -0
  71. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/database/resolver.py +0 -0
  72. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/mcp/__init__.py +0 -0
  73. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/memory/__init__.py +0 -0
  74. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/memory/manager.py +0 -0
  75. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/memory/storage.py +0 -0
  76. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/models/__init__.py +0 -0
  77. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/models/events.py +0 -0
  78. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/models/types.py +0 -0
  79. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/tests/__init__.py +0 -0
  80. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/tests/conftest.py +0 -0
  81. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/tests/test_agents/test_anthropic_oauth.py +0 -0
  82. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/tests/test_cli/__init__.py +0 -0
  83. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/tests/test_clients/test_anthropic_client.py +0 -0
  84. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/tests/test_clients/test_streaming.py +0 -0
  85. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/tests/test_config/__init__.py +0 -0
  86. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/tests/test_config/test_database.py +0 -0
  87. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/tests/test_config/test_oauth.py +0 -0
  88. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/tests/test_config/test_settings.py +0 -0
  89. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/tests/test_database/__init__.py +0 -0
  90. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/tests/test_database/test_connection.py +0 -0
  91. {sqlsaber-0.13.0 → sqlsaber-0.15.0}/tests/test_database_resolver.py +0 -0
@@ -17,14 +17,14 @@ jobs:
17
17
  # github.event.pull_request.user.login == 'external-contributor' ||
18
18
  # github.event.pull_request.user.login == 'new-developer' ||
19
19
  # github.event.pull_request.author_association == 'FIRST_TIME_CONTRIBUTOR'
20
-
20
+
21
21
  runs-on: ubuntu-latest
22
22
  permissions:
23
23
  contents: read
24
24
  pull-requests: read
25
25
  issues: read
26
26
  id-token: write
27
-
27
+
28
28
  steps:
29
29
  - name: Checkout repository
30
30
  uses: actions/checkout@v4
@@ -39,7 +39,7 @@ jobs:
39
39
 
40
40
  # Optional: Specify model (defaults to Claude Sonnet 4, uncomment for Claude Opus 4)
41
41
  # model: "claude-opus-4-20250514"
42
-
42
+
43
43
  # Direct prompt for automated review (no @claude mention needed)
44
44
  direct_prompt: |
45
45
  Please review this pull request and provide feedback on:
@@ -48,12 +48,12 @@ jobs:
48
48
  - Performance considerations
49
49
  - Security concerns
50
50
  - Test coverage
51
-
51
+
52
52
  Be constructive and helpful in your feedback.
53
53
 
54
54
  # Optional: Use sticky comments to make Claude reuse the same comment on subsequent pushes to the same PR
55
55
  # use_sticky_comment: true
56
-
56
+
57
57
  # Optional: Customize review based on file types
58
58
  # direct_prompt: |
59
59
  # Review this PR focusing on:
@@ -61,18 +61,17 @@ jobs:
61
61
  # - For API endpoints: Security, input validation, and error handling
62
62
  # - For React components: Performance, accessibility, and best practices
63
63
  # - For tests: Coverage, edge cases, and test quality
64
-
64
+
65
65
  # Optional: Different prompts for different authors
66
66
  # direct_prompt: |
67
- # ${{ github.event.pull_request.author_association == 'FIRST_TIME_CONTRIBUTOR' &&
67
+ # ${{ github.event.pull_request.author_association == 'FIRST_TIME_CONTRIBUTOR' &&
68
68
  # 'Welcome! Please review this PR from a first-time contributor. Be encouraging and provide detailed explanations for any suggestions.' ||
69
69
  # 'Please provide a thorough code review focusing on our coding standards and best practices.' }}
70
-
70
+
71
71
  # Optional: Add specific tools for running tests or linting
72
72
  # allowed_tools: "Bash(npm run test),Bash(npm run lint),Bash(npm run typecheck)"
73
-
74
- # Optional: Skip review for certain conditions
75
- # if: |
76
- # !contains(github.event.pull_request.title, '[skip-review]') &&
77
- # !contains(github.event.pull_request.title, '[WIP]')
78
73
 
74
+ # Optional: Skip review for certain conditions
75
+ if: |
76
+ !contains(github.event.pull_request.title, '[skip-review]') &&
77
+ !contains(github.event.pull_request.title, '[WIP]')
@@ -0,0 +1,33 @@
1
+ name: Tests
2
+
3
+ on:
4
+ push:
5
+ branches: [main]
6
+ pull_request:
7
+
8
+ # Cancel active CI runs for a PR before starting another run
9
+ concurrency:
10
+ group: ${{ github.workflow}}-${{ github.ref }}
11
+ cancel-in-progress: ${{ github.event_name == 'pull_request' }}
12
+
13
+ jobs:
14
+ test:
15
+ runs-on: ubuntu-latest
16
+ steps:
17
+ - uses: actions/checkout@v4
18
+
19
+ - name: Install uv
20
+ uses: astral-sh/setup-uv@v5
21
+ with:
22
+ enable-cache: true
23
+
24
+ - name: Set up Python
25
+ uses: actions/setup-python@v5
26
+ with:
27
+ python-version: "3.12"
28
+
29
+ - name: Install dependencies
30
+ run: uv sync --locked --all-extras --dev
31
+
32
+ - name: Run tests
33
+ run: uv run python -m pytest
@@ -4,6 +4,41 @@ All notable changes to SQLSaber will be documented in this file.
4
4
 
5
5
  ## [Unreleased]
6
6
 
7
+ ## [0.15.0] - 2025-08-18
8
+
9
+ ### Added
10
+
11
+ - Tool abstraction system with centralized registry (new `Tool` base class, `ToolRegistry`, decorators)
12
+ - Dynamic instruction generation system (`InstructionBuilder`)
13
+ - Comprehensive test suite for the tools module
14
+
15
+ ### Changed
16
+
17
+ - Refactored agents to use centralized tool registry instead of hardcoded tools
18
+ - Enhanced MCP server with dynamic tool registration
19
+ - Moved core SQL functionality to dedicated tool classes
20
+
21
+ ## [0.14.0] - 2025-08-01
22
+
23
+ ### Added
24
+
25
+ - Local conversation storage between user and agent
26
+ - Store conversation history persistently
27
+ - Track messages with proper attribution
28
+ - Added automated test execution in CI
29
+ - New GitHub Actions workflow for running tests
30
+ - Updated code review workflow
31
+
32
+ ### Fixed
33
+
34
+ - Fixed CLI commands test suite (#11)
35
+
36
+ ### Changed
37
+
38
+ - Removed schema caching from SchemaManager
39
+ - Simplified schema introspection by removing cache logic
40
+ - Direct database queries for schema information
41
+
7
42
  ## [0.13.0] - 2025-07-26
8
43
 
9
44
  ### Added
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sqlsaber
3
- Version: 0.13.0
3
+ Version: 0.15.0
4
4
  Summary: SQLSaber - Agentic SQL assistant like Claude Code
5
5
  License-File: LICENSE
6
6
  Requires-Python: >=3.12
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "sqlsaber"
3
- version = "0.13.0"
3
+ version = "0.15.0"
4
4
  description = "SQLSaber - Agentic SQL assistant like Claude Code"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.12"
@@ -21,6 +21,8 @@ from sqlsaber.config.settings import Config
21
21
  from sqlsaber.database.connection import BaseDatabaseConnection
22
22
  from sqlsaber.memory.manager import MemoryManager
23
23
  from sqlsaber.models.events import StreamEvent
24
+ from sqlsaber.tools import tool_registry
25
+ from sqlsaber.tools.instructions import InstructionBuilder
24
26
 
25
27
 
26
28
  class AnthropicSQLAgent(BaseSQLAgent):
@@ -51,89 +53,11 @@ class AnthropicSQLAgent(BaseSQLAgent):
51
53
  self._last_results = None
52
54
  self._last_query = None
53
55
 
54
- # Define tools in the new format
55
- self.tools: list[ToolDefinition] = [
56
- ToolDefinition(
57
- name="list_tables",
58
- description="Get a list of all tables in the database with row counts. Use this first to discover available tables.",
59
- input_schema={
60
- "type": "object",
61
- "properties": {},
62
- "required": [],
63
- },
64
- ),
65
- ToolDefinition(
66
- name="introspect_schema",
67
- description="Introspect database schema to understand table structures.",
68
- input_schema={
69
- "type": "object",
70
- "properties": {
71
- "table_pattern": {
72
- "type": "string",
73
- "description": "Optional pattern to filter tables (e.g., 'public.users', 'user%', '%order%')",
74
- }
75
- },
76
- "required": [],
77
- },
78
- ),
79
- ToolDefinition(
80
- name="execute_sql",
81
- description="Execute a SQL query against the database.",
82
- input_schema={
83
- "type": "object",
84
- "properties": {
85
- "query": {
86
- "type": "string",
87
- "description": "SQL query to execute",
88
- },
89
- "limit": {
90
- "type": "integer",
91
- "description": f"Maximum number of rows to return (default: {AnthropicSQLAgent.DEFAULT_SQL_LIMIT})",
92
- "default": AnthropicSQLAgent.DEFAULT_SQL_LIMIT,
93
- },
94
- },
95
- "required": ["query"],
96
- },
97
- ),
98
- ToolDefinition(
99
- name="plot_data",
100
- description="Create a plot of query results.",
101
- input_schema={
102
- "type": "object",
103
- "properties": {
104
- "y_values": {
105
- "type": "array",
106
- "items": {"type": ["number", "null"]},
107
- "description": "Y-axis data points (required)",
108
- },
109
- "x_values": {
110
- "type": "array",
111
- "items": {"type": ["number", "null"]},
112
- "description": "X-axis data points (optional, will use indices if not provided)",
113
- },
114
- "plot_type": {
115
- "type": "string",
116
- "enum": ["line", "scatter", "histogram"],
117
- "description": "Type of plot to create (default: line)",
118
- "default": "line",
119
- },
120
- "title": {
121
- "type": "string",
122
- "description": "Title for the plot",
123
- },
124
- "x_label": {
125
- "type": "string",
126
- "description": "Label for X-axis",
127
- },
128
- "y_label": {
129
- "type": "string",
130
- "description": "Label for Y-axis",
131
- },
132
- },
133
- "required": ["y_values"],
134
- },
135
- ),
136
- ]
56
+ # Get tool definitions from registry
57
+ self.tools: list[ToolDefinition] = tool_registry.get_tool_definitions()
58
+
59
+ # Initialize instruction builder
60
+ self.instruction_builder = InstructionBuilder(tool_registry)
137
61
 
138
62
  # Build system prompt with memories if available
139
63
  self.system_prompt = self._build_system_prompt()
@@ -157,31 +81,9 @@ class AnthropicSQLAgent(BaseSQLAgent):
157
81
  def _get_sql_assistant_instructions(self) -> str:
158
82
  """Get the detailed SQL assistant instructions."""
159
83
  db_type = self._get_database_type_name()
160
- instructions = f"""You are also a helpful SQL assistant that helps users query their {db_type} database.
161
-
162
- Your responsibilities:
163
- 1. Understand user's natural language requests, think and convert them to SQL
164
- 2. Use the provided tools efficiently to explore database schema
165
- 3. Generate appropriate SQL queries
166
- 4. Execute queries safely - queries that modify the database are not allowed
167
- 5. Format and explain results clearly
168
- 6. Create visualizations when requested or when they would be helpful
169
-
170
- IMPORTANT - Schema Discovery Strategy:
171
- 1. ALWAYS start with 'list_tables' to see available tables and row counts
172
- 2. Based on the user's query, identify which specific tables are relevant
173
- 3. Use 'introspect_schema' with a table_pattern to get details ONLY for relevant tables
174
- 4. Timestamp columns must be converted to text when you write queries
175
-
176
- Guidelines:
177
- - Use list_tables first, then introspect_schema for specific tables only
178
- - Use table patterns like 'sample%' or '%experiment%' to filter related tables
179
- - Use proper JOIN syntax and avoid cartesian products
180
- - Include appropriate WHERE clauses to limit results
181
- - Explain what the query does in simple terms
182
- - Handle errors gracefully and suggest fixes
183
- - Be security conscious - use parameterized queries when needed
184
- """
84
+
85
+ # Build dynamic instructions from available tools
86
+ instructions = self.instruction_builder.build_instructions(db_type=db_type)
185
87
 
186
88
  # Add memory context if database name is available
187
89
  if self.database_name:
@@ -189,7 +91,7 @@ Guidelines:
189
91
  self.database_name
190
92
  )
191
93
  if memory_context.strip():
192
- instructions += memory_context
94
+ instructions += "\n\n" + memory_context
193
95
 
194
96
  return instructions
195
97
 
@@ -199,16 +101,19 @@ Guidelines:
199
101
  return None
200
102
 
201
103
  memory = self.memory_manager.add_memory(self.database_name, content)
202
- # Rebuild system prompt with new memory
104
+ # Rebuild system prompt with new memory (includes dynamic instructions)
203
105
  self.system_prompt = self._build_system_prompt()
204
106
  return memory.id
205
107
 
206
- async def execute_sql(self, query: str, limit: int | None = None) -> str:
207
- """Execute a SQL query against the database with streaming support."""
208
- # Call parent implementation for core functionality
209
- result = await super().execute_sql(query, limit)
108
+ async def _execute_sql_with_tracking(
109
+ self, query: str, limit: int | None = None
110
+ ) -> str:
111
+ """Execute SQL and track results for streaming."""
112
+ # Get the execute_sql tool and run it
113
+ tool = tool_registry.get_tool("execute_sql")
114
+ result = await tool.execute(query=query, limit=limit)
210
115
 
211
- # Parse result to extract data for streaming (AnthropicSQLAgent specific)
116
+ # Parse result to extract data for streaming
212
117
  try:
213
118
  result_data = json.loads(result)
214
119
  if result_data.get("success") and "results" in result_data:
@@ -228,7 +133,14 @@ Guidelines:
228
133
  self, tool_name: str, tool_input: dict[str, Any]
229
134
  ) -> str:
230
135
  """Process a tool call and return the result."""
231
- # Use parent implementation for core tools
136
+ # Special handling for execute_sql to track results
137
+ if tool_name == "execute_sql":
138
+ return await self._execute_sql_with_tracking(
139
+ tool_input.get("query", ""),
140
+ tool_input.get("limit", self.DEFAULT_SQL_LIMIT),
141
+ )
142
+
143
+ # Use parent implementation for all other tools
232
144
  return await super().process_tool_call(tool_name, tool_input)
233
145
 
234
146
  def _convert_user_message_to_message(
@@ -450,6 +362,16 @@ Guidelines:
450
362
  self._last_query = None
451
363
 
452
364
  try:
365
+ # Ensure conversation is active for persistence
366
+ await self._ensure_conversation()
367
+
368
+ # Store user message in conversation history and persistence
369
+ if use_history:
370
+ self.conversation_history.append(
371
+ {"role": "user", "content": user_query}
372
+ )
373
+ await self._store_user_message(user_query)
374
+
453
375
  # Build messages with history if requested
454
376
  messages = []
455
377
  if use_history:
@@ -461,8 +383,9 @@ Guidelines:
461
383
  instructions = self._get_sql_assistant_instructions()
462
384
  messages.append(Message(MessageRole.USER, instructions))
463
385
 
464
- # Add current user message
465
- messages.append(Message(MessageRole.USER, user_query))
386
+ # Add current user message if not already in messages from history
387
+ if not use_history:
388
+ messages.append(Message(MessageRole.USER, user_query))
466
389
 
467
390
  # Create initial request and get response
468
391
  request = self._create_message_request(messages)
@@ -484,9 +407,12 @@ Guidelines:
484
407
  return
485
408
 
486
409
  # Add assistant's response to conversation
487
- collected_content.append(
488
- {"role": "assistant", "content": response.content}
489
- )
410
+ assistant_content = {"role": "assistant", "content": response.content}
411
+ collected_content.append(assistant_content)
412
+
413
+ # Store the assistant message immediately (not from collected_content)
414
+ if use_history:
415
+ await self._store_assistant_message(response.content)
490
416
 
491
417
  # Execute tools and get results
492
418
  tool_results = []
@@ -499,9 +425,19 @@ Guidelines:
499
425
  tool_results = event
500
426
 
501
427
  # Continue conversation with tool results
502
- collected_content.append({"role": "user", "content": tool_results})
428
+ tool_content = {"role": "user", "content": tool_results}
429
+ collected_content.append(tool_content)
430
+
431
+ # Store the tool message immediately and update history
503
432
  if use_history:
504
- self.conversation_history.extend(collected_content)
433
+ # Only add the NEW messages to history (not the accumulated ones)
434
+ # collected_content has [assistant1, tool1, assistant2, tool2, ...]
435
+ # We only want to add the last 2 items that were just added
436
+ new_messages_for_history = collected_content[
437
+ -2:
438
+ ] # Last assistant + tool pair
439
+ self.conversation_history.extend(new_messages_for_history)
440
+ await self._store_tool_message(tool_results)
505
441
 
506
442
  if cancellation_token is not None and cancellation_token.is_set():
507
443
  return
@@ -541,6 +477,10 @@ Guidelines:
541
477
  {"role": "assistant", "content": response.content}
542
478
  )
543
479
 
480
+ # Store final assistant message in persistence (only if not tool_use)
481
+ if response.stop_reason != "tool_use":
482
+ await self._store_assistant_message(response.content)
483
+
544
484
  except asyncio.CancelledError:
545
485
  return
546
486
  except Exception as e:
@@ -0,0 +1,187 @@
1
+ """Abstract base class for SQL agents."""
2
+
3
+ import asyncio
4
+ import json
5
+ from abc import ABC, abstractmethod
6
+ from typing import Any, AsyncIterator
7
+
8
+ from sqlsaber.conversation.manager import ConversationManager
9
+ from sqlsaber.database.connection import (
10
+ BaseDatabaseConnection,
11
+ CSVConnection,
12
+ MySQLConnection,
13
+ PostgreSQLConnection,
14
+ SQLiteConnection,
15
+ )
16
+ from sqlsaber.database.schema import SchemaManager
17
+ from sqlsaber.models.events import StreamEvent
18
+ from sqlsaber.tools import SQLTool, tool_registry
19
+
20
+
21
+ class BaseSQLAgent(ABC):
22
+ """Abstract base class for SQL agents."""
23
+
24
+ def __init__(self, db_connection: BaseDatabaseConnection):
25
+ self.db = db_connection
26
+ self.schema_manager = SchemaManager(db_connection)
27
+ self.conversation_history: list[dict[str, Any]] = []
28
+
29
+ # Conversation persistence
30
+ self._conv_manager = ConversationManager()
31
+ self._conversation_id: str | None = None
32
+ self._msg_index: int = 0
33
+
34
+ # Initialize SQL tools with database connection
35
+ self._init_tools()
36
+
37
+ @abstractmethod
38
+ async def query_stream(
39
+ self,
40
+ user_query: str,
41
+ use_history: bool = True,
42
+ cancellation_token: asyncio.Event | None = None,
43
+ ) -> AsyncIterator[StreamEvent]:
44
+ """Process a user query and stream responses.
45
+
46
+ Args:
47
+ user_query: The user's query to process
48
+ use_history: Whether to include conversation history
49
+ cancellation_token: Optional event to signal cancellation
50
+ """
51
+ pass
52
+
53
+ async def clear_history(self):
54
+ """Clear conversation history."""
55
+ # End current conversation in storage
56
+ await self._end_conversation()
57
+
58
+ # Clear in-memory history
59
+ self.conversation_history = []
60
+
61
+ def _get_database_type_name(self) -> str:
62
+ """Get the human-readable database type name."""
63
+ if isinstance(self.db, PostgreSQLConnection):
64
+ return "PostgreSQL"
65
+ elif isinstance(self.db, MySQLConnection):
66
+ return "MySQL"
67
+ elif isinstance(self.db, SQLiteConnection):
68
+ return "SQLite"
69
+ elif isinstance(self.db, CSVConnection):
70
+ return "SQLite" # we convert csv to in-memory sqlite
71
+ else:
72
+ return "database" # Fallback
73
+
74
+ def _init_tools(self) -> None:
75
+ """Initialize SQL tools with database connection."""
76
+ # Get all SQL tools and set their database connection
77
+ for tool_name in tool_registry.list_tools(category="sql"):
78
+ tool = tool_registry.get_tool(tool_name)
79
+ if isinstance(tool, SQLTool):
80
+ tool.set_connection(self.db)
81
+
82
+ async def process_tool_call(
83
+ self, tool_name: str, tool_input: dict[str, Any]
84
+ ) -> str:
85
+ """Process a tool call and return the result."""
86
+ try:
87
+ tool = tool_registry.get_tool(tool_name)
88
+ return await tool.execute(**tool_input)
89
+ except KeyError:
90
+ return json.dumps({"error": f"Unknown tool: {tool_name}"})
91
+ except Exception as e:
92
+ return json.dumps(
93
+ {"error": f"Error executing tool '{tool_name}': {str(e)}"}
94
+ )
95
+
96
+ # Conversation persistence helpers
97
+
98
+ async def _ensure_conversation(self) -> None:
99
+ """Ensure a conversation is active for storing messages."""
100
+ if self._conversation_id is None:
101
+ db_name = getattr(self, "database_name", "unknown")
102
+ self._conversation_id = await self._conv_manager.start_conversation(db_name)
103
+ self._msg_index = 0
104
+
105
+ async def _store_user_message(self, content: str | dict[str, Any]) -> None:
106
+ """Store a user message in conversation history."""
107
+ if self._conversation_id is None:
108
+ return
109
+
110
+ await self._conv_manager.add_user_message(
111
+ self._conversation_id, content, self._msg_index
112
+ )
113
+ self._msg_index += 1
114
+
115
+ async def _store_assistant_message(
116
+ self, content: list[dict[str, Any]] | dict[str, Any]
117
+ ) -> None:
118
+ """Store an assistant message in conversation history."""
119
+ if self._conversation_id is None:
120
+ return
121
+
122
+ await self._conv_manager.add_assistant_message(
123
+ self._conversation_id, content, self._msg_index
124
+ )
125
+ self._msg_index += 1
126
+
127
+ async def _store_tool_message(
128
+ self, content: list[dict[str, Any]] | dict[str, Any]
129
+ ) -> None:
130
+ """Store a tool/system message in conversation history."""
131
+ if self._conversation_id is None:
132
+ return
133
+
134
+ await self._conv_manager.add_tool_message(
135
+ self._conversation_id, content, self._msg_index
136
+ )
137
+ self._msg_index += 1
138
+
139
+ async def _end_conversation(self) -> None:
140
+ """End the current conversation."""
141
+ if self._conversation_id:
142
+ await self._conv_manager.end_conversation(self._conversation_id)
143
+ self._conversation_id = None
144
+ self._msg_index = 0
145
+
146
+ async def restore_conversation(self, conversation_id: str) -> bool:
147
+ """Restore a conversation from storage to in-memory history.
148
+
149
+ Args:
150
+ conversation_id: ID of the conversation to restore
151
+
152
+ Returns:
153
+ True if successfully restored, False otherwise
154
+ """
155
+ success = await self._conv_manager.restore_conversation_to_agent(
156
+ conversation_id, self.conversation_history
157
+ )
158
+
159
+ if success:
160
+ # Set up for continuing this conversation
161
+ self._conversation_id = conversation_id
162
+ self._msg_index = len(self.conversation_history)
163
+
164
+ return success
165
+
166
+ async def list_conversations(self, limit: int = 50) -> list:
167
+ """List conversations for this agent's database.
168
+
169
+ Args:
170
+ limit: Maximum number of conversations to return
171
+
172
+ Returns:
173
+ List of conversation data
174
+ """
175
+ db_name = getattr(self, "database_name", None)
176
+ conversations = await self._conv_manager.list_conversations(db_name, limit)
177
+
178
+ return [
179
+ {
180
+ "id": conv.id,
181
+ "database_name": conv.database_name,
182
+ "started_at": conv.formatted_start_time(),
183
+ "ended_at": conv.formatted_end_time(),
184
+ "duration": conv.duration_seconds(),
185
+ }
186
+ for conv in conversations
187
+ ]
@@ -136,11 +136,15 @@ class InteractiveSession:
136
136
  if not user_query:
137
137
  continue
138
138
 
139
- if user_query in ["/exit", "/quit"]:
139
+ if (
140
+ user_query in ["/exit", "/quit"]
141
+ or user_query.startswith("/exit")
142
+ or user_query.startswith("/quit")
143
+ ):
140
144
  break
141
145
 
142
146
  if user_query == "/clear":
143
- self.agent.clear_history()
147
+ await self.agent.clear_history()
144
148
  self.console.print("[green]Conversation history cleared.[/green]\n")
145
149
  continue
146
150
 
@@ -0,0 +1,12 @@
1
+ """Conversation history storage for SQLSaber."""
2
+
3
+ from .manager import ConversationManager
4
+ from .models import Conversation, ConversationMessage
5
+ from .storage import ConversationStorage
6
+
7
+ __all__ = [
8
+ "Conversation",
9
+ "ConversationMessage",
10
+ "ConversationStorage",
11
+ "ConversationManager",
12
+ ]