AstrBot 4.5.7__py3-none-any.whl → 4.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. astrbot/core/agent/mcp_client.py +152 -26
  2. astrbot/core/agent/message.py +8 -1
  3. astrbot/core/config/default.py +8 -1
  4. astrbot/core/core_lifecycle.py +8 -0
  5. astrbot/core/db/__init__.py +50 -1
  6. astrbot/core/db/migration/migra_webchat_session.py +131 -0
  7. astrbot/core/db/po.py +49 -13
  8. astrbot/core/db/sqlite.py +102 -3
  9. astrbot/core/knowledge_base/kb_helper.py +314 -33
  10. astrbot/core/knowledge_base/kb_mgr.py +45 -1
  11. astrbot/core/knowledge_base/parsers/url_parser.py +103 -0
  12. astrbot/core/knowledge_base/prompts.py +65 -0
  13. astrbot/core/pipeline/process_stage/method/llm_request.py +28 -14
  14. astrbot/core/pipeline/process_stage/utils.py +60 -16
  15. astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +13 -10
  16. astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +8 -4
  17. astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +0 -4
  18. astrbot/core/provider/entities.py +22 -9
  19. astrbot/core/provider/func_tool_manager.py +12 -9
  20. astrbot/core/provider/sources/gemini_source.py +25 -8
  21. astrbot/core/provider/sources/openai_source.py +9 -16
  22. astrbot/dashboard/routes/chat.py +134 -77
  23. astrbot/dashboard/routes/knowledge_base.py +172 -0
  24. {astrbot-4.5.7.dist-info → astrbot-4.6.0.dist-info}/METADATA +4 -3
  25. {astrbot-4.5.7.dist-info → astrbot-4.6.0.dist-info}/RECORD +28 -25
  26. {astrbot-4.5.7.dist-info → astrbot-4.6.0.dist-info}/WHEEL +0 -0
  27. {astrbot-4.5.7.dist-info → astrbot-4.6.0.dist-info}/entry_points.txt +0 -0
  28. {astrbot-4.5.7.dist-info → astrbot-4.6.0.dist-info}/licenses/LICENSE +0 -0
@@ -4,6 +4,14 @@ from contextlib import AsyncExitStack
4
4
  from datetime import timedelta
5
5
  from typing import Generic
6
6
 
7
+ from tenacity import (
8
+ before_sleep_log,
9
+ retry,
10
+ retry_if_exception_type,
11
+ stop_after_attempt,
12
+ wait_exponential,
13
+ )
14
+
7
15
  from astrbot import logger
8
16
  from astrbot.core.agent.run_context import ContextWrapper
9
17
  from astrbot.core.utils.log_pipe import LogPipe
@@ -12,21 +20,24 @@ from .run_context import TContext
12
20
  from .tool import FunctionTool
13
21
 
14
22
  try:
23
+ import anyio
15
24
  import mcp
16
25
  from mcp.client.sse import sse_client
17
26
  except (ModuleNotFoundError, ImportError):
18
- logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
27
+ logger.warning(
28
+ "Warning: Missing 'mcp' dependency, MCP services will be unavailable."
29
+ )
19
30
 
20
31
  try:
21
32
  from mcp.client.streamable_http import streamablehttp_client
22
33
  except (ModuleNotFoundError, ImportError):
23
34
  logger.warning(
24
- "警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。",
35
+ "Warning: Missing 'mcp' dependency or MCP library version too old, Streamable HTTP connection unavailable.",
25
36
  )
26
37
 
27
38
 
28
39
  def _prepare_config(config: dict) -> dict:
29
- """准备配置,处理嵌套格式"""
40
+ """Prepare configuration, handle nested format"""
30
41
  if config.get("mcpServers"):
31
42
  first_key = next(iter(config["mcpServers"]))
32
43
  config = config["mcpServers"][first_key]
@@ -35,7 +46,7 @@ def _prepare_config(config: dict) -> dict:
35
46
 
36
47
 
37
48
  async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
38
- """快速测试 MCP 服务器可达性"""
49
+ """Quick test MCP server connectivity"""
39
50
  import aiohttp
40
51
 
41
52
  cfg = _prepare_config(config.copy())
@@ -50,7 +61,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
50
61
  elif "type" in cfg:
51
62
  transport_type = cfg["type"]
52
63
  else:
53
- raise Exception("MCP 连接配置缺少 transport type 字段")
64
+ raise Exception("MCP connection config missing transport or type field")
54
65
 
55
66
  async with aiohttp.ClientSession() as session:
56
67
  if transport_type == "streamable_http":
@@ -91,7 +102,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
91
102
  return False, f"HTTP {response.status}: {response.reason}"
92
103
 
93
104
  except asyncio.TimeoutError:
94
- return False, f"连接超时: {timeout}"
105
+ return False, f"Connection timeout: {timeout} seconds"
95
106
  except Exception as e:
96
107
  return False, f"{e!s}"
97
108
 
@@ -101,6 +112,7 @@ class MCPClient:
101
112
  # Initialize session and client objects
102
113
  self.session: mcp.ClientSession | None = None
103
114
  self.exit_stack = AsyncExitStack()
115
+ self._old_exit_stacks: list[AsyncExitStack] = [] # Track old stacks for cleanup
104
116
 
105
117
  self.name: str | None = None
106
118
  self.active: bool = True
@@ -108,22 +120,32 @@ class MCPClient:
108
120
  self.server_errlogs: list[str] = []
109
121
  self.running_event = asyncio.Event()
110
122
 
123
+ # Store connection config for reconnection
124
+ self._mcp_server_config: dict | None = None
125
+ self._server_name: str | None = None
126
+ self._reconnect_lock = asyncio.Lock() # Lock for thread-safe reconnection
127
+ self._reconnecting: bool = False # For logging and debugging
128
+
111
129
  async def connect_to_server(self, mcp_server_config: dict, name: str):
112
- """连接到 MCP 服务器
130
+ """Connect to MCP server
113
131
 
114
- 如果 `url` 参数存在:
115
- 1. transport 指定为 `streamable_http` 时,使用 Streamable HTTP 连接方式。
116
- 1. transport 指定为 `sse` 时,使用 SSE 连接方式。
117
- 2. 如果没有指定,默认使用 SSE 的方式连接到 MCP 服务。
132
+ If `url` parameter exists:
133
+ 1. When transport is specified as `streamable_http`, use Streamable HTTP connection.
134
+ 2. When transport is specified as `sse`, use SSE connection.
135
+ 3. If not specified, default to SSE connection to MCP service.
118
136
 
119
137
  Args:
120
138
  mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
121
139
 
122
140
  """
141
+ # Store config for reconnection
142
+ self._mcp_server_config = mcp_server_config
143
+ self._server_name = name
144
+
123
145
  cfg = _prepare_config(mcp_server_config.copy())
124
146
 
125
147
  def logging_callback(msg: str):
126
- # 处理 MCP 服务的错误日志
148
+ # Handle MCP service error logs
127
149
  print(f"MCP Server {name} Error: {msg}")
128
150
  self.server_errlogs.append(msg)
129
151
 
@@ -137,7 +159,7 @@ class MCPClient:
137
159
  elif "type" in cfg:
138
160
  transport_type = cfg["type"]
139
161
  else:
140
- raise Exception("MCP 连接配置缺少 transport type 字段")
162
+ raise Exception("MCP connection config missing transport or type field")
141
163
 
142
164
  if transport_type != "streamable_http":
143
165
  # SSE transport method
@@ -193,7 +215,7 @@ class MCPClient:
193
215
  )
194
216
 
195
217
  def callback(msg: str):
196
- # 处理 MCP 服务的错误日志
218
+ # Handle MCP service error logs
197
219
  self.server_errlogs.append(msg)
198
220
 
199
221
  stdio_transport = await self.exit_stack.enter_async_context(
@@ -222,10 +244,120 @@ class MCPClient:
222
244
  self.tools = response.tools
223
245
  return response
224
246
 
247
+ async def _reconnect(self) -> None:
248
+ """Reconnect to the MCP server using the stored configuration.
249
+
250
+ Uses asyncio.Lock to ensure thread-safe reconnection in concurrent environments.
251
+
252
+ Raises:
253
+ Exception: raised when reconnection fails
254
+ """
255
+ async with self._reconnect_lock:
256
+ # Check if already reconnecting (useful for logging)
257
+ if self._reconnecting:
258
+ logger.debug(
259
+ f"MCP Client {self._server_name} is already reconnecting, skipping"
260
+ )
261
+ return
262
+
263
+ if not self._mcp_server_config or not self._server_name:
264
+ raise Exception("Cannot reconnect: missing connection configuration")
265
+
266
+ self._reconnecting = True
267
+ try:
268
+ logger.info(
269
+ f"Attempting to reconnect to MCP server {self._server_name}..."
270
+ )
271
+
272
+ # Save old exit_stack for later cleanup (don't close it now to avoid cancel scope issues)
273
+ if self.exit_stack:
274
+ self._old_exit_stacks.append(self.exit_stack)
275
+
276
+ # Mark old session as invalid
277
+ self.session = None
278
+
279
+ # Create new exit stack for new connection
280
+ self.exit_stack = AsyncExitStack()
281
+
282
+ # Reconnect using stored config
283
+ await self.connect_to_server(self._mcp_server_config, self._server_name)
284
+ await self.list_tools_and_save()
285
+
286
+ logger.info(
287
+ f"Successfully reconnected to MCP server {self._server_name}"
288
+ )
289
+ except Exception as e:
290
+ logger.error(
291
+ f"Failed to reconnect to MCP server {self._server_name}: {e}"
292
+ )
293
+ raise
294
+ finally:
295
+ self._reconnecting = False
296
+
297
+ async def call_tool_with_reconnect(
298
+ self,
299
+ tool_name: str,
300
+ arguments: dict,
301
+ read_timeout_seconds: timedelta,
302
+ ) -> mcp.types.CallToolResult:
303
+ """Call MCP tool with automatic reconnection on failure, max 2 retries.
304
+
305
+ Args:
306
+ tool_name: tool name
307
+ arguments: tool arguments
308
+ read_timeout_seconds: read timeout
309
+
310
+ Returns:
311
+ MCP tool call result
312
+
313
+ Raises:
314
+ ValueError: MCP session is not available
315
+ anyio.ClosedResourceError: raised after reconnection failure
316
+ """
317
+
318
+ @retry(
319
+ retry=retry_if_exception_type(anyio.ClosedResourceError),
320
+ stop=stop_after_attempt(2),
321
+ wait=wait_exponential(multiplier=1, min=1, max=3),
322
+ before_sleep=before_sleep_log(logger, logging.WARNING),
323
+ reraise=True,
324
+ )
325
+ async def _call_with_retry():
326
+ if not self.session:
327
+ raise ValueError("MCP session is not available for MCP function tools.")
328
+
329
+ try:
330
+ return await self.session.call_tool(
331
+ name=tool_name,
332
+ arguments=arguments,
333
+ read_timeout_seconds=read_timeout_seconds,
334
+ )
335
+ except anyio.ClosedResourceError:
336
+ logger.warning(
337
+ f"MCP tool {tool_name} call failed (ClosedResourceError), attempting to reconnect..."
338
+ )
339
+ # Attempt to reconnect
340
+ await self._reconnect()
341
+ # Reraise the exception to trigger tenacity retry
342
+ raise
343
+
344
+ return await _call_with_retry()
345
+
225
346
  async def cleanup(self):
226
- """Clean up resources"""
227
- await self.exit_stack.aclose()
228
- self.running_event.set() # Set the running event to indicate cleanup is done
347
+ """Clean up resources including old exit stacks from reconnections"""
348
+ # Set running_event first to unblock any waiting tasks
349
+ self.running_event.set()
350
+
351
+ # Close current exit stack
352
+ try:
353
+ await self.exit_stack.aclose()
354
+ except Exception as e:
355
+ logger.debug(f"Error closing current exit stack: {e}")
356
+
357
+ # Don't close old exit stacks as they may be in different task contexts
358
+ # They will be garbage collected naturally
359
+ # Just clear the list to release references
360
+ self._old_exit_stacks.clear()
229
361
 
230
362
 
231
363
  class MCPTool(FunctionTool, Generic[TContext]):
@@ -246,14 +378,8 @@ class MCPTool(FunctionTool, Generic[TContext]):
246
378
  async def call(
247
379
  self, context: ContextWrapper[TContext], **kwargs
248
380
  ) -> mcp.types.CallToolResult:
249
- session = self.mcp_client.session
250
- if not session:
251
- raise ValueError("MCP session is not available for MCP function tools.")
252
- res = await session.call_tool(
253
- name=self.mcp_tool.name,
381
+ return await self.mcp_client.call_tool_with_reconnect(
382
+ tool_name=self.mcp_tool.name,
254
383
  arguments=kwargs,
255
- read_timeout_seconds=timedelta(
256
- seconds=context.tool_call_timeout,
257
- ),
384
+ read_timeout_seconds=timedelta(seconds=context.tool_call_timeout),
258
385
  )
259
- return res
@@ -76,7 +76,7 @@ class ImageURLPart(ContentPart):
76
76
  """The ID of the image, to allow LLMs to distinguish different images."""
77
77
 
78
78
  type: str = "image_url"
79
- image_url: str
79
+ image_url: ImageURL
80
80
 
81
81
 
82
82
  class AudioURLPart(ContentPart):
@@ -119,6 +119,13 @@ class ToolCall(BaseModel):
119
119
  """The ID of the tool call."""
120
120
  function: FunctionBody
121
121
  """The function body of the tool call."""
122
+ extra_content: dict[str, Any] | None = None
123
+ """Extra metadata for the tool call."""
124
+
125
+ def model_dump(self, **kwargs: Any) -> dict[str, Any]:
126
+ if self.extra_content is None:
127
+ kwargs.setdefault("exclude", set()).add("extra_content")
128
+ return super().model_dump(**kwargs)
122
129
 
123
130
 
124
131
  class ToolCallPart(BaseModel):
@@ -4,7 +4,7 @@ import os
4
4
 
5
5
  from astrbot.core.utils.astrbot_path import get_astrbot_data_path
6
6
 
7
- VERSION = "4.5.7"
7
+ VERSION = "4.6.0"
8
8
  DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
9
9
 
10
10
  # 默认配置
@@ -137,6 +137,7 @@ DEFAULT_CONFIG = {
137
137
  "kb_names": [], # 默认知识库名称列表
138
138
  "kb_fusion_top_k": 20, # 知识库检索融合阶段返回结果数量
139
139
  "kb_final_top_k": 5, # 知识库检索最终返回结果数量
140
+ "kb_agentic_mode": False,
140
141
  }
141
142
 
142
143
 
@@ -2146,6 +2147,7 @@ CONFIG_METADATA_2 = {
2146
2147
  "kb_names": {"type": "list", "items": {"type": "string"}},
2147
2148
  "kb_fusion_top_k": {"type": "int", "default": 20},
2148
2149
  "kb_final_top_k": {"type": "int", "default": 5},
2150
+ "kb_agentic_mode": {"type": "bool"},
2149
2151
  },
2150
2152
  },
2151
2153
  }
@@ -2241,6 +2243,11 @@ CONFIG_METADATA_3 = {
2241
2243
  "type": "int",
2242
2244
  "hint": "从知识库中检索到的结果数量,越大可能获得越多相关信息,但也可能引入噪音。建议根据实际需求调整",
2243
2245
  },
2246
+ "kb_agentic_mode": {
2247
+ "description": "Agentic 知识库检索",
2248
+ "type": "bool",
2249
+ "hint": "启用后,知识库检索将作为 LLM Tool,由模型自主决定何时调用知识库进行查询。需要模型支持函数调用能力。",
2250
+ },
2244
2251
  },
2245
2252
  },
2246
2253
  "websearch": {
@@ -22,6 +22,7 @@ from astrbot.core.config.default import VERSION
22
22
  from astrbot.core.conversation_mgr import ConversationManager
23
23
  from astrbot.core.db import BaseDatabase
24
24
  from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46
25
+ from astrbot.core.db.migration.migra_webchat_session import migrate_webchat_session
25
26
  from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
26
27
  from astrbot.core.persona_mgr import PersonaManager
27
28
  from astrbot.core.pipeline.scheduler import PipelineContext, PipelineScheduler
@@ -103,6 +104,13 @@ class AstrBotCoreLifecycle:
103
104
  logger.error(f"Migration from version 4.5 to 4.6 failed: {e!s}")
104
105
  logger.error(traceback.format_exc())
105
106
 
107
+ # migration for webchat session
108
+ try:
109
+ await migrate_webchat_session(self.db)
110
+ except Exception as e:
111
+ logger.error(f"Migration for webchat session failed: {e!s}")
112
+ logger.error(traceback.format_exc())
113
+
106
114
  # 初始化事件队列
107
115
  self.event_queue = Queue()
108
116
 
@@ -13,6 +13,7 @@ from astrbot.core.db.po import (
13
13
  ConversationV2,
14
14
  Persona,
15
15
  PlatformMessageHistory,
16
+ PlatformSession,
16
17
  PlatformStat,
17
18
  Preference,
18
19
  Stats,
@@ -183,7 +184,7 @@ class BaseDatabase(abc.ABC):
183
184
  user_id: str,
184
185
  offset_sec: int = 86400,
185
186
  ) -> None:
186
- """Delete platform message history records older than the specified offset."""
187
+ """Delete platform message history records newer than the specified offset."""
187
188
  ...
188
189
 
189
190
  @abc.abstractmethod
@@ -313,3 +314,51 @@ class BaseDatabase(abc.ABC):
313
314
  ) -> tuple[list[dict], int]:
314
315
  """Get paginated session conversations with joined conversation and persona details, support search and platform filter."""
315
316
  ...
317
+
318
+ # ====
319
+ # Platform Session Management
320
+ # ====
321
+
322
+ @abc.abstractmethod
323
+ async def create_platform_session(
324
+ self,
325
+ creator: str,
326
+ platform_id: str = "webchat",
327
+ session_id: str | None = None,
328
+ display_name: str | None = None,
329
+ is_group: int = 0,
330
+ ) -> PlatformSession:
331
+ """Create a new Platform session."""
332
+ ...
333
+
334
+ @abc.abstractmethod
335
+ async def get_platform_session_by_id(
336
+ self, session_id: str
337
+ ) -> PlatformSession | None:
338
+ """Get a Platform session by its ID."""
339
+ ...
340
+
341
+ @abc.abstractmethod
342
+ async def get_platform_sessions_by_creator(
343
+ self,
344
+ creator: str,
345
+ platform_id: str | None = None,
346
+ page: int = 1,
347
+ page_size: int = 20,
348
+ ) -> list[PlatformSession]:
349
+ """Get all Platform sessions for a specific creator (username) and optionally platform."""
350
+ ...
351
+
352
+ @abc.abstractmethod
353
+ async def update_platform_session(
354
+ self,
355
+ session_id: str,
356
+ display_name: str | None = None,
357
+ ) -> None:
358
+ """Update a Platform session's updated_at timestamp and optionally display_name."""
359
+ ...
360
+
361
+ @abc.abstractmethod
362
+ async def delete_platform_session(self, session_id: str) -> None:
363
+ """Delete a Platform session by its ID."""
364
+ ...
@@ -0,0 +1,131 @@
1
+ """Migration script for WebChat sessions.
2
+
3
+ This migration creates PlatformSession from existing platform_message_history records.
4
+
5
+ Changes:
6
+ - Creates platform_sessions table
7
+ - Adds platform_id field (default: 'webchat')
8
+ - Adds display_name field
9
+ - Session_id format: {platform_id}_{uuid}
10
+ """
11
+
12
+ from sqlalchemy import func, select
13
+ from sqlmodel import col
14
+
15
+ from astrbot.api import logger, sp
16
+ from astrbot.core.db import BaseDatabase
17
+ from astrbot.core.db.po import ConversationV2, PlatformMessageHistory, PlatformSession
18
+
19
+
20
+ async def migrate_webchat_session(db_helper: BaseDatabase):
21
+ """Create PlatformSession records from platform_message_history.
22
+
23
+ This migration extracts all unique user_ids from platform_message_history
24
+ where platform_id='webchat' and creates corresponding PlatformSession records.
25
+ """
26
+ # 检查是否已经完成迁移
27
+ migration_done = await db_helper.get_preference(
28
+ "global", "global", "migration_done_webchat_session"
29
+ )
30
+ if migration_done:
31
+ return
32
+
33
+ logger.info("开始执行数据库迁移(WebChat 会话迁移)...")
34
+
35
+ try:
36
+ async with db_helper.get_db() as session:
37
+ # 从 platform_message_history 创建 PlatformSession
38
+ query = (
39
+ select(
40
+ col(PlatformMessageHistory.user_id),
41
+ col(PlatformMessageHistory.sender_name),
42
+ func.min(PlatformMessageHistory.created_at).label("earliest"),
43
+ func.max(PlatformMessageHistory.updated_at).label("latest"),
44
+ )
45
+ .where(col(PlatformMessageHistory.platform_id) == "webchat")
46
+ .where(col(PlatformMessageHistory.sender_id) == "astrbot")
47
+ .group_by(col(PlatformMessageHistory.user_id))
48
+ )
49
+
50
+ result = await session.execute(query)
51
+ webchat_users = result.all()
52
+
53
+ if not webchat_users:
54
+ logger.info("没有找到需要迁移的 WebChat 数据")
55
+ await sp.put_async(
56
+ "global", "global", "migration_done_webchat_session", True
57
+ )
58
+ return
59
+
60
+ logger.info(f"找到 {len(webchat_users)} 个 WebChat 会话需要迁移")
61
+
62
+ # 检查已存在的会话
63
+ existing_query = select(col(PlatformSession.session_id))
64
+ existing_result = await session.execute(existing_query)
65
+ existing_session_ids = {row[0] for row in existing_result.fetchall()}
66
+
67
+ # 查询 Conversations 表中的 title,用于设置 display_name
68
+ # 对于每个 user_id,对应的 conversation user_id 格式为: webchat:FriendMessage:webchat!astrbot!{user_id}
69
+ user_ids_to_query = [
70
+ f"webchat:FriendMessage:webchat!astrbot!{user_id}"
71
+ for user_id, _, _, _ in webchat_users
72
+ ]
73
+ conv_query = select(
74
+ col(ConversationV2.user_id), col(ConversationV2.title)
75
+ ).where(col(ConversationV2.user_id).in_(user_ids_to_query))
76
+ conv_result = await session.execute(conv_query)
77
+ # 创建 user_id -> title 的映射字典
78
+ title_map = {
79
+ user_id.replace("webchat:FriendMessage:webchat!astrbot!", ""): title
80
+ for user_id, title in conv_result.fetchall()
81
+ }
82
+
83
+ # 批量创建 PlatformSession 记录
84
+ sessions_to_add = []
85
+ skipped_count = 0
86
+
87
+ for user_id, sender_name, created_at, updated_at in webchat_users:
88
+ # user_id 就是 webchat_conv_id (session_id)
89
+ session_id = user_id
90
+
91
+ # sender_name 通常是 username,但可能为 None
92
+ creator = sender_name if sender_name else "guest"
93
+
94
+ # 检查是否已经存在该会话
95
+ if session_id in existing_session_ids:
96
+ logger.debug(f"会话 {session_id} 已存在,跳过")
97
+ skipped_count += 1
98
+ continue
99
+
100
+ # 从 Conversations 表中获取 display_name
101
+ display_name = title_map.get(user_id)
102
+
103
+ # 创建新的 PlatformSession(保留原有的时间戳)
104
+ new_session = PlatformSession(
105
+ session_id=session_id,
106
+ platform_id="webchat",
107
+ creator=creator,
108
+ is_group=0,
109
+ created_at=created_at,
110
+ updated_at=updated_at,
111
+ display_name=display_name,
112
+ )
113
+ sessions_to_add.append(new_session)
114
+
115
+ # 批量插入
116
+ if sessions_to_add:
117
+ session.add_all(sessions_to_add)
118
+ await session.commit()
119
+
120
+ logger.info(
121
+ f"WebChat 会话迁移完成!成功迁移: {len(sessions_to_add)}, 跳过: {skipped_count}",
122
+ )
123
+ else:
124
+ logger.info("没有新会话需要迁移")
125
+
126
+ # 标记迁移完成
127
+ await sp.put_async("global", "global", "migration_done_webchat_session", True)
128
+
129
+ except Exception as e:
130
+ logger.error(f"迁移过程中发生错误: {e}", exc_info=True)
131
+ raise
astrbot/core/db/po.py CHANGED
@@ -3,13 +3,7 @@ from dataclasses import dataclass, field
3
3
  from datetime import datetime, timezone
4
4
  from typing import TypedDict
5
5
 
6
- from sqlmodel import (
7
- JSON,
8
- Field,
9
- SQLModel,
10
- Text,
11
- UniqueConstraint,
12
- )
6
+ from sqlmodel import JSON, Field, SQLModel, Text, UniqueConstraint
13
7
 
14
8
 
15
9
  class PlatformStat(SQLModel, table=True):
@@ -18,7 +12,7 @@ class PlatformStat(SQLModel, table=True):
18
12
  Note: In astrbot v4, we moved `platform` table to here.
19
13
  """
20
14
 
21
- __tablename__ = "platform_stats"
15
+ __tablename__ = "platform_stats" # type: ignore
22
16
 
23
17
  id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
24
18
  timestamp: datetime = Field(nullable=False)
@@ -37,7 +31,7 @@ class PlatformStat(SQLModel, table=True):
37
31
 
38
32
 
39
33
  class ConversationV2(SQLModel, table=True):
40
- __tablename__ = "conversations"
34
+ __tablename__ = "conversations" # type: ignore
41
35
 
42
36
  inner_conversation_id: int = Field(
43
37
  primary_key=True,
@@ -74,7 +68,7 @@ class Persona(SQLModel, table=True):
74
68
  It can be used to customize the behavior of LLMs.
75
69
  """
76
70
 
77
- __tablename__ = "personas"
71
+ __tablename__ = "personas" # type: ignore
78
72
 
79
73
  id: int | None = Field(
80
74
  primary_key=True,
@@ -104,7 +98,7 @@ class Persona(SQLModel, table=True):
104
98
  class Preference(SQLModel, table=True):
105
99
  """This class represents preferences for bots."""
106
100
 
107
- __tablename__ = "preferences"
101
+ __tablename__ = "preferences" # type: ignore
108
102
 
109
103
  id: int | None = Field(
110
104
  default=None,
@@ -140,7 +134,7 @@ class PlatformMessageHistory(SQLModel, table=True):
140
134
  or platform-specific messages.
141
135
  """
142
136
 
143
- __tablename__ = "platform_message_history"
137
+ __tablename__ = "platform_message_history" # type: ignore
144
138
 
145
139
  id: int | None = Field(
146
140
  primary_key=True,
@@ -161,13 +155,55 @@ class PlatformMessageHistory(SQLModel, table=True):
161
155
  )
162
156
 
163
157
 
158
+ class PlatformSession(SQLModel, table=True):
159
+ """Platform session table for managing user sessions across different platforms.
160
+
161
+ A session represents a chat window for a specific user on a specific platform.
162
+ Each session can have multiple conversations (对话) associated with it.
163
+ """
164
+
165
+ __tablename__ = "platform_sessions" # type: ignore
166
+
167
+ inner_id: int | None = Field(
168
+ primary_key=True,
169
+ sa_column_kwargs={"autoincrement": True},
170
+ default=None,
171
+ )
172
+ session_id: str = Field(
173
+ max_length=100,
174
+ nullable=False,
175
+ unique=True,
176
+ default_factory=lambda: f"webchat_{uuid.uuid4()}",
177
+ )
178
+ platform_id: str = Field(default="webchat", nullable=False)
179
+ """Platform identifier (e.g., 'webchat', 'qq', 'discord')"""
180
+ creator: str = Field(nullable=False)
181
+ """Username of the session creator"""
182
+ display_name: str | None = Field(default=None, max_length=255)
183
+ """Display name for the session"""
184
+ is_group: int = Field(default=0, nullable=False)
185
+ """0 for private chat, 1 for group chat (not implemented yet)"""
186
+ created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
187
+ updated_at: datetime = Field(
188
+ default_factory=lambda: datetime.now(timezone.utc),
189
+ sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
190
+ )
191
+
192
+ __table_args__ = (
193
+ UniqueConstraint(
194
+ "session_id",
195
+ name="uix_platform_session_id",
196
+ ),
197
+ )
198
+
199
+
164
200
  class Attachment(SQLModel, table=True):
165
201
  """This class represents attachments for messages in AstrBot.
166
202
 
167
203
  Attachments can be images, files, or other media types.
168
204
  """
169
205
 
170
- __tablename__ = "attachments"
206
+ __tablename__ = "attachments" # type: ignore
171
207
 
172
208
  inner_attachment_id: int | None = Field(
173
209
  primary_key=True,