sqlsaber 0.13.0__tar.gz → 0.15.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/.github/workflows/claude-code-review.yml +12 -13
- sqlsaber-0.15.0/.github/workflows/test.yml +33 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/CHANGELOG.md +35 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/PKG-INFO +1 -1
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/pyproject.toml +1 -1
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/agents/anthropic.py +63 -123
- sqlsaber-0.15.0/src/sqlsaber/agents/base.py +187 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/cli/interactive.py +6 -2
- sqlsaber-0.15.0/src/sqlsaber/conversation/__init__.py +12 -0
- sqlsaber-0.15.0/src/sqlsaber/conversation/manager.py +224 -0
- sqlsaber-0.15.0/src/sqlsaber/conversation/models.py +120 -0
- sqlsaber-0.15.0/src/sqlsaber/conversation/storage.py +362 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/database/schema.py +2 -51
- sqlsaber-0.15.0/src/sqlsaber/mcp/mcp.py +129 -0
- sqlsaber-0.15.0/src/sqlsaber/tools/__init__.py +25 -0
- sqlsaber-0.15.0/src/sqlsaber/tools/base.py +83 -0
- sqlsaber-0.15.0/src/sqlsaber/tools/enums.py +21 -0
- sqlsaber-0.15.0/src/sqlsaber/tools/instructions.py +251 -0
- sqlsaber-0.15.0/src/sqlsaber/tools/registry.py +130 -0
- sqlsaber-0.15.0/src/sqlsaber/tools/sql_tools.py +275 -0
- sqlsaber-0.15.0/src/sqlsaber/tools/visualization_tools.py +144 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/tests/test_cli/test_commands.py +23 -23
- sqlsaber-0.15.0/tests/test_conversation_storage.py +136 -0
- sqlsaber-0.15.0/tests/test_tools/__init__.py +1 -0
- sqlsaber-0.15.0/tests/test_tools/test_base.py +63 -0
- sqlsaber-0.15.0/tests/test_tools/test_instructions.py +255 -0
- sqlsaber-0.15.0/tests/test_tools/test_registry.py +189 -0
- sqlsaber-0.15.0/tests/test_tools/test_sql_tools.py +218 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/uv.lock +1 -1
- sqlsaber-0.13.0/src/sqlsaber/agents/base.py +0 -286
- sqlsaber-0.13.0/src/sqlsaber/mcp/mcp.py +0 -137
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/.github/workflows/claude.yml +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/.github/workflows/publish.yml +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/.gitignore +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/.python-version +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/AGENT.md +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/CLAUDE.md +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/LICENSE +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/README.md +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/pytest.ini +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/sqlsaber.svg +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/__init__.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/__main__.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/agents/__init__.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/agents/mcp.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/agents/streaming.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/cli/__init__.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/cli/auth.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/cli/commands.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/cli/completers.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/cli/database.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/cli/display.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/cli/memory.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/cli/models.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/cli/streaming.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/clients/__init__.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/clients/anthropic.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/clients/base.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/clients/exceptions.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/clients/models.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/clients/streaming.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/config/__init__.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/config/api_keys.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/config/auth.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/config/database.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/config/oauth_flow.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/config/oauth_tokens.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/config/settings.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/database/__init__.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/database/connection.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/database/resolver.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/mcp/__init__.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/memory/__init__.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/memory/manager.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/memory/storage.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/models/__init__.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/models/events.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/src/sqlsaber/models/types.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/tests/__init__.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/tests/conftest.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/tests/test_agents/test_anthropic_oauth.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/tests/test_cli/__init__.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/tests/test_clients/test_anthropic_client.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/tests/test_clients/test_streaming.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/tests/test_config/__init__.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/tests/test_config/test_database.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/tests/test_config/test_oauth.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/tests/test_config/test_settings.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/tests/test_database/__init__.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/tests/test_database/test_connection.py +0 -0
- {sqlsaber-0.13.0 → sqlsaber-0.15.0}/tests/test_database_resolver.py +0 -0
|
@@ -17,14 +17,14 @@ jobs:
|
|
|
17
17
|
# github.event.pull_request.user.login == 'external-contributor' ||
|
|
18
18
|
# github.event.pull_request.user.login == 'new-developer' ||
|
|
19
19
|
# github.event.pull_request.author_association == 'FIRST_TIME_CONTRIBUTOR'
|
|
20
|
-
|
|
20
|
+
|
|
21
21
|
runs-on: ubuntu-latest
|
|
22
22
|
permissions:
|
|
23
23
|
contents: read
|
|
24
24
|
pull-requests: read
|
|
25
25
|
issues: read
|
|
26
26
|
id-token: write
|
|
27
|
-
|
|
27
|
+
|
|
28
28
|
steps:
|
|
29
29
|
- name: Checkout repository
|
|
30
30
|
uses: actions/checkout@v4
|
|
@@ -39,7 +39,7 @@ jobs:
|
|
|
39
39
|
|
|
40
40
|
# Optional: Specify model (defaults to Claude Sonnet 4, uncomment for Claude Opus 4)
|
|
41
41
|
# model: "claude-opus-4-20250514"
|
|
42
|
-
|
|
42
|
+
|
|
43
43
|
# Direct prompt for automated review (no @claude mention needed)
|
|
44
44
|
direct_prompt: |
|
|
45
45
|
Please review this pull request and provide feedback on:
|
|
@@ -48,12 +48,12 @@ jobs:
|
|
|
48
48
|
- Performance considerations
|
|
49
49
|
- Security concerns
|
|
50
50
|
- Test coverage
|
|
51
|
-
|
|
51
|
+
|
|
52
52
|
Be constructive and helpful in your feedback.
|
|
53
53
|
|
|
54
54
|
# Optional: Use sticky comments to make Claude reuse the same comment on subsequent pushes to the same PR
|
|
55
55
|
# use_sticky_comment: true
|
|
56
|
-
|
|
56
|
+
|
|
57
57
|
# Optional: Customize review based on file types
|
|
58
58
|
# direct_prompt: |
|
|
59
59
|
# Review this PR focusing on:
|
|
@@ -61,18 +61,17 @@ jobs:
|
|
|
61
61
|
# - For API endpoints: Security, input validation, and error handling
|
|
62
62
|
# - For React components: Performance, accessibility, and best practices
|
|
63
63
|
# - For tests: Coverage, edge cases, and test quality
|
|
64
|
-
|
|
64
|
+
|
|
65
65
|
# Optional: Different prompts for different authors
|
|
66
66
|
# direct_prompt: |
|
|
67
|
-
# ${{ github.event.pull_request.author_association == 'FIRST_TIME_CONTRIBUTOR' &&
|
|
67
|
+
# ${{ github.event.pull_request.author_association == 'FIRST_TIME_CONTRIBUTOR' &&
|
|
68
68
|
# 'Welcome! Please review this PR from a first-time contributor. Be encouraging and provide detailed explanations for any suggestions.' ||
|
|
69
69
|
# 'Please provide a thorough code review focusing on our coding standards and best practices.' }}
|
|
70
|
-
|
|
70
|
+
|
|
71
71
|
# Optional: Add specific tools for running tests or linting
|
|
72
72
|
# allowed_tools: "Bash(npm run test),Bash(npm run lint),Bash(npm run typecheck)"
|
|
73
|
-
|
|
74
|
-
# Optional: Skip review for certain conditions
|
|
75
|
-
# if: |
|
|
76
|
-
# !contains(github.event.pull_request.title, '[skip-review]') &&
|
|
77
|
-
# !contains(github.event.pull_request.title, '[WIP]')
|
|
78
73
|
|
|
74
|
+
# Optional: Skip review for certain conditions
|
|
75
|
+
if: |
|
|
76
|
+
!contains(github.event.pull_request.title, '[skip-review]') &&
|
|
77
|
+
!contains(github.event.pull_request.title, '[WIP]')
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
name: Tests
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
branches: [main]
|
|
6
|
+
pull_request:
|
|
7
|
+
|
|
8
|
+
# Cancel active CI runs for a PR before starting another run
|
|
9
|
+
concurrency:
|
|
10
|
+
group: ${{ github.workflow}}-${{ github.ref }}
|
|
11
|
+
cancel-in-progress: ${{ github.event_name == 'pull_request' }}
|
|
12
|
+
|
|
13
|
+
jobs:
|
|
14
|
+
test:
|
|
15
|
+
runs-on: ubuntu-latest
|
|
16
|
+
steps:
|
|
17
|
+
- uses: actions/checkout@v4
|
|
18
|
+
|
|
19
|
+
- name: Install uv
|
|
20
|
+
uses: astral-sh/setup-uv@v5
|
|
21
|
+
with:
|
|
22
|
+
enable-cache: true
|
|
23
|
+
|
|
24
|
+
- name: Set up Python
|
|
25
|
+
uses: actions/setup-python@v5
|
|
26
|
+
with:
|
|
27
|
+
python-version: "3.12"
|
|
28
|
+
|
|
29
|
+
- name: Install dependencies
|
|
30
|
+
run: uv sync --locked --all-extras --dev
|
|
31
|
+
|
|
32
|
+
- name: Run tests
|
|
33
|
+
run: uv run python -m pytest
|
|
@@ -4,6 +4,41 @@ All notable changes to SQLSaber will be documented in this file.
|
|
|
4
4
|
|
|
5
5
|
## [Unreleased]
|
|
6
6
|
|
|
7
|
+
## [0.15.0] - 2025-08-18
|
|
8
|
+
|
|
9
|
+
### Added
|
|
10
|
+
|
|
11
|
+
- Tool abstraction system with centralized registry (new `Tool` base class, `ToolRegistry`, decorators)
|
|
12
|
+
- Dynamic instruction generation system (`InstructionBuilder`)
|
|
13
|
+
- Comprehensive test suite for the tools module
|
|
14
|
+
|
|
15
|
+
### Changed
|
|
16
|
+
|
|
17
|
+
- Refactored agents to use centralized tool registry instead of hardcoded tools
|
|
18
|
+
- Enhanced MCP server with dynamic tool registration
|
|
19
|
+
- Moved core SQL functionality to dedicated tool classes
|
|
20
|
+
|
|
21
|
+
## [0.14.0] - 2025-08-01
|
|
22
|
+
|
|
23
|
+
### Added
|
|
24
|
+
|
|
25
|
+
- Local conversation storage between user and agent
|
|
26
|
+
- Store conversation history persistently
|
|
27
|
+
- Track messages with proper attribution
|
|
28
|
+
- Added automated test execution in CI
|
|
29
|
+
- New GitHub Actions workflow for running tests
|
|
30
|
+
- Updated code review workflow
|
|
31
|
+
|
|
32
|
+
### Fixed
|
|
33
|
+
|
|
34
|
+
- Fixed CLI commands test suite (#11)
|
|
35
|
+
|
|
36
|
+
### Changed
|
|
37
|
+
|
|
38
|
+
- Removed schema caching from SchemaManager
|
|
39
|
+
- Simplified schema introspection by removing cache logic
|
|
40
|
+
- Direct database queries for schema information
|
|
41
|
+
|
|
7
42
|
## [0.13.0] - 2025-07-26
|
|
8
43
|
|
|
9
44
|
### Added
|
|
@@ -21,6 +21,8 @@ from sqlsaber.config.settings import Config
|
|
|
21
21
|
from sqlsaber.database.connection import BaseDatabaseConnection
|
|
22
22
|
from sqlsaber.memory.manager import MemoryManager
|
|
23
23
|
from sqlsaber.models.events import StreamEvent
|
|
24
|
+
from sqlsaber.tools import tool_registry
|
|
25
|
+
from sqlsaber.tools.instructions import InstructionBuilder
|
|
24
26
|
|
|
25
27
|
|
|
26
28
|
class AnthropicSQLAgent(BaseSQLAgent):
|
|
@@ -51,89 +53,11 @@ class AnthropicSQLAgent(BaseSQLAgent):
|
|
|
51
53
|
self._last_results = None
|
|
52
54
|
self._last_query = None
|
|
53
55
|
|
|
54
|
-
#
|
|
55
|
-
self.tools: list[ToolDefinition] =
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
input_schema={
|
|
60
|
-
"type": "object",
|
|
61
|
-
"properties": {},
|
|
62
|
-
"required": [],
|
|
63
|
-
},
|
|
64
|
-
),
|
|
65
|
-
ToolDefinition(
|
|
66
|
-
name="introspect_schema",
|
|
67
|
-
description="Introspect database schema to understand table structures.",
|
|
68
|
-
input_schema={
|
|
69
|
-
"type": "object",
|
|
70
|
-
"properties": {
|
|
71
|
-
"table_pattern": {
|
|
72
|
-
"type": "string",
|
|
73
|
-
"description": "Optional pattern to filter tables (e.g., 'public.users', 'user%', '%order%')",
|
|
74
|
-
}
|
|
75
|
-
},
|
|
76
|
-
"required": [],
|
|
77
|
-
},
|
|
78
|
-
),
|
|
79
|
-
ToolDefinition(
|
|
80
|
-
name="execute_sql",
|
|
81
|
-
description="Execute a SQL query against the database.",
|
|
82
|
-
input_schema={
|
|
83
|
-
"type": "object",
|
|
84
|
-
"properties": {
|
|
85
|
-
"query": {
|
|
86
|
-
"type": "string",
|
|
87
|
-
"description": "SQL query to execute",
|
|
88
|
-
},
|
|
89
|
-
"limit": {
|
|
90
|
-
"type": "integer",
|
|
91
|
-
"description": f"Maximum number of rows to return (default: {AnthropicSQLAgent.DEFAULT_SQL_LIMIT})",
|
|
92
|
-
"default": AnthropicSQLAgent.DEFAULT_SQL_LIMIT,
|
|
93
|
-
},
|
|
94
|
-
},
|
|
95
|
-
"required": ["query"],
|
|
96
|
-
},
|
|
97
|
-
),
|
|
98
|
-
ToolDefinition(
|
|
99
|
-
name="plot_data",
|
|
100
|
-
description="Create a plot of query results.",
|
|
101
|
-
input_schema={
|
|
102
|
-
"type": "object",
|
|
103
|
-
"properties": {
|
|
104
|
-
"y_values": {
|
|
105
|
-
"type": "array",
|
|
106
|
-
"items": {"type": ["number", "null"]},
|
|
107
|
-
"description": "Y-axis data points (required)",
|
|
108
|
-
},
|
|
109
|
-
"x_values": {
|
|
110
|
-
"type": "array",
|
|
111
|
-
"items": {"type": ["number", "null"]},
|
|
112
|
-
"description": "X-axis data points (optional, will use indices if not provided)",
|
|
113
|
-
},
|
|
114
|
-
"plot_type": {
|
|
115
|
-
"type": "string",
|
|
116
|
-
"enum": ["line", "scatter", "histogram"],
|
|
117
|
-
"description": "Type of plot to create (default: line)",
|
|
118
|
-
"default": "line",
|
|
119
|
-
},
|
|
120
|
-
"title": {
|
|
121
|
-
"type": "string",
|
|
122
|
-
"description": "Title for the plot",
|
|
123
|
-
},
|
|
124
|
-
"x_label": {
|
|
125
|
-
"type": "string",
|
|
126
|
-
"description": "Label for X-axis",
|
|
127
|
-
},
|
|
128
|
-
"y_label": {
|
|
129
|
-
"type": "string",
|
|
130
|
-
"description": "Label for Y-axis",
|
|
131
|
-
},
|
|
132
|
-
},
|
|
133
|
-
"required": ["y_values"],
|
|
134
|
-
},
|
|
135
|
-
),
|
|
136
|
-
]
|
|
56
|
+
# Get tool definitions from registry
|
|
57
|
+
self.tools: list[ToolDefinition] = tool_registry.get_tool_definitions()
|
|
58
|
+
|
|
59
|
+
# Initialize instruction builder
|
|
60
|
+
self.instruction_builder = InstructionBuilder(tool_registry)
|
|
137
61
|
|
|
138
62
|
# Build system prompt with memories if available
|
|
139
63
|
self.system_prompt = self._build_system_prompt()
|
|
@@ -157,31 +81,9 @@ class AnthropicSQLAgent(BaseSQLAgent):
|
|
|
157
81
|
def _get_sql_assistant_instructions(self) -> str:
|
|
158
82
|
"""Get the detailed SQL assistant instructions."""
|
|
159
83
|
db_type = self._get_database_type_name()
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
1. Understand user's natural language requests, think and convert them to SQL
|
|
164
|
-
2. Use the provided tools efficiently to explore database schema
|
|
165
|
-
3. Generate appropriate SQL queries
|
|
166
|
-
4. Execute queries safely - queries that modify the database are not allowed
|
|
167
|
-
5. Format and explain results clearly
|
|
168
|
-
6. Create visualizations when requested or when they would be helpful
|
|
169
|
-
|
|
170
|
-
IMPORTANT - Schema Discovery Strategy:
|
|
171
|
-
1. ALWAYS start with 'list_tables' to see available tables and row counts
|
|
172
|
-
2. Based on the user's query, identify which specific tables are relevant
|
|
173
|
-
3. Use 'introspect_schema' with a table_pattern to get details ONLY for relevant tables
|
|
174
|
-
4. Timestamp columns must be converted to text when you write queries
|
|
175
|
-
|
|
176
|
-
Guidelines:
|
|
177
|
-
- Use list_tables first, then introspect_schema for specific tables only
|
|
178
|
-
- Use table patterns like 'sample%' or '%experiment%' to filter related tables
|
|
179
|
-
- Use proper JOIN syntax and avoid cartesian products
|
|
180
|
-
- Include appropriate WHERE clauses to limit results
|
|
181
|
-
- Explain what the query does in simple terms
|
|
182
|
-
- Handle errors gracefully and suggest fixes
|
|
183
|
-
- Be security conscious - use parameterized queries when needed
|
|
184
|
-
"""
|
|
84
|
+
|
|
85
|
+
# Build dynamic instructions from available tools
|
|
86
|
+
instructions = self.instruction_builder.build_instructions(db_type=db_type)
|
|
185
87
|
|
|
186
88
|
# Add memory context if database name is available
|
|
187
89
|
if self.database_name:
|
|
@@ -189,7 +91,7 @@ Guidelines:
|
|
|
189
91
|
self.database_name
|
|
190
92
|
)
|
|
191
93
|
if memory_context.strip():
|
|
192
|
-
instructions += memory_context
|
|
94
|
+
instructions += "\n\n" + memory_context
|
|
193
95
|
|
|
194
96
|
return instructions
|
|
195
97
|
|
|
@@ -199,16 +101,19 @@ Guidelines:
|
|
|
199
101
|
return None
|
|
200
102
|
|
|
201
103
|
memory = self.memory_manager.add_memory(self.database_name, content)
|
|
202
|
-
# Rebuild system prompt with new memory
|
|
104
|
+
# Rebuild system prompt with new memory (includes dynamic instructions)
|
|
203
105
|
self.system_prompt = self._build_system_prompt()
|
|
204
106
|
return memory.id
|
|
205
107
|
|
|
206
|
-
async def
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
108
|
+
async def _execute_sql_with_tracking(
|
|
109
|
+
self, query: str, limit: int | None = None
|
|
110
|
+
) -> str:
|
|
111
|
+
"""Execute SQL and track results for streaming."""
|
|
112
|
+
# Get the execute_sql tool and run it
|
|
113
|
+
tool = tool_registry.get_tool("execute_sql")
|
|
114
|
+
result = await tool.execute(query=query, limit=limit)
|
|
210
115
|
|
|
211
|
-
# Parse result to extract data for streaming
|
|
116
|
+
# Parse result to extract data for streaming
|
|
212
117
|
try:
|
|
213
118
|
result_data = json.loads(result)
|
|
214
119
|
if result_data.get("success") and "results" in result_data:
|
|
@@ -228,7 +133,14 @@ Guidelines:
|
|
|
228
133
|
self, tool_name: str, tool_input: dict[str, Any]
|
|
229
134
|
) -> str:
|
|
230
135
|
"""Process a tool call and return the result."""
|
|
231
|
-
#
|
|
136
|
+
# Special handling for execute_sql to track results
|
|
137
|
+
if tool_name == "execute_sql":
|
|
138
|
+
return await self._execute_sql_with_tracking(
|
|
139
|
+
tool_input.get("query", ""),
|
|
140
|
+
tool_input.get("limit", self.DEFAULT_SQL_LIMIT),
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# Use parent implementation for all other tools
|
|
232
144
|
return await super().process_tool_call(tool_name, tool_input)
|
|
233
145
|
|
|
234
146
|
def _convert_user_message_to_message(
|
|
@@ -450,6 +362,16 @@ Guidelines:
|
|
|
450
362
|
self._last_query = None
|
|
451
363
|
|
|
452
364
|
try:
|
|
365
|
+
# Ensure conversation is active for persistence
|
|
366
|
+
await self._ensure_conversation()
|
|
367
|
+
|
|
368
|
+
# Store user message in conversation history and persistence
|
|
369
|
+
if use_history:
|
|
370
|
+
self.conversation_history.append(
|
|
371
|
+
{"role": "user", "content": user_query}
|
|
372
|
+
)
|
|
373
|
+
await self._store_user_message(user_query)
|
|
374
|
+
|
|
453
375
|
# Build messages with history if requested
|
|
454
376
|
messages = []
|
|
455
377
|
if use_history:
|
|
@@ -461,8 +383,9 @@ Guidelines:
|
|
|
461
383
|
instructions = self._get_sql_assistant_instructions()
|
|
462
384
|
messages.append(Message(MessageRole.USER, instructions))
|
|
463
385
|
|
|
464
|
-
# Add current user message
|
|
465
|
-
|
|
386
|
+
# Add current user message if not already in messages from history
|
|
387
|
+
if not use_history:
|
|
388
|
+
messages.append(Message(MessageRole.USER, user_query))
|
|
466
389
|
|
|
467
390
|
# Create initial request and get response
|
|
468
391
|
request = self._create_message_request(messages)
|
|
@@ -484,9 +407,12 @@ Guidelines:
|
|
|
484
407
|
return
|
|
485
408
|
|
|
486
409
|
# Add assistant's response to conversation
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
410
|
+
assistant_content = {"role": "assistant", "content": response.content}
|
|
411
|
+
collected_content.append(assistant_content)
|
|
412
|
+
|
|
413
|
+
# Store the assistant message immediately (not from collected_content)
|
|
414
|
+
if use_history:
|
|
415
|
+
await self._store_assistant_message(response.content)
|
|
490
416
|
|
|
491
417
|
# Execute tools and get results
|
|
492
418
|
tool_results = []
|
|
@@ -499,9 +425,19 @@ Guidelines:
|
|
|
499
425
|
tool_results = event
|
|
500
426
|
|
|
501
427
|
# Continue conversation with tool results
|
|
502
|
-
|
|
428
|
+
tool_content = {"role": "user", "content": tool_results}
|
|
429
|
+
collected_content.append(tool_content)
|
|
430
|
+
|
|
431
|
+
# Store the tool message immediately and update history
|
|
503
432
|
if use_history:
|
|
504
|
-
|
|
433
|
+
# Only add the NEW messages to history (not the accumulated ones)
|
|
434
|
+
# collected_content has [assistant1, tool1, assistant2, tool2, ...]
|
|
435
|
+
# We only want to add the last 2 items that were just added
|
|
436
|
+
new_messages_for_history = collected_content[
|
|
437
|
+
-2:
|
|
438
|
+
] # Last assistant + tool pair
|
|
439
|
+
self.conversation_history.extend(new_messages_for_history)
|
|
440
|
+
await self._store_tool_message(tool_results)
|
|
505
441
|
|
|
506
442
|
if cancellation_token is not None and cancellation_token.is_set():
|
|
507
443
|
return
|
|
@@ -541,6 +477,10 @@ Guidelines:
|
|
|
541
477
|
{"role": "assistant", "content": response.content}
|
|
542
478
|
)
|
|
543
479
|
|
|
480
|
+
# Store final assistant message in persistence (only if not tool_use)
|
|
481
|
+
if response.stop_reason != "tool_use":
|
|
482
|
+
await self._store_assistant_message(response.content)
|
|
483
|
+
|
|
544
484
|
except asyncio.CancelledError:
|
|
545
485
|
return
|
|
546
486
|
except Exception as e:
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
"""Abstract base class for SQL agents."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from typing import Any, AsyncIterator
|
|
7
|
+
|
|
8
|
+
from sqlsaber.conversation.manager import ConversationManager
|
|
9
|
+
from sqlsaber.database.connection import (
|
|
10
|
+
BaseDatabaseConnection,
|
|
11
|
+
CSVConnection,
|
|
12
|
+
MySQLConnection,
|
|
13
|
+
PostgreSQLConnection,
|
|
14
|
+
SQLiteConnection,
|
|
15
|
+
)
|
|
16
|
+
from sqlsaber.database.schema import SchemaManager
|
|
17
|
+
from sqlsaber.models.events import StreamEvent
|
|
18
|
+
from sqlsaber.tools import SQLTool, tool_registry
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class BaseSQLAgent(ABC):
|
|
22
|
+
"""Abstract base class for SQL agents."""
|
|
23
|
+
|
|
24
|
+
def __init__(self, db_connection: BaseDatabaseConnection):
|
|
25
|
+
self.db = db_connection
|
|
26
|
+
self.schema_manager = SchemaManager(db_connection)
|
|
27
|
+
self.conversation_history: list[dict[str, Any]] = []
|
|
28
|
+
|
|
29
|
+
# Conversation persistence
|
|
30
|
+
self._conv_manager = ConversationManager()
|
|
31
|
+
self._conversation_id: str | None = None
|
|
32
|
+
self._msg_index: int = 0
|
|
33
|
+
|
|
34
|
+
# Initialize SQL tools with database connection
|
|
35
|
+
self._init_tools()
|
|
36
|
+
|
|
37
|
+
@abstractmethod
|
|
38
|
+
async def query_stream(
|
|
39
|
+
self,
|
|
40
|
+
user_query: str,
|
|
41
|
+
use_history: bool = True,
|
|
42
|
+
cancellation_token: asyncio.Event | None = None,
|
|
43
|
+
) -> AsyncIterator[StreamEvent]:
|
|
44
|
+
"""Process a user query and stream responses.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
user_query: The user's query to process
|
|
48
|
+
use_history: Whether to include conversation history
|
|
49
|
+
cancellation_token: Optional event to signal cancellation
|
|
50
|
+
"""
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
async def clear_history(self):
|
|
54
|
+
"""Clear conversation history."""
|
|
55
|
+
# End current conversation in storage
|
|
56
|
+
await self._end_conversation()
|
|
57
|
+
|
|
58
|
+
# Clear in-memory history
|
|
59
|
+
self.conversation_history = []
|
|
60
|
+
|
|
61
|
+
def _get_database_type_name(self) -> str:
|
|
62
|
+
"""Get the human-readable database type name."""
|
|
63
|
+
if isinstance(self.db, PostgreSQLConnection):
|
|
64
|
+
return "PostgreSQL"
|
|
65
|
+
elif isinstance(self.db, MySQLConnection):
|
|
66
|
+
return "MySQL"
|
|
67
|
+
elif isinstance(self.db, SQLiteConnection):
|
|
68
|
+
return "SQLite"
|
|
69
|
+
elif isinstance(self.db, CSVConnection):
|
|
70
|
+
return "SQLite" # we convert csv to in-memory sqlite
|
|
71
|
+
else:
|
|
72
|
+
return "database" # Fallback
|
|
73
|
+
|
|
74
|
+
def _init_tools(self) -> None:
|
|
75
|
+
"""Initialize SQL tools with database connection."""
|
|
76
|
+
# Get all SQL tools and set their database connection
|
|
77
|
+
for tool_name in tool_registry.list_tools(category="sql"):
|
|
78
|
+
tool = tool_registry.get_tool(tool_name)
|
|
79
|
+
if isinstance(tool, SQLTool):
|
|
80
|
+
tool.set_connection(self.db)
|
|
81
|
+
|
|
82
|
+
async def process_tool_call(
|
|
83
|
+
self, tool_name: str, tool_input: dict[str, Any]
|
|
84
|
+
) -> str:
|
|
85
|
+
"""Process a tool call and return the result."""
|
|
86
|
+
try:
|
|
87
|
+
tool = tool_registry.get_tool(tool_name)
|
|
88
|
+
return await tool.execute(**tool_input)
|
|
89
|
+
except KeyError:
|
|
90
|
+
return json.dumps({"error": f"Unknown tool: {tool_name}"})
|
|
91
|
+
except Exception as e:
|
|
92
|
+
return json.dumps(
|
|
93
|
+
{"error": f"Error executing tool '{tool_name}': {str(e)}"}
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# Conversation persistence helpers
|
|
97
|
+
|
|
98
|
+
async def _ensure_conversation(self) -> None:
|
|
99
|
+
"""Ensure a conversation is active for storing messages."""
|
|
100
|
+
if self._conversation_id is None:
|
|
101
|
+
db_name = getattr(self, "database_name", "unknown")
|
|
102
|
+
self._conversation_id = await self._conv_manager.start_conversation(db_name)
|
|
103
|
+
self._msg_index = 0
|
|
104
|
+
|
|
105
|
+
async def _store_user_message(self, content: str | dict[str, Any]) -> None:
|
|
106
|
+
"""Store a user message in conversation history."""
|
|
107
|
+
if self._conversation_id is None:
|
|
108
|
+
return
|
|
109
|
+
|
|
110
|
+
await self._conv_manager.add_user_message(
|
|
111
|
+
self._conversation_id, content, self._msg_index
|
|
112
|
+
)
|
|
113
|
+
self._msg_index += 1
|
|
114
|
+
|
|
115
|
+
async def _store_assistant_message(
|
|
116
|
+
self, content: list[dict[str, Any]] | dict[str, Any]
|
|
117
|
+
) -> None:
|
|
118
|
+
"""Store an assistant message in conversation history."""
|
|
119
|
+
if self._conversation_id is None:
|
|
120
|
+
return
|
|
121
|
+
|
|
122
|
+
await self._conv_manager.add_assistant_message(
|
|
123
|
+
self._conversation_id, content, self._msg_index
|
|
124
|
+
)
|
|
125
|
+
self._msg_index += 1
|
|
126
|
+
|
|
127
|
+
async def _store_tool_message(
|
|
128
|
+
self, content: list[dict[str, Any]] | dict[str, Any]
|
|
129
|
+
) -> None:
|
|
130
|
+
"""Store a tool/system message in conversation history."""
|
|
131
|
+
if self._conversation_id is None:
|
|
132
|
+
return
|
|
133
|
+
|
|
134
|
+
await self._conv_manager.add_tool_message(
|
|
135
|
+
self._conversation_id, content, self._msg_index
|
|
136
|
+
)
|
|
137
|
+
self._msg_index += 1
|
|
138
|
+
|
|
139
|
+
async def _end_conversation(self) -> None:
|
|
140
|
+
"""End the current conversation."""
|
|
141
|
+
if self._conversation_id:
|
|
142
|
+
await self._conv_manager.end_conversation(self._conversation_id)
|
|
143
|
+
self._conversation_id = None
|
|
144
|
+
self._msg_index = 0
|
|
145
|
+
|
|
146
|
+
async def restore_conversation(self, conversation_id: str) -> bool:
|
|
147
|
+
"""Restore a conversation from storage to in-memory history.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
conversation_id: ID of the conversation to restore
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
True if successfully restored, False otherwise
|
|
154
|
+
"""
|
|
155
|
+
success = await self._conv_manager.restore_conversation_to_agent(
|
|
156
|
+
conversation_id, self.conversation_history
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
if success:
|
|
160
|
+
# Set up for continuing this conversation
|
|
161
|
+
self._conversation_id = conversation_id
|
|
162
|
+
self._msg_index = len(self.conversation_history)
|
|
163
|
+
|
|
164
|
+
return success
|
|
165
|
+
|
|
166
|
+
async def list_conversations(self, limit: int = 50) -> list:
|
|
167
|
+
"""List conversations for this agent's database.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
limit: Maximum number of conversations to return
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
List of conversation data
|
|
174
|
+
"""
|
|
175
|
+
db_name = getattr(self, "database_name", None)
|
|
176
|
+
conversations = await self._conv_manager.list_conversations(db_name, limit)
|
|
177
|
+
|
|
178
|
+
return [
|
|
179
|
+
{
|
|
180
|
+
"id": conv.id,
|
|
181
|
+
"database_name": conv.database_name,
|
|
182
|
+
"started_at": conv.formatted_start_time(),
|
|
183
|
+
"ended_at": conv.formatted_end_time(),
|
|
184
|
+
"duration": conv.duration_seconds(),
|
|
185
|
+
}
|
|
186
|
+
for conv in conversations
|
|
187
|
+
]
|
|
@@ -136,11 +136,15 @@ class InteractiveSession:
|
|
|
136
136
|
if not user_query:
|
|
137
137
|
continue
|
|
138
138
|
|
|
139
|
-
if
|
|
139
|
+
if (
|
|
140
|
+
user_query in ["/exit", "/quit"]
|
|
141
|
+
or user_query.startswith("/exit")
|
|
142
|
+
or user_query.startswith("/quit")
|
|
143
|
+
):
|
|
140
144
|
break
|
|
141
145
|
|
|
142
146
|
if user_query == "/clear":
|
|
143
|
-
self.agent.clear_history()
|
|
147
|
+
await self.agent.clear_history()
|
|
144
148
|
self.console.print("[green]Conversation history cleared.[/green]\n")
|
|
145
149
|
continue
|
|
146
150
|
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""Conversation history storage for SQLSaber."""
|
|
2
|
+
|
|
3
|
+
from .manager import ConversationManager
|
|
4
|
+
from .models import Conversation, ConversationMessage
|
|
5
|
+
from .storage import ConversationStorage
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"Conversation",
|
|
9
|
+
"ConversationMessage",
|
|
10
|
+
"ConversationStorage",
|
|
11
|
+
"ConversationManager",
|
|
12
|
+
]
|