AstrBot 4.12.2__py3-none-any.whl → 4.12.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/builtin_stars/builtin_commands/commands/__init__.py +0 -2
- astrbot/builtin_stars/builtin_commands/commands/persona.py +68 -6
- astrbot/builtin_stars/builtin_commands/main.py +0 -26
- astrbot/cli/__init__.py +1 -1
- astrbot/core/astr_agent_hooks.py +5 -3
- astrbot/core/astr_agent_run_util.py +243 -1
- astrbot/core/config/default.py +30 -1
- astrbot/core/db/__init__.py +91 -1
- astrbot/core/db/po.py +42 -0
- astrbot/core/db/sqlite.py +230 -0
- astrbot/core/persona_mgr.py +154 -2
- astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +57 -4
- astrbot/core/pipeline/process_stage/utils.py +13 -1
- astrbot/core/pipeline/waking_check/stage.py +0 -1
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +32 -14
- astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +61 -2
- astrbot/core/platform/sources/dingtalk/dingtalk_event.py +57 -11
- astrbot/core/platform/sources/webchat/webchat_adapter.py +1 -0
- astrbot/core/platform/sources/webchat/webchat_event.py +24 -0
- astrbot/core/provider/manager.py +7 -0
- astrbot/core/provider/provider.py +54 -0
- astrbot/core/provider/sources/gemini_embedding_source.py +1 -1
- astrbot/core/provider/sources/genie_tts.py +128 -0
- astrbot/core/provider/sources/openai_embedding_source.py +1 -1
- astrbot/core/star/context.py +9 -8
- astrbot/core/star/register/star_handler.py +2 -4
- astrbot/core/star/star_handler.py +2 -1
- astrbot/dashboard/routes/live_chat.py +423 -0
- astrbot/dashboard/routes/persona.py +258 -1
- astrbot/dashboard/server.py +2 -0
- {astrbot-4.12.2.dist-info → astrbot-4.12.4.dist-info}/METADATA +1 -1
- {astrbot-4.12.2.dist-info → astrbot-4.12.4.dist-info}/RECORD +35 -34
- astrbot/builtin_stars/builtin_commands/commands/tool.py +0 -31
- {astrbot-4.12.2.dist-info → astrbot-4.12.4.dist-info}/WHEEL +0 -0
- {astrbot-4.12.2.dist-info → astrbot-4.12.4.dist-info}/entry_points.txt +0 -0
- {astrbot-4.12.2.dist-info → astrbot-4.12.4.dist-info}/licenses/LICENSE +0 -0
|
@@ -39,7 +39,7 @@ class MyEventHandler(dingtalk_stream.EventHandler):
|
|
|
39
39
|
|
|
40
40
|
|
|
41
41
|
@register_platform_adapter(
|
|
42
|
-
"dingtalk", "钉钉机器人官方 API 适配器", support_streaming_message=
|
|
42
|
+
"dingtalk", "钉钉机器人官方 API 适配器", support_streaming_message=True
|
|
43
43
|
)
|
|
44
44
|
class DingtalkPlatformAdapter(Platform):
|
|
45
45
|
def __init__(
|
|
@@ -75,6 +75,8 @@ class DingtalkPlatformAdapter(Platform):
|
|
|
75
75
|
)
|
|
76
76
|
self.client_ = client # 用于 websockets 的 client
|
|
77
77
|
self._shutdown_event: threading.Event | None = None
|
|
78
|
+
self.card_template_id = platform_config.get("card_template_id")
|
|
79
|
+
self.card_instance_id_dict = {}
|
|
78
80
|
|
|
79
81
|
def _id_to_sid(self, dingtalk_id: str | None) -> str:
|
|
80
82
|
if not dingtalk_id:
|
|
@@ -96,9 +98,65 @@ class DingtalkPlatformAdapter(Platform):
|
|
|
96
98
|
name="dingtalk",
|
|
97
99
|
description="钉钉机器人官方 API 适配器",
|
|
98
100
|
id=cast(str, self.config.get("id")),
|
|
99
|
-
support_streaming_message=
|
|
101
|
+
support_streaming_message=True,
|
|
100
102
|
)
|
|
101
103
|
|
|
104
|
+
async def create_message_card(
|
|
105
|
+
self, message_id: str, incoming_message: dingtalk_stream.ChatbotMessage
|
|
106
|
+
):
|
|
107
|
+
if not self.card_template_id:
|
|
108
|
+
return False
|
|
109
|
+
|
|
110
|
+
card_instance = dingtalk_stream.AICardReplier(self.client_, incoming_message)
|
|
111
|
+
card_data = {"content": ""} # Initial content empty
|
|
112
|
+
|
|
113
|
+
try:
|
|
114
|
+
card_instance_id = await card_instance.async_create_and_deliver_card(
|
|
115
|
+
self.card_template_id,
|
|
116
|
+
card_data,
|
|
117
|
+
)
|
|
118
|
+
self.card_instance_id_dict[message_id] = (card_instance, card_instance_id)
|
|
119
|
+
return True
|
|
120
|
+
except Exception as e:
|
|
121
|
+
logger.error(f"创建钉钉卡片失败: {e}")
|
|
122
|
+
return False
|
|
123
|
+
|
|
124
|
+
async def send_card_message(self, message_id: str, content: str, is_final: bool):
|
|
125
|
+
if message_id not in self.card_instance_id_dict:
|
|
126
|
+
return
|
|
127
|
+
|
|
128
|
+
card_instance, card_instance_id = self.card_instance_id_dict[message_id]
|
|
129
|
+
content_key = "content"
|
|
130
|
+
|
|
131
|
+
try:
|
|
132
|
+
# 钉钉卡片流式更新
|
|
133
|
+
|
|
134
|
+
await card_instance.async_streaming(
|
|
135
|
+
card_instance_id,
|
|
136
|
+
content_key=content_key,
|
|
137
|
+
content_value=content,
|
|
138
|
+
append=False,
|
|
139
|
+
finished=is_final,
|
|
140
|
+
failed=False,
|
|
141
|
+
)
|
|
142
|
+
except Exception as e:
|
|
143
|
+
logger.error(f"发送钉钉卡片消息失败: {e}")
|
|
144
|
+
# Try to report failure
|
|
145
|
+
try:
|
|
146
|
+
await card_instance.async_streaming(
|
|
147
|
+
card_instance_id,
|
|
148
|
+
content_key=content_key,
|
|
149
|
+
content_value=content, # Keep existing content
|
|
150
|
+
append=False,
|
|
151
|
+
finished=True,
|
|
152
|
+
failed=True,
|
|
153
|
+
)
|
|
154
|
+
except Exception:
|
|
155
|
+
pass
|
|
156
|
+
|
|
157
|
+
if is_final:
|
|
158
|
+
self.card_instance_id_dict.pop(message_id, None)
|
|
159
|
+
|
|
102
160
|
async def convert_msg(
|
|
103
161
|
self,
|
|
104
162
|
message: dingtalk_stream.ChatbotMessage,
|
|
@@ -224,6 +282,7 @@ class DingtalkPlatformAdapter(Platform):
|
|
|
224
282
|
platform_meta=self.meta(),
|
|
225
283
|
session_id=abm.session_id,
|
|
226
284
|
client=self.client,
|
|
285
|
+
adapter=self,
|
|
227
286
|
)
|
|
228
287
|
|
|
229
288
|
self._event_queue.put_nowait(event)
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
-
from typing import cast
|
|
2
|
+
from typing import Any, cast
|
|
3
3
|
|
|
4
4
|
import dingtalk_stream
|
|
5
5
|
|
|
@@ -16,9 +16,11 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
|
|
16
16
|
platform_meta,
|
|
17
17
|
session_id,
|
|
18
18
|
client: dingtalk_stream.ChatbotHandler,
|
|
19
|
+
adapter: "Any" = None,
|
|
19
20
|
):
|
|
20
21
|
super().__init__(message_str, message_obj, platform_meta, session_id)
|
|
21
22
|
self.client = client
|
|
23
|
+
self.adapter = adapter
|
|
22
24
|
|
|
23
25
|
async def send_with_client(
|
|
24
26
|
self,
|
|
@@ -83,14 +85,58 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
|
|
83
85
|
await super().send(message)
|
|
84
86
|
|
|
85
87
|
async def send_streaming(self, generator, use_fallback: bool = False):
|
|
86
|
-
|
|
87
|
-
|
|
88
|
+
if not self.adapter or not self.adapter.card_template_id:
|
|
89
|
+
logger.warning(
|
|
90
|
+
f"DingTalk streaming is enabled, but 'card_template_id' is not configured for platform '{self.platform_meta.id}'. Falling back to text streaming."
|
|
91
|
+
)
|
|
92
|
+
# Fallback to default behavior (buffer and send)
|
|
93
|
+
buffer = None
|
|
94
|
+
async for chain in generator:
|
|
95
|
+
if not buffer:
|
|
96
|
+
buffer = chain
|
|
97
|
+
else:
|
|
98
|
+
buffer.chain.extend(chain.chain)
|
|
99
|
+
if not buffer:
|
|
100
|
+
return None
|
|
101
|
+
buffer.squash_plain()
|
|
102
|
+
await self.send(buffer)
|
|
103
|
+
return await super().send_streaming(generator, use_fallback)
|
|
104
|
+
|
|
105
|
+
# Create card
|
|
106
|
+
msg_id = self.message_obj.message_id
|
|
107
|
+
incoming_msg = self.message_obj.raw_message
|
|
108
|
+
created = await self.adapter.create_message_card(msg_id, incoming_msg)
|
|
109
|
+
|
|
110
|
+
if not created:
|
|
111
|
+
# Fallback to default behavior (buffer and send)
|
|
112
|
+
buffer = None
|
|
113
|
+
async for chain in generator:
|
|
114
|
+
if not buffer:
|
|
115
|
+
buffer = chain
|
|
116
|
+
else:
|
|
117
|
+
buffer.chain.extend(chain.chain)
|
|
88
118
|
if not buffer:
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
119
|
+
return None
|
|
120
|
+
buffer.squash_plain()
|
|
121
|
+
await self.send(buffer)
|
|
122
|
+
return await super().send_streaming(generator, use_fallback)
|
|
123
|
+
|
|
124
|
+
full_content = ""
|
|
125
|
+
seq = 0
|
|
126
|
+
try:
|
|
127
|
+
async for chain in generator:
|
|
128
|
+
for segment in chain.chain:
|
|
129
|
+
if isinstance(segment, Comp.Plain):
|
|
130
|
+
full_content += segment.text
|
|
131
|
+
|
|
132
|
+
seq += 1
|
|
133
|
+
if seq % 2 == 0: # Update every 2 chunks to be more responsive than 8
|
|
134
|
+
await self.adapter.send_card_message(
|
|
135
|
+
msg_id, full_content, is_final=False
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
await self.adapter.send_card_message(msg_id, full_content, is_final=True)
|
|
139
|
+
except Exception as e:
|
|
140
|
+
logger.error(f"DingTalk streaming error: {e}")
|
|
141
|
+
# Try to ensure final state is sent or cleaned up?
|
|
142
|
+
await self.adapter.send_card_message(msg_id, full_content, is_final=True)
|
|
@@ -128,6 +128,30 @@ class WebChatMessageEvent(AstrMessageEvent):
|
|
|
128
128
|
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
|
|
129
129
|
message_id = self.message_obj.message_id
|
|
130
130
|
async for chain in generator:
|
|
131
|
+
# 处理音频流(Live Mode)
|
|
132
|
+
if chain.type == "audio_chunk":
|
|
133
|
+
# 音频流数据,直接发送
|
|
134
|
+
audio_b64 = ""
|
|
135
|
+
text = None
|
|
136
|
+
|
|
137
|
+
if chain.chain and isinstance(chain.chain[0], Plain):
|
|
138
|
+
audio_b64 = chain.chain[0].text
|
|
139
|
+
|
|
140
|
+
if len(chain.chain) > 1 and isinstance(chain.chain[1], Json):
|
|
141
|
+
text = chain.chain[1].data.get("text")
|
|
142
|
+
|
|
143
|
+
payload = {
|
|
144
|
+
"type": "audio_chunk",
|
|
145
|
+
"data": audio_b64,
|
|
146
|
+
"streaming": True,
|
|
147
|
+
"message_id": message_id,
|
|
148
|
+
}
|
|
149
|
+
if text:
|
|
150
|
+
payload["text"] = text
|
|
151
|
+
|
|
152
|
+
await web_chat_back_queue.put(payload)
|
|
153
|
+
continue
|
|
154
|
+
|
|
131
155
|
# if chain.type == "break" and final_data:
|
|
132
156
|
# # 分割符
|
|
133
157
|
# await web_chat_back_queue.put(
|
astrbot/core/provider/manager.py
CHANGED
|
@@ -322,6 +322,10 @@ class ProviderManager:
|
|
|
322
322
|
from .sources.openai_tts_api_source import (
|
|
323
323
|
ProviderOpenAITTSAPI as ProviderOpenAITTSAPI,
|
|
324
324
|
)
|
|
325
|
+
case "genie_tts":
|
|
326
|
+
from .sources.genie_tts import (
|
|
327
|
+
GenieTTSProvider as GenieTTSProvider,
|
|
328
|
+
)
|
|
325
329
|
case "edge_tts":
|
|
326
330
|
from .sources.edge_tts_source import (
|
|
327
331
|
ProviderEdgeTTS as ProviderEdgeTTS,
|
|
@@ -422,17 +426,20 @@ class ProviderManager:
|
|
|
422
426
|
except (ImportError, ModuleNotFoundError) as e:
|
|
423
427
|
logger.critical(
|
|
424
428
|
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。",
|
|
429
|
+
exc_info=True,
|
|
425
430
|
)
|
|
426
431
|
return
|
|
427
432
|
except Exception as e:
|
|
428
433
|
logger.critical(
|
|
429
434
|
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。未知原因",
|
|
435
|
+
exc_info=True,
|
|
430
436
|
)
|
|
431
437
|
return
|
|
432
438
|
|
|
433
439
|
if provider_config["type"] not in provider_cls_map:
|
|
434
440
|
logger.error(
|
|
435
441
|
f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。",
|
|
442
|
+
exc_info=True,
|
|
436
443
|
)
|
|
437
444
|
return
|
|
438
445
|
|
|
@@ -221,11 +221,65 @@ class TTSProvider(AbstractProvider):
|
|
|
221
221
|
self.provider_config = provider_config
|
|
222
222
|
self.provider_settings = provider_settings
|
|
223
223
|
|
|
224
|
+
def support_stream(self) -> bool:
|
|
225
|
+
"""是否支持流式 TTS
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
bool: True 表示支持流式处理,False 表示不支持(默认)
|
|
229
|
+
|
|
230
|
+
Notes:
|
|
231
|
+
子类可以重写此方法返回 True 来启用流式 TTS 支持
|
|
232
|
+
"""
|
|
233
|
+
return False
|
|
234
|
+
|
|
224
235
|
@abc.abstractmethod
|
|
225
236
|
async def get_audio(self, text: str) -> str:
|
|
226
237
|
"""获取文本的音频,返回音频文件路径"""
|
|
227
238
|
raise NotImplementedError
|
|
228
239
|
|
|
240
|
+
async def get_audio_stream(
|
|
241
|
+
self,
|
|
242
|
+
text_queue: asyncio.Queue[str | None],
|
|
243
|
+
audio_queue: "asyncio.Queue[bytes | tuple[str, bytes] | None]",
|
|
244
|
+
) -> None:
|
|
245
|
+
"""流式 TTS 处理方法。
|
|
246
|
+
|
|
247
|
+
从 text_queue 中读取文本片段,将生成的音频数据(WAV 格式的 in-memory bytes)放入 audio_queue。
|
|
248
|
+
当 text_queue 收到 None 时,表示文本输入结束,此时应该处理完所有剩余文本并向 audio_queue 发送 None 表示结束。
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
text_queue: 输入文本队列,None 表示输入结束
|
|
252
|
+
audio_queue: 输出音频队列(bytes 或 (text, bytes)),None 表示输出结束
|
|
253
|
+
|
|
254
|
+
Notes:
|
|
255
|
+
- 默认实现会将文本累积后一次性调用 get_audio 生成完整音频
|
|
256
|
+
- 子类可以重写此方法实现真正的流式 TTS
|
|
257
|
+
- 音频数据应该是 WAV 格式的 bytes
|
|
258
|
+
"""
|
|
259
|
+
accumulated_text = ""
|
|
260
|
+
|
|
261
|
+
while True:
|
|
262
|
+
text_part = await text_queue.get()
|
|
263
|
+
|
|
264
|
+
if text_part is None:
|
|
265
|
+
# 输入结束,处理累积的文本
|
|
266
|
+
if accumulated_text:
|
|
267
|
+
try:
|
|
268
|
+
# 调用原有的 get_audio 方法获取音频文件路径
|
|
269
|
+
audio_path = await self.get_audio(accumulated_text)
|
|
270
|
+
# 读取音频文件内容
|
|
271
|
+
with open(audio_path, "rb") as f:
|
|
272
|
+
audio_data = f.read()
|
|
273
|
+
await audio_queue.put((accumulated_text, audio_data))
|
|
274
|
+
except Exception:
|
|
275
|
+
# 出错时也要发送 None 结束标记
|
|
276
|
+
pass
|
|
277
|
+
# 发送结束标记
|
|
278
|
+
await audio_queue.put(None)
|
|
279
|
+
break
|
|
280
|
+
|
|
281
|
+
accumulated_text += text_part
|
|
282
|
+
|
|
229
283
|
async def test(self):
|
|
230
284
|
await self.get_audio("hi")
|
|
231
285
|
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import os
|
|
3
|
+
import uuid
|
|
4
|
+
|
|
5
|
+
from astrbot.core import logger
|
|
6
|
+
from astrbot.core.provider.entities import ProviderType
|
|
7
|
+
from astrbot.core.provider.provider import TTSProvider
|
|
8
|
+
from astrbot.core.provider.register import register_provider_adapter
|
|
9
|
+
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
import genie_tts as genie # type: ignore
|
|
13
|
+
except ImportError:
|
|
14
|
+
genie = None
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@register_provider_adapter(
|
|
18
|
+
"genie_tts",
|
|
19
|
+
"Genie TTS",
|
|
20
|
+
provider_type=ProviderType.TEXT_TO_SPEECH,
|
|
21
|
+
)
|
|
22
|
+
class GenieTTSProvider(TTSProvider):
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
provider_config: dict,
|
|
26
|
+
provider_settings: dict,
|
|
27
|
+
) -> None:
|
|
28
|
+
super().__init__(provider_config, provider_settings)
|
|
29
|
+
if not genie:
|
|
30
|
+
raise ImportError("Please install genie_tts first.")
|
|
31
|
+
|
|
32
|
+
self.character_name = provider_config.get("genie_character_name", "mika")
|
|
33
|
+
language = provider_config.get("genie_language", "Japanese")
|
|
34
|
+
model_dir = provider_config.get("genie_onnx_model_dir", "")
|
|
35
|
+
refer_audio_path = provider_config.get("genie_refer_audio_path", "")
|
|
36
|
+
refer_text = provider_config.get("genie_refer_text", "")
|
|
37
|
+
|
|
38
|
+
try:
|
|
39
|
+
genie.load_character(
|
|
40
|
+
character_name=self.character_name,
|
|
41
|
+
language=language,
|
|
42
|
+
onnx_model_dir=model_dir,
|
|
43
|
+
)
|
|
44
|
+
genie.set_reference_audio(
|
|
45
|
+
character_name=self.character_name,
|
|
46
|
+
audio_path=refer_audio_path,
|
|
47
|
+
audio_text=refer_text,
|
|
48
|
+
language=language,
|
|
49
|
+
)
|
|
50
|
+
except Exception as e:
|
|
51
|
+
raise RuntimeError(f"Failed to load character {self.character_name}: {e}")
|
|
52
|
+
|
|
53
|
+
def support_stream(self) -> bool:
|
|
54
|
+
return True
|
|
55
|
+
|
|
56
|
+
async def get_audio(self, text: str) -> str:
|
|
57
|
+
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
|
58
|
+
os.makedirs(temp_dir, exist_ok=True)
|
|
59
|
+
filename = f"genie_tts_{uuid.uuid4()}.wav"
|
|
60
|
+
path = os.path.join(temp_dir, filename)
|
|
61
|
+
|
|
62
|
+
loop = asyncio.get_event_loop()
|
|
63
|
+
|
|
64
|
+
def _generate(save_path: str):
|
|
65
|
+
assert genie is not None
|
|
66
|
+
genie.tts(
|
|
67
|
+
character_name=self.character_name,
|
|
68
|
+
text=text,
|
|
69
|
+
save_path=save_path,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
try:
|
|
73
|
+
await loop.run_in_executor(None, _generate, path)
|
|
74
|
+
|
|
75
|
+
if os.path.exists(path):
|
|
76
|
+
return path
|
|
77
|
+
|
|
78
|
+
raise RuntimeError("Genie TTS did not save to file.")
|
|
79
|
+
|
|
80
|
+
except Exception as e:
|
|
81
|
+
raise RuntimeError(f"Genie TTS generation failed: {e}")
|
|
82
|
+
|
|
83
|
+
async def get_audio_stream(
|
|
84
|
+
self,
|
|
85
|
+
text_queue: asyncio.Queue[str | None],
|
|
86
|
+
audio_queue: "asyncio.Queue[bytes | tuple[str, bytes] | None]",
|
|
87
|
+
) -> None:
|
|
88
|
+
loop = asyncio.get_event_loop()
|
|
89
|
+
|
|
90
|
+
while True:
|
|
91
|
+
text = await text_queue.get()
|
|
92
|
+
if text is None:
|
|
93
|
+
await audio_queue.put(None)
|
|
94
|
+
break
|
|
95
|
+
|
|
96
|
+
try:
|
|
97
|
+
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
|
98
|
+
os.makedirs(temp_dir, exist_ok=True)
|
|
99
|
+
filename = f"genie_tts_{uuid.uuid4()}.wav"
|
|
100
|
+
path = os.path.join(temp_dir, filename)
|
|
101
|
+
|
|
102
|
+
def _generate(save_path: str, t: str):
|
|
103
|
+
assert genie is not None
|
|
104
|
+
genie.tts(
|
|
105
|
+
character_name=self.character_name,
|
|
106
|
+
text=t,
|
|
107
|
+
save_path=save_path,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
await loop.run_in_executor(None, _generate, path, text)
|
|
111
|
+
|
|
112
|
+
if os.path.exists(path):
|
|
113
|
+
with open(path, "rb") as f:
|
|
114
|
+
audio_data = f.read()
|
|
115
|
+
|
|
116
|
+
# Put (text, bytes) into queue so frontend can display text
|
|
117
|
+
await audio_queue.put((text, audio_data))
|
|
118
|
+
|
|
119
|
+
# Clean up
|
|
120
|
+
try:
|
|
121
|
+
os.remove(path)
|
|
122
|
+
except OSError:
|
|
123
|
+
pass
|
|
124
|
+
else:
|
|
125
|
+
logger.error(f"Genie TTS failed to generate audio for: {text}")
|
|
126
|
+
|
|
127
|
+
except Exception as e:
|
|
128
|
+
logger.error(f"Genie TTS stream error: {e}")
|
astrbot/core/star/context.py
CHANGED
|
@@ -328,28 +328,29 @@ class Context:
|
|
|
328
328
|
"""获取所有用于 Embedding 任务的 Provider。"""
|
|
329
329
|
return self.provider_manager.embedding_provider_insts
|
|
330
330
|
|
|
331
|
-
def get_using_provider(self, umo: str | None = None) -> Provider:
|
|
331
|
+
def get_using_provider(self, umo: str | None = None) -> Provider | None:
|
|
332
332
|
"""获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。
|
|
333
333
|
|
|
334
334
|
Args:
|
|
335
335
|
umo: unified_message_origin 值,如果传入并且用户启用了提供商会话隔离,
|
|
336
|
-
|
|
336
|
+
则使用该会话偏好的对话模型(提供商)。
|
|
337
337
|
|
|
338
338
|
Returns:
|
|
339
|
-
|
|
339
|
+
当前使用的对话模型(提供商),如果未设置则返回 None。
|
|
340
340
|
|
|
341
341
|
Raises:
|
|
342
|
-
ValueError:
|
|
343
|
-
|
|
344
|
-
Note:
|
|
345
|
-
通过 /provider 指令可以切换提供者。
|
|
342
|
+
ValueError: 该会话来源配置的的对话模型(提供商)的类型不正确。
|
|
346
343
|
"""
|
|
347
344
|
prov = self.provider_manager.get_using_provider(
|
|
348
345
|
provider_type=ProviderType.CHAT_COMPLETION,
|
|
349
346
|
umo=umo,
|
|
350
347
|
)
|
|
348
|
+
if prov is None:
|
|
349
|
+
return None
|
|
351
350
|
if not isinstance(prov, Provider):
|
|
352
|
-
raise ValueError(
|
|
351
|
+
raise ValueError(
|
|
352
|
+
f"该会话来源的对话模型(提供商)的类型不正确: {type(prov)}"
|
|
353
|
+
)
|
|
353
354
|
return prov
|
|
354
355
|
|
|
355
356
|
def get_using_tts_provider(self, umo: str | None = None) -> TTSProvider | None:
|
|
@@ -427,7 +427,7 @@ def register_on_using_llm_tool(**kwargs):
|
|
|
427
427
|
"""
|
|
428
428
|
|
|
429
429
|
def decorator(awaitable):
|
|
430
|
-
_ = get_handler_or_create(awaitable, EventType.
|
|
430
|
+
_ = get_handler_or_create(awaitable, EventType.OnUsingLLMToolEvent, **kwargs)
|
|
431
431
|
return awaitable
|
|
432
432
|
|
|
433
433
|
return decorator
|
|
@@ -452,9 +452,7 @@ def register_on_llm_tool_respond(**kwargs):
|
|
|
452
452
|
"""
|
|
453
453
|
|
|
454
454
|
def decorator(awaitable):
|
|
455
|
-
_ = get_handler_or_create(
|
|
456
|
-
awaitable, EventType.OnAfterCallingFuncToolEvent, **kwargs
|
|
457
|
-
)
|
|
455
|
+
_ = get_handler_or_create(awaitable, EventType.OnLLMToolRespondEvent, **kwargs)
|
|
458
456
|
return awaitable
|
|
459
457
|
|
|
460
458
|
return decorator
|
|
@@ -189,7 +189,8 @@ class EventType(enum.Enum):
|
|
|
189
189
|
OnLLMResponseEvent = enum.auto() # LLM 响应后
|
|
190
190
|
OnDecoratingResultEvent = enum.auto() # 发送消息前
|
|
191
191
|
OnCallingFuncToolEvent = enum.auto() # 调用函数工具
|
|
192
|
-
|
|
192
|
+
OnUsingLLMToolEvent = enum.auto() # 使用 LLM 工具
|
|
193
|
+
OnLLMToolRespondEvent = enum.auto() # 调用函数工具后
|
|
193
194
|
OnAfterMessageSentEvent = enum.auto() # 发送消息后
|
|
194
195
|
|
|
195
196
|
|