AstrBot 4.1.6__py3-none-any.whl → 4.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -22,6 +22,9 @@ class CommandGroupFilter(HandlerFilter):
22
22
  self.custom_filter_list: List[CustomFilter] = []
23
23
  self.parent_group = parent_group
24
24
 
25
+ # Cache for complete command names list
26
+ self._cmpl_cmd_names: list | None = None
27
+
25
28
  def add_sub_command_filter(
26
29
  self, sub_command_filter: Union[CommandFilter, CommandGroupFilter]
27
30
  ):
@@ -34,6 +37,9 @@ class CommandGroupFilter(HandlerFilter):
34
37
  """遍历父节点获取完整的指令名。
35
38
 
36
39
  新版本 v3.4.29 采用预编译指令,不再从指令组递归遍历子指令,因此这个方法是返回包括别名在内的整个指令名列表。"""
40
+ if self._cmpl_cmd_names is not None:
41
+ return self._cmpl_cmd_names
42
+
37
43
  parent_cmd_names = (
38
44
  self.parent_group.get_complete_command_names() if self.parent_group else []
39
45
  )
@@ -47,6 +53,7 @@ class CommandGroupFilter(HandlerFilter):
47
53
  for parent_cmd_name in parent_cmd_names:
48
54
  for candidate in candidates:
49
55
  result.append(parent_cmd_name + " " + candidate)
56
+ self._cmpl_cmd_names = result
50
57
  return result
51
58
 
52
59
  # 以树的形式打印出来
@@ -97,6 +104,12 @@ class CommandGroupFilter(HandlerFilter):
97
104
  return False
98
105
  return True
99
106
 
107
+ def startswith(self, message_str: str) -> bool:
108
+ return message_str.startswith(tuple(self.get_complete_command_names()))
109
+
110
+ def equals(self, message_str: str) -> bool:
111
+ return message_str in self.get_complete_command_names()
112
+
100
113
  def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
101
114
  if not event.is_at_or_wake_command:
102
115
  return False
@@ -105,8 +118,7 @@ class CommandGroupFilter(HandlerFilter):
105
118
  if not self.custom_filter_ok(event, cfg):
106
119
  return False
107
120
 
108
- complete_command_names = self.get_complete_command_names()
109
- if event.message_str.strip() in complete_command_names:
121
+ if self.equals(event.message_str.strip()):
110
122
  tree = (
111
123
  self.group_name
112
124
  + "\n"
@@ -116,6 +128,4 @@ class CommandGroupFilter(HandlerFilter):
116
128
  f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n" + tree
117
129
  )
118
130
 
119
- # complete_command_names = [name + " " for name in complete_command_names]
120
- # return event.message_str.startswith(tuple(complete_command_names))
121
- return False
131
+ return self.startswith(event.message_str)
@@ -52,10 +52,6 @@ class SessionServiceManager:
52
52
  "session_service_config", session_config, scope="umo", scope_id=session_id
53
53
  )
54
54
 
55
- logger.info(
56
- f"会话 {session_id} 的LLM状态已更新为: {'启用' if enabled else '禁用'}"
57
- )
58
-
59
55
  @staticmethod
60
56
  def should_process_llm_request(event: AstrMessageEvent) -> bool:
61
57
  """检查是否应该处理LLM请求
@@ -1,9 +1,33 @@
1
+ import codecs
1
2
  import json
2
3
  from astrbot.core import logger
3
- from aiohttp import ClientSession
4
+ from aiohttp import ClientSession, ClientResponse
4
5
  from typing import Dict, List, Any, AsyncGenerator
5
6
 
6
7
 
8
+ async def _stream_sse(resp: ClientResponse) -> AsyncGenerator[dict, None]:
9
+ decoder = codecs.getincrementaldecoder("utf-8")()
10
+ buffer = ""
11
+ async for chunk in resp.content.iter_chunked(8192):
12
+ buffer += decoder.decode(chunk)
13
+ while "\n\n" in buffer:
14
+ block, buffer = buffer.split("\n\n", 1)
15
+ if block.strip().startswith("data:"):
16
+ try:
17
+ yield json.loads(block[5:])
18
+ except json.JSONDecodeError:
19
+ logger.warning(f"Drop invalid dify json data: {block[5:]}")
20
+ continue
21
+ # flush any remaining text
22
+ buffer += decoder.decode(b"", final=True)
23
+ if buffer.strip().startswith("data:"):
24
+ try:
25
+ yield json.loads(buffer[5:])
26
+ except json.JSONDecodeError:
27
+ logger.warning(f"Drop invalid dify json data: {buffer[5:]}")
28
+ pass
29
+
30
+
7
31
  class DifyAPIClient:
8
32
  def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1"):
9
33
  self.api_key = api_key
@@ -33,31 +57,11 @@ class DifyAPIClient:
33
57
  ) as resp:
34
58
  if resp.status != 200:
35
59
  text = await resp.text()
36
- raise Exception(f"chat_messages 请求失败:{resp.status}. {text}")
37
-
38
- buffer = ""
39
- while True:
40
- # 保持原有的8192字节限制,防止数据过大导致高水位报错
41
- chunk = await resp.content.read(8192)
42
- if not chunk:
43
- break
44
-
45
- buffer += chunk.decode("utf-8")
46
- blocks = buffer.split("\n\n")
47
-
48
- # 处理完整的数据块
49
- for block in blocks[:-1]:
50
- if block.strip() and block.startswith("data:"):
51
- try:
52
- json_str = block[5:] # 移除 "data:" 前缀
53
- json_obj = json.loads(json_str)
54
- yield json_obj
55
- except json.JSONDecodeError as e:
56
- logger.error(f"JSON解析错误: {str(e)}")
57
- logger.error(f"原始数据块: {json_str}")
58
-
59
- # 保留最后一个可能不完整的块
60
- buffer = blocks[-1] if blocks else ""
60
+ raise Exception(
61
+ f"Dify /chat-messages 接口请求失败:{resp.status}. {text}"
62
+ )
63
+ async for event in _stream_sse(resp):
64
+ yield event
61
65
 
62
66
  async def workflow_run(
63
67
  self,
@@ -77,31 +81,11 @@ class DifyAPIClient:
77
81
  ) as resp:
78
82
  if resp.status != 200:
79
83
  text = await resp.text()
80
- raise Exception(f"workflow_run 请求失败:{resp.status}. {text}")
81
-
82
- buffer = ""
83
- while True:
84
- # 保持原有的8192字节限制,防止数据过大导致高水位报错
85
- chunk = await resp.content.read(8192)
86
- if not chunk:
87
- break
88
-
89
- buffer += chunk.decode("utf-8")
90
- blocks = buffer.split("\n\n")
91
-
92
- # 处理完整的数据块
93
- for block in blocks[:-1]:
94
- if block.strip() and block.startswith("data:"):
95
- try:
96
- json_str = block[5:] # 移除 "data:" 前缀
97
- json_obj = json.loads(json_str)
98
- yield json_obj
99
- except json.JSONDecodeError as e:
100
- logger.error(f"JSON解析错误: {str(e)}")
101
- logger.error(f"原始数据块: {json_str}")
102
-
103
- # 保留最后一个可能不完整的块
104
- buffer = blocks[-1] if blocks else ""
84
+ raise Exception(
85
+ f"Dify /workflows/run 接口请求失败:{resp.status}. {text}"
86
+ )
87
+ async for event in _stream_sse(resp):
88
+ yield event
105
89
 
106
90
  async def file_upload(
107
91
  self,
@@ -109,12 +93,15 @@ class DifyAPIClient:
109
93
  user: str,
110
94
  ) -> Dict[str, Any]:
111
95
  url = f"{self.api_base}/files/upload"
112
- payload = {
113
- "user": user,
114
- "file": open(file_path, "rb"),
115
- }
116
- async with self.session.post(url, data=payload, headers=self.headers) as resp:
117
- return await resp.json() # {"id": "xxx", ...}
96
+ with open(file_path, "rb") as f:
97
+ payload = {
98
+ "user": user,
99
+ "file": f,
100
+ }
101
+ async with self.session.post(
102
+ url, data=payload, headers=self.headers
103
+ ) as resp:
104
+ return await resp.json() # {"id": "xxx", ...}
118
105
 
119
106
  async def close(self):
120
107
  await self.session.close()
astrbot/core/utils/io.py CHANGED
@@ -227,9 +227,11 @@ async def download_dashboard(
227
227
  path = os.path.join(get_astrbot_data_path(), "dashboard.zip")
228
228
 
229
229
  if latest or len(str(version)) != 40:
230
- logger.info(f"准备下载 {version} 发行版本的 AstrBot WebUI 文件")
231
230
  ver_name = "latest" if latest else version
232
231
  dashboard_release_url = f"https://astrbot-registry.soulter.top/download/astrbot-dashboard/{ver_name}/dist.zip"
232
+ logger.info(
233
+ f"准备下载指定发行版本的 AstrBot WebUI 文件: {dashboard_release_url}"
234
+ )
233
235
  try:
234
236
  await download_file(dashboard_release_url, path, show_progress=True)
235
237
  except BaseException as _:
@@ -241,24 +243,10 @@ async def download_dashboard(
241
243
  dashboard_release_url = f"{proxy}/{dashboard_release_url}"
242
244
  await download_file(dashboard_release_url, path, show_progress=True)
243
245
  else:
244
- logger.info(f"准备下载指定版本的 AstrBot WebUI: {version}")
245
-
246
- url = (
247
- "https://api.github.com/repos/AstrBotDevs/astrbot-release-harbour/releases"
248
- )
246
+ url = f"https://github.com/AstrBotDevs/astrbot-release-harbour/releases/download/release-{version}/dist.zip"
247
+ logger.info(f"准备下载指定版本的 AstrBot WebUI: {url}")
249
248
  if proxy:
250
249
  url = f"{proxy}/{url}"
251
- async with aiohttp.ClientSession(trust_env=True) as session:
252
- async with session.get(url) as resp:
253
- if resp.status == 200:
254
- releases = await resp.json()
255
- for release in releases:
256
- if version in release["tag_name"]:
257
- download_url = release["assets"][0]["browser_download_url"]
258
- await download_file(download_url, path, show_progress=True)
259
- else:
260
- logger.warning(f"未找到指定的版本的 Dashboard 构建文件: {version}")
261
- return
262
-
250
+ await download_file(url, path, show_progress=True)
263
251
  with zipfile.ZipFile(path, "r") as z:
264
252
  z.extractall(extract_path)
@@ -1,17 +1,27 @@
1
1
  import uuid
2
2
  import json
3
3
  import os
4
+ import asyncio
5
+ from contextlib import asynccontextmanager
4
6
  from .route import Route, Response, RouteContext
5
7
  from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
6
8
  from quart import request, Response as QuartResponse, g, make_response
7
9
  from astrbot.core.db import BaseDatabase
8
- import asyncio
9
10
  from astrbot.core import logger
10
11
  from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
11
12
  from astrbot.core.utils.astrbot_path import get_astrbot_data_path
12
13
  from astrbot.core.platform.astr_message_event import MessageSession
13
14
 
14
15
 
16
+ @asynccontextmanager
17
+ async def track_conversation(convs: dict, conv_id: str):
18
+ convs[conv_id] = True
19
+ try:
20
+ yield
21
+ finally:
22
+ convs.pop(conv_id, None)
23
+
24
+
15
25
  class ChatRoute(Route):
16
26
  def __init__(
17
27
  self,
@@ -40,6 +50,8 @@ class ChatRoute(Route):
40
50
  self.conv_mgr = core_lifecycle.conversation_manager
41
51
  self.platform_history_mgr = core_lifecycle.platform_message_history_manager
42
52
 
53
+ self.running_convs: dict[str, bool] = {}
54
+
43
55
  async def get_file(self):
44
56
  filename = request.args.get("filename")
45
57
  if not filename:
@@ -139,42 +151,63 @@ class ChatRoute(Route):
139
151
  )
140
152
 
141
153
  async def stream():
154
+ client_disconnected = False
155
+
142
156
  try:
143
- while True:
144
- try:
145
- result = await asyncio.wait_for(back_queue.get(), timeout=10)
146
- except asyncio.TimeoutError:
147
- continue
148
-
149
- if not result:
150
- continue
151
-
152
- result_text = result["data"]
153
- type = result.get("type")
154
- streaming = result.get("streaming", False)
155
- yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n"
156
- await asyncio.sleep(0.05)
157
-
158
- if type == "end":
159
- break
160
- elif (
161
- (streaming and type == "complete")
162
- or not streaming
163
- or type == "break"
164
- ):
165
- # append bot message
166
- new_his = {"type": "bot", "message": result_text}
167
- await self.platform_history_mgr.insert(
168
- platform_id="webchat",
169
- user_id=webchat_conv_id,
170
- content=new_his,
171
- sender_id="bot",
172
- sender_name="bot",
173
- )
174
-
175
- except BaseException as _:
176
- logger.debug(f"用户 {username} 断开聊天长连接。")
177
- return
157
+ async with track_conversation(self.running_convs, webchat_conv_id):
158
+ while True:
159
+ try:
160
+ result = await asyncio.wait_for(back_queue.get(), timeout=1)
161
+ except asyncio.TimeoutError:
162
+ continue
163
+ except asyncio.CancelledError:
164
+ logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。")
165
+ client_disconnected = True
166
+ except Exception as e:
167
+ logger.error(f"WebChat stream error: {e}")
168
+
169
+ if not result:
170
+ continue
171
+
172
+ result_text = result["data"]
173
+ type = result.get("type")
174
+ streaming = result.get("streaming", False)
175
+
176
+ try:
177
+ if not client_disconnected:
178
+ yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n"
179
+ except Exception as e:
180
+ if not client_disconnected:
181
+ logger.debug(
182
+ f"[WebChat] 用户 {username} 断开聊天长连接。 {e}"
183
+ )
184
+ client_disconnected = True
185
+
186
+ try:
187
+ if not client_disconnected:
188
+ await asyncio.sleep(0.05)
189
+ except asyncio.CancelledError:
190
+ logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。")
191
+ client_disconnected = True
192
+
193
+ if type == "end":
194
+ break
195
+ elif (
196
+ (streaming and type == "complete")
197
+ or not streaming
198
+ or type == "break"
199
+ ):
200
+ # append bot message
201
+ new_his = {"type": "bot", "message": result_text}
202
+ await self.platform_history_mgr.insert(
203
+ platform_id="webchat",
204
+ user_id=webchat_conv_id,
205
+ content=new_his,
206
+ sender_id="bot",
207
+ sender_name="bot",
208
+ )
209
+ except BaseException as e:
210
+ logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True)
178
211
 
179
212
  # Put message to conversation-specific queue
180
213
  chat_queue = webchat_queue_mgr.get_or_create_queue(webchat_conv_id)
@@ -291,6 +324,7 @@ class ChatRoute(Route):
291
324
  .ok(
292
325
  data={
293
326
  "history": history_res,
327
+ "is_running": self.running_convs.get(webchat_conv_id, False),
294
328
  }
295
329
  )
296
330
  .__dict__
@@ -51,24 +51,6 @@ def validate_config(
51
51
  def validate(data: dict, metadata: dict = schema, path=""):
52
52
  for key, value in data.items():
53
53
  if key not in metadata:
54
- # 无 schema 的配置项,执行类型猜测
55
- if isinstance(value, str):
56
- try:
57
- data[key] = int(value)
58
- continue
59
- except ValueError:
60
- pass
61
-
62
- try:
63
- data[key] = float(value)
64
- continue
65
- except ValueError:
66
- pass
67
-
68
- if value.lower() == "true":
69
- data[key] = True
70
- elif value.lower() == "false":
71
- data[key] = False
72
54
  continue
73
55
  meta = metadata[key]
74
56
  if "type" not in meta:
@@ -127,12 +109,12 @@ def validate_config(
127
109
  )
128
110
 
129
111
  if is_core:
130
- for key, group in schema.items():
131
- group_meta = group.get("metadata")
132
- if not group_meta:
133
- continue
134
- # logger.info(f"验证配置: 组 {key} ...")
135
- validate(data, group_meta, path=f"{key}.")
112
+ meta_all = {
113
+ **schema["platform_group"]["metadata"],
114
+ **schema["provider_group"]["metadata"],
115
+ **schema["misc_config_group"]["metadata"],
116
+ }
117
+ validate(data, meta_all)
136
118
  else:
137
119
  validate(data, schema)
138
120
 
@@ -142,6 +124,7 @@ def validate_config(
142
124
  def save_config(post_config: dict, config: AstrBotConfig, is_core: bool = False):
143
125
  """验证并保存配置"""
144
126
  errors = None
127
+ logger.info(f"Saving config, is_core={is_core}")
145
128
  try:
146
129
  if is_core:
147
130
  errors, post_config = validate_config(
@@ -169,15 +169,65 @@ class ConversationRoute(Route):
169
169
  """删除对话"""
170
170
  try:
171
171
  data = await request.get_json()
172
- user_id = data.get("user_id")
173
- cid = data.get("cid")
174
172
 
175
- if not user_id or not cid:
176
- return Response().error("缺少必要参数: user_id 和 cid").__dict__
177
- await self.core_lifecycle.conversation_manager.delete_conversation(
178
- unified_msg_origin=user_id, conversation_id=cid
179
- )
180
- return Response().ok({"message": "对话删除成功"}).__dict__
173
+ # 检查是否是批量删除
174
+ if "conversations" in data:
175
+ # 批量删除
176
+ conversations = data.get("conversations", [])
177
+ if not conversations:
178
+ return (
179
+ Response().error("批量删除时conversations参数不能为空").__dict__
180
+ )
181
+
182
+ deleted_count = 0
183
+ failed_items = []
184
+
185
+ for conv in conversations:
186
+ user_id = conv.get("user_id")
187
+ cid = conv.get("cid")
188
+
189
+ if not user_id or not cid:
190
+ failed_items.append(
191
+ f"user_id:{user_id}, cid:{cid} - 缺少必要参数"
192
+ )
193
+ continue
194
+
195
+ try:
196
+ await self.core_lifecycle.conversation_manager.delete_conversation(
197
+ unified_msg_origin=user_id, conversation_id=cid
198
+ )
199
+ deleted_count += 1
200
+ except Exception as e:
201
+ failed_items.append(f"user_id:{user_id}, cid:{cid} - {str(e)}")
202
+
203
+ message = f"成功删除 {deleted_count} 个对话"
204
+ if failed_items:
205
+ message += f",失败 {len(failed_items)} 个"
206
+
207
+ return (
208
+ Response()
209
+ .ok(
210
+ {
211
+ "message": message,
212
+ "deleted_count": deleted_count,
213
+ "failed_count": len(failed_items),
214
+ "failed_items": failed_items,
215
+ }
216
+ )
217
+ .__dict__
218
+ )
219
+ else:
220
+ # 单个删除
221
+ user_id = data.get("user_id")
222
+ cid = data.get("cid")
223
+
224
+ if not user_id or not cid:
225
+ return Response().error("缺少必要参数: user_id 和 cid").__dict__
226
+
227
+ await self.core_lifecycle.conversation_manager.delete_conversation(
228
+ unified_msg_origin=user_id, conversation_id=cid
229
+ )
230
+ return Response().ok({"message": "对话删除成功"}).__dict__
181
231
 
182
232
  except Exception as e:
183
233
  logger.error(f"删除对话失败: {str(e)}\n{traceback.format_exc()}")