sqlsaber 0.13.0__py3-none-any.whl → 0.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- sqlsaber/agents/anthropic.py +63 -123
- sqlsaber/agents/base.py +111 -210
- sqlsaber/cli/interactive.py +6 -2
- sqlsaber/conversation/__init__.py +12 -0
- sqlsaber/conversation/manager.py +224 -0
- sqlsaber/conversation/models.py +120 -0
- sqlsaber/conversation/storage.py +362 -0
- sqlsaber/database/schema.py +2 -51
- sqlsaber/mcp/mcp.py +43 -51
- sqlsaber/tools/__init__.py +25 -0
- sqlsaber/tools/base.py +83 -0
- sqlsaber/tools/enums.py +21 -0
- sqlsaber/tools/instructions.py +251 -0
- sqlsaber/tools/registry.py +130 -0
- sqlsaber/tools/sql_tools.py +275 -0
- sqlsaber/tools/visualization_tools.py +144 -0
- {sqlsaber-0.13.0.dist-info → sqlsaber-0.15.0.dist-info}/METADATA +1 -1
- {sqlsaber-0.13.0.dist-info → sqlsaber-0.15.0.dist-info}/RECORD +21 -10
- {sqlsaber-0.13.0.dist-info → sqlsaber-0.15.0.dist-info}/WHEEL +0 -0
- {sqlsaber-0.13.0.dist-info → sqlsaber-0.15.0.dist-info}/entry_points.txt +0 -0
- {sqlsaber-0.13.0.dist-info → sqlsaber-0.15.0.dist-info}/licenses/LICENSE +0 -0
sqlsaber/agents/anthropic.py
CHANGED
|
@@ -21,6 +21,8 @@ from sqlsaber.config.settings import Config
|
|
|
21
21
|
from sqlsaber.database.connection import BaseDatabaseConnection
|
|
22
22
|
from sqlsaber.memory.manager import MemoryManager
|
|
23
23
|
from sqlsaber.models.events import StreamEvent
|
|
24
|
+
from sqlsaber.tools import tool_registry
|
|
25
|
+
from sqlsaber.tools.instructions import InstructionBuilder
|
|
24
26
|
|
|
25
27
|
|
|
26
28
|
class AnthropicSQLAgent(BaseSQLAgent):
|
|
@@ -51,89 +53,11 @@ class AnthropicSQLAgent(BaseSQLAgent):
|
|
|
51
53
|
self._last_results = None
|
|
52
54
|
self._last_query = None
|
|
53
55
|
|
|
54
|
-
#
|
|
55
|
-
self.tools: list[ToolDefinition] =
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
input_schema={
|
|
60
|
-
"type": "object",
|
|
61
|
-
"properties": {},
|
|
62
|
-
"required": [],
|
|
63
|
-
},
|
|
64
|
-
),
|
|
65
|
-
ToolDefinition(
|
|
66
|
-
name="introspect_schema",
|
|
67
|
-
description="Introspect database schema to understand table structures.",
|
|
68
|
-
input_schema={
|
|
69
|
-
"type": "object",
|
|
70
|
-
"properties": {
|
|
71
|
-
"table_pattern": {
|
|
72
|
-
"type": "string",
|
|
73
|
-
"description": "Optional pattern to filter tables (e.g., 'public.users', 'user%', '%order%')",
|
|
74
|
-
}
|
|
75
|
-
},
|
|
76
|
-
"required": [],
|
|
77
|
-
},
|
|
78
|
-
),
|
|
79
|
-
ToolDefinition(
|
|
80
|
-
name="execute_sql",
|
|
81
|
-
description="Execute a SQL query against the database.",
|
|
82
|
-
input_schema={
|
|
83
|
-
"type": "object",
|
|
84
|
-
"properties": {
|
|
85
|
-
"query": {
|
|
86
|
-
"type": "string",
|
|
87
|
-
"description": "SQL query to execute",
|
|
88
|
-
},
|
|
89
|
-
"limit": {
|
|
90
|
-
"type": "integer",
|
|
91
|
-
"description": f"Maximum number of rows to return (default: {AnthropicSQLAgent.DEFAULT_SQL_LIMIT})",
|
|
92
|
-
"default": AnthropicSQLAgent.DEFAULT_SQL_LIMIT,
|
|
93
|
-
},
|
|
94
|
-
},
|
|
95
|
-
"required": ["query"],
|
|
96
|
-
},
|
|
97
|
-
),
|
|
98
|
-
ToolDefinition(
|
|
99
|
-
name="plot_data",
|
|
100
|
-
description="Create a plot of query results.",
|
|
101
|
-
input_schema={
|
|
102
|
-
"type": "object",
|
|
103
|
-
"properties": {
|
|
104
|
-
"y_values": {
|
|
105
|
-
"type": "array",
|
|
106
|
-
"items": {"type": ["number", "null"]},
|
|
107
|
-
"description": "Y-axis data points (required)",
|
|
108
|
-
},
|
|
109
|
-
"x_values": {
|
|
110
|
-
"type": "array",
|
|
111
|
-
"items": {"type": ["number", "null"]},
|
|
112
|
-
"description": "X-axis data points (optional, will use indices if not provided)",
|
|
113
|
-
},
|
|
114
|
-
"plot_type": {
|
|
115
|
-
"type": "string",
|
|
116
|
-
"enum": ["line", "scatter", "histogram"],
|
|
117
|
-
"description": "Type of plot to create (default: line)",
|
|
118
|
-
"default": "line",
|
|
119
|
-
},
|
|
120
|
-
"title": {
|
|
121
|
-
"type": "string",
|
|
122
|
-
"description": "Title for the plot",
|
|
123
|
-
},
|
|
124
|
-
"x_label": {
|
|
125
|
-
"type": "string",
|
|
126
|
-
"description": "Label for X-axis",
|
|
127
|
-
},
|
|
128
|
-
"y_label": {
|
|
129
|
-
"type": "string",
|
|
130
|
-
"description": "Label for Y-axis",
|
|
131
|
-
},
|
|
132
|
-
},
|
|
133
|
-
"required": ["y_values"],
|
|
134
|
-
},
|
|
135
|
-
),
|
|
136
|
-
]
|
|
56
|
+
# Get tool definitions from registry
|
|
57
|
+
self.tools: list[ToolDefinition] = tool_registry.get_tool_definitions()
|
|
58
|
+
|
|
59
|
+
# Initialize instruction builder
|
|
60
|
+
self.instruction_builder = InstructionBuilder(tool_registry)
|
|
137
61
|
|
|
138
62
|
# Build system prompt with memories if available
|
|
139
63
|
self.system_prompt = self._build_system_prompt()
|
|
@@ -157,31 +81,9 @@ class AnthropicSQLAgent(BaseSQLAgent):
|
|
|
157
81
|
def _get_sql_assistant_instructions(self) -> str:
|
|
158
82
|
"""Get the detailed SQL assistant instructions."""
|
|
159
83
|
db_type = self._get_database_type_name()
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
1. Understand user's natural language requests, think and convert them to SQL
|
|
164
|
-
2. Use the provided tools efficiently to explore database schema
|
|
165
|
-
3. Generate appropriate SQL queries
|
|
166
|
-
4. Execute queries safely - queries that modify the database are not allowed
|
|
167
|
-
5. Format and explain results clearly
|
|
168
|
-
6. Create visualizations when requested or when they would be helpful
|
|
169
|
-
|
|
170
|
-
IMPORTANT - Schema Discovery Strategy:
|
|
171
|
-
1. ALWAYS start with 'list_tables' to see available tables and row counts
|
|
172
|
-
2. Based on the user's query, identify which specific tables are relevant
|
|
173
|
-
3. Use 'introspect_schema' with a table_pattern to get details ONLY for relevant tables
|
|
174
|
-
4. Timestamp columns must be converted to text when you write queries
|
|
175
|
-
|
|
176
|
-
Guidelines:
|
|
177
|
-
- Use list_tables first, then introspect_schema for specific tables only
|
|
178
|
-
- Use table patterns like 'sample%' or '%experiment%' to filter related tables
|
|
179
|
-
- Use proper JOIN syntax and avoid cartesian products
|
|
180
|
-
- Include appropriate WHERE clauses to limit results
|
|
181
|
-
- Explain what the query does in simple terms
|
|
182
|
-
- Handle errors gracefully and suggest fixes
|
|
183
|
-
- Be security conscious - use parameterized queries when needed
|
|
184
|
-
"""
|
|
84
|
+
|
|
85
|
+
# Build dynamic instructions from available tools
|
|
86
|
+
instructions = self.instruction_builder.build_instructions(db_type=db_type)
|
|
185
87
|
|
|
186
88
|
# Add memory context if database name is available
|
|
187
89
|
if self.database_name:
|
|
@@ -189,7 +91,7 @@ Guidelines:
|
|
|
189
91
|
self.database_name
|
|
190
92
|
)
|
|
191
93
|
if memory_context.strip():
|
|
192
|
-
instructions += memory_context
|
|
94
|
+
instructions += "\n\n" + memory_context
|
|
193
95
|
|
|
194
96
|
return instructions
|
|
195
97
|
|
|
@@ -199,16 +101,19 @@ Guidelines:
|
|
|
199
101
|
return None
|
|
200
102
|
|
|
201
103
|
memory = self.memory_manager.add_memory(self.database_name, content)
|
|
202
|
-
# Rebuild system prompt with new memory
|
|
104
|
+
# Rebuild system prompt with new memory (includes dynamic instructions)
|
|
203
105
|
self.system_prompt = self._build_system_prompt()
|
|
204
106
|
return memory.id
|
|
205
107
|
|
|
206
|
-
async def
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
108
|
+
async def _execute_sql_with_tracking(
|
|
109
|
+
self, query: str, limit: int | None = None
|
|
110
|
+
) -> str:
|
|
111
|
+
"""Execute SQL and track results for streaming."""
|
|
112
|
+
# Get the execute_sql tool and run it
|
|
113
|
+
tool = tool_registry.get_tool("execute_sql")
|
|
114
|
+
result = await tool.execute(query=query, limit=limit)
|
|
210
115
|
|
|
211
|
-
# Parse result to extract data for streaming
|
|
116
|
+
# Parse result to extract data for streaming
|
|
212
117
|
try:
|
|
213
118
|
result_data = json.loads(result)
|
|
214
119
|
if result_data.get("success") and "results" in result_data:
|
|
@@ -228,7 +133,14 @@ Guidelines:
|
|
|
228
133
|
self, tool_name: str, tool_input: dict[str, Any]
|
|
229
134
|
) -> str:
|
|
230
135
|
"""Process a tool call and return the result."""
|
|
231
|
-
#
|
|
136
|
+
# Special handling for execute_sql to track results
|
|
137
|
+
if tool_name == "execute_sql":
|
|
138
|
+
return await self._execute_sql_with_tracking(
|
|
139
|
+
tool_input.get("query", ""),
|
|
140
|
+
tool_input.get("limit", self.DEFAULT_SQL_LIMIT),
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# Use parent implementation for all other tools
|
|
232
144
|
return await super().process_tool_call(tool_name, tool_input)
|
|
233
145
|
|
|
234
146
|
def _convert_user_message_to_message(
|
|
@@ -450,6 +362,16 @@ Guidelines:
|
|
|
450
362
|
self._last_query = None
|
|
451
363
|
|
|
452
364
|
try:
|
|
365
|
+
# Ensure conversation is active for persistence
|
|
366
|
+
await self._ensure_conversation()
|
|
367
|
+
|
|
368
|
+
# Store user message in conversation history and persistence
|
|
369
|
+
if use_history:
|
|
370
|
+
self.conversation_history.append(
|
|
371
|
+
{"role": "user", "content": user_query}
|
|
372
|
+
)
|
|
373
|
+
await self._store_user_message(user_query)
|
|
374
|
+
|
|
453
375
|
# Build messages with history if requested
|
|
454
376
|
messages = []
|
|
455
377
|
if use_history:
|
|
@@ -461,8 +383,9 @@ Guidelines:
|
|
|
461
383
|
instructions = self._get_sql_assistant_instructions()
|
|
462
384
|
messages.append(Message(MessageRole.USER, instructions))
|
|
463
385
|
|
|
464
|
-
# Add current user message
|
|
465
|
-
|
|
386
|
+
# Add current user message if not already in messages from history
|
|
387
|
+
if not use_history:
|
|
388
|
+
messages.append(Message(MessageRole.USER, user_query))
|
|
466
389
|
|
|
467
390
|
# Create initial request and get response
|
|
468
391
|
request = self._create_message_request(messages)
|
|
@@ -484,9 +407,12 @@ Guidelines:
|
|
|
484
407
|
return
|
|
485
408
|
|
|
486
409
|
# Add assistant's response to conversation
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
410
|
+
assistant_content = {"role": "assistant", "content": response.content}
|
|
411
|
+
collected_content.append(assistant_content)
|
|
412
|
+
|
|
413
|
+
# Store the assistant message immediately (not from collected_content)
|
|
414
|
+
if use_history:
|
|
415
|
+
await self._store_assistant_message(response.content)
|
|
490
416
|
|
|
491
417
|
# Execute tools and get results
|
|
492
418
|
tool_results = []
|
|
@@ -499,9 +425,19 @@ Guidelines:
|
|
|
499
425
|
tool_results = event
|
|
500
426
|
|
|
501
427
|
# Continue conversation with tool results
|
|
502
|
-
|
|
428
|
+
tool_content = {"role": "user", "content": tool_results}
|
|
429
|
+
collected_content.append(tool_content)
|
|
430
|
+
|
|
431
|
+
# Store the tool message immediately and update history
|
|
503
432
|
if use_history:
|
|
504
|
-
|
|
433
|
+
# Only add the NEW messages to history (not the accumulated ones)
|
|
434
|
+
# collected_content has [assistant1, tool1, assistant2, tool2, ...]
|
|
435
|
+
# We only want to add the last 2 items that were just added
|
|
436
|
+
new_messages_for_history = collected_content[
|
|
437
|
+
-2:
|
|
438
|
+
] # Last assistant + tool pair
|
|
439
|
+
self.conversation_history.extend(new_messages_for_history)
|
|
440
|
+
await self._store_tool_message(tool_results)
|
|
505
441
|
|
|
506
442
|
if cancellation_token is not None and cancellation_token.is_set():
|
|
507
443
|
return
|
|
@@ -541,6 +477,10 @@ Guidelines:
|
|
|
541
477
|
{"role": "assistant", "content": response.content}
|
|
542
478
|
)
|
|
543
479
|
|
|
480
|
+
# Store final assistant message in persistence (only if not tool_use)
|
|
481
|
+
if response.stop_reason != "tool_use":
|
|
482
|
+
await self._store_assistant_message(response.content)
|
|
483
|
+
|
|
544
484
|
except asyncio.CancelledError:
|
|
545
485
|
return
|
|
546
486
|
except Exception as e:
|
sqlsaber/agents/base.py
CHANGED
|
@@ -5,8 +5,7 @@ import json
|
|
|
5
5
|
from abc import ABC, abstractmethod
|
|
6
6
|
from typing import Any, AsyncIterator
|
|
7
7
|
|
|
8
|
-
from
|
|
9
|
-
|
|
8
|
+
from sqlsaber.conversation.manager import ConversationManager
|
|
10
9
|
from sqlsaber.database.connection import (
|
|
11
10
|
BaseDatabaseConnection,
|
|
12
11
|
CSVConnection,
|
|
@@ -16,6 +15,7 @@ from sqlsaber.database.connection import (
|
|
|
16
15
|
)
|
|
17
16
|
from sqlsaber.database.schema import SchemaManager
|
|
18
17
|
from sqlsaber.models.events import StreamEvent
|
|
18
|
+
from sqlsaber.tools import SQLTool, tool_registry
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class BaseSQLAgent(ABC):
|
|
@@ -26,6 +26,14 @@ class BaseSQLAgent(ABC):
|
|
|
26
26
|
self.schema_manager = SchemaManager(db_connection)
|
|
27
27
|
self.conversation_history: list[dict[str, Any]] = []
|
|
28
28
|
|
|
29
|
+
# Conversation persistence
|
|
30
|
+
self._conv_manager = ConversationManager()
|
|
31
|
+
self._conversation_id: str | None = None
|
|
32
|
+
self._msg_index: int = 0
|
|
33
|
+
|
|
34
|
+
# Initialize SQL tools with database connection
|
|
35
|
+
self._init_tools()
|
|
36
|
+
|
|
29
37
|
@abstractmethod
|
|
30
38
|
async def query_stream(
|
|
31
39
|
self,
|
|
@@ -42,8 +50,12 @@ class BaseSQLAgent(ABC):
|
|
|
42
50
|
"""
|
|
43
51
|
pass
|
|
44
52
|
|
|
45
|
-
def clear_history(self):
|
|
53
|
+
async def clear_history(self):
|
|
46
54
|
"""Clear conversation history."""
|
|
55
|
+
# End current conversation in storage
|
|
56
|
+
await self._end_conversation()
|
|
57
|
+
|
|
58
|
+
# Clear in-memory history
|
|
47
59
|
self.conversation_history = []
|
|
48
60
|
|
|
49
61
|
def _get_database_type_name(self) -> str:
|
|
@@ -59,228 +71,117 @@ class BaseSQLAgent(ABC):
|
|
|
59
71
|
else:
|
|
60
72
|
return "database" # Fallback
|
|
61
73
|
|
|
62
|
-
|
|
63
|
-
"""
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
formatted_info = {}
|
|
70
|
-
for table_name, table_info in schema_info.items():
|
|
71
|
-
formatted_info[table_name] = {
|
|
72
|
-
"columns": {
|
|
73
|
-
col_name: {
|
|
74
|
-
"type": col_info["data_type"],
|
|
75
|
-
"nullable": col_info["nullable"],
|
|
76
|
-
"default": col_info["default"],
|
|
77
|
-
}
|
|
78
|
-
for col_name, col_info in table_info["columns"].items()
|
|
79
|
-
},
|
|
80
|
-
"primary_keys": table_info["primary_keys"],
|
|
81
|
-
"foreign_keys": [
|
|
82
|
-
f"{fk['column']} -> {fk['references']['table']}.{fk['references']['column']}"
|
|
83
|
-
for fk in table_info["foreign_keys"]
|
|
84
|
-
],
|
|
85
|
-
}
|
|
86
|
-
|
|
87
|
-
return json.dumps(formatted_info)
|
|
88
|
-
except Exception as e:
|
|
89
|
-
return json.dumps({"error": f"Error introspecting schema: {str(e)}"})
|
|
90
|
-
|
|
91
|
-
async def list_tables(self) -> str:
|
|
92
|
-
"""List all tables in the database with basic information."""
|
|
93
|
-
try:
|
|
94
|
-
tables_info = await self.schema_manager.list_tables()
|
|
95
|
-
return json.dumps(tables_info)
|
|
96
|
-
except Exception as e:
|
|
97
|
-
return json.dumps({"error": f"Error listing tables: {str(e)}"})
|
|
98
|
-
|
|
99
|
-
async def execute_sql(self, query: str, limit: int | None = None) -> str:
|
|
100
|
-
"""Execute a SQL query against the database."""
|
|
101
|
-
try:
|
|
102
|
-
# Security check - only allow SELECT queries unless write is enabled
|
|
103
|
-
write_error = self._validate_write_operation(query)
|
|
104
|
-
if write_error:
|
|
105
|
-
return json.dumps(
|
|
106
|
-
{
|
|
107
|
-
"error": write_error,
|
|
108
|
-
}
|
|
109
|
-
)
|
|
110
|
-
|
|
111
|
-
# Add LIMIT if not present and it's a SELECT query
|
|
112
|
-
query = self._add_limit_to_query(query, limit)
|
|
113
|
-
|
|
114
|
-
# Execute the query (wrapped in a transaction for safety)
|
|
115
|
-
results = await self.db.execute_query(query)
|
|
116
|
-
|
|
117
|
-
# Format results
|
|
118
|
-
actual_limit = limit if limit is not None else len(results)
|
|
119
|
-
|
|
120
|
-
return json.dumps(
|
|
121
|
-
{
|
|
122
|
-
"success": True,
|
|
123
|
-
"row_count": len(results),
|
|
124
|
-
"results": results[:actual_limit], # Extra safety for limit
|
|
125
|
-
"truncated": len(results) > actual_limit,
|
|
126
|
-
}
|
|
127
|
-
)
|
|
128
|
-
|
|
129
|
-
except Exception as e:
|
|
130
|
-
error_msg = str(e)
|
|
131
|
-
|
|
132
|
-
# Provide helpful error messages
|
|
133
|
-
suggestions = []
|
|
134
|
-
if "column" in error_msg.lower() and "does not exist" in error_msg.lower():
|
|
135
|
-
suggestions.append(
|
|
136
|
-
"Check column names using the schema introspection tool"
|
|
137
|
-
)
|
|
138
|
-
elif "table" in error_msg.lower() and "does not exist" in error_msg.lower():
|
|
139
|
-
suggestions.append(
|
|
140
|
-
"Check table names using the schema introspection tool"
|
|
141
|
-
)
|
|
142
|
-
elif "syntax error" in error_msg.lower():
|
|
143
|
-
suggestions.append(
|
|
144
|
-
"Review SQL syntax, especially JOIN conditions and WHERE clauses"
|
|
145
|
-
)
|
|
146
|
-
|
|
147
|
-
return json.dumps({"error": error_msg, "suggestions": suggestions})
|
|
74
|
+
def _init_tools(self) -> None:
|
|
75
|
+
"""Initialize SQL tools with database connection."""
|
|
76
|
+
# Get all SQL tools and set their database connection
|
|
77
|
+
for tool_name in tool_registry.list_tools(category="sql"):
|
|
78
|
+
tool = tool_registry.get_tool(tool_name)
|
|
79
|
+
if isinstance(tool, SQLTool):
|
|
80
|
+
tool.set_connection(self.db)
|
|
148
81
|
|
|
149
82
|
async def process_tool_call(
|
|
150
83
|
self, tool_name: str, tool_input: dict[str, Any]
|
|
151
84
|
) -> str:
|
|
152
85
|
"""Process a tool call and return the result."""
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
elif tool_name == "execute_sql":
|
|
158
|
-
return await self.execute_sql(
|
|
159
|
-
tool_input["query"], tool_input.get("limit", 100)
|
|
160
|
-
)
|
|
161
|
-
elif tool_name == "plot_data":
|
|
162
|
-
return await self.plot_data(
|
|
163
|
-
y_values=tool_input["y_values"],
|
|
164
|
-
x_values=tool_input.get("x_values"),
|
|
165
|
-
plot_type=tool_input.get("plot_type", "line"),
|
|
166
|
-
title=tool_input.get("title"),
|
|
167
|
-
x_label=tool_input.get("x_label"),
|
|
168
|
-
y_label=tool_input.get("y_label"),
|
|
169
|
-
)
|
|
170
|
-
else:
|
|
86
|
+
try:
|
|
87
|
+
tool = tool_registry.get_tool(tool_name)
|
|
88
|
+
return await tool.execute(**tool_input)
|
|
89
|
+
except KeyError:
|
|
171
90
|
return json.dumps({"error": f"Unknown tool: {tool_name}"})
|
|
91
|
+
except Exception as e:
|
|
92
|
+
return json.dumps(
|
|
93
|
+
{"error": f"Error executing tool '{tool_name}': {str(e)}"}
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# Conversation persistence helpers
|
|
97
|
+
|
|
98
|
+
async def _ensure_conversation(self) -> None:
|
|
99
|
+
"""Ensure a conversation is active for storing messages."""
|
|
100
|
+
if self._conversation_id is None:
|
|
101
|
+
db_name = getattr(self, "database_name", "unknown")
|
|
102
|
+
self._conversation_id = await self._conv_manager.start_conversation(db_name)
|
|
103
|
+
self._msg_index = 0
|
|
104
|
+
|
|
105
|
+
async def _store_user_message(self, content: str | dict[str, Any]) -> None:
|
|
106
|
+
"""Store a user message in conversation history."""
|
|
107
|
+
if self._conversation_id is None:
|
|
108
|
+
return
|
|
109
|
+
|
|
110
|
+
await self._conv_manager.add_user_message(
|
|
111
|
+
self._conversation_id, content, self._msg_index
|
|
112
|
+
)
|
|
113
|
+
self._msg_index += 1
|
|
114
|
+
|
|
115
|
+
async def _store_assistant_message(
|
|
116
|
+
self, content: list[dict[str, Any]] | dict[str, Any]
|
|
117
|
+
) -> None:
|
|
118
|
+
"""Store an assistant message in conversation history."""
|
|
119
|
+
if self._conversation_id is None:
|
|
120
|
+
return
|
|
121
|
+
|
|
122
|
+
await self._conv_manager.add_assistant_message(
|
|
123
|
+
self._conversation_id, content, self._msg_index
|
|
124
|
+
)
|
|
125
|
+
self._msg_index += 1
|
|
126
|
+
|
|
127
|
+
async def _store_tool_message(
|
|
128
|
+
self, content: list[dict[str, Any]] | dict[str, Any]
|
|
129
|
+
) -> None:
|
|
130
|
+
"""Store a tool/system message in conversation history."""
|
|
131
|
+
if self._conversation_id is None:
|
|
132
|
+
return
|
|
133
|
+
|
|
134
|
+
await self._conv_manager.add_tool_message(
|
|
135
|
+
self._conversation_id, content, self._msg_index
|
|
136
|
+
)
|
|
137
|
+
self._msg_index += 1
|
|
138
|
+
|
|
139
|
+
async def _end_conversation(self) -> None:
|
|
140
|
+
"""End the current conversation."""
|
|
141
|
+
if self._conversation_id:
|
|
142
|
+
await self._conv_manager.end_conversation(self._conversation_id)
|
|
143
|
+
self._conversation_id = None
|
|
144
|
+
self._msg_index = 0
|
|
145
|
+
|
|
146
|
+
async def restore_conversation(self, conversation_id: str) -> bool:
|
|
147
|
+
"""Restore a conversation from storage to in-memory history.
|
|
172
148
|
|
|
173
|
-
|
|
174
|
-
|
|
149
|
+
Args:
|
|
150
|
+
conversation_id: ID of the conversation to restore
|
|
175
151
|
|
|
176
152
|
Returns:
|
|
177
|
-
|
|
153
|
+
True if successfully restored, False otherwise
|
|
178
154
|
"""
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
write_keywords = [
|
|
183
|
-
"INSERT",
|
|
184
|
-
"UPDATE",
|
|
185
|
-
"DELETE",
|
|
186
|
-
"DROP",
|
|
187
|
-
"CREATE",
|
|
188
|
-
"ALTER",
|
|
189
|
-
"TRUNCATE",
|
|
190
|
-
]
|
|
191
|
-
is_write_query = any(query_upper.startswith(kw) for kw in write_keywords)
|
|
192
|
-
|
|
193
|
-
if is_write_query:
|
|
194
|
-
return (
|
|
195
|
-
"Write operations are not allowed. Only SELECT queries are permitted."
|
|
196
|
-
)
|
|
155
|
+
success = await self._conv_manager.restore_conversation_to_agent(
|
|
156
|
+
conversation_id, self.conversation_history
|
|
157
|
+
)
|
|
197
158
|
|
|
198
|
-
|
|
159
|
+
if success:
|
|
160
|
+
# Set up for continuing this conversation
|
|
161
|
+
self._conversation_id = conversation_id
|
|
162
|
+
self._msg_index = len(self.conversation_history)
|
|
199
163
|
|
|
200
|
-
|
|
201
|
-
"""Add LIMIT clause to SELECT queries if not present."""
|
|
202
|
-
query_upper = query.strip().upper()
|
|
203
|
-
if query_upper.startswith("SELECT") and "LIMIT" not in query_upper:
|
|
204
|
-
return f"{query.rstrip(';')} LIMIT {limit};"
|
|
205
|
-
return query
|
|
164
|
+
return success
|
|
206
165
|
|
|
207
|
-
async def
|
|
208
|
-
|
|
209
|
-
y_values: list[float],
|
|
210
|
-
x_values: list[float] | None = None,
|
|
211
|
-
plot_type: str = "line",
|
|
212
|
-
title: str | None = None,
|
|
213
|
-
x_label: str | None = None,
|
|
214
|
-
y_label: str | None = None,
|
|
215
|
-
) -> str:
|
|
216
|
-
"""Create a terminal plot using uniplot.
|
|
166
|
+
async def list_conversations(self, limit: int = 50) -> list:
|
|
167
|
+
"""List conversations for this agent's database.
|
|
217
168
|
|
|
218
169
|
Args:
|
|
219
|
-
|
|
220
|
-
x_values: X-axis data points (optional)
|
|
221
|
-
plot_type: Type of plot - "line", "scatter", or "histogram"
|
|
222
|
-
title: Plot title
|
|
223
|
-
x_label: X-axis label
|
|
224
|
-
y_label: Y-axis label
|
|
170
|
+
limit: Maximum number of conversations to return
|
|
225
171
|
|
|
226
172
|
Returns:
|
|
227
|
-
|
|
173
|
+
List of conversation data
|
|
228
174
|
"""
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
# Create the plot
|
|
243
|
-
if plot_type == "histogram":
|
|
244
|
-
# For histogram, we only need y_values
|
|
245
|
-
histogram(
|
|
246
|
-
y_values,
|
|
247
|
-
title=title,
|
|
248
|
-
bins=min(20, len(set(y_values))), # Adaptive bin count
|
|
249
|
-
)
|
|
250
|
-
plot_info = {
|
|
251
|
-
"type": "histogram",
|
|
252
|
-
"data_points": len(y_values),
|
|
253
|
-
"title": title or "Histogram",
|
|
254
|
-
}
|
|
255
|
-
elif plot_type in ["line", "scatter"]:
|
|
256
|
-
# For line/scatter plots
|
|
257
|
-
plot_kwargs = {
|
|
258
|
-
"ys": y_values,
|
|
259
|
-
"title": title,
|
|
260
|
-
"lines": plot_type == "line",
|
|
261
|
-
}
|
|
262
|
-
|
|
263
|
-
if x_values:
|
|
264
|
-
plot_kwargs["xs"] = x_values
|
|
265
|
-
if x_label:
|
|
266
|
-
plot_kwargs["x_unit"] = x_label
|
|
267
|
-
if y_label:
|
|
268
|
-
plot_kwargs["y_unit"] = y_label
|
|
269
|
-
|
|
270
|
-
plot(**plot_kwargs)
|
|
271
|
-
|
|
272
|
-
plot_info = {
|
|
273
|
-
"type": plot_type,
|
|
274
|
-
"data_points": len(y_values),
|
|
275
|
-
"title": title or f"{plot_type.capitalize()} Plot",
|
|
276
|
-
"has_x_values": x_values is not None,
|
|
277
|
-
}
|
|
278
|
-
else:
|
|
279
|
-
return json.dumps({"error": f"Unsupported plot type: {plot_type}"})
|
|
280
|
-
|
|
281
|
-
return json.dumps(
|
|
282
|
-
{"success": True, "plot_rendered": True, "plot_info": plot_info}
|
|
283
|
-
)
|
|
284
|
-
|
|
285
|
-
except Exception as e:
|
|
286
|
-
return json.dumps({"error": f"Error creating plot: {str(e)}"})
|
|
175
|
+
db_name = getattr(self, "database_name", None)
|
|
176
|
+
conversations = await self._conv_manager.list_conversations(db_name, limit)
|
|
177
|
+
|
|
178
|
+
return [
|
|
179
|
+
{
|
|
180
|
+
"id": conv.id,
|
|
181
|
+
"database_name": conv.database_name,
|
|
182
|
+
"started_at": conv.formatted_start_time(),
|
|
183
|
+
"ended_at": conv.formatted_end_time(),
|
|
184
|
+
"duration": conv.duration_seconds(),
|
|
185
|
+
}
|
|
186
|
+
for conv in conversations
|
|
187
|
+
]
|
sqlsaber/cli/interactive.py
CHANGED
|
@@ -136,11 +136,15 @@ class InteractiveSession:
|
|
|
136
136
|
if not user_query:
|
|
137
137
|
continue
|
|
138
138
|
|
|
139
|
-
if
|
|
139
|
+
if (
|
|
140
|
+
user_query in ["/exit", "/quit"]
|
|
141
|
+
or user_query.startswith("/exit")
|
|
142
|
+
or user_query.startswith("/quit")
|
|
143
|
+
):
|
|
140
144
|
break
|
|
141
145
|
|
|
142
146
|
if user_query == "/clear":
|
|
143
|
-
self.agent.clear_history()
|
|
147
|
+
await self.agent.clear_history()
|
|
144
148
|
self.console.print("[green]Conversation history cleared.[/green]\n")
|
|
145
149
|
continue
|
|
146
150
|
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""Conversation history storage for SQLSaber."""
|
|
2
|
+
|
|
3
|
+
from .manager import ConversationManager
|
|
4
|
+
from .models import Conversation, ConversationMessage
|
|
5
|
+
from .storage import ConversationStorage
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"Conversation",
|
|
9
|
+
"ConversationMessage",
|
|
10
|
+
"ConversationStorage",
|
|
11
|
+
"ConversationManager",
|
|
12
|
+
]
|