sqlsaber 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- sqlsaber/agents/base.py +0 -108
- sqlsaber/cli/commands.py +34 -2
- sqlsaber/cli/display.py +0 -30
- sqlsaber/cli/interactive.py +76 -24
- sqlsaber/cli/streaming.py +0 -4
- sqlsaber/cli/threads.py +301 -0
- sqlsaber/database/schema.py +30 -2
- sqlsaber/threads/__init__.py +5 -0
- sqlsaber/threads/storage.py +303 -0
- sqlsaber/tools/__init__.py +0 -2
- sqlsaber/tools/base.py +0 -12
- sqlsaber/tools/enums.py +0 -2
- sqlsaber/tools/instructions.py +3 -23
- sqlsaber/tools/registry.py +0 -12
- {sqlsaber-0.16.0.dist-info → sqlsaber-0.17.0.dist-info}/METADATA +12 -3
- {sqlsaber-0.16.0.dist-info → sqlsaber-0.17.0.dist-info}/RECORD +19 -23
- sqlsaber/conversation/__init__.py +0 -12
- sqlsaber/conversation/manager.py +0 -224
- sqlsaber/conversation/models.py +0 -120
- sqlsaber/conversation/storage.py +0 -362
- sqlsaber/models/__init__.py +0 -10
- sqlsaber/models/types.py +0 -40
- sqlsaber/tools/visualization_tools.py +0 -144
- {sqlsaber-0.16.0.dist-info → sqlsaber-0.17.0.dist-info}/WHEEL +0 -0
- {sqlsaber-0.16.0.dist-info → sqlsaber-0.17.0.dist-info}/entry_points.txt +0 -0
- {sqlsaber-0.16.0.dist-info → sqlsaber-0.17.0.dist-info}/licenses/LICENSE +0 -0
sqlsaber/tools/enums.py
CHANGED
|
@@ -8,7 +8,6 @@ class ToolCategory(Enum):
|
|
|
8
8
|
|
|
9
9
|
GENERAL = "general"
|
|
10
10
|
SQL = "sql"
|
|
11
|
-
VISUALIZATION = "visualization"
|
|
12
11
|
|
|
13
12
|
|
|
14
13
|
class WorkflowPosition(Enum):
|
|
@@ -17,5 +16,4 @@ class WorkflowPosition(Enum):
|
|
|
17
16
|
DISCOVERY = "discovery"
|
|
18
17
|
ANALYSIS = "analysis"
|
|
19
18
|
EXECUTION = "execution"
|
|
20
|
-
VISUALIZATION = "visualization"
|
|
21
19
|
OTHER = "other"
|
sqlsaber/tools/instructions.py
CHANGED
|
@@ -69,8 +69,7 @@ Your responsibilities:
|
|
|
69
69
|
2. Use the provided tools efficiently to explore database schema
|
|
70
70
|
3. Generate appropriate SQL queries
|
|
71
71
|
4. Execute queries safely - queries that modify the database are not allowed
|
|
72
|
-
5. Format and explain results clearly
|
|
73
|
-
6. Create visualizations when requested or when they would be helpful"""
|
|
72
|
+
5. Format and explain results clearly"""
|
|
74
73
|
|
|
75
74
|
def _sort_tools_by_workflow(self, tools: list[Tool]) -> list[Tool]:
|
|
76
75
|
"""Sort tools by priority and workflow position."""
|
|
@@ -79,14 +78,13 @@ Your responsibilities:
|
|
|
79
78
|
WorkflowPosition.DISCOVERY: 1,
|
|
80
79
|
WorkflowPosition.ANALYSIS: 2,
|
|
81
80
|
WorkflowPosition.EXECUTION: 3,
|
|
82
|
-
WorkflowPosition.
|
|
83
|
-
WorkflowPosition.OTHER: 5,
|
|
81
|
+
WorkflowPosition.OTHER: 4,
|
|
84
82
|
}
|
|
85
83
|
|
|
86
84
|
return sorted(
|
|
87
85
|
tools,
|
|
88
86
|
key=lambda tool: (
|
|
89
|
-
position_order.get(tool.get_workflow_position(),
|
|
87
|
+
position_order.get(tool.get_workflow_position(), 4),
|
|
90
88
|
tool.get_priority(),
|
|
91
89
|
tool.name,
|
|
92
90
|
),
|
|
@@ -145,19 +143,6 @@ Your responsibilities:
|
|
|
145
143
|
)
|
|
146
144
|
step += 1
|
|
147
145
|
|
|
148
|
-
# Add visualization tools
|
|
149
|
-
if WorkflowPosition.VISUALIZATION in workflow_groups:
|
|
150
|
-
viz_tools = workflow_groups[WorkflowPosition.VISUALIZATION]
|
|
151
|
-
for tool in viz_tools:
|
|
152
|
-
usage = tool.get_usage_instructions()
|
|
153
|
-
if usage:
|
|
154
|
-
instructions.append(f"{step}. {usage}")
|
|
155
|
-
else:
|
|
156
|
-
instructions.append(
|
|
157
|
-
f"{step}. Use '{tool.name}' when creating visualizations"
|
|
158
|
-
)
|
|
159
|
-
step += 1
|
|
160
|
-
|
|
161
146
|
return "\n".join(instructions) if len(instructions) > 1 else ""
|
|
162
147
|
|
|
163
148
|
def _build_tool_guidelines(self, sorted_tools: list[Tool]) -> str:
|
|
@@ -195,11 +180,6 @@ Your responsibilities:
|
|
|
195
180
|
]
|
|
196
181
|
)
|
|
197
182
|
|
|
198
|
-
if ToolCategory.VISUALIZATION in categories:
|
|
199
|
-
guidelines.append(
|
|
200
|
-
"- Create visualizations when they would enhance understanding of the data"
|
|
201
|
-
)
|
|
202
|
-
|
|
203
183
|
return "\n".join(guidelines)
|
|
204
184
|
|
|
205
185
|
def _is_usage_in_workflow(self, usage: str) -> bool:
|
sqlsaber/tools/registry.py
CHANGED
|
@@ -101,18 +101,6 @@ class ToolRegistry:
|
|
|
101
101
|
names = self.list_tools(category)
|
|
102
102
|
return [self.get_tool(name) for name in names]
|
|
103
103
|
|
|
104
|
-
def get_tool_definitions(self, category: str | ToolCategory | None = None) -> list:
|
|
105
|
-
"""Get tool definitions for all tools.
|
|
106
|
-
|
|
107
|
-
Args:
|
|
108
|
-
category: Optional category to filter by (string or ToolCategory enum)
|
|
109
|
-
|
|
110
|
-
Returns:
|
|
111
|
-
List of ToolDefinition objects
|
|
112
|
-
"""
|
|
113
|
-
tools = self.get_all_tools(category)
|
|
114
|
-
return [tool.to_definition() for tool in tools]
|
|
115
|
-
|
|
116
104
|
|
|
117
105
|
# Global registry instance
|
|
118
106
|
tool_registry = ToolRegistry()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sqlsaber
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.17.0
|
|
4
4
|
Summary: SQLsaber - Open-source agentic SQL assistant
|
|
5
5
|
License-File: LICENSE
|
|
6
6
|
Requires-Python: >=3.12
|
|
@@ -16,7 +16,6 @@ Requires-Dist: platformdirs>=4.0.0
|
|
|
16
16
|
Requires-Dist: pydantic-ai
|
|
17
17
|
Requires-Dist: questionary>=2.1.0
|
|
18
18
|
Requires-Dist: rich>=13.7.0
|
|
19
|
-
Requires-Dist: uniplot>=0.21.2
|
|
20
19
|
Description-Content-Type: text/markdown
|
|
21
20
|
|
|
22
21
|
# SQLsaber
|
|
@@ -40,6 +39,7 @@ Ask your questions in natural language and `sqlsaber` will gather the right cont
|
|
|
40
39
|
- [Usage](#usage)
|
|
41
40
|
- [Interactive Mode](#interactive-mode)
|
|
42
41
|
- [Single Query](#single-query)
|
|
42
|
+
- [Resume Past Conversation](#resume-past-conversation)
|
|
43
43
|
- [Database Selection](#database-selection)
|
|
44
44
|
- [Examples](#examples)
|
|
45
45
|
- [MCP Server Integration](#mcp-server-integration)
|
|
@@ -56,6 +56,7 @@ Ask your questions in natural language and `sqlsaber` will gather the right cont
|
|
|
56
56
|
- 🛡️ Safe query execution (read-only by default)
|
|
57
57
|
- 🧠 Memory management
|
|
58
58
|
- 💬 Interactive REPL mode
|
|
59
|
+
- 🧵 Conversation threads (store, display, and resume conversations)
|
|
59
60
|
- 🗄️ Support for PostgreSQL, SQLite, and MySQL
|
|
60
61
|
- 🔌 MCP (Model Context Protocol) server support
|
|
61
62
|
- 🎨 Beautiful formatted output
|
|
@@ -147,6 +148,14 @@ echo "show me all users created this month" | saber
|
|
|
147
148
|
cat query.txt | saber
|
|
148
149
|
```
|
|
149
150
|
|
|
151
|
+
### Resume Past Conversation
|
|
152
|
+
|
|
153
|
+
Continue a previous conversation thread:
|
|
154
|
+
|
|
155
|
+
```bash
|
|
156
|
+
saber threads resume THREAD_ID
|
|
157
|
+
```
|
|
158
|
+
|
|
150
159
|
### Database Selection
|
|
151
160
|
|
|
152
161
|
Use a specific database connection:
|
|
@@ -230,7 +239,7 @@ SQLsaber uses a multi-step agentic process to gather the right context and execu
|
|
|
230
239
|
|
|
231
240
|
4. **SQL Generation**: Creates optimized SQL queries based on natural language input
|
|
232
241
|
5. **Safe Execution**: Runs read-only queries with built-in protections against destructive operations
|
|
233
|
-
6. **Result Formatting**: Presents results with explanations in tables
|
|
242
|
+
6. **Result Formatting**: Presents results with explanations in tables
|
|
234
243
|
|
|
235
244
|
## Contributing
|
|
236
245
|
|
|
@@ -1,19 +1,20 @@
|
|
|
1
1
|
sqlsaber/__init__.py,sha256=HjS8ULtP4MGpnTL7njVY45NKV9Fi4e_yeYuY-hyXWQc,73
|
|
2
2
|
sqlsaber/__main__.py,sha256=RIHxWeWh2QvLfah-2OkhI5IJxojWfy4fXpMnVEJYvxw,78
|
|
3
3
|
sqlsaber/agents/__init__.py,sha256=i_MI2eWMQaVzGikKU71FPCmSQxNDKq36Imq1PrYoIPU,130
|
|
4
|
-
sqlsaber/agents/base.py,sha256=
|
|
4
|
+
sqlsaber/agents/base.py,sha256=7zOZTHKxUuU0uMc-NTaCkkBfDnU3jtwbT8_eP1ZtJ2k,2615
|
|
5
5
|
sqlsaber/agents/mcp.py,sha256=GcJTx7YDYH6aaxIADEIxSgcWAdWakUx395JIzVnf17U,768
|
|
6
6
|
sqlsaber/agents/pydantic_ai_agent.py,sha256=dGdsgyxCZvfK-v-MH8KimKOr-xb2aSfSWY8CMcOUCT8,6795
|
|
7
7
|
sqlsaber/cli/__init__.py,sha256=qVSLVJLLJYzoC6aj6y9MFrzZvAwc4_OgxU9DlkQnZ4M,86
|
|
8
8
|
sqlsaber/cli/auth.py,sha256=jTsRgbmlGPlASSuIKmdjjwfqtKvjfKd_cTYxX0-QqaQ,7400
|
|
9
|
-
sqlsaber/cli/commands.py,sha256=
|
|
9
|
+
sqlsaber/cli/commands.py,sha256=ffEJq8WOfX7YBJzn5UCgeXBpU5lXnVBUMkbJEQ0R9WY,8169
|
|
10
10
|
sqlsaber/cli/completers.py,sha256=HsUPjaZweLSeYCWkAcgMl8FylQ1xjWBWYTEL_9F6xfU,6430
|
|
11
11
|
sqlsaber/cli/database.py,sha256=atwg3l8acQ3YTDuhq7vNrBN6tpOv0syz6V62KTF-Bh8,12910
|
|
12
|
-
sqlsaber/cli/display.py,sha256=
|
|
13
|
-
sqlsaber/cli/interactive.py,sha256=
|
|
12
|
+
sqlsaber/cli/display.py,sha256=wa7BjTBwXwqLT145Q1AEL0C28pQJTrvDN10mnFMjqsg,8554
|
|
13
|
+
sqlsaber/cli/interactive.py,sha256=QqBjSsjtb6XoBVRyGS520cQrm7DvWC-obQ_EflcygbU,12051
|
|
14
14
|
sqlsaber/cli/memory.py,sha256=OufHFJFwV0_GGn7LvKRTJikkWhV1IwNIUDOxFPHXOaQ,7794
|
|
15
15
|
sqlsaber/cli/models.py,sha256=ZewtwGQwhd9b-yxBAPKePolvI1qQG-EkmeWAGMqtWNQ,8986
|
|
16
|
-
sqlsaber/cli/streaming.py,sha256=
|
|
16
|
+
sqlsaber/cli/streaming.py,sha256=WNqBYYbWtL5CNQkRg5YWhYpWKI8qz7JmqneB2DXTOHY,5259
|
|
17
|
+
sqlsaber/cli/threads.py,sha256=xti7_kvh3loQfLb7_GC8wSULJ4Oj56jXY8GQp69CQCI,11111
|
|
17
18
|
sqlsaber/config/__init__.py,sha256=olwC45k8Nc61yK0WmPUk7XHdbsZH9HuUAbwnmKe3IgA,100
|
|
18
19
|
sqlsaber/config/api_keys.py,sha256=RqWQCko1tY7sES7YOlexgBH5Hd5ne_kGXHdBDNqcV2U,3649
|
|
19
20
|
sqlsaber/config/auth.py,sha256=b5qB2h1doXyO9Bn8z0CcL8LAR2jF431gGXBGKLgTmtQ,2756
|
|
@@ -22,30 +23,25 @@ sqlsaber/config/oauth_flow.py,sha256=A3bSXaBLzuAfXV2ZPA94m9NV33c2MyL6M4ii9oEkswQ
|
|
|
22
23
|
sqlsaber/config/oauth_tokens.py,sha256=C9z35hyx-PvSAYdC1LNf3rg9_wsEIY56hkEczelbad0,6015
|
|
23
24
|
sqlsaber/config/providers.py,sha256=JFjeJv1K5Q93zWSlWq3hAvgch1TlgoF0qFa0KJROkKY,2957
|
|
24
25
|
sqlsaber/config/settings.py,sha256=vgb_RXaM-7DgbxYDmWNw1cSyMqwys4j3qNCvM4bljwI,5586
|
|
25
|
-
sqlsaber/conversation/__init__.py,sha256=xa-1gX6NsZpVGg_LDrsZAtDtsDo5FZc1SO8gwtm_IPk,302
|
|
26
|
-
sqlsaber/conversation/manager.py,sha256=LDfmKGIMvTzsL7S0aXGWw6Ve54CHIeTGLU4qwes2NgU,7046
|
|
27
|
-
sqlsaber/conversation/models.py,sha256=fq4wpIB2yxLCQtsXhdpDji4FpscG2ayrOBACrNvgF14,3510
|
|
28
|
-
sqlsaber/conversation/storage.py,sha256=phpGEnZjXVFTmV5PalCKZpiO9VFHubMMfWA9OJCDbwc,11626
|
|
29
26
|
sqlsaber/database/__init__.py,sha256=a_gtKRJnZVO8-fEZI7g3Z8YnGa6Nio-5Y50PgVp07ss,176
|
|
30
27
|
sqlsaber/database/connection.py,sha256=sJtIIe0GVbo-1Py9-j66UxJoY1aKL9gqk68jkDL-Kvk,15123
|
|
31
28
|
sqlsaber/database/resolver.py,sha256=RPXF5EoKzvQDDLmPGNHYd2uG_oNICH8qvUjBp6iXmNY,3348
|
|
32
|
-
sqlsaber/database/schema.py,sha256=
|
|
29
|
+
sqlsaber/database/schema.py,sha256=r12qoN3tdtAXdO22EKlauAe7QwOm8lL2vTMM59XEMMY,26594
|
|
33
30
|
sqlsaber/mcp/__init__.py,sha256=COdWq7wauPBp5Ew8tfZItFzbcLDSEkHBJSMhxzy8C9c,112
|
|
34
31
|
sqlsaber/mcp/mcp.py,sha256=X12oCMZYAtgJ7MNuh5cqz8y3lALrOzkXWcfpuY0Ijxk,3950
|
|
35
32
|
sqlsaber/memory/__init__.py,sha256=GiWkU6f6YYVV0EvvXDmFWe_CxarmDCql05t70MkTEWs,63
|
|
36
33
|
sqlsaber/memory/manager.py,sha256=p3fybMVfH-E4ApT1ZRZUnQIWSk9dkfUPCyfkmA0HALs,2739
|
|
37
34
|
sqlsaber/memory/storage.py,sha256=ne8szLlGj5NELheqLnI7zu21V8YS4rtpYGGC7tOmi-s,5745
|
|
38
|
-
sqlsaber/
|
|
39
|
-
sqlsaber/
|
|
40
|
-
sqlsaber/tools/__init__.py,sha256=
|
|
41
|
-
sqlsaber/tools/base.py,sha256=
|
|
42
|
-
sqlsaber/tools/enums.py,sha256=
|
|
43
|
-
sqlsaber/tools/instructions.py,sha256=
|
|
44
|
-
sqlsaber/tools/registry.py,sha256=
|
|
35
|
+
sqlsaber/threads/__init__.py,sha256=Hh3dIG1tuC8fXprREUpslCIgPYz8_6o7aRLx4yNeO48,139
|
|
36
|
+
sqlsaber/threads/storage.py,sha256=rsUdxT4CR52D7xtGir9UlsFnBMk11jZeflzDrk2q4ME,11183
|
|
37
|
+
sqlsaber/tools/__init__.py,sha256=x3YdmX_7P0Qq_HtZHAgfIVKTLxYqKk6oc4tGsujQWsc,586
|
|
38
|
+
sqlsaber/tools/base.py,sha256=mHhvAj27BHmckyvuDLCPlAQdzABJyYxd9SJnaYAwwuA,1777
|
|
39
|
+
sqlsaber/tools/enums.py,sha256=CH32mL-0k9ZA18911xLpNtsgpV6tB85TktMj6uqGz54,411
|
|
40
|
+
sqlsaber/tools/instructions.py,sha256=X-x8maVkkyi16b6Tl0hcAFgjiYceZaSwyWTfmrvx8U8,9024
|
|
41
|
+
sqlsaber/tools/registry.py,sha256=HWOQMsNIdL4XZS6TeNUyrL-5KoSDH6PHsWd3X66o-18,3211
|
|
45
42
|
sqlsaber/tools/sql_tools.py,sha256=hM6tKqW5MDhFUt6MesoqhTUqIpq_5baIIDoN1MjDCXY,9647
|
|
46
|
-
sqlsaber/
|
|
47
|
-
sqlsaber-0.
|
|
48
|
-
sqlsaber-0.
|
|
49
|
-
sqlsaber-0.
|
|
50
|
-
sqlsaber-0.
|
|
51
|
-
sqlsaber-0.16.0.dist-info/RECORD,,
|
|
43
|
+
sqlsaber-0.17.0.dist-info/METADATA,sha256=epsxWGdmHWrFxLbbxWmOVsGqFMYmH1xUNAIGyGxpl_w,6141
|
|
44
|
+
sqlsaber-0.17.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
45
|
+
sqlsaber-0.17.0.dist-info/entry_points.txt,sha256=qEbOB7OffXPFgyJc7qEIJlMEX5RN9xdzLmWZa91zCQQ,162
|
|
46
|
+
sqlsaber-0.17.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
47
|
+
sqlsaber-0.17.0.dist-info/RECORD,,
|
|
@@ -1,12 +0,0 @@
|
|
|
1
|
-
"""Conversation history storage for SQLSaber."""
|
|
2
|
-
|
|
3
|
-
from .manager import ConversationManager
|
|
4
|
-
from .models import Conversation, ConversationMessage
|
|
5
|
-
from .storage import ConversationStorage
|
|
6
|
-
|
|
7
|
-
__all__ = [
|
|
8
|
-
"Conversation",
|
|
9
|
-
"ConversationMessage",
|
|
10
|
-
"ConversationStorage",
|
|
11
|
-
"ConversationManager",
|
|
12
|
-
]
|
sqlsaber/conversation/manager.py
DELETED
|
@@ -1,224 +0,0 @@
|
|
|
1
|
-
"""Manager for conversation storage operations."""
|
|
2
|
-
|
|
3
|
-
import logging
|
|
4
|
-
import uuid
|
|
5
|
-
from typing import Any
|
|
6
|
-
|
|
7
|
-
from .models import Conversation, ConversationMessage
|
|
8
|
-
from .storage import ConversationStorage
|
|
9
|
-
|
|
10
|
-
logger = logging.getLogger(__name__)
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class ConversationManager:
|
|
14
|
-
"""High-level manager for conversation storage operations."""
|
|
15
|
-
|
|
16
|
-
def __init__(self):
|
|
17
|
-
"""Initialize conversation manager."""
|
|
18
|
-
self._storage = ConversationStorage()
|
|
19
|
-
|
|
20
|
-
async def start_conversation(self, database_name: str) -> str:
|
|
21
|
-
"""Start a new conversation.
|
|
22
|
-
|
|
23
|
-
Args:
|
|
24
|
-
database_name: Name of the database for this conversation
|
|
25
|
-
|
|
26
|
-
Returns:
|
|
27
|
-
Conversation ID
|
|
28
|
-
"""
|
|
29
|
-
try:
|
|
30
|
-
return await self._storage.create_conversation(database_name)
|
|
31
|
-
except Exception as e:
|
|
32
|
-
logger.warning(f"Failed to start conversation: {e}")
|
|
33
|
-
return str(uuid.uuid4())
|
|
34
|
-
|
|
35
|
-
async def add_user_message(
|
|
36
|
-
self, conversation_id: str, content: str | dict[str, Any], index: int
|
|
37
|
-
) -> str:
|
|
38
|
-
"""Add a user message to the conversation.
|
|
39
|
-
|
|
40
|
-
Args:
|
|
41
|
-
conversation_id: ID of the conversation
|
|
42
|
-
content: Message content
|
|
43
|
-
index: Sequential index in conversation
|
|
44
|
-
|
|
45
|
-
Returns:
|
|
46
|
-
Message ID
|
|
47
|
-
"""
|
|
48
|
-
try:
|
|
49
|
-
return await self._storage.add_message(
|
|
50
|
-
conversation_id, "user", content, index
|
|
51
|
-
)
|
|
52
|
-
except Exception as e:
|
|
53
|
-
logger.warning(f"Failed to add user message: {e}")
|
|
54
|
-
return str(uuid.uuid4())
|
|
55
|
-
|
|
56
|
-
async def add_assistant_message(
|
|
57
|
-
self,
|
|
58
|
-
conversation_id: str,
|
|
59
|
-
content: list[dict[str, Any]] | dict[str, Any],
|
|
60
|
-
index: int,
|
|
61
|
-
) -> str:
|
|
62
|
-
"""Add an assistant message to the conversation.
|
|
63
|
-
|
|
64
|
-
Args:
|
|
65
|
-
conversation_id: ID of the conversation
|
|
66
|
-
content: Message content (typically ContentBlock list)
|
|
67
|
-
index: Sequential index in conversation
|
|
68
|
-
|
|
69
|
-
Returns:
|
|
70
|
-
Message ID
|
|
71
|
-
"""
|
|
72
|
-
try:
|
|
73
|
-
return await self._storage.add_message(
|
|
74
|
-
conversation_id, "assistant", content, index
|
|
75
|
-
)
|
|
76
|
-
except Exception as e:
|
|
77
|
-
logger.warning(f"Failed to add assistant message: {e}")
|
|
78
|
-
return str(uuid.uuid4())
|
|
79
|
-
|
|
80
|
-
async def add_tool_message(
|
|
81
|
-
self,
|
|
82
|
-
conversation_id: str,
|
|
83
|
-
content: list[dict[str, Any]] | dict[str, Any],
|
|
84
|
-
index: int,
|
|
85
|
-
) -> str:
|
|
86
|
-
"""Add a tool/system message to the conversation.
|
|
87
|
-
|
|
88
|
-
Args:
|
|
89
|
-
conversation_id: ID of the conversation
|
|
90
|
-
content: Message content (typically tool results)
|
|
91
|
-
index: Sequential index in conversation
|
|
92
|
-
|
|
93
|
-
Returns:
|
|
94
|
-
Message ID
|
|
95
|
-
"""
|
|
96
|
-
try:
|
|
97
|
-
return await self._storage.add_message(
|
|
98
|
-
conversation_id, "tool", content, index
|
|
99
|
-
)
|
|
100
|
-
except Exception as e:
|
|
101
|
-
logger.warning(f"Failed to add tool message: {e}")
|
|
102
|
-
return str(uuid.uuid4())
|
|
103
|
-
|
|
104
|
-
async def end_conversation(self, conversation_id: str) -> bool:
|
|
105
|
-
"""End a conversation.
|
|
106
|
-
|
|
107
|
-
Args:
|
|
108
|
-
conversation_id: ID of the conversation to end
|
|
109
|
-
|
|
110
|
-
Returns:
|
|
111
|
-
True if successfully ended, False otherwise
|
|
112
|
-
"""
|
|
113
|
-
try:
|
|
114
|
-
return await self._storage.end_conversation(conversation_id)
|
|
115
|
-
except Exception as e:
|
|
116
|
-
logger.warning(f"Failed to end conversation: {e}")
|
|
117
|
-
return False
|
|
118
|
-
|
|
119
|
-
async def get_conversation(self, conversation_id: str) -> Conversation | None:
|
|
120
|
-
"""Get a conversation by ID.
|
|
121
|
-
|
|
122
|
-
Args:
|
|
123
|
-
conversation_id: ID of the conversation
|
|
124
|
-
|
|
125
|
-
Returns:
|
|
126
|
-
Conversation object or None if not found
|
|
127
|
-
"""
|
|
128
|
-
try:
|
|
129
|
-
return await self._storage.get_conversation(conversation_id)
|
|
130
|
-
except Exception as e:
|
|
131
|
-
logger.warning(f"Failed to get conversation: {e}")
|
|
132
|
-
return None
|
|
133
|
-
|
|
134
|
-
async def get_conversation_messages(
|
|
135
|
-
self, conversation_id: str
|
|
136
|
-
) -> list[ConversationMessage]:
|
|
137
|
-
"""Get all messages for a conversation.
|
|
138
|
-
|
|
139
|
-
Args:
|
|
140
|
-
conversation_id: ID of the conversation
|
|
141
|
-
|
|
142
|
-
Returns:
|
|
143
|
-
List of messages ordered by index
|
|
144
|
-
"""
|
|
145
|
-
try:
|
|
146
|
-
return await self._storage.get_conversation_messages(conversation_id)
|
|
147
|
-
except Exception as e:
|
|
148
|
-
logger.warning(f"Failed to get conversation messages: {e}")
|
|
149
|
-
return []
|
|
150
|
-
|
|
151
|
-
async def list_conversations(
|
|
152
|
-
self, database_name: str | None = None, limit: int = 50
|
|
153
|
-
) -> list[Conversation]:
|
|
154
|
-
"""List conversations.
|
|
155
|
-
|
|
156
|
-
Args:
|
|
157
|
-
database_name: Optional database name filter
|
|
158
|
-
limit: Maximum number of conversations to return
|
|
159
|
-
|
|
160
|
-
Returns:
|
|
161
|
-
List of conversations ordered by start time (newest first)
|
|
162
|
-
"""
|
|
163
|
-
try:
|
|
164
|
-
return await self._storage.list_conversations(database_name, limit)
|
|
165
|
-
except Exception as e:
|
|
166
|
-
logger.warning(f"Failed to list conversations: {e}")
|
|
167
|
-
return []
|
|
168
|
-
|
|
169
|
-
async def delete_conversation(self, conversation_id: str) -> bool:
|
|
170
|
-
"""Delete a conversation.
|
|
171
|
-
|
|
172
|
-
Args:
|
|
173
|
-
conversation_id: ID of the conversation to delete
|
|
174
|
-
|
|
175
|
-
Returns:
|
|
176
|
-
True if successfully deleted, False otherwise
|
|
177
|
-
"""
|
|
178
|
-
try:
|
|
179
|
-
return await self._storage.delete_conversation(conversation_id)
|
|
180
|
-
except Exception as e:
|
|
181
|
-
logger.warning(f"Failed to delete conversation: {e}")
|
|
182
|
-
return False
|
|
183
|
-
|
|
184
|
-
async def get_database_names(self) -> list[str]:
|
|
185
|
-
"""Get list of database names with conversations.
|
|
186
|
-
|
|
187
|
-
Returns:
|
|
188
|
-
List of unique database names
|
|
189
|
-
"""
|
|
190
|
-
try:
|
|
191
|
-
return await self._storage.get_database_names()
|
|
192
|
-
except Exception as e:
|
|
193
|
-
logger.warning(f"Failed to get database names: {e}")
|
|
194
|
-
return []
|
|
195
|
-
|
|
196
|
-
async def restore_conversation_to_agent(
|
|
197
|
-
self, conversation_id: str, agent_history: list[dict[str, Any]]
|
|
198
|
-
) -> bool:
|
|
199
|
-
"""Restore a conversation's messages to an agent's in-memory history.
|
|
200
|
-
|
|
201
|
-
Args:
|
|
202
|
-
conversation_id: ID of the conversation to restore
|
|
203
|
-
agent_history: Agent's conversation_history list to populate
|
|
204
|
-
|
|
205
|
-
Returns:
|
|
206
|
-
True if successfully restored, False otherwise
|
|
207
|
-
"""
|
|
208
|
-
try:
|
|
209
|
-
messages = await self.get_conversation_messages(conversation_id)
|
|
210
|
-
|
|
211
|
-
# Clear existing history
|
|
212
|
-
agent_history.clear()
|
|
213
|
-
|
|
214
|
-
# Convert messages back to agent format
|
|
215
|
-
for msg in messages:
|
|
216
|
-
if msg.role in ("user", "assistant", "tool"):
|
|
217
|
-
agent_history.append({"role": msg.role, "content": msg.content})
|
|
218
|
-
|
|
219
|
-
logger.debug(f"Restored {len(messages)} messages to agent history")
|
|
220
|
-
return True
|
|
221
|
-
|
|
222
|
-
except Exception as e:
|
|
223
|
-
logger.warning(f"Failed to restore conversation to agent: {e}")
|
|
224
|
-
return False
|
sqlsaber/conversation/models.py
DELETED
|
@@ -1,120 +0,0 @@
|
|
|
1
|
-
"""Data models for conversation storage."""
|
|
2
|
-
|
|
3
|
-
import json
|
|
4
|
-
import time
|
|
5
|
-
from dataclasses import dataclass
|
|
6
|
-
from typing import Any
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
@dataclass
|
|
10
|
-
class Conversation:
|
|
11
|
-
"""Represents a conversation session."""
|
|
12
|
-
|
|
13
|
-
id: str
|
|
14
|
-
database_name: str
|
|
15
|
-
started_at: float
|
|
16
|
-
ended_at: float | None = None
|
|
17
|
-
|
|
18
|
-
def to_dict(self) -> dict[str, Any]:
|
|
19
|
-
"""Convert to dictionary for JSON serialization."""
|
|
20
|
-
return {
|
|
21
|
-
"id": self.id,
|
|
22
|
-
"database_name": self.database_name,
|
|
23
|
-
"started_at": self.started_at,
|
|
24
|
-
"ended_at": self.ended_at,
|
|
25
|
-
}
|
|
26
|
-
|
|
27
|
-
@classmethod
|
|
28
|
-
def from_dict(cls, data: dict[str, Any]) -> "Conversation":
|
|
29
|
-
"""Create from dictionary."""
|
|
30
|
-
return cls(
|
|
31
|
-
id=data["id"],
|
|
32
|
-
database_name=data["database_name"],
|
|
33
|
-
started_at=data["started_at"],
|
|
34
|
-
ended_at=data.get("ended_at"),
|
|
35
|
-
)
|
|
36
|
-
|
|
37
|
-
def formatted_start_time(self) -> str:
|
|
38
|
-
"""Get human-readable start timestamp."""
|
|
39
|
-
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.started_at))
|
|
40
|
-
|
|
41
|
-
def formatted_end_time(self) -> str | None:
|
|
42
|
-
"""Get human-readable end timestamp."""
|
|
43
|
-
if self.ended_at is None:
|
|
44
|
-
return None
|
|
45
|
-
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.ended_at))
|
|
46
|
-
|
|
47
|
-
def duration_seconds(self) -> float | None:
|
|
48
|
-
"""Get conversation duration in seconds."""
|
|
49
|
-
if self.ended_at is None:
|
|
50
|
-
return None
|
|
51
|
-
return self.ended_at - self.started_at
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
@dataclass
|
|
55
|
-
class ConversationMessage:
|
|
56
|
-
"""Represents a single message in a conversation."""
|
|
57
|
-
|
|
58
|
-
id: str
|
|
59
|
-
conversation_id: str
|
|
60
|
-
role: str
|
|
61
|
-
content: dict[str, Any] | str
|
|
62
|
-
index_in_conv: int
|
|
63
|
-
created_at: float
|
|
64
|
-
|
|
65
|
-
def to_dict(self) -> dict[str, Any]:
|
|
66
|
-
"""Convert to dictionary for JSON serialization."""
|
|
67
|
-
return {
|
|
68
|
-
"id": self.id,
|
|
69
|
-
"conversation_id": self.conversation_id,
|
|
70
|
-
"role": self.role,
|
|
71
|
-
"content": self.content,
|
|
72
|
-
"index_in_conv": self.index_in_conv,
|
|
73
|
-
"created_at": self.created_at,
|
|
74
|
-
}
|
|
75
|
-
|
|
76
|
-
@classmethod
|
|
77
|
-
def from_dict(cls, data: dict[str, Any]) -> "ConversationMessage":
|
|
78
|
-
"""Create from dictionary."""
|
|
79
|
-
return cls(
|
|
80
|
-
id=data["id"],
|
|
81
|
-
conversation_id=data["conversation_id"],
|
|
82
|
-
role=data["role"],
|
|
83
|
-
content=data["content"],
|
|
84
|
-
index_in_conv=data["index_in_conv"],
|
|
85
|
-
created_at=data["created_at"],
|
|
86
|
-
)
|
|
87
|
-
|
|
88
|
-
def formatted_timestamp(self) -> str:
|
|
89
|
-
"""Get human-readable timestamp."""
|
|
90
|
-
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.created_at))
|
|
91
|
-
|
|
92
|
-
def content_json(self) -> str:
|
|
93
|
-
"""Get content as JSON string for storage."""
|
|
94
|
-
return json.dumps(self.content)
|
|
95
|
-
|
|
96
|
-
@classmethod
|
|
97
|
-
def from_storage_data(
|
|
98
|
-
cls,
|
|
99
|
-
id_: str,
|
|
100
|
-
conversation_id: str,
|
|
101
|
-
role: str,
|
|
102
|
-
content_json: str,
|
|
103
|
-
index_in_conv: int,
|
|
104
|
-
created_at: float,
|
|
105
|
-
) -> "ConversationMessage":
|
|
106
|
-
"""Create from SQLite storage data."""
|
|
107
|
-
try:
|
|
108
|
-
content = json.loads(content_json)
|
|
109
|
-
except json.JSONDecodeError:
|
|
110
|
-
# Fallback to string content for malformed JSON
|
|
111
|
-
content = content_json
|
|
112
|
-
|
|
113
|
-
return cls(
|
|
114
|
-
id=id_,
|
|
115
|
-
conversation_id=conversation_id,
|
|
116
|
-
role=role,
|
|
117
|
-
content=content,
|
|
118
|
-
index_in_conv=index_in_conv,
|
|
119
|
-
created_at=created_at,
|
|
120
|
-
)
|