AstrBot 4.5.8__py3-none-any.whl → 4.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/core/agent/mcp_client.py +152 -26
- astrbot/core/agent/message.py +7 -0
- astrbot/core/config/default.py +8 -1
- astrbot/core/core_lifecycle.py +8 -0
- astrbot/core/db/__init__.py +50 -1
- astrbot/core/db/migration/migra_webchat_session.py +131 -0
- astrbot/core/db/po.py +49 -13
- astrbot/core/db/sqlite.py +102 -3
- astrbot/core/knowledge_base/kb_helper.py +314 -33
- astrbot/core/knowledge_base/kb_mgr.py +45 -1
- astrbot/core/knowledge_base/parsers/url_parser.py +103 -0
- astrbot/core/knowledge_base/prompts.py +65 -0
- astrbot/core/pipeline/process_stage/method/llm_request.py +28 -14
- astrbot/core/pipeline/process_stage/utils.py +60 -16
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +13 -10
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +8 -4
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +0 -4
- astrbot/core/provider/entities.py +22 -9
- astrbot/core/provider/func_tool_manager.py +12 -9
- astrbot/core/provider/sources/gemini_source.py +25 -8
- astrbot/core/provider/sources/openai_source.py +9 -16
- astrbot/dashboard/routes/chat.py +134 -77
- astrbot/dashboard/routes/knowledge_base.py +172 -0
- {astrbot-4.5.8.dist-info → astrbot-4.6.0.dist-info}/METADATA +4 -3
- {astrbot-4.5.8.dist-info → astrbot-4.6.0.dist-info}/RECORD +28 -25
- {astrbot-4.5.8.dist-info → astrbot-4.6.0.dist-info}/WHEEL +0 -0
- {astrbot-4.5.8.dist-info → astrbot-4.6.0.dist-info}/entry_points.txt +0 -0
- {astrbot-4.5.8.dist-info → astrbot-4.6.0.dist-info}/licenses/LICENSE +0 -0
astrbot/core/agent/mcp_client.py
CHANGED
|
@@ -4,6 +4,14 @@ from contextlib import AsyncExitStack
|
|
|
4
4
|
from datetime import timedelta
|
|
5
5
|
from typing import Generic
|
|
6
6
|
|
|
7
|
+
from tenacity import (
|
|
8
|
+
before_sleep_log,
|
|
9
|
+
retry,
|
|
10
|
+
retry_if_exception_type,
|
|
11
|
+
stop_after_attempt,
|
|
12
|
+
wait_exponential,
|
|
13
|
+
)
|
|
14
|
+
|
|
7
15
|
from astrbot import logger
|
|
8
16
|
from astrbot.core.agent.run_context import ContextWrapper
|
|
9
17
|
from astrbot.core.utils.log_pipe import LogPipe
|
|
@@ -12,21 +20,24 @@ from .run_context import TContext
|
|
|
12
20
|
from .tool import FunctionTool
|
|
13
21
|
|
|
14
22
|
try:
|
|
23
|
+
import anyio
|
|
15
24
|
import mcp
|
|
16
25
|
from mcp.client.sse import sse_client
|
|
17
26
|
except (ModuleNotFoundError, ImportError):
|
|
18
|
-
logger.warning(
|
|
27
|
+
logger.warning(
|
|
28
|
+
"Warning: Missing 'mcp' dependency, MCP services will be unavailable."
|
|
29
|
+
)
|
|
19
30
|
|
|
20
31
|
try:
|
|
21
32
|
from mcp.client.streamable_http import streamablehttp_client
|
|
22
33
|
except (ModuleNotFoundError, ImportError):
|
|
23
34
|
logger.warning(
|
|
24
|
-
"
|
|
35
|
+
"Warning: Missing 'mcp' dependency or MCP library version too old, Streamable HTTP connection unavailable.",
|
|
25
36
|
)
|
|
26
37
|
|
|
27
38
|
|
|
28
39
|
def _prepare_config(config: dict) -> dict:
|
|
29
|
-
"""
|
|
40
|
+
"""Prepare configuration, handle nested format"""
|
|
30
41
|
if config.get("mcpServers"):
|
|
31
42
|
first_key = next(iter(config["mcpServers"]))
|
|
32
43
|
config = config["mcpServers"][first_key]
|
|
@@ -35,7 +46,7 @@ def _prepare_config(config: dict) -> dict:
|
|
|
35
46
|
|
|
36
47
|
|
|
37
48
|
async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
|
38
|
-
"""
|
|
49
|
+
"""Quick test MCP server connectivity"""
|
|
39
50
|
import aiohttp
|
|
40
51
|
|
|
41
52
|
cfg = _prepare_config(config.copy())
|
|
@@ -50,7 +61,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
|
|
50
61
|
elif "type" in cfg:
|
|
51
62
|
transport_type = cfg["type"]
|
|
52
63
|
else:
|
|
53
|
-
raise Exception("MCP
|
|
64
|
+
raise Exception("MCP connection config missing transport or type field")
|
|
54
65
|
|
|
55
66
|
async with aiohttp.ClientSession() as session:
|
|
56
67
|
if transport_type == "streamable_http":
|
|
@@ -91,7 +102,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
|
|
91
102
|
return False, f"HTTP {response.status}: {response.reason}"
|
|
92
103
|
|
|
93
104
|
except asyncio.TimeoutError:
|
|
94
|
-
return False, f"
|
|
105
|
+
return False, f"Connection timeout: {timeout} seconds"
|
|
95
106
|
except Exception as e:
|
|
96
107
|
return False, f"{e!s}"
|
|
97
108
|
|
|
@@ -101,6 +112,7 @@ class MCPClient:
|
|
|
101
112
|
# Initialize session and client objects
|
|
102
113
|
self.session: mcp.ClientSession | None = None
|
|
103
114
|
self.exit_stack = AsyncExitStack()
|
|
115
|
+
self._old_exit_stacks: list[AsyncExitStack] = [] # Track old stacks for cleanup
|
|
104
116
|
|
|
105
117
|
self.name: str | None = None
|
|
106
118
|
self.active: bool = True
|
|
@@ -108,22 +120,32 @@ class MCPClient:
|
|
|
108
120
|
self.server_errlogs: list[str] = []
|
|
109
121
|
self.running_event = asyncio.Event()
|
|
110
122
|
|
|
123
|
+
# Store connection config for reconnection
|
|
124
|
+
self._mcp_server_config: dict | None = None
|
|
125
|
+
self._server_name: str | None = None
|
|
126
|
+
self._reconnect_lock = asyncio.Lock() # Lock for thread-safe reconnection
|
|
127
|
+
self._reconnecting: bool = False # For logging and debugging
|
|
128
|
+
|
|
111
129
|
async def connect_to_server(self, mcp_server_config: dict, name: str):
|
|
112
|
-
"""
|
|
130
|
+
"""Connect to MCP server
|
|
113
131
|
|
|
114
|
-
|
|
115
|
-
1.
|
|
116
|
-
|
|
117
|
-
|
|
132
|
+
If `url` parameter exists:
|
|
133
|
+
1. When transport is specified as `streamable_http`, use Streamable HTTP connection.
|
|
134
|
+
2. When transport is specified as `sse`, use SSE connection.
|
|
135
|
+
3. If not specified, default to SSE connection to MCP service.
|
|
118
136
|
|
|
119
137
|
Args:
|
|
120
138
|
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
|
|
121
139
|
|
|
122
140
|
"""
|
|
141
|
+
# Store config for reconnection
|
|
142
|
+
self._mcp_server_config = mcp_server_config
|
|
143
|
+
self._server_name = name
|
|
144
|
+
|
|
123
145
|
cfg = _prepare_config(mcp_server_config.copy())
|
|
124
146
|
|
|
125
147
|
def logging_callback(msg: str):
|
|
126
|
-
#
|
|
148
|
+
# Handle MCP service error logs
|
|
127
149
|
print(f"MCP Server {name} Error: {msg}")
|
|
128
150
|
self.server_errlogs.append(msg)
|
|
129
151
|
|
|
@@ -137,7 +159,7 @@ class MCPClient:
|
|
|
137
159
|
elif "type" in cfg:
|
|
138
160
|
transport_type = cfg["type"]
|
|
139
161
|
else:
|
|
140
|
-
raise Exception("MCP
|
|
162
|
+
raise Exception("MCP connection config missing transport or type field")
|
|
141
163
|
|
|
142
164
|
if transport_type != "streamable_http":
|
|
143
165
|
# SSE transport method
|
|
@@ -193,7 +215,7 @@ class MCPClient:
|
|
|
193
215
|
)
|
|
194
216
|
|
|
195
217
|
def callback(msg: str):
|
|
196
|
-
#
|
|
218
|
+
# Handle MCP service error logs
|
|
197
219
|
self.server_errlogs.append(msg)
|
|
198
220
|
|
|
199
221
|
stdio_transport = await self.exit_stack.enter_async_context(
|
|
@@ -222,10 +244,120 @@ class MCPClient:
|
|
|
222
244
|
self.tools = response.tools
|
|
223
245
|
return response
|
|
224
246
|
|
|
247
|
+
async def _reconnect(self) -> None:
|
|
248
|
+
"""Reconnect to the MCP server using the stored configuration.
|
|
249
|
+
|
|
250
|
+
Uses asyncio.Lock to ensure thread-safe reconnection in concurrent environments.
|
|
251
|
+
|
|
252
|
+
Raises:
|
|
253
|
+
Exception: raised when reconnection fails
|
|
254
|
+
"""
|
|
255
|
+
async with self._reconnect_lock:
|
|
256
|
+
# Check if already reconnecting (useful for logging)
|
|
257
|
+
if self._reconnecting:
|
|
258
|
+
logger.debug(
|
|
259
|
+
f"MCP Client {self._server_name} is already reconnecting, skipping"
|
|
260
|
+
)
|
|
261
|
+
return
|
|
262
|
+
|
|
263
|
+
if not self._mcp_server_config or not self._server_name:
|
|
264
|
+
raise Exception("Cannot reconnect: missing connection configuration")
|
|
265
|
+
|
|
266
|
+
self._reconnecting = True
|
|
267
|
+
try:
|
|
268
|
+
logger.info(
|
|
269
|
+
f"Attempting to reconnect to MCP server {self._server_name}..."
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
# Save old exit_stack for later cleanup (don't close it now to avoid cancel scope issues)
|
|
273
|
+
if self.exit_stack:
|
|
274
|
+
self._old_exit_stacks.append(self.exit_stack)
|
|
275
|
+
|
|
276
|
+
# Mark old session as invalid
|
|
277
|
+
self.session = None
|
|
278
|
+
|
|
279
|
+
# Create new exit stack for new connection
|
|
280
|
+
self.exit_stack = AsyncExitStack()
|
|
281
|
+
|
|
282
|
+
# Reconnect using stored config
|
|
283
|
+
await self.connect_to_server(self._mcp_server_config, self._server_name)
|
|
284
|
+
await self.list_tools_and_save()
|
|
285
|
+
|
|
286
|
+
logger.info(
|
|
287
|
+
f"Successfully reconnected to MCP server {self._server_name}"
|
|
288
|
+
)
|
|
289
|
+
except Exception as e:
|
|
290
|
+
logger.error(
|
|
291
|
+
f"Failed to reconnect to MCP server {self._server_name}: {e}"
|
|
292
|
+
)
|
|
293
|
+
raise
|
|
294
|
+
finally:
|
|
295
|
+
self._reconnecting = False
|
|
296
|
+
|
|
297
|
+
async def call_tool_with_reconnect(
|
|
298
|
+
self,
|
|
299
|
+
tool_name: str,
|
|
300
|
+
arguments: dict,
|
|
301
|
+
read_timeout_seconds: timedelta,
|
|
302
|
+
) -> mcp.types.CallToolResult:
|
|
303
|
+
"""Call MCP tool with automatic reconnection on failure, max 2 retries.
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
tool_name: tool name
|
|
307
|
+
arguments: tool arguments
|
|
308
|
+
read_timeout_seconds: read timeout
|
|
309
|
+
|
|
310
|
+
Returns:
|
|
311
|
+
MCP tool call result
|
|
312
|
+
|
|
313
|
+
Raises:
|
|
314
|
+
ValueError: MCP session is not available
|
|
315
|
+
anyio.ClosedResourceError: raised after reconnection failure
|
|
316
|
+
"""
|
|
317
|
+
|
|
318
|
+
@retry(
|
|
319
|
+
retry=retry_if_exception_type(anyio.ClosedResourceError),
|
|
320
|
+
stop=stop_after_attempt(2),
|
|
321
|
+
wait=wait_exponential(multiplier=1, min=1, max=3),
|
|
322
|
+
before_sleep=before_sleep_log(logger, logging.WARNING),
|
|
323
|
+
reraise=True,
|
|
324
|
+
)
|
|
325
|
+
async def _call_with_retry():
|
|
326
|
+
if not self.session:
|
|
327
|
+
raise ValueError("MCP session is not available for MCP function tools.")
|
|
328
|
+
|
|
329
|
+
try:
|
|
330
|
+
return await self.session.call_tool(
|
|
331
|
+
name=tool_name,
|
|
332
|
+
arguments=arguments,
|
|
333
|
+
read_timeout_seconds=read_timeout_seconds,
|
|
334
|
+
)
|
|
335
|
+
except anyio.ClosedResourceError:
|
|
336
|
+
logger.warning(
|
|
337
|
+
f"MCP tool {tool_name} call failed (ClosedResourceError), attempting to reconnect..."
|
|
338
|
+
)
|
|
339
|
+
# Attempt to reconnect
|
|
340
|
+
await self._reconnect()
|
|
341
|
+
# Reraise the exception to trigger tenacity retry
|
|
342
|
+
raise
|
|
343
|
+
|
|
344
|
+
return await _call_with_retry()
|
|
345
|
+
|
|
225
346
|
async def cleanup(self):
|
|
226
|
-
"""Clean up resources"""
|
|
227
|
-
|
|
228
|
-
self.running_event.set()
|
|
347
|
+
"""Clean up resources including old exit stacks from reconnections"""
|
|
348
|
+
# Set running_event first to unblock any waiting tasks
|
|
349
|
+
self.running_event.set()
|
|
350
|
+
|
|
351
|
+
# Close current exit stack
|
|
352
|
+
try:
|
|
353
|
+
await self.exit_stack.aclose()
|
|
354
|
+
except Exception as e:
|
|
355
|
+
logger.debug(f"Error closing current exit stack: {e}")
|
|
356
|
+
|
|
357
|
+
# Don't close old exit stacks as they may be in different task contexts
|
|
358
|
+
# They will be garbage collected naturally
|
|
359
|
+
# Just clear the list to release references
|
|
360
|
+
self._old_exit_stacks.clear()
|
|
229
361
|
|
|
230
362
|
|
|
231
363
|
class MCPTool(FunctionTool, Generic[TContext]):
|
|
@@ -246,14 +378,8 @@ class MCPTool(FunctionTool, Generic[TContext]):
|
|
|
246
378
|
async def call(
|
|
247
379
|
self, context: ContextWrapper[TContext], **kwargs
|
|
248
380
|
) -> mcp.types.CallToolResult:
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
raise ValueError("MCP session is not available for MCP function tools.")
|
|
252
|
-
res = await session.call_tool(
|
|
253
|
-
name=self.mcp_tool.name,
|
|
381
|
+
return await self.mcp_client.call_tool_with_reconnect(
|
|
382
|
+
tool_name=self.mcp_tool.name,
|
|
254
383
|
arguments=kwargs,
|
|
255
|
-
read_timeout_seconds=timedelta(
|
|
256
|
-
seconds=context.tool_call_timeout,
|
|
257
|
-
),
|
|
384
|
+
read_timeout_seconds=timedelta(seconds=context.tool_call_timeout),
|
|
258
385
|
)
|
|
259
|
-
return res
|
astrbot/core/agent/message.py
CHANGED
|
@@ -119,6 +119,13 @@ class ToolCall(BaseModel):
|
|
|
119
119
|
"""The ID of the tool call."""
|
|
120
120
|
function: FunctionBody
|
|
121
121
|
"""The function body of the tool call."""
|
|
122
|
+
extra_content: dict[str, Any] | None = None
|
|
123
|
+
"""Extra metadata for the tool call."""
|
|
124
|
+
|
|
125
|
+
def model_dump(self, **kwargs: Any) -> dict[str, Any]:
|
|
126
|
+
if self.extra_content is None:
|
|
127
|
+
kwargs.setdefault("exclude", set()).add("extra_content")
|
|
128
|
+
return super().model_dump(**kwargs)
|
|
122
129
|
|
|
123
130
|
|
|
124
131
|
class ToolCallPart(BaseModel):
|
astrbot/core/config/default.py
CHANGED
|
@@ -4,7 +4,7 @@ import os
|
|
|
4
4
|
|
|
5
5
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
|
6
6
|
|
|
7
|
-
VERSION = "4.
|
|
7
|
+
VERSION = "4.6.0"
|
|
8
8
|
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
|
9
9
|
|
|
10
10
|
# 默认配置
|
|
@@ -137,6 +137,7 @@ DEFAULT_CONFIG = {
|
|
|
137
137
|
"kb_names": [], # 默认知识库名称列表
|
|
138
138
|
"kb_fusion_top_k": 20, # 知识库检索融合阶段返回结果数量
|
|
139
139
|
"kb_final_top_k": 5, # 知识库检索最终返回结果数量
|
|
140
|
+
"kb_agentic_mode": False,
|
|
140
141
|
}
|
|
141
142
|
|
|
142
143
|
|
|
@@ -2146,6 +2147,7 @@ CONFIG_METADATA_2 = {
|
|
|
2146
2147
|
"kb_names": {"type": "list", "items": {"type": "string"}},
|
|
2147
2148
|
"kb_fusion_top_k": {"type": "int", "default": 20},
|
|
2148
2149
|
"kb_final_top_k": {"type": "int", "default": 5},
|
|
2150
|
+
"kb_agentic_mode": {"type": "bool"},
|
|
2149
2151
|
},
|
|
2150
2152
|
},
|
|
2151
2153
|
}
|
|
@@ -2241,6 +2243,11 @@ CONFIG_METADATA_3 = {
|
|
|
2241
2243
|
"type": "int",
|
|
2242
2244
|
"hint": "从知识库中检索到的结果数量,越大可能获得越多相关信息,但也可能引入噪音。建议根据实际需求调整",
|
|
2243
2245
|
},
|
|
2246
|
+
"kb_agentic_mode": {
|
|
2247
|
+
"description": "Agentic 知识库检索",
|
|
2248
|
+
"type": "bool",
|
|
2249
|
+
"hint": "启用后,知识库检索将作为 LLM Tool,由模型自主决定何时调用知识库进行查询。需要模型支持函数调用能力。",
|
|
2250
|
+
},
|
|
2244
2251
|
},
|
|
2245
2252
|
},
|
|
2246
2253
|
"websearch": {
|
astrbot/core/core_lifecycle.py
CHANGED
|
@@ -22,6 +22,7 @@ from astrbot.core.config.default import VERSION
|
|
|
22
22
|
from astrbot.core.conversation_mgr import ConversationManager
|
|
23
23
|
from astrbot.core.db import BaseDatabase
|
|
24
24
|
from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46
|
|
25
|
+
from astrbot.core.db.migration.migra_webchat_session import migrate_webchat_session
|
|
25
26
|
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
|
26
27
|
from astrbot.core.persona_mgr import PersonaManager
|
|
27
28
|
from astrbot.core.pipeline.scheduler import PipelineContext, PipelineScheduler
|
|
@@ -103,6 +104,13 @@ class AstrBotCoreLifecycle:
|
|
|
103
104
|
logger.error(f"Migration from version 4.5 to 4.6 failed: {e!s}")
|
|
104
105
|
logger.error(traceback.format_exc())
|
|
105
106
|
|
|
107
|
+
# migration for webchat session
|
|
108
|
+
try:
|
|
109
|
+
await migrate_webchat_session(self.db)
|
|
110
|
+
except Exception as e:
|
|
111
|
+
logger.error(f"Migration for webchat session failed: {e!s}")
|
|
112
|
+
logger.error(traceback.format_exc())
|
|
113
|
+
|
|
106
114
|
# 初始化事件队列
|
|
107
115
|
self.event_queue = Queue()
|
|
108
116
|
|
astrbot/core/db/__init__.py
CHANGED
|
@@ -13,6 +13,7 @@ from astrbot.core.db.po import (
|
|
|
13
13
|
ConversationV2,
|
|
14
14
|
Persona,
|
|
15
15
|
PlatformMessageHistory,
|
|
16
|
+
PlatformSession,
|
|
16
17
|
PlatformStat,
|
|
17
18
|
Preference,
|
|
18
19
|
Stats,
|
|
@@ -183,7 +184,7 @@ class BaseDatabase(abc.ABC):
|
|
|
183
184
|
user_id: str,
|
|
184
185
|
offset_sec: int = 86400,
|
|
185
186
|
) -> None:
|
|
186
|
-
"""Delete platform message history records
|
|
187
|
+
"""Delete platform message history records newer than the specified offset."""
|
|
187
188
|
...
|
|
188
189
|
|
|
189
190
|
@abc.abstractmethod
|
|
@@ -313,3 +314,51 @@ class BaseDatabase(abc.ABC):
|
|
|
313
314
|
) -> tuple[list[dict], int]:
|
|
314
315
|
"""Get paginated session conversations with joined conversation and persona details, support search and platform filter."""
|
|
315
316
|
...
|
|
317
|
+
|
|
318
|
+
# ====
|
|
319
|
+
# Platform Session Management
|
|
320
|
+
# ====
|
|
321
|
+
|
|
322
|
+
@abc.abstractmethod
|
|
323
|
+
async def create_platform_session(
|
|
324
|
+
self,
|
|
325
|
+
creator: str,
|
|
326
|
+
platform_id: str = "webchat",
|
|
327
|
+
session_id: str | None = None,
|
|
328
|
+
display_name: str | None = None,
|
|
329
|
+
is_group: int = 0,
|
|
330
|
+
) -> PlatformSession:
|
|
331
|
+
"""Create a new Platform session."""
|
|
332
|
+
...
|
|
333
|
+
|
|
334
|
+
@abc.abstractmethod
|
|
335
|
+
async def get_platform_session_by_id(
|
|
336
|
+
self, session_id: str
|
|
337
|
+
) -> PlatformSession | None:
|
|
338
|
+
"""Get a Platform session by its ID."""
|
|
339
|
+
...
|
|
340
|
+
|
|
341
|
+
@abc.abstractmethod
|
|
342
|
+
async def get_platform_sessions_by_creator(
|
|
343
|
+
self,
|
|
344
|
+
creator: str,
|
|
345
|
+
platform_id: str | None = None,
|
|
346
|
+
page: int = 1,
|
|
347
|
+
page_size: int = 20,
|
|
348
|
+
) -> list[PlatformSession]:
|
|
349
|
+
"""Get all Platform sessions for a specific creator (username) and optionally platform."""
|
|
350
|
+
...
|
|
351
|
+
|
|
352
|
+
@abc.abstractmethod
|
|
353
|
+
async def update_platform_session(
|
|
354
|
+
self,
|
|
355
|
+
session_id: str,
|
|
356
|
+
display_name: str | None = None,
|
|
357
|
+
) -> None:
|
|
358
|
+
"""Update a Platform session's updated_at timestamp and optionally display_name."""
|
|
359
|
+
...
|
|
360
|
+
|
|
361
|
+
@abc.abstractmethod
|
|
362
|
+
async def delete_platform_session(self, session_id: str) -> None:
|
|
363
|
+
"""Delete a Platform session by its ID."""
|
|
364
|
+
...
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
"""Migration script for WebChat sessions.
|
|
2
|
+
|
|
3
|
+
This migration creates PlatformSession from existing platform_message_history records.
|
|
4
|
+
|
|
5
|
+
Changes:
|
|
6
|
+
- Creates platform_sessions table
|
|
7
|
+
- Adds platform_id field (default: 'webchat')
|
|
8
|
+
- Adds display_name field
|
|
9
|
+
- Session_id format: {platform_id}_{uuid}
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from sqlalchemy import func, select
|
|
13
|
+
from sqlmodel import col
|
|
14
|
+
|
|
15
|
+
from astrbot.api import logger, sp
|
|
16
|
+
from astrbot.core.db import BaseDatabase
|
|
17
|
+
from astrbot.core.db.po import ConversationV2, PlatformMessageHistory, PlatformSession
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
async def migrate_webchat_session(db_helper: BaseDatabase):
|
|
21
|
+
"""Create PlatformSession records from platform_message_history.
|
|
22
|
+
|
|
23
|
+
This migration extracts all unique user_ids from platform_message_history
|
|
24
|
+
where platform_id='webchat' and creates corresponding PlatformSession records.
|
|
25
|
+
"""
|
|
26
|
+
# 检查是否已经完成迁移
|
|
27
|
+
migration_done = await db_helper.get_preference(
|
|
28
|
+
"global", "global", "migration_done_webchat_session"
|
|
29
|
+
)
|
|
30
|
+
if migration_done:
|
|
31
|
+
return
|
|
32
|
+
|
|
33
|
+
logger.info("开始执行数据库迁移(WebChat 会话迁移)...")
|
|
34
|
+
|
|
35
|
+
try:
|
|
36
|
+
async with db_helper.get_db() as session:
|
|
37
|
+
# 从 platform_message_history 创建 PlatformSession
|
|
38
|
+
query = (
|
|
39
|
+
select(
|
|
40
|
+
col(PlatformMessageHistory.user_id),
|
|
41
|
+
col(PlatformMessageHistory.sender_name),
|
|
42
|
+
func.min(PlatformMessageHistory.created_at).label("earliest"),
|
|
43
|
+
func.max(PlatformMessageHistory.updated_at).label("latest"),
|
|
44
|
+
)
|
|
45
|
+
.where(col(PlatformMessageHistory.platform_id) == "webchat")
|
|
46
|
+
.where(col(PlatformMessageHistory.sender_id) == "astrbot")
|
|
47
|
+
.group_by(col(PlatformMessageHistory.user_id))
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
result = await session.execute(query)
|
|
51
|
+
webchat_users = result.all()
|
|
52
|
+
|
|
53
|
+
if not webchat_users:
|
|
54
|
+
logger.info("没有找到需要迁移的 WebChat 数据")
|
|
55
|
+
await sp.put_async(
|
|
56
|
+
"global", "global", "migration_done_webchat_session", True
|
|
57
|
+
)
|
|
58
|
+
return
|
|
59
|
+
|
|
60
|
+
logger.info(f"找到 {len(webchat_users)} 个 WebChat 会话需要迁移")
|
|
61
|
+
|
|
62
|
+
# 检查已存在的会话
|
|
63
|
+
existing_query = select(col(PlatformSession.session_id))
|
|
64
|
+
existing_result = await session.execute(existing_query)
|
|
65
|
+
existing_session_ids = {row[0] for row in existing_result.fetchall()}
|
|
66
|
+
|
|
67
|
+
# 查询 Conversations 表中的 title,用于设置 display_name
|
|
68
|
+
# 对于每个 user_id,对应的 conversation user_id 格式为: webchat:FriendMessage:webchat!astrbot!{user_id}
|
|
69
|
+
user_ids_to_query = [
|
|
70
|
+
f"webchat:FriendMessage:webchat!astrbot!{user_id}"
|
|
71
|
+
for user_id, _, _, _ in webchat_users
|
|
72
|
+
]
|
|
73
|
+
conv_query = select(
|
|
74
|
+
col(ConversationV2.user_id), col(ConversationV2.title)
|
|
75
|
+
).where(col(ConversationV2.user_id).in_(user_ids_to_query))
|
|
76
|
+
conv_result = await session.execute(conv_query)
|
|
77
|
+
# 创建 user_id -> title 的映射字典
|
|
78
|
+
title_map = {
|
|
79
|
+
user_id.replace("webchat:FriendMessage:webchat!astrbot!", ""): title
|
|
80
|
+
for user_id, title in conv_result.fetchall()
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
# 批量创建 PlatformSession 记录
|
|
84
|
+
sessions_to_add = []
|
|
85
|
+
skipped_count = 0
|
|
86
|
+
|
|
87
|
+
for user_id, sender_name, created_at, updated_at in webchat_users:
|
|
88
|
+
# user_id 就是 webchat_conv_id (session_id)
|
|
89
|
+
session_id = user_id
|
|
90
|
+
|
|
91
|
+
# sender_name 通常是 username,但可能为 None
|
|
92
|
+
creator = sender_name if sender_name else "guest"
|
|
93
|
+
|
|
94
|
+
# 检查是否已经存在该会话
|
|
95
|
+
if session_id in existing_session_ids:
|
|
96
|
+
logger.debug(f"会话 {session_id} 已存在,跳过")
|
|
97
|
+
skipped_count += 1
|
|
98
|
+
continue
|
|
99
|
+
|
|
100
|
+
# 从 Conversations 表中获取 display_name
|
|
101
|
+
display_name = title_map.get(user_id)
|
|
102
|
+
|
|
103
|
+
# 创建新的 PlatformSession(保留原有的时间戳)
|
|
104
|
+
new_session = PlatformSession(
|
|
105
|
+
session_id=session_id,
|
|
106
|
+
platform_id="webchat",
|
|
107
|
+
creator=creator,
|
|
108
|
+
is_group=0,
|
|
109
|
+
created_at=created_at,
|
|
110
|
+
updated_at=updated_at,
|
|
111
|
+
display_name=display_name,
|
|
112
|
+
)
|
|
113
|
+
sessions_to_add.append(new_session)
|
|
114
|
+
|
|
115
|
+
# 批量插入
|
|
116
|
+
if sessions_to_add:
|
|
117
|
+
session.add_all(sessions_to_add)
|
|
118
|
+
await session.commit()
|
|
119
|
+
|
|
120
|
+
logger.info(
|
|
121
|
+
f"WebChat 会话迁移完成!成功迁移: {len(sessions_to_add)}, 跳过: {skipped_count}",
|
|
122
|
+
)
|
|
123
|
+
else:
|
|
124
|
+
logger.info("没有新会话需要迁移")
|
|
125
|
+
|
|
126
|
+
# 标记迁移完成
|
|
127
|
+
await sp.put_async("global", "global", "migration_done_webchat_session", True)
|
|
128
|
+
|
|
129
|
+
except Exception as e:
|
|
130
|
+
logger.error(f"迁移过程中发生错误: {e}", exc_info=True)
|
|
131
|
+
raise
|
astrbot/core/db/po.py
CHANGED
|
@@ -3,13 +3,7 @@ from dataclasses import dataclass, field
|
|
|
3
3
|
from datetime import datetime, timezone
|
|
4
4
|
from typing import TypedDict
|
|
5
5
|
|
|
6
|
-
from sqlmodel import
|
|
7
|
-
JSON,
|
|
8
|
-
Field,
|
|
9
|
-
SQLModel,
|
|
10
|
-
Text,
|
|
11
|
-
UniqueConstraint,
|
|
12
|
-
)
|
|
6
|
+
from sqlmodel import JSON, Field, SQLModel, Text, UniqueConstraint
|
|
13
7
|
|
|
14
8
|
|
|
15
9
|
class PlatformStat(SQLModel, table=True):
|
|
@@ -18,7 +12,7 @@ class PlatformStat(SQLModel, table=True):
|
|
|
18
12
|
Note: In astrbot v4, we moved `platform` table to here.
|
|
19
13
|
"""
|
|
20
14
|
|
|
21
|
-
__tablename__ = "platform_stats"
|
|
15
|
+
__tablename__ = "platform_stats" # type: ignore
|
|
22
16
|
|
|
23
17
|
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
|
|
24
18
|
timestamp: datetime = Field(nullable=False)
|
|
@@ -37,7 +31,7 @@ class PlatformStat(SQLModel, table=True):
|
|
|
37
31
|
|
|
38
32
|
|
|
39
33
|
class ConversationV2(SQLModel, table=True):
|
|
40
|
-
__tablename__ = "conversations"
|
|
34
|
+
__tablename__ = "conversations" # type: ignore
|
|
41
35
|
|
|
42
36
|
inner_conversation_id: int = Field(
|
|
43
37
|
primary_key=True,
|
|
@@ -74,7 +68,7 @@ class Persona(SQLModel, table=True):
|
|
|
74
68
|
It can be used to customize the behavior of LLMs.
|
|
75
69
|
"""
|
|
76
70
|
|
|
77
|
-
__tablename__ = "personas"
|
|
71
|
+
__tablename__ = "personas" # type: ignore
|
|
78
72
|
|
|
79
73
|
id: int | None = Field(
|
|
80
74
|
primary_key=True,
|
|
@@ -104,7 +98,7 @@ class Persona(SQLModel, table=True):
|
|
|
104
98
|
class Preference(SQLModel, table=True):
|
|
105
99
|
"""This class represents preferences for bots."""
|
|
106
100
|
|
|
107
|
-
__tablename__ = "preferences"
|
|
101
|
+
__tablename__ = "preferences" # type: ignore
|
|
108
102
|
|
|
109
103
|
id: int | None = Field(
|
|
110
104
|
default=None,
|
|
@@ -140,7 +134,7 @@ class PlatformMessageHistory(SQLModel, table=True):
|
|
|
140
134
|
or platform-specific messages.
|
|
141
135
|
"""
|
|
142
136
|
|
|
143
|
-
__tablename__ = "platform_message_history"
|
|
137
|
+
__tablename__ = "platform_message_history" # type: ignore
|
|
144
138
|
|
|
145
139
|
id: int | None = Field(
|
|
146
140
|
primary_key=True,
|
|
@@ -161,13 +155,55 @@ class PlatformMessageHistory(SQLModel, table=True):
|
|
|
161
155
|
)
|
|
162
156
|
|
|
163
157
|
|
|
158
|
+
class PlatformSession(SQLModel, table=True):
|
|
159
|
+
"""Platform session table for managing user sessions across different platforms.
|
|
160
|
+
|
|
161
|
+
A session represents a chat window for a specific user on a specific platform.
|
|
162
|
+
Each session can have multiple conversations (对话) associated with it.
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
__tablename__ = "platform_sessions" # type: ignore
|
|
166
|
+
|
|
167
|
+
inner_id: int | None = Field(
|
|
168
|
+
primary_key=True,
|
|
169
|
+
sa_column_kwargs={"autoincrement": True},
|
|
170
|
+
default=None,
|
|
171
|
+
)
|
|
172
|
+
session_id: str = Field(
|
|
173
|
+
max_length=100,
|
|
174
|
+
nullable=False,
|
|
175
|
+
unique=True,
|
|
176
|
+
default_factory=lambda: f"webchat_{uuid.uuid4()}",
|
|
177
|
+
)
|
|
178
|
+
platform_id: str = Field(default="webchat", nullable=False)
|
|
179
|
+
"""Platform identifier (e.g., 'webchat', 'qq', 'discord')"""
|
|
180
|
+
creator: str = Field(nullable=False)
|
|
181
|
+
"""Username of the session creator"""
|
|
182
|
+
display_name: str | None = Field(default=None, max_length=255)
|
|
183
|
+
"""Display name for the session"""
|
|
184
|
+
is_group: int = Field(default=0, nullable=False)
|
|
185
|
+
"""0 for private chat, 1 for group chat (not implemented yet)"""
|
|
186
|
+
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
|
187
|
+
updated_at: datetime = Field(
|
|
188
|
+
default_factory=lambda: datetime.now(timezone.utc),
|
|
189
|
+
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
__table_args__ = (
|
|
193
|
+
UniqueConstraint(
|
|
194
|
+
"session_id",
|
|
195
|
+
name="uix_platform_session_id",
|
|
196
|
+
),
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
|
|
164
200
|
class Attachment(SQLModel, table=True):
|
|
165
201
|
"""This class represents attachments for messages in AstrBot.
|
|
166
202
|
|
|
167
203
|
Attachments can be images, files, or other media types.
|
|
168
204
|
"""
|
|
169
205
|
|
|
170
|
-
__tablename__ = "attachments"
|
|
206
|
+
__tablename__ = "attachments" # type: ignore
|
|
171
207
|
|
|
172
208
|
inner_attachment_id: int | None = Field(
|
|
173
209
|
primary_key=True,
|