AstrBot 4.1.7__py3-none-any.whl → 4.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -22,6 +22,9 @@ class CommandGroupFilter(HandlerFilter):
22
22
  self.custom_filter_list: List[CustomFilter] = []
23
23
  self.parent_group = parent_group
24
24
 
25
+ # Cache for complete command names list
26
+ self._cmpl_cmd_names: list | None = None
27
+
25
28
  def add_sub_command_filter(
26
29
  self, sub_command_filter: Union[CommandFilter, CommandGroupFilter]
27
30
  ):
@@ -34,6 +37,9 @@ class CommandGroupFilter(HandlerFilter):
34
37
  """遍历父节点获取完整的指令名。
35
38
 
36
39
  新版本 v3.4.29 采用预编译指令,不再从指令组递归遍历子指令,因此这个方法是返回包括别名在内的整个指令名列表。"""
40
+ if self._cmpl_cmd_names is not None:
41
+ return self._cmpl_cmd_names
42
+
37
43
  parent_cmd_names = (
38
44
  self.parent_group.get_complete_command_names() if self.parent_group else []
39
45
  )
@@ -47,6 +53,7 @@ class CommandGroupFilter(HandlerFilter):
47
53
  for parent_cmd_name in parent_cmd_names:
48
54
  for candidate in candidates:
49
55
  result.append(parent_cmd_name + " " + candidate)
56
+ self._cmpl_cmd_names = result
50
57
  return result
51
58
 
52
59
  # 以树的形式打印出来
@@ -97,6 +104,12 @@ class CommandGroupFilter(HandlerFilter):
97
104
  return False
98
105
  return True
99
106
 
107
+ def startswith(self, message_str: str) -> bool:
108
+ return message_str.startswith(tuple(self.get_complete_command_names()))
109
+
110
+ def equals(self, message_str: str) -> bool:
111
+ return message_str in self.get_complete_command_names()
112
+
100
113
  def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
101
114
  if not event.is_at_or_wake_command:
102
115
  return False
@@ -105,8 +118,7 @@ class CommandGroupFilter(HandlerFilter):
105
118
  if not self.custom_filter_ok(event, cfg):
106
119
  return False
107
120
 
108
- complete_command_names = self.get_complete_command_names()
109
- if event.message_str.strip() in complete_command_names:
121
+ if self.equals(event.message_str.strip()):
110
122
  tree = (
111
123
  self.group_name
112
124
  + "\n"
@@ -116,6 +128,4 @@ class CommandGroupFilter(HandlerFilter):
116
128
  f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n" + tree
117
129
  )
118
130
 
119
- # complete_command_names = [name + " " for name in complete_command_names]
120
- # return event.message_str.startswith(tuple(complete_command_names))
121
- return False
131
+ return self.startswith(event.message_str)
@@ -52,10 +52,6 @@ class SessionServiceManager:
52
52
  "session_service_config", session_config, scope="umo", scope_id=session_id
53
53
  )
54
54
 
55
- logger.info(
56
- f"会话 {session_id} 的LLM状态已更新为: {'启用' if enabled else '禁用'}"
57
- )
58
-
59
55
  @staticmethod
60
56
  def should_process_llm_request(event: AstrMessageEvent) -> bool:
61
57
  """检查是否应该处理LLM请求
@@ -1,9 +1,33 @@
1
+ import codecs
1
2
  import json
2
3
  from astrbot.core import logger
3
- from aiohttp import ClientSession
4
+ from aiohttp import ClientSession, ClientResponse
4
5
  from typing import Dict, List, Any, AsyncGenerator
5
6
 
6
7
 
8
+ async def _stream_sse(resp: ClientResponse) -> AsyncGenerator[dict, None]:
9
+ decoder = codecs.getincrementaldecoder("utf-8")()
10
+ buffer = ""
11
+ async for chunk in resp.content.iter_chunked(8192):
12
+ buffer += decoder.decode(chunk)
13
+ while "\n\n" in buffer:
14
+ block, buffer = buffer.split("\n\n", 1)
15
+ if block.strip().startswith("data:"):
16
+ try:
17
+ yield json.loads(block[5:])
18
+ except json.JSONDecodeError:
19
+ logger.warning(f"Drop invalid dify json data: {block[5:]}")
20
+ continue
21
+ # flush any remaining text
22
+ buffer += decoder.decode(b"", final=True)
23
+ if buffer.strip().startswith("data:"):
24
+ try:
25
+ yield json.loads(buffer[5:])
26
+ except json.JSONDecodeError:
27
+ logger.warning(f"Drop invalid dify json data: {buffer[5:]}")
28
+ pass
29
+
30
+
7
31
  class DifyAPIClient:
8
32
  def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1"):
9
33
  self.api_key = api_key
@@ -33,31 +57,11 @@ class DifyAPIClient:
33
57
  ) as resp:
34
58
  if resp.status != 200:
35
59
  text = await resp.text()
36
- raise Exception(f"chat_messages 请求失败:{resp.status}. {text}")
37
-
38
- buffer = ""
39
- while True:
40
- # 保持原有的8192字节限制,防止数据过大导致高水位报错
41
- chunk = await resp.content.read(8192)
42
- if not chunk:
43
- break
44
-
45
- buffer += chunk.decode("utf-8")
46
- blocks = buffer.split("\n\n")
47
-
48
- # 处理完整的数据块
49
- for block in blocks[:-1]:
50
- if block.strip() and block.startswith("data:"):
51
- try:
52
- json_str = block[5:] # 移除 "data:" 前缀
53
- json_obj = json.loads(json_str)
54
- yield json_obj
55
- except json.JSONDecodeError as e:
56
- logger.error(f"JSON解析错误: {str(e)}")
57
- logger.error(f"原始数据块: {json_str}")
58
-
59
- # 保留最后一个可能不完整的块
60
- buffer = blocks[-1] if blocks else ""
60
+ raise Exception(
61
+ f"Dify /chat-messages 接口请求失败:{resp.status}. {text}"
62
+ )
63
+ async for event in _stream_sse(resp):
64
+ yield event
61
65
 
62
66
  async def workflow_run(
63
67
  self,
@@ -77,31 +81,11 @@ class DifyAPIClient:
77
81
  ) as resp:
78
82
  if resp.status != 200:
79
83
  text = await resp.text()
80
- raise Exception(f"workflow_run 请求失败:{resp.status}. {text}")
81
-
82
- buffer = ""
83
- while True:
84
- # 保持原有的8192字节限制,防止数据过大导致高水位报错
85
- chunk = await resp.content.read(8192)
86
- if not chunk:
87
- break
88
-
89
- buffer += chunk.decode("utf-8")
90
- blocks = buffer.split("\n\n")
91
-
92
- # 处理完整的数据块
93
- for block in blocks[:-1]:
94
- if block.strip() and block.startswith("data:"):
95
- try:
96
- json_str = block[5:] # 移除 "data:" 前缀
97
- json_obj = json.loads(json_str)
98
- yield json_obj
99
- except json.JSONDecodeError as e:
100
- logger.error(f"JSON解析错误: {str(e)}")
101
- logger.error(f"原始数据块: {json_str}")
102
-
103
- # 保留最后一个可能不完整的块
104
- buffer = blocks[-1] if blocks else ""
84
+ raise Exception(
85
+ f"Dify /workflows/run 接口请求失败:{resp.status}. {text}"
86
+ )
87
+ async for event in _stream_sse(resp):
88
+ yield event
105
89
 
106
90
  async def file_upload(
107
91
  self,
@@ -109,12 +93,15 @@ class DifyAPIClient:
109
93
  user: str,
110
94
  ) -> Dict[str, Any]:
111
95
  url = f"{self.api_base}/files/upload"
112
- payload = {
113
- "user": user,
114
- "file": open(file_path, "rb"),
115
- }
116
- async with self.session.post(url, data=payload, headers=self.headers) as resp:
117
- return await resp.json() # {"id": "xxx", ...}
96
+ with open(file_path, "rb") as f:
97
+ payload = {
98
+ "user": user,
99
+ "file": f,
100
+ }
101
+ async with self.session.post(
102
+ url, data=payload, headers=self.headers
103
+ ) as resp:
104
+ return await resp.json() # {"id": "xxx", ...}
118
105
 
119
106
  async def close(self):
120
107
  await self.session.close()
@@ -1,17 +1,27 @@
1
1
  import uuid
2
2
  import json
3
3
  import os
4
+ import asyncio
5
+ from contextlib import asynccontextmanager
4
6
  from .route import Route, Response, RouteContext
5
7
  from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
6
8
  from quart import request, Response as QuartResponse, g, make_response
7
9
  from astrbot.core.db import BaseDatabase
8
- import asyncio
9
10
  from astrbot.core import logger
10
11
  from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
11
12
  from astrbot.core.utils.astrbot_path import get_astrbot_data_path
12
13
  from astrbot.core.platform.astr_message_event import MessageSession
13
14
 
14
15
 
16
+ @asynccontextmanager
17
+ async def track_conversation(convs: dict, conv_id: str):
18
+ convs[conv_id] = True
19
+ try:
20
+ yield
21
+ finally:
22
+ convs.pop(conv_id, None)
23
+
24
+
15
25
  class ChatRoute(Route):
16
26
  def __init__(
17
27
  self,
@@ -40,6 +50,8 @@ class ChatRoute(Route):
40
50
  self.conv_mgr = core_lifecycle.conversation_manager
41
51
  self.platform_history_mgr = core_lifecycle.platform_message_history_manager
42
52
 
53
+ self.running_convs: dict[str, bool] = {}
54
+
43
55
  async def get_file(self):
44
56
  filename = request.args.get("filename")
45
57
  if not filename:
@@ -139,42 +151,63 @@ class ChatRoute(Route):
139
151
  )
140
152
 
141
153
  async def stream():
154
+ client_disconnected = False
155
+
142
156
  try:
143
- while True:
144
- try:
145
- result = await asyncio.wait_for(back_queue.get(), timeout=10)
146
- except asyncio.TimeoutError:
147
- continue
148
-
149
- if not result:
150
- continue
151
-
152
- result_text = result["data"]
153
- type = result.get("type")
154
- streaming = result.get("streaming", False)
155
- yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n"
156
- await asyncio.sleep(0.05)
157
-
158
- if type == "end":
159
- break
160
- elif (
161
- (streaming and type == "complete")
162
- or not streaming
163
- or type == "break"
164
- ):
165
- # append bot message
166
- new_his = {"type": "bot", "message": result_text}
167
- await self.platform_history_mgr.insert(
168
- platform_id="webchat",
169
- user_id=webchat_conv_id,
170
- content=new_his,
171
- sender_id="bot",
172
- sender_name="bot",
173
- )
174
-
175
- except BaseException as _:
176
- logger.debug(f"用户 {username} 断开聊天长连接。")
177
- return
157
+ async with track_conversation(self.running_convs, webchat_conv_id):
158
+ while True:
159
+ try:
160
+ result = await asyncio.wait_for(back_queue.get(), timeout=1)
161
+ except asyncio.TimeoutError:
162
+ continue
163
+ except asyncio.CancelledError:
164
+ logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。")
165
+ client_disconnected = True
166
+ except Exception as e:
167
+ logger.error(f"WebChat stream error: {e}")
168
+
169
+ if not result:
170
+ continue
171
+
172
+ result_text = result["data"]
173
+ type = result.get("type")
174
+ streaming = result.get("streaming", False)
175
+
176
+ try:
177
+ if not client_disconnected:
178
+ yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n"
179
+ except Exception as e:
180
+ if not client_disconnected:
181
+ logger.debug(
182
+ f"[WebChat] 用户 {username} 断开聊天长连接。 {e}"
183
+ )
184
+ client_disconnected = True
185
+
186
+ try:
187
+ if not client_disconnected:
188
+ await asyncio.sleep(0.05)
189
+ except asyncio.CancelledError:
190
+ logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。")
191
+ client_disconnected = True
192
+
193
+ if type == "end":
194
+ break
195
+ elif (
196
+ (streaming and type == "complete")
197
+ or not streaming
198
+ or type == "break"
199
+ ):
200
+ # append bot message
201
+ new_his = {"type": "bot", "message": result_text}
202
+ await self.platform_history_mgr.insert(
203
+ platform_id="webchat",
204
+ user_id=webchat_conv_id,
205
+ content=new_his,
206
+ sender_id="bot",
207
+ sender_name="bot",
208
+ )
209
+ except BaseException as e:
210
+ logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True)
178
211
 
179
212
  # Put message to conversation-specific queue
180
213
  chat_queue = webchat_queue_mgr.get_or_create_queue(webchat_conv_id)
@@ -291,6 +324,7 @@ class ChatRoute(Route):
291
324
  .ok(
292
325
  data={
293
326
  "history": history_res,
327
+ "is_running": self.running_convs.get(webchat_conv_id, False),
294
328
  }
295
329
  )
296
330
  .__dict__