AstrBot 4.6.1__py3-none-any.whl → 4.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/core/agent/mcp_client.py +3 -3
- astrbot/core/agent/runners/base.py +7 -4
- astrbot/core/agent/runners/coze/coze_agent_runner.py +367 -0
- astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py +403 -0
- astrbot/core/agent/runners/dify/dify_agent_runner.py +336 -0
- astrbot/core/{utils → agent/runners/dify}/dify_api_client.py +51 -13
- astrbot/core/agent/runners/tool_loop_agent_runner.py +0 -6
- astrbot/core/config/default.py +141 -26
- astrbot/core/config/i18n_utils.py +110 -0
- astrbot/core/core_lifecycle.py +11 -13
- astrbot/core/db/po.py +1 -1
- astrbot/core/db/sqlite.py +2 -2
- astrbot/core/pipeline/process_stage/method/agent_request.py +48 -0
- astrbot/core/pipeline/process_stage/method/{llm_request.py → agent_sub_stages/internal.py} +13 -34
- astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +202 -0
- astrbot/core/pipeline/process_stage/method/star_request.py +1 -1
- astrbot/core/pipeline/process_stage/stage.py +8 -5
- astrbot/core/pipeline/result_decorate/stage.py +15 -5
- astrbot/core/provider/manager.py +43 -41
- astrbot/core/star/session_llm_manager.py +0 -107
- astrbot/core/star/session_plugin_manager.py +0 -81
- astrbot/core/umop_config_router.py +19 -0
- astrbot/core/utils/migra_helper.py +73 -0
- astrbot/core/utils/shared_preferences.py +1 -28
- astrbot/dashboard/routes/chat.py +13 -1
- astrbot/dashboard/routes/config.py +20 -16
- astrbot/dashboard/routes/knowledge_base.py +0 -156
- astrbot/dashboard/routes/session_management.py +311 -606
- {astrbot-4.6.1.dist-info → astrbot-4.7.1.dist-info}/METADATA +1 -1
- {astrbot-4.6.1.dist-info → astrbot-4.7.1.dist-info}/RECORD +34 -30
- {astrbot-4.6.1.dist-info → astrbot-4.7.1.dist-info}/WHEEL +1 -1
- astrbot/core/provider/sources/coze_source.py +0 -650
- astrbot/core/provider/sources/dashscope_source.py +0 -207
- astrbot/core/provider/sources/dify_source.py +0 -285
- /astrbot/core/{provider/sources → agent/runners/coze}/coze_api_client.py +0 -0
- {astrbot-4.6.1.dist-info → astrbot-4.7.1.dist-info}/entry_points.txt +0 -0
- {astrbot-4.6.1.dist-info → astrbot-4.7.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,403 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import functools
|
|
3
|
+
import queue
|
|
4
|
+
import re
|
|
5
|
+
import sys
|
|
6
|
+
import threading
|
|
7
|
+
import typing as T
|
|
8
|
+
|
|
9
|
+
from dashscope import Application
|
|
10
|
+
from dashscope.app.application_response import ApplicationResponse
|
|
11
|
+
|
|
12
|
+
import astrbot.core.message.components as Comp
|
|
13
|
+
from astrbot.core import logger, sp
|
|
14
|
+
from astrbot.core.message.message_event_result import MessageChain
|
|
15
|
+
from astrbot.core.provider.entities import (
|
|
16
|
+
LLMResponse,
|
|
17
|
+
ProviderRequest,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
from ...hooks import BaseAgentRunHooks
|
|
21
|
+
from ...response import AgentResponseData
|
|
22
|
+
from ...run_context import ContextWrapper, TContext
|
|
23
|
+
from ..base import AgentResponse, AgentState, BaseAgentRunner
|
|
24
|
+
|
|
25
|
+
if sys.version_info >= (3, 12):
|
|
26
|
+
from typing import override
|
|
27
|
+
else:
|
|
28
|
+
from typing_extensions import override
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class DashscopeAgentRunner(BaseAgentRunner[TContext]):
|
|
32
|
+
"""Dashscope Agent Runner"""
|
|
33
|
+
|
|
34
|
+
@override
|
|
35
|
+
async def reset(
|
|
36
|
+
self,
|
|
37
|
+
request: ProviderRequest,
|
|
38
|
+
run_context: ContextWrapper[TContext],
|
|
39
|
+
agent_hooks: BaseAgentRunHooks[TContext],
|
|
40
|
+
provider_config: dict,
|
|
41
|
+
**kwargs: T.Any,
|
|
42
|
+
) -> None:
|
|
43
|
+
self.req = request
|
|
44
|
+
self.streaming = kwargs.get("streaming", False)
|
|
45
|
+
self.final_llm_resp = None
|
|
46
|
+
self._state = AgentState.IDLE
|
|
47
|
+
self.agent_hooks = agent_hooks
|
|
48
|
+
self.run_context = run_context
|
|
49
|
+
|
|
50
|
+
self.api_key = provider_config.get("dashscope_api_key", "")
|
|
51
|
+
if not self.api_key:
|
|
52
|
+
raise Exception("阿里云百炼 API Key 不能为空。")
|
|
53
|
+
self.app_id = provider_config.get("dashscope_app_id", "")
|
|
54
|
+
if not self.app_id:
|
|
55
|
+
raise Exception("阿里云百炼 APP ID 不能为空。")
|
|
56
|
+
self.dashscope_app_type = provider_config.get("dashscope_app_type", "")
|
|
57
|
+
if not self.dashscope_app_type:
|
|
58
|
+
raise Exception("阿里云百炼 APP 类型不能为空。")
|
|
59
|
+
|
|
60
|
+
self.variables: dict = provider_config.get("variables", {}) or {}
|
|
61
|
+
self.rag_options: dict = provider_config.get("rag_options", {})
|
|
62
|
+
self.output_reference = self.rag_options.get("output_reference", False)
|
|
63
|
+
self.rag_options = self.rag_options.copy()
|
|
64
|
+
self.rag_options.pop("output_reference", None)
|
|
65
|
+
|
|
66
|
+
self.timeout = provider_config.get("timeout", 120)
|
|
67
|
+
if isinstance(self.timeout, str):
|
|
68
|
+
self.timeout = int(self.timeout)
|
|
69
|
+
|
|
70
|
+
def has_rag_options(self):
|
|
71
|
+
"""判断是否有 RAG 选项
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
bool: 是否有 RAG 选项
|
|
75
|
+
|
|
76
|
+
"""
|
|
77
|
+
if self.rag_options and (
|
|
78
|
+
len(self.rag_options.get("pipeline_ids", [])) > 0
|
|
79
|
+
or len(self.rag_options.get("file_ids", [])) > 0
|
|
80
|
+
):
|
|
81
|
+
return True
|
|
82
|
+
return False
|
|
83
|
+
|
|
84
|
+
@override
|
|
85
|
+
async def step(self):
|
|
86
|
+
"""
|
|
87
|
+
执行 Dashscope Agent 的一个步骤
|
|
88
|
+
"""
|
|
89
|
+
if not self.req:
|
|
90
|
+
raise ValueError("Request is not set. Please call reset() first.")
|
|
91
|
+
|
|
92
|
+
if self._state == AgentState.IDLE:
|
|
93
|
+
try:
|
|
94
|
+
await self.agent_hooks.on_agent_begin(self.run_context)
|
|
95
|
+
except Exception as e:
|
|
96
|
+
logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True)
|
|
97
|
+
|
|
98
|
+
# 开始处理,转换到运行状态
|
|
99
|
+
self._transition_state(AgentState.RUNNING)
|
|
100
|
+
|
|
101
|
+
try:
|
|
102
|
+
# 执行 Dashscope 请求并处理结果
|
|
103
|
+
async for response in self._execute_dashscope_request():
|
|
104
|
+
yield response
|
|
105
|
+
except Exception as e:
|
|
106
|
+
logger.error(f"阿里云百炼请求失败:{str(e)}")
|
|
107
|
+
self._transition_state(AgentState.ERROR)
|
|
108
|
+
self.final_llm_resp = LLMResponse(
|
|
109
|
+
role="err", completion_text=f"阿里云百炼请求失败:{str(e)}"
|
|
110
|
+
)
|
|
111
|
+
yield AgentResponse(
|
|
112
|
+
type="err",
|
|
113
|
+
data=AgentResponseData(
|
|
114
|
+
chain=MessageChain().message(f"阿里云百炼请求失败:{str(e)}")
|
|
115
|
+
),
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
@override
|
|
119
|
+
async def step_until_done(
|
|
120
|
+
self, max_step: int = 30
|
|
121
|
+
) -> T.AsyncGenerator[AgentResponse, None]:
|
|
122
|
+
while not self.done():
|
|
123
|
+
async for resp in self.step():
|
|
124
|
+
yield resp
|
|
125
|
+
|
|
126
|
+
def _consume_sync_generator(
|
|
127
|
+
self, response: T.Any, response_queue: queue.Queue
|
|
128
|
+
) -> None:
|
|
129
|
+
"""在线程中消费同步generator,将结果放入队列
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
response: 同步generator对象
|
|
133
|
+
response_queue: 用于传递数据的队列
|
|
134
|
+
|
|
135
|
+
"""
|
|
136
|
+
try:
|
|
137
|
+
if self.streaming:
|
|
138
|
+
for chunk in response:
|
|
139
|
+
response_queue.put(("data", chunk))
|
|
140
|
+
else:
|
|
141
|
+
response_queue.put(("data", response))
|
|
142
|
+
except Exception as e:
|
|
143
|
+
response_queue.put(("error", e))
|
|
144
|
+
finally:
|
|
145
|
+
response_queue.put(("done", None))
|
|
146
|
+
|
|
147
|
+
async def _process_stream_chunk(
|
|
148
|
+
self, chunk: ApplicationResponse, output_text: str
|
|
149
|
+
) -> tuple[str, list | None, AgentResponse | None]:
|
|
150
|
+
"""处理流式响应的单个chunk
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
chunk: Dashscope响应chunk
|
|
154
|
+
output_text: 当前累积的输出文本
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
(更新后的output_text, doc_references, AgentResponse或None)
|
|
158
|
+
|
|
159
|
+
"""
|
|
160
|
+
logger.debug(f"dashscope stream chunk: {chunk}")
|
|
161
|
+
|
|
162
|
+
if chunk.status_code != 200:
|
|
163
|
+
logger.error(
|
|
164
|
+
f"阿里云百炼请求失败: request_id={chunk.request_id}, code={chunk.status_code}, message={chunk.message}, 请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code",
|
|
165
|
+
)
|
|
166
|
+
self._transition_state(AgentState.ERROR)
|
|
167
|
+
error_msg = (
|
|
168
|
+
f"阿里云百炼请求失败: message={chunk.message} code={chunk.status_code}"
|
|
169
|
+
)
|
|
170
|
+
self.final_llm_resp = LLMResponse(
|
|
171
|
+
role="err",
|
|
172
|
+
result_chain=MessageChain().message(error_msg),
|
|
173
|
+
)
|
|
174
|
+
return (
|
|
175
|
+
output_text,
|
|
176
|
+
None,
|
|
177
|
+
AgentResponse(
|
|
178
|
+
type="err",
|
|
179
|
+
data=AgentResponseData(chain=MessageChain().message(error_msg)),
|
|
180
|
+
),
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
chunk_text = chunk.output.get("text", "") or ""
|
|
184
|
+
# RAG 引用脚标格式化
|
|
185
|
+
chunk_text = re.sub(r"<ref>\[(\d+)\]</ref>", r"[\1]", chunk_text)
|
|
186
|
+
|
|
187
|
+
response = None
|
|
188
|
+
if chunk_text:
|
|
189
|
+
output_text += chunk_text
|
|
190
|
+
response = AgentResponse(
|
|
191
|
+
type="streaming_delta",
|
|
192
|
+
data=AgentResponseData(chain=MessageChain().message(chunk_text)),
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# 获取文档引用
|
|
196
|
+
doc_references = chunk.output.get("doc_references", None)
|
|
197
|
+
|
|
198
|
+
return output_text, doc_references, response
|
|
199
|
+
|
|
200
|
+
def _format_doc_references(self, doc_references: list) -> str:
|
|
201
|
+
"""格式化文档引用为文本
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
doc_references: 文档引用列表
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
格式化后的引用文本
|
|
208
|
+
|
|
209
|
+
"""
|
|
210
|
+
ref_parts = []
|
|
211
|
+
for ref in doc_references:
|
|
212
|
+
ref_title = (
|
|
213
|
+
ref.get("title", "") if ref.get("title") else ref.get("doc_name", "")
|
|
214
|
+
)
|
|
215
|
+
ref_parts.append(f"{ref['index_id']}. {ref_title}\n")
|
|
216
|
+
ref_str = "".join(ref_parts)
|
|
217
|
+
return f"\n\n回答来源:\n{ref_str}"
|
|
218
|
+
|
|
219
|
+
async def _build_request_payload(
|
|
220
|
+
self, prompt: str, session_id: str, contexts: list, system_prompt: str
|
|
221
|
+
) -> dict:
|
|
222
|
+
"""构建请求payload
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
prompt: 用户输入
|
|
226
|
+
session_id: 会话ID
|
|
227
|
+
contexts: 上下文列表
|
|
228
|
+
system_prompt: 系统提示词
|
|
229
|
+
|
|
230
|
+
Returns:
|
|
231
|
+
请求payload字典
|
|
232
|
+
|
|
233
|
+
"""
|
|
234
|
+
conversation_id = await sp.get_async(
|
|
235
|
+
scope="umo",
|
|
236
|
+
scope_id=session_id,
|
|
237
|
+
key="dashscope_conversation_id",
|
|
238
|
+
default="",
|
|
239
|
+
)
|
|
240
|
+
# 获得会话变量
|
|
241
|
+
payload_vars = self.variables.copy()
|
|
242
|
+
session_var = await sp.get_async(
|
|
243
|
+
scope="umo",
|
|
244
|
+
scope_id=session_id,
|
|
245
|
+
key="session_variables",
|
|
246
|
+
default={},
|
|
247
|
+
)
|
|
248
|
+
payload_vars.update(session_var)
|
|
249
|
+
|
|
250
|
+
if (
|
|
251
|
+
self.dashscope_app_type in ["agent", "dialog-workflow"]
|
|
252
|
+
and not self.has_rag_options()
|
|
253
|
+
):
|
|
254
|
+
# 支持多轮对话的
|
|
255
|
+
p = {
|
|
256
|
+
"app_id": self.app_id,
|
|
257
|
+
"api_key": self.api_key,
|
|
258
|
+
"prompt": prompt,
|
|
259
|
+
"biz_params": payload_vars or None,
|
|
260
|
+
"stream": self.streaming,
|
|
261
|
+
"incremental_output": True,
|
|
262
|
+
}
|
|
263
|
+
if conversation_id:
|
|
264
|
+
p["session_id"] = conversation_id
|
|
265
|
+
return p
|
|
266
|
+
else:
|
|
267
|
+
# 不支持多轮对话的
|
|
268
|
+
payload = {
|
|
269
|
+
"app_id": self.app_id,
|
|
270
|
+
"prompt": prompt,
|
|
271
|
+
"api_key": self.api_key,
|
|
272
|
+
"biz_params": payload_vars or None,
|
|
273
|
+
"stream": self.streaming,
|
|
274
|
+
"incremental_output": True,
|
|
275
|
+
}
|
|
276
|
+
if self.rag_options:
|
|
277
|
+
payload["rag_options"] = self.rag_options
|
|
278
|
+
return payload
|
|
279
|
+
|
|
280
|
+
async def _handle_streaming_response(
|
|
281
|
+
self, response: T.Any, session_id: str
|
|
282
|
+
) -> T.AsyncGenerator[AgentResponse, None]:
|
|
283
|
+
"""处理流式响应
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
response: Dashscope 流式响应 generator
|
|
287
|
+
|
|
288
|
+
Yields:
|
|
289
|
+
AgentResponse 对象
|
|
290
|
+
|
|
291
|
+
"""
|
|
292
|
+
response_queue = queue.Queue()
|
|
293
|
+
consumer_thread = threading.Thread(
|
|
294
|
+
target=self._consume_sync_generator,
|
|
295
|
+
args=(response, response_queue),
|
|
296
|
+
daemon=True,
|
|
297
|
+
)
|
|
298
|
+
consumer_thread.start()
|
|
299
|
+
|
|
300
|
+
output_text = ""
|
|
301
|
+
doc_references = None
|
|
302
|
+
|
|
303
|
+
while True:
|
|
304
|
+
try:
|
|
305
|
+
item_type, item_data = await asyncio.get_event_loop().run_in_executor(
|
|
306
|
+
None, response_queue.get, True, 1
|
|
307
|
+
)
|
|
308
|
+
except queue.Empty:
|
|
309
|
+
continue
|
|
310
|
+
|
|
311
|
+
if item_type == "done":
|
|
312
|
+
break
|
|
313
|
+
elif item_type == "error":
|
|
314
|
+
raise item_data
|
|
315
|
+
elif item_type == "data":
|
|
316
|
+
chunk = item_data
|
|
317
|
+
assert isinstance(chunk, ApplicationResponse)
|
|
318
|
+
|
|
319
|
+
(
|
|
320
|
+
output_text,
|
|
321
|
+
chunk_doc_refs,
|
|
322
|
+
response,
|
|
323
|
+
) = await self._process_stream_chunk(chunk, output_text)
|
|
324
|
+
|
|
325
|
+
if response:
|
|
326
|
+
if response.type == "err":
|
|
327
|
+
yield response
|
|
328
|
+
return
|
|
329
|
+
yield response
|
|
330
|
+
|
|
331
|
+
if chunk_doc_refs:
|
|
332
|
+
doc_references = chunk_doc_refs
|
|
333
|
+
|
|
334
|
+
if chunk.output.session_id:
|
|
335
|
+
await sp.put_async(
|
|
336
|
+
scope="umo",
|
|
337
|
+
scope_id=session_id,
|
|
338
|
+
key="dashscope_conversation_id",
|
|
339
|
+
value=chunk.output.session_id,
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
# 添加 RAG 引用
|
|
343
|
+
if self.output_reference and doc_references:
|
|
344
|
+
ref_text = self._format_doc_references(doc_references)
|
|
345
|
+
output_text += ref_text
|
|
346
|
+
|
|
347
|
+
if self.streaming:
|
|
348
|
+
yield AgentResponse(
|
|
349
|
+
type="streaming_delta",
|
|
350
|
+
data=AgentResponseData(chain=MessageChain().message(ref_text)),
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
# 创建最终响应
|
|
354
|
+
chain = MessageChain(chain=[Comp.Plain(output_text)])
|
|
355
|
+
self.final_llm_resp = LLMResponse(role="assistant", result_chain=chain)
|
|
356
|
+
self._transition_state(AgentState.DONE)
|
|
357
|
+
|
|
358
|
+
try:
|
|
359
|
+
await self.agent_hooks.on_agent_done(self.run_context, self.final_llm_resp)
|
|
360
|
+
except Exception as e:
|
|
361
|
+
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
|
|
362
|
+
|
|
363
|
+
# 返回最终结果
|
|
364
|
+
yield AgentResponse(
|
|
365
|
+
type="llm_result",
|
|
366
|
+
data=AgentResponseData(chain=chain),
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
async def _execute_dashscope_request(self):
|
|
370
|
+
"""执行 Dashscope 请求的核心逻辑"""
|
|
371
|
+
prompt = self.req.prompt or ""
|
|
372
|
+
session_id = self.req.session_id or "unknown"
|
|
373
|
+
image_urls = self.req.image_urls or []
|
|
374
|
+
contexts = self.req.contexts or []
|
|
375
|
+
system_prompt = self.req.system_prompt
|
|
376
|
+
|
|
377
|
+
# 检查图片输入
|
|
378
|
+
if image_urls:
|
|
379
|
+
logger.warning("阿里云百炼暂不支持图片输入,将自动忽略图片内容。")
|
|
380
|
+
|
|
381
|
+
# 构建请求payload
|
|
382
|
+
payload = await self._build_request_payload(
|
|
383
|
+
prompt, session_id, contexts, system_prompt
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
if not self.streaming:
|
|
387
|
+
payload["incremental_output"] = False
|
|
388
|
+
|
|
389
|
+
# 发起请求
|
|
390
|
+
partial = functools.partial(Application.call, **payload)
|
|
391
|
+
response = await asyncio.get_event_loop().run_in_executor(None, partial)
|
|
392
|
+
|
|
393
|
+
async for resp in self._handle_streaming_response(response, session_id):
|
|
394
|
+
yield resp
|
|
395
|
+
|
|
396
|
+
@override
|
|
397
|
+
def done(self) -> bool:
|
|
398
|
+
"""检查 Agent 是否已完成工作"""
|
|
399
|
+
return self._state in (AgentState.DONE, AgentState.ERROR)
|
|
400
|
+
|
|
401
|
+
@override
|
|
402
|
+
def get_final_llm_resp(self) -> LLMResponse | None:
|
|
403
|
+
return self.final_llm_resp
|