AstrBot 4.12.2__py3-none-any.whl → 4.12.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. astrbot/builtin_stars/builtin_commands/commands/__init__.py +0 -2
  2. astrbot/builtin_stars/builtin_commands/commands/persona.py +68 -6
  3. astrbot/builtin_stars/builtin_commands/main.py +0 -26
  4. astrbot/cli/__init__.py +1 -1
  5. astrbot/core/astr_agent_hooks.py +5 -3
  6. astrbot/core/astr_agent_run_util.py +243 -1
  7. astrbot/core/config/default.py +30 -1
  8. astrbot/core/db/__init__.py +91 -1
  9. astrbot/core/db/po.py +42 -0
  10. astrbot/core/db/sqlite.py +230 -0
  11. astrbot/core/persona_mgr.py +154 -2
  12. astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +57 -4
  13. astrbot/core/pipeline/process_stage/utils.py +13 -1
  14. astrbot/core/pipeline/waking_check/stage.py +0 -1
  15. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +32 -14
  16. astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +61 -2
  17. astrbot/core/platform/sources/dingtalk/dingtalk_event.py +57 -11
  18. astrbot/core/platform/sources/webchat/webchat_adapter.py +1 -0
  19. astrbot/core/platform/sources/webchat/webchat_event.py +24 -0
  20. astrbot/core/provider/manager.py +7 -0
  21. astrbot/core/provider/provider.py +54 -0
  22. astrbot/core/provider/sources/gemini_embedding_source.py +1 -1
  23. astrbot/core/provider/sources/genie_tts.py +128 -0
  24. astrbot/core/provider/sources/openai_embedding_source.py +1 -1
  25. astrbot/core/star/context.py +9 -8
  26. astrbot/core/star/register/star_handler.py +2 -4
  27. astrbot/core/star/star_handler.py +2 -1
  28. astrbot/dashboard/routes/live_chat.py +423 -0
  29. astrbot/dashboard/routes/persona.py +258 -1
  30. astrbot/dashboard/server.py +2 -0
  31. {astrbot-4.12.2.dist-info → astrbot-4.12.4.dist-info}/METADATA +1 -1
  32. {astrbot-4.12.2.dist-info → astrbot-4.12.4.dist-info}/RECORD +35 -34
  33. astrbot/builtin_stars/builtin_commands/commands/tool.py +0 -31
  34. {astrbot-4.12.2.dist-info → astrbot-4.12.4.dist-info}/WHEEL +0 -0
  35. {astrbot-4.12.2.dist-info → astrbot-4.12.4.dist-info}/entry_points.txt +0 -0
  36. {astrbot-4.12.2.dist-info → astrbot-4.12.4.dist-info}/licenses/LICENSE +0 -0
@@ -39,7 +39,7 @@ class MyEventHandler(dingtalk_stream.EventHandler):
39
39
 
40
40
 
41
41
  @register_platform_adapter(
42
- "dingtalk", "钉钉机器人官方 API 适配器", support_streaming_message=False
42
+ "dingtalk", "钉钉机器人官方 API 适配器", support_streaming_message=True
43
43
  )
44
44
  class DingtalkPlatformAdapter(Platform):
45
45
  def __init__(
@@ -75,6 +75,8 @@ class DingtalkPlatformAdapter(Platform):
75
75
  )
76
76
  self.client_ = client # 用于 websockets 的 client
77
77
  self._shutdown_event: threading.Event | None = None
78
+ self.card_template_id = platform_config.get("card_template_id")
79
+ self.card_instance_id_dict = {}
78
80
 
79
81
  def _id_to_sid(self, dingtalk_id: str | None) -> str:
80
82
  if not dingtalk_id:
@@ -96,9 +98,65 @@ class DingtalkPlatformAdapter(Platform):
96
98
  name="dingtalk",
97
99
  description="钉钉机器人官方 API 适配器",
98
100
  id=cast(str, self.config.get("id")),
99
- support_streaming_message=False,
101
+ support_streaming_message=True,
100
102
  )
101
103
 
104
+ async def create_message_card(
105
+ self, message_id: str, incoming_message: dingtalk_stream.ChatbotMessage
106
+ ):
107
+ if not self.card_template_id:
108
+ return False
109
+
110
+ card_instance = dingtalk_stream.AICardReplier(self.client_, incoming_message)
111
+ card_data = {"content": ""} # Initial content empty
112
+
113
+ try:
114
+ card_instance_id = await card_instance.async_create_and_deliver_card(
115
+ self.card_template_id,
116
+ card_data,
117
+ )
118
+ self.card_instance_id_dict[message_id] = (card_instance, card_instance_id)
119
+ return True
120
+ except Exception as e:
121
+ logger.error(f"创建钉钉卡片失败: {e}")
122
+ return False
123
+
124
+ async def send_card_message(self, message_id: str, content: str, is_final: bool):
125
+ if message_id not in self.card_instance_id_dict:
126
+ return
127
+
128
+ card_instance, card_instance_id = self.card_instance_id_dict[message_id]
129
+ content_key = "content"
130
+
131
+ try:
132
+ # 钉钉卡片流式更新
133
+
134
+ await card_instance.async_streaming(
135
+ card_instance_id,
136
+ content_key=content_key,
137
+ content_value=content,
138
+ append=False,
139
+ finished=is_final,
140
+ failed=False,
141
+ )
142
+ except Exception as e:
143
+ logger.error(f"发送钉钉卡片消息失败: {e}")
144
+ # Try to report failure
145
+ try:
146
+ await card_instance.async_streaming(
147
+ card_instance_id,
148
+ content_key=content_key,
149
+ content_value=content, # Keep existing content
150
+ append=False,
151
+ finished=True,
152
+ failed=True,
153
+ )
154
+ except Exception:
155
+ pass
156
+
157
+ if is_final:
158
+ self.card_instance_id_dict.pop(message_id, None)
159
+
102
160
  async def convert_msg(
103
161
  self,
104
162
  message: dingtalk_stream.ChatbotMessage,
@@ -224,6 +282,7 @@ class DingtalkPlatformAdapter(Platform):
224
282
  platform_meta=self.meta(),
225
283
  session_id=abm.session_id,
226
284
  client=self.client,
285
+ adapter=self,
227
286
  )
228
287
 
229
288
  self._event_queue.put_nowait(event)
@@ -1,5 +1,5 @@
1
1
  import asyncio
2
- from typing import cast
2
+ from typing import Any, cast
3
3
 
4
4
  import dingtalk_stream
5
5
 
@@ -16,9 +16,11 @@ class DingtalkMessageEvent(AstrMessageEvent):
16
16
  platform_meta,
17
17
  session_id,
18
18
  client: dingtalk_stream.ChatbotHandler,
19
+ adapter: "Any" = None,
19
20
  ):
20
21
  super().__init__(message_str, message_obj, platform_meta, session_id)
21
22
  self.client = client
23
+ self.adapter = adapter
22
24
 
23
25
  async def send_with_client(
24
26
  self,
@@ -83,14 +85,58 @@ class DingtalkMessageEvent(AstrMessageEvent):
83
85
  await super().send(message)
84
86
 
85
87
  async def send_streaming(self, generator, use_fallback: bool = False):
86
- buffer = None
87
- async for chain in generator:
88
+ if not self.adapter or not self.adapter.card_template_id:
89
+ logger.warning(
90
+ f"DingTalk streaming is enabled, but 'card_template_id' is not configured for platform '{self.platform_meta.id}'. Falling back to text streaming."
91
+ )
92
+ # Fallback to default behavior (buffer and send)
93
+ buffer = None
94
+ async for chain in generator:
95
+ if not buffer:
96
+ buffer = chain
97
+ else:
98
+ buffer.chain.extend(chain.chain)
99
+ if not buffer:
100
+ return None
101
+ buffer.squash_plain()
102
+ await self.send(buffer)
103
+ return await super().send_streaming(generator, use_fallback)
104
+
105
+ # Create card
106
+ msg_id = self.message_obj.message_id
107
+ incoming_msg = self.message_obj.raw_message
108
+ created = await self.adapter.create_message_card(msg_id, incoming_msg)
109
+
110
+ if not created:
111
+ # Fallback to default behavior (buffer and send)
112
+ buffer = None
113
+ async for chain in generator:
114
+ if not buffer:
115
+ buffer = chain
116
+ else:
117
+ buffer.chain.extend(chain.chain)
88
118
  if not buffer:
89
- buffer = chain
90
- else:
91
- buffer.chain.extend(chain.chain)
92
- if not buffer:
93
- return None
94
- buffer.squash_plain()
95
- await self.send(buffer)
96
- return await super().send_streaming(generator, use_fallback)
119
+ return None
120
+ buffer.squash_plain()
121
+ await self.send(buffer)
122
+ return await super().send_streaming(generator, use_fallback)
123
+
124
+ full_content = ""
125
+ seq = 0
126
+ try:
127
+ async for chain in generator:
128
+ for segment in chain.chain:
129
+ if isinstance(segment, Comp.Plain):
130
+ full_content += segment.text
131
+
132
+ seq += 1
133
+ if seq % 2 == 0: # Update every 2 chunks to be more responsive than 8
134
+ await self.adapter.send_card_message(
135
+ msg_id, full_content, is_final=False
136
+ )
137
+
138
+ await self.adapter.send_card_message(msg_id, full_content, is_final=True)
139
+ except Exception as e:
140
+ logger.error(f"DingTalk streaming error: {e}")
141
+ # Try to ensure final state is sent or cleaned up?
142
+ await self.adapter.send_card_message(msg_id, full_content, is_final=True)
@@ -235,6 +235,7 @@ class WebChatAdapter(Platform):
235
235
  message_event.set_extra(
236
236
  "enable_streaming", payload.get("enable_streaming", True)
237
237
  )
238
+ message_event.set_extra("action_type", payload.get("action_type"))
238
239
 
239
240
  self.commit_event(message_event)
240
241
 
@@ -128,6 +128,30 @@ class WebChatMessageEvent(AstrMessageEvent):
128
128
  web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
129
129
  message_id = self.message_obj.message_id
130
130
  async for chain in generator:
131
+ # 处理音频流(Live Mode)
132
+ if chain.type == "audio_chunk":
133
+ # 音频流数据,直接发送
134
+ audio_b64 = ""
135
+ text = None
136
+
137
+ if chain.chain and isinstance(chain.chain[0], Plain):
138
+ audio_b64 = chain.chain[0].text
139
+
140
+ if len(chain.chain) > 1 and isinstance(chain.chain[1], Json):
141
+ text = chain.chain[1].data.get("text")
142
+
143
+ payload = {
144
+ "type": "audio_chunk",
145
+ "data": audio_b64,
146
+ "streaming": True,
147
+ "message_id": message_id,
148
+ }
149
+ if text:
150
+ payload["text"] = text
151
+
152
+ await web_chat_back_queue.put(payload)
153
+ continue
154
+
131
155
  # if chain.type == "break" and final_data:
132
156
  # # 分割符
133
157
  # await web_chat_back_queue.put(
@@ -322,6 +322,10 @@ class ProviderManager:
322
322
  from .sources.openai_tts_api_source import (
323
323
  ProviderOpenAITTSAPI as ProviderOpenAITTSAPI,
324
324
  )
325
+ case "genie_tts":
326
+ from .sources.genie_tts import (
327
+ GenieTTSProvider as GenieTTSProvider,
328
+ )
325
329
  case "edge_tts":
326
330
  from .sources.edge_tts_source import (
327
331
  ProviderEdgeTTS as ProviderEdgeTTS,
@@ -422,17 +426,20 @@ class ProviderManager:
422
426
  except (ImportError, ModuleNotFoundError) as e:
423
427
  logger.critical(
424
428
  f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。",
429
+ exc_info=True,
425
430
  )
426
431
  return
427
432
  except Exception as e:
428
433
  logger.critical(
429
434
  f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。未知原因",
435
+ exc_info=True,
430
436
  )
431
437
  return
432
438
 
433
439
  if provider_config["type"] not in provider_cls_map:
434
440
  logger.error(
435
441
  f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。",
442
+ exc_info=True,
436
443
  )
437
444
  return
438
445
 
@@ -221,11 +221,65 @@ class TTSProvider(AbstractProvider):
221
221
  self.provider_config = provider_config
222
222
  self.provider_settings = provider_settings
223
223
 
224
+ def support_stream(self) -> bool:
225
+ """是否支持流式 TTS
226
+
227
+ Returns:
228
+ bool: True 表示支持流式处理,False 表示不支持(默认)
229
+
230
+ Notes:
231
+ 子类可以重写此方法返回 True 来启用流式 TTS 支持
232
+ """
233
+ return False
234
+
224
235
  @abc.abstractmethod
225
236
  async def get_audio(self, text: str) -> str:
226
237
  """获取文本的音频,返回音频文件路径"""
227
238
  raise NotImplementedError
228
239
 
240
+ async def get_audio_stream(
241
+ self,
242
+ text_queue: asyncio.Queue[str | None],
243
+ audio_queue: "asyncio.Queue[bytes | tuple[str, bytes] | None]",
244
+ ) -> None:
245
+ """流式 TTS 处理方法。
246
+
247
+ 从 text_queue 中读取文本片段,将生成的音频数据(WAV 格式的 in-memory bytes)放入 audio_queue。
248
+ 当 text_queue 收到 None 时,表示文本输入结束,此时应该处理完所有剩余文本并向 audio_queue 发送 None 表示结束。
249
+
250
+ Args:
251
+ text_queue: 输入文本队列,None 表示输入结束
252
+ audio_queue: 输出音频队列(bytes 或 (text, bytes)),None 表示输出结束
253
+
254
+ Notes:
255
+ - 默认实现会将文本累积后一次性调用 get_audio 生成完整音频
256
+ - 子类可以重写此方法实现真正的流式 TTS
257
+ - 音频数据应该是 WAV 格式的 bytes
258
+ """
259
+ accumulated_text = ""
260
+
261
+ while True:
262
+ text_part = await text_queue.get()
263
+
264
+ if text_part is None:
265
+ # 输入结束,处理累积的文本
266
+ if accumulated_text:
267
+ try:
268
+ # 调用原有的 get_audio 方法获取音频文件路径
269
+ audio_path = await self.get_audio(accumulated_text)
270
+ # 读取音频文件内容
271
+ with open(audio_path, "rb") as f:
272
+ audio_data = f.read()
273
+ await audio_queue.put((accumulated_text, audio_data))
274
+ except Exception:
275
+ # 出错时也要发送 None 结束标记
276
+ pass
277
+ # 发送结束标记
278
+ await audio_queue.put(None)
279
+ break
280
+
281
+ accumulated_text += text_part
282
+
229
283
  async def test(self):
230
284
  await self.get_audio("hi")
231
285
 
@@ -68,4 +68,4 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
68
68
 
69
69
  def get_dim(self) -> int:
70
70
  """获取向量的维度"""
71
- return self.provider_config.get("embedding_dimensions", 768)
71
+ return int(self.provider_config.get("embedding_dimensions", 768))
@@ -0,0 +1,128 @@
1
+ import asyncio
2
+ import os
3
+ import uuid
4
+
5
+ from astrbot.core import logger
6
+ from astrbot.core.provider.entities import ProviderType
7
+ from astrbot.core.provider.provider import TTSProvider
8
+ from astrbot.core.provider.register import register_provider_adapter
9
+ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
10
+
11
+ try:
12
+ import genie_tts as genie # type: ignore
13
+ except ImportError:
14
+ genie = None
15
+
16
+
17
+ @register_provider_adapter(
18
+ "genie_tts",
19
+ "Genie TTS",
20
+ provider_type=ProviderType.TEXT_TO_SPEECH,
21
+ )
22
+ class GenieTTSProvider(TTSProvider):
23
+ def __init__(
24
+ self,
25
+ provider_config: dict,
26
+ provider_settings: dict,
27
+ ) -> None:
28
+ super().__init__(provider_config, provider_settings)
29
+ if not genie:
30
+ raise ImportError("Please install genie_tts first.")
31
+
32
+ self.character_name = provider_config.get("genie_character_name", "mika")
33
+ language = provider_config.get("genie_language", "Japanese")
34
+ model_dir = provider_config.get("genie_onnx_model_dir", "")
35
+ refer_audio_path = provider_config.get("genie_refer_audio_path", "")
36
+ refer_text = provider_config.get("genie_refer_text", "")
37
+
38
+ try:
39
+ genie.load_character(
40
+ character_name=self.character_name,
41
+ language=language,
42
+ onnx_model_dir=model_dir,
43
+ )
44
+ genie.set_reference_audio(
45
+ character_name=self.character_name,
46
+ audio_path=refer_audio_path,
47
+ audio_text=refer_text,
48
+ language=language,
49
+ )
50
+ except Exception as e:
51
+ raise RuntimeError(f"Failed to load character {self.character_name}: {e}")
52
+
53
+ def support_stream(self) -> bool:
54
+ return True
55
+
56
+ async def get_audio(self, text: str) -> str:
57
+ temp_dir = os.path.join(get_astrbot_data_path(), "temp")
58
+ os.makedirs(temp_dir, exist_ok=True)
59
+ filename = f"genie_tts_{uuid.uuid4()}.wav"
60
+ path = os.path.join(temp_dir, filename)
61
+
62
+ loop = asyncio.get_event_loop()
63
+
64
+ def _generate(save_path: str):
65
+ assert genie is not None
66
+ genie.tts(
67
+ character_name=self.character_name,
68
+ text=text,
69
+ save_path=save_path,
70
+ )
71
+
72
+ try:
73
+ await loop.run_in_executor(None, _generate, path)
74
+
75
+ if os.path.exists(path):
76
+ return path
77
+
78
+ raise RuntimeError("Genie TTS did not save to file.")
79
+
80
+ except Exception as e:
81
+ raise RuntimeError(f"Genie TTS generation failed: {e}")
82
+
83
+ async def get_audio_stream(
84
+ self,
85
+ text_queue: asyncio.Queue[str | None],
86
+ audio_queue: "asyncio.Queue[bytes | tuple[str, bytes] | None]",
87
+ ) -> None:
88
+ loop = asyncio.get_event_loop()
89
+
90
+ while True:
91
+ text = await text_queue.get()
92
+ if text is None:
93
+ await audio_queue.put(None)
94
+ break
95
+
96
+ try:
97
+ temp_dir = os.path.join(get_astrbot_data_path(), "temp")
98
+ os.makedirs(temp_dir, exist_ok=True)
99
+ filename = f"genie_tts_{uuid.uuid4()}.wav"
100
+ path = os.path.join(temp_dir, filename)
101
+
102
+ def _generate(save_path: str, t: str):
103
+ assert genie is not None
104
+ genie.tts(
105
+ character_name=self.character_name,
106
+ text=t,
107
+ save_path=save_path,
108
+ )
109
+
110
+ await loop.run_in_executor(None, _generate, path, text)
111
+
112
+ if os.path.exists(path):
113
+ with open(path, "rb") as f:
114
+ audio_data = f.read()
115
+
116
+ # Put (text, bytes) into queue so frontend can display text
117
+ await audio_queue.put((text, audio_data))
118
+
119
+ # Clean up
120
+ try:
121
+ os.remove(path)
122
+ except OSError:
123
+ pass
124
+ else:
125
+ logger.error(f"Genie TTS failed to generate audio for: {text}")
126
+
127
+ except Exception as e:
128
+ logger.error(f"Genie TTS stream error: {e}")
@@ -37,4 +37,4 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
37
37
 
38
38
  def get_dim(self) -> int:
39
39
  """获取向量的维度"""
40
- return self.provider_config.get("embedding_dimensions", 1024)
40
+ return int(self.provider_config.get("embedding_dimensions", 1024))
@@ -328,28 +328,29 @@ class Context:
328
328
  """获取所有用于 Embedding 任务的 Provider。"""
329
329
  return self.provider_manager.embedding_provider_insts
330
330
 
331
- def get_using_provider(self, umo: str | None = None) -> Provider:
331
+ def get_using_provider(self, umo: str | None = None) -> Provider | None:
332
332
  """获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。
333
333
 
334
334
  Args:
335
335
  umo: unified_message_origin 值,如果传入并且用户启用了提供商会话隔离,
336
- 则使用该会话偏好的提供商。
336
+ 则使用该会话偏好的对话模型(提供商)。
337
337
 
338
338
  Returns:
339
- 当前使用的文本生成提供者。
339
+ 当前使用的对话模型(提供商),如果未设置则返回 None。
340
340
 
341
341
  Raises:
342
- ValueError: 返回的提供者不是 Provider 类型。
343
-
344
- Note:
345
- 通过 /provider 指令可以切换提供者。
342
+ ValueError: 该会话来源配置的的对话模型(提供商)的类型不正确。
346
343
  """
347
344
  prov = self.provider_manager.get_using_provider(
348
345
  provider_type=ProviderType.CHAT_COMPLETION,
349
346
  umo=umo,
350
347
  )
348
+ if prov is None:
349
+ return None
351
350
  if not isinstance(prov, Provider):
352
- raise ValueError("返回的 Provider 不是 Provider 类型")
351
+ raise ValueError(
352
+ f"该会话来源的对话模型(提供商)的类型不正确: {type(prov)}"
353
+ )
353
354
  return prov
354
355
 
355
356
  def get_using_tts_provider(self, umo: str | None = None) -> TTSProvider | None:
@@ -427,7 +427,7 @@ def register_on_using_llm_tool(**kwargs):
427
427
  """
428
428
 
429
429
  def decorator(awaitable):
430
- _ = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent, **kwargs)
430
+ _ = get_handler_or_create(awaitable, EventType.OnUsingLLMToolEvent, **kwargs)
431
431
  return awaitable
432
432
 
433
433
  return decorator
@@ -452,9 +452,7 @@ def register_on_llm_tool_respond(**kwargs):
452
452
  """
453
453
 
454
454
  def decorator(awaitable):
455
- _ = get_handler_or_create(
456
- awaitable, EventType.OnAfterCallingFuncToolEvent, **kwargs
457
- )
455
+ _ = get_handler_or_create(awaitable, EventType.OnLLMToolRespondEvent, **kwargs)
458
456
  return awaitable
459
457
 
460
458
  return decorator
@@ -189,7 +189,8 @@ class EventType(enum.Enum):
189
189
  OnLLMResponseEvent = enum.auto() # LLM 响应后
190
190
  OnDecoratingResultEvent = enum.auto() # 发送消息前
191
191
  OnCallingFuncToolEvent = enum.auto() # 调用函数工具
192
- OnAfterCallingFuncToolEvent = enum.auto() # 调用函数工具后
192
+ OnUsingLLMToolEvent = enum.auto() # 使用 LLM 工具
193
+ OnLLMToolRespondEvent = enum.auto() # 调用函数工具后
193
194
  OnAfterMessageSentEvent = enum.auto() # 发送消息后
194
195
 
195
196