AstrBot 4.3.5__py3-none-any.whl → 4.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. astrbot/core/agent/runners/tool_loop_agent_runner.py +31 -2
  2. astrbot/core/astrbot_config_mgr.py +23 -51
  3. astrbot/core/config/default.py +132 -12
  4. astrbot/core/conversation_mgr.py +36 -1
  5. astrbot/core/core_lifecycle.py +24 -5
  6. astrbot/core/db/migration/helper.py +6 -3
  7. astrbot/core/db/migration/migra_45_to_46.py +44 -0
  8. astrbot/core/db/vec_db/base.py +33 -2
  9. astrbot/core/db/vec_db/faiss_impl/document_storage.py +310 -52
  10. astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +31 -3
  11. astrbot/core/db/vec_db/faiss_impl/vec_db.py +81 -23
  12. astrbot/core/file_token_service.py +6 -1
  13. astrbot/core/initial_loader.py +6 -3
  14. astrbot/core/knowledge_base/chunking/__init__.py +11 -0
  15. astrbot/core/knowledge_base/chunking/base.py +24 -0
  16. astrbot/core/knowledge_base/chunking/fixed_size.py +57 -0
  17. astrbot/core/knowledge_base/chunking/recursive.py +155 -0
  18. astrbot/core/knowledge_base/kb_db_sqlite.py +299 -0
  19. astrbot/core/knowledge_base/kb_helper.py +348 -0
  20. astrbot/core/knowledge_base/kb_mgr.py +287 -0
  21. astrbot/core/knowledge_base/models.py +114 -0
  22. astrbot/core/knowledge_base/parsers/__init__.py +15 -0
  23. astrbot/core/knowledge_base/parsers/base.py +50 -0
  24. astrbot/core/knowledge_base/parsers/markitdown_parser.py +25 -0
  25. astrbot/core/knowledge_base/parsers/pdf_parser.py +100 -0
  26. astrbot/core/knowledge_base/parsers/text_parser.py +41 -0
  27. astrbot/core/knowledge_base/parsers/util.py +13 -0
  28. astrbot/core/knowledge_base/retrieval/__init__.py +16 -0
  29. astrbot/core/knowledge_base/retrieval/hit_stopwords.txt +767 -0
  30. astrbot/core/knowledge_base/retrieval/manager.py +273 -0
  31. astrbot/core/knowledge_base/retrieval/rank_fusion.py +138 -0
  32. astrbot/core/knowledge_base/retrieval/sparse_retriever.py +130 -0
  33. astrbot/core/pipeline/process_stage/method/llm_request.py +29 -7
  34. astrbot/core/pipeline/process_stage/utils.py +80 -0
  35. astrbot/core/platform/astr_message_event.py +8 -7
  36. astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +5 -2
  37. astrbot/core/platform/sources/misskey/misskey_adapter.py +380 -44
  38. astrbot/core/platform/sources/misskey/misskey_api.py +581 -45
  39. astrbot/core/platform/sources/misskey/misskey_event.py +76 -41
  40. astrbot/core/platform/sources/misskey/misskey_utils.py +254 -43
  41. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +2 -1
  42. astrbot/core/platform/sources/satori/satori_adapter.py +27 -1
  43. astrbot/core/platform/sources/satori/satori_event.py +270 -99
  44. astrbot/core/provider/manager.py +22 -9
  45. astrbot/core/provider/provider.py +67 -0
  46. astrbot/core/provider/sources/anthropic_source.py +4 -4
  47. astrbot/core/provider/sources/dashscope_source.py +10 -9
  48. astrbot/core/provider/sources/dify_source.py +6 -8
  49. astrbot/core/provider/sources/gemini_embedding_source.py +1 -2
  50. astrbot/core/provider/sources/openai_embedding_source.py +1 -2
  51. astrbot/core/provider/sources/openai_source.py +43 -15
  52. astrbot/core/provider/sources/openai_tts_api_source.py +1 -1
  53. astrbot/core/provider/sources/xinference_rerank_source.py +108 -0
  54. astrbot/core/provider/sources/xinference_stt_provider.py +187 -0
  55. astrbot/core/star/context.py +19 -13
  56. astrbot/core/star/star.py +6 -0
  57. astrbot/core/star/star_manager.py +13 -7
  58. astrbot/core/umop_config_router.py +81 -0
  59. astrbot/core/updator.py +1 -1
  60. astrbot/core/utils/io.py +23 -12
  61. astrbot/dashboard/routes/__init__.py +2 -0
  62. astrbot/dashboard/routes/config.py +137 -9
  63. astrbot/dashboard/routes/knowledge_base.py +1065 -0
  64. astrbot/dashboard/routes/plugin.py +24 -5
  65. astrbot/dashboard/routes/update.py +1 -1
  66. astrbot/dashboard/server.py +6 -0
  67. astrbot/dashboard/utils.py +161 -0
  68. {astrbot-4.3.5.dist-info → astrbot-4.5.1.dist-info}/METADATA +30 -13
  69. {astrbot-4.3.5.dist-info → astrbot-4.5.1.dist-info}/RECORD +72 -46
  70. {astrbot-4.3.5.dist-info → astrbot-4.5.1.dist-info}/WHEEL +0 -0
  71. {astrbot-4.3.5.dist-info → astrbot-4.5.1.dist-info}/entry_points.txt +0 -0
  72. {astrbot-4.3.5.dist-info → astrbot-4.5.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,9 +1,7 @@
1
1
  import astrbot.core.message.components as Comp
2
2
  import os
3
- from typing import List
4
3
  from .. import Provider
5
4
  from ..entities import LLMResponse
6
- from ..func_tool_manager import FuncCall
7
5
  from ..register import register_provider_adapter
8
6
  from astrbot.core.utils.dify_api_client import DifyAPIClient
9
7
  from astrbot.core.utils.io import download_image_by_url, download_file
@@ -55,11 +53,11 @@ class ProviderDify(Provider):
55
53
  async def text_chat(
56
54
  self,
57
55
  prompt: str,
58
- session_id: str = None,
59
- image_urls: List[str] = None,
60
- func_tool: FuncCall = None,
61
- contexts: List = None,
62
- system_prompt: str = None,
56
+ session_id=None,
57
+ image_urls=None,
58
+ func_tool=None,
59
+ contexts=None,
60
+ system_prompt=None,
63
61
  tool_calls_result=None,
64
62
  model=None,
65
63
  **kwargs,
@@ -223,7 +221,7 @@ class ProviderDify(Provider):
223
221
  # Chat
224
222
  return MessageChain(chain=[Comp.Plain(chunk)])
225
223
 
226
- async def parse_file(item: dict) -> Comp:
224
+ async def parse_file(item: dict):
227
225
  match item["type"]:
228
226
  case "image":
229
227
  return Comp.Image(file=item["url"], url=item["url"])
@@ -32,7 +32,6 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
32
32
  self.model = provider_config.get(
33
33
  "embedding_model", "gemini-embedding-exp-03-07"
34
34
  )
35
- self.dimension = provider_config.get("embedding_dimensions", 768)
36
35
 
37
36
  async def get_embedding(self, text: str) -> list[float]:
38
37
  """
@@ -60,4 +59,4 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
60
59
 
61
60
  def get_dim(self) -> int:
62
61
  """获取向量的维度"""
63
- return self.dimension
62
+ return self.provider_config.get("embedding_dimensions", 768)
@@ -22,7 +22,6 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
22
22
  timeout=int(provider_config.get("timeout", 20)),
23
23
  )
24
24
  self.model = provider_config.get("embedding_model", "text-embedding-3-small")
25
- self.dimension = provider_config.get("embedding_dimensions", 1024)
26
25
 
27
26
  async def get_embedding(self, text: str) -> list[float]:
28
27
  """
@@ -40,4 +39,4 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
40
39
 
41
40
  def get_dim(self) -> int:
42
41
  """获取向量的维度"""
43
- return self.dimension
42
+ return self.provider_config.get("embedding_dimensions", 1024)
@@ -16,7 +16,7 @@ from astrbot.core.message.message_event_result import MessageChain
16
16
 
17
17
  from astrbot.api.provider import Provider
18
18
  from astrbot import logger
19
- from astrbot.core.provider.func_tool_manager import FuncCall
19
+ from astrbot.core.provider.func_tool_manager import ToolSet
20
20
  from typing import List, AsyncGenerator
21
21
  from ..register import register_provider_adapter
22
22
  from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
@@ -49,7 +49,7 @@ class ProviderOpenAIOfficial(Provider):
49
49
  self.client = AsyncAzureOpenAI(
50
50
  api_key=self.chosen_api_key,
51
51
  api_version=provider_config.get("api_version", None),
52
- base_url=provider_config.get("api_base", None),
52
+ base_url=provider_config.get("api_base", ""),
53
53
  timeout=self.timeout,
54
54
  )
55
55
  else:
@@ -68,6 +68,28 @@ class ProviderOpenAIOfficial(Provider):
68
68
  model = model_config.get("model", "unknown")
69
69
  self.set_model(model)
70
70
 
71
+ def _maybe_inject_xai_search(self, payloads: dict, **kwargs):
72
+ """当开启 xAI 原生搜索时,向请求体注入 Live Search 参数。
73
+
74
+ - 仅在 provider_config.xai_native_search 为 True 时生效
75
+ - 默认注入 {"mode": "auto"}
76
+ - 允许通过 kwargs 使用 xai_search_mode 覆盖(on/auto/off)
77
+ """
78
+ if not bool(self.provider_config.get("xai_native_search", False)):
79
+ return
80
+
81
+ mode = kwargs.get("xai_search_mode", "auto")
82
+ mode = str(mode).lower()
83
+ if mode not in ("auto", "on", "off"):
84
+ mode = "auto"
85
+
86
+ # off 时不注入,保持与未开启一致
87
+ if mode == "off":
88
+ return
89
+
90
+ # OpenAI SDK 不识别的字段会在 _query/_query_stream 中放入 extra_body
91
+ payloads["search_parameters"] = {"mode": mode}
92
+
71
93
  async def get_models(self):
72
94
  try:
73
95
  models_str = []
@@ -79,7 +101,7 @@ class ProviderOpenAIOfficial(Provider):
79
101
  except NotFoundError as e:
80
102
  raise Exception(f"获取模型列表失败:{e}")
81
103
 
82
- async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
104
+ async def _query(self, payloads: dict, tools: ToolSet) -> LLMResponse:
83
105
  if tools:
84
106
  model = payloads.get("model", "").lower()
85
107
  omit_empty_param_field = "gemini" in model
@@ -126,7 +148,7 @@ class ProviderOpenAIOfficial(Provider):
126
148
  return llm_response
127
149
 
128
150
  async def _query_stream(
129
- self, payloads: dict, tools: FuncCall
151
+ self, payloads: dict, tools: ToolSet
130
152
  ) -> AsyncGenerator[LLMResponse, None]:
131
153
  """流式查询API,逐步返回结果"""
132
154
  if tools:
@@ -183,9 +205,7 @@ class ProviderOpenAIOfficial(Provider):
183
205
 
184
206
  yield llm_response
185
207
 
186
- async def parse_openai_completion(
187
- self, completion: ChatCompletion, tools: FuncCall
188
- ):
208
+ async def parse_openai_completion(self, completion: ChatCompletion, tools: ToolSet):
189
209
  """解析 OpenAI 的 ChatCompletion 响应"""
190
210
  llm_response = LLMResponse("assistant")
191
211
 
@@ -208,7 +228,10 @@ class ProviderOpenAIOfficial(Provider):
208
228
  # workaround for #1359
209
229
  tool_call = json.loads(tool_call)
210
230
  for tool in tools.func_list:
211
- if tool.name == tool_call.function.name:
231
+ if (
232
+ tool_call.type == "function"
233
+ and tool.name == tool_call.function.name
234
+ ):
212
235
  # workaround for #1454
213
236
  if isinstance(tool_call.function.arguments, str):
214
237
  args = json.loads(tool_call.function.arguments)
@@ -270,6 +293,9 @@ class ProviderOpenAIOfficial(Provider):
270
293
 
271
294
  payloads = {"messages": context_query, **model_config}
272
295
 
296
+ # xAI 原生搜索参数(最小侵入地在此处注入)
297
+ self._maybe_inject_xai_search(payloads, **kwargs)
298
+
273
299
  return payloads, context_query
274
300
 
275
301
  async def _handle_api_error(
@@ -277,7 +303,7 @@ class ProviderOpenAIOfficial(Provider):
277
303
  e: Exception,
278
304
  payloads: dict,
279
305
  context_query: list,
280
- func_tool: FuncCall,
306
+ func_tool: ToolSet,
281
307
  chosen_key: str,
282
308
  available_api_keys: List[str],
283
309
  retry_cnt: int,
@@ -420,7 +446,7 @@ class ProviderOpenAIOfficial(Provider):
420
446
  if success:
421
447
  break
422
448
 
423
- if retry_cnt == max_retries - 1:
449
+ if retry_cnt == max_retries - 1 or llm_response is None:
424
450
  logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。")
425
451
  if last_exception is None:
426
452
  raise Exception("未知错误")
@@ -430,10 +456,10 @@ class ProviderOpenAIOfficial(Provider):
430
456
  async def text_chat_stream(
431
457
  self,
432
458
  prompt: str,
433
- session_id: str = None,
434
- image_urls: List[str] = [],
435
- func_tool: FuncCall = None,
436
- contexts=[],
459
+ session_id=None,
460
+ image_urls=None,
461
+ func_tool=None,
462
+ contexts=None,
437
463
  system_prompt=None,
438
464
  tool_calls_result=None,
439
465
  model=None,
@@ -526,7 +552,9 @@ class ProviderOpenAIOfficial(Provider):
526
552
  def set_key(self, key):
527
553
  self.client.api_key = key
528
554
 
529
- async def assemble_context(self, text: str, image_urls: List[str] = None) -> dict:
555
+ async def assemble_context(
556
+ self, text: str, image_urls: List[str] | None = None
557
+ ) -> dict:
530
558
  """组装成符合 OpenAI 格式的 role 为 user 的消息段"""
531
559
  if image_urls:
532
560
  user_content = {
@@ -30,7 +30,7 @@ class ProviderOpenAITTSAPI(TTSProvider):
30
30
  timeout=timeout,
31
31
  )
32
32
 
33
- self.set_model(provider_config.get("model", None))
33
+ self.set_model(provider_config.get("model", ""))
34
34
 
35
35
  async def get_audio(self, text: str) -> str:
36
36
  temp_dir = os.path.join(get_astrbot_data_path(), "temp")
@@ -0,0 +1,108 @@
1
+ from xinference_client.client.restful.async_restful_client import (
2
+ AsyncClient as Client,
3
+ )
4
+ from astrbot import logger
5
+ from ..provider import RerankProvider
6
+ from ..register import register_provider_adapter
7
+ from ..entities import ProviderType, RerankResult
8
+
9
+
10
+ @register_provider_adapter(
11
+ "xinference_rerank",
12
+ "Xinference Rerank 适配器",
13
+ provider_type=ProviderType.RERANK,
14
+ )
15
+ class XinferenceRerankProvider(RerankProvider):
16
+ def __init__(self, provider_config: dict, provider_settings: dict) -> None:
17
+ super().__init__(provider_config, provider_settings)
18
+ self.provider_config = provider_config
19
+ self.provider_settings = provider_settings
20
+ self.base_url = provider_config.get("rerank_api_base", "http://127.0.0.1:8000")
21
+ self.base_url = self.base_url.rstrip("/")
22
+ self.timeout = provider_config.get("timeout", 20)
23
+ self.model_name = provider_config.get("rerank_model", "BAAI/bge-reranker-base")
24
+ self.api_key = provider_config.get("rerank_api_key")
25
+ self.launch_model_if_not_running = provider_config.get(
26
+ "launch_model_if_not_running", False
27
+ )
28
+ self.client = None
29
+ self.model = None
30
+ self.model_uid = None
31
+
32
+ async def initialize(self):
33
+ if self.api_key:
34
+ logger.info("Xinference Rerank: Using API key for authentication.")
35
+ self.client = Client(self.base_url, api_key=self.api_key)
36
+ else:
37
+ logger.info("Xinference Rerank: No API key provided.")
38
+ self.client = Client(self.base_url)
39
+
40
+ try:
41
+ running_models = await self.client.list_models()
42
+ for uid, model_spec in running_models.items():
43
+ if model_spec.get("model_name") == self.model_name:
44
+ logger.info(
45
+ f"Model '{self.model_name}' is already running with UID: {uid}"
46
+ )
47
+ self.model_uid = uid
48
+ break
49
+
50
+ if self.model_uid is None:
51
+ if self.launch_model_if_not_running:
52
+ logger.info(f"Launching {self.model_name} model...")
53
+ self.model_uid = await self.client.launch_model(
54
+ model_name=self.model_name, model_type="rerank"
55
+ )
56
+ logger.info("Model launched.")
57
+ else:
58
+ logger.warning(
59
+ f"Model '{self.model_name}' is not running and auto-launch is disabled. Provider will not be available."
60
+ )
61
+ return
62
+
63
+ if self.model_uid:
64
+ self.model = await self.client.get_model(self.model_uid)
65
+
66
+ except Exception as e:
67
+ logger.error(f"Failed to initialize Xinference model: {e}")
68
+ logger.debug(
69
+ f"Xinference initialization failed with exception: {e}", exc_info=True
70
+ )
71
+ self.model = None
72
+
73
+ async def rerank(
74
+ self, query: str, documents: list[str], top_n: int | None = None
75
+ ) -> list[RerankResult]:
76
+ if not self.model:
77
+ logger.error("Xinference rerank model is not initialized.")
78
+ return []
79
+ try:
80
+ response = await self.model.rerank(documents, query, top_n)
81
+ results = response.get("results", [])
82
+ logger.debug(f"Rerank API response: {response}")
83
+
84
+ if not results:
85
+ logger.warning(
86
+ f"Rerank API returned an empty list. Original response: {response}"
87
+ )
88
+
89
+ return [
90
+ RerankResult(
91
+ index=result["index"],
92
+ relevance_score=result["relevance_score"],
93
+ )
94
+ for result in results
95
+ ]
96
+ except Exception as e:
97
+ logger.error(f"Xinference rerank failed: {e}")
98
+ logger.debug(f"Xinference rerank failed with exception: {e}", exc_info=True)
99
+ return []
100
+
101
+ async def terminate(self) -> None:
102
+ """关闭客户端会话"""
103
+ if self.client:
104
+ logger.info("Closing Xinference rerank client...")
105
+ try:
106
+ await self.client.close()
107
+ except Exception as e:
108
+ logger.error(f"Failed to close Xinference client: {e}", exc_info=True)
@@ -0,0 +1,187 @@
1
+ import uuid
2
+ import os
3
+ import aiohttp
4
+ from xinference_client.client.restful.async_restful_client import (
5
+ AsyncClient as Client,
6
+ )
7
+ from ..provider import STTProvider
8
+ from ..entities import ProviderType
9
+ from ..register import register_provider_adapter
10
+ from astrbot.core import logger
11
+ from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
12
+ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
13
+
14
+
15
+ @register_provider_adapter(
16
+ "xinference_stt",
17
+ "Xinference STT",
18
+ provider_type=ProviderType.SPEECH_TO_TEXT,
19
+ )
20
+ class ProviderXinferenceSTT(STTProvider):
21
+ def __init__(self, provider_config: dict, provider_settings: dict) -> None:
22
+ super().__init__(provider_config, provider_settings)
23
+ self.provider_config = provider_config
24
+ self.provider_settings = provider_settings
25
+ self.base_url = provider_config.get("api_base", "http://127.0.0.1:9997")
26
+ self.base_url = self.base_url.rstrip("/")
27
+ self.timeout = provider_config.get("timeout", 180)
28
+ self.model_name = provider_config.get("model", "whisper-large-v3")
29
+ self.api_key = provider_config.get("api_key")
30
+ self.launch_model_if_not_running = provider_config.get(
31
+ "launch_model_if_not_running", False
32
+ )
33
+ self.client = None
34
+ self.model_uid = None
35
+
36
+ async def initialize(self):
37
+ if self.api_key:
38
+ logger.info("Xinference STT: Using API key for authentication.")
39
+ self.client = Client(self.base_url, api_key=self.api_key)
40
+ else:
41
+ logger.info("Xinference STT: No API key provided.")
42
+ self.client = Client(self.base_url)
43
+
44
+ try:
45
+ running_models = await self.client.list_models()
46
+ for uid, model_spec in running_models.items():
47
+ if model_spec.get("model_name") == self.model_name:
48
+ logger.info(
49
+ f"Model '{self.model_name}' is already running with UID: {uid}"
50
+ )
51
+ self.model_uid = uid
52
+ break
53
+
54
+ if self.model_uid is None:
55
+ if self.launch_model_if_not_running:
56
+ logger.info(f"Launching {self.model_name} model...")
57
+ self.model_uid = await self.client.launch_model(
58
+ model_name=self.model_name, model_type="audio"
59
+ )
60
+ logger.info("Model launched.")
61
+ else:
62
+ logger.warning(
63
+ f"Model '{self.model_name}' is not running and auto-launch is disabled. Provider will not be available."
64
+ )
65
+ return
66
+
67
+ except Exception as e:
68
+ logger.error(f"Failed to initialize Xinference model: {e}")
69
+ logger.debug(
70
+ f"Xinference initialization failed with exception: {e}", exc_info=True
71
+ )
72
+
73
+ async def get_text(self, audio_url: str) -> str:
74
+ if not self.model_uid or self.client is None or self.client.session is None:
75
+ logger.error("Xinference STT model is not initialized.")
76
+ return ""
77
+
78
+ audio_bytes = None
79
+ temp_files = []
80
+ is_tencent = False
81
+
82
+ try:
83
+ # 1. Get audio bytes
84
+ if audio_url.startswith("http"):
85
+ if "multimedia.nt.qq.com.cn" in audio_url:
86
+ is_tencent = True
87
+ async with aiohttp.ClientSession() as session:
88
+ async with session.get(audio_url, timeout=self.timeout) as resp:
89
+ if resp.status == 200:
90
+ audio_bytes = await resp.read()
91
+ else:
92
+ logger.error(
93
+ f"Failed to download audio from {audio_url}, status: {resp.status}"
94
+ )
95
+ return ""
96
+ else:
97
+ if os.path.exists(audio_url):
98
+ with open(audio_url, "rb") as f:
99
+ audio_bytes = f.read()
100
+ else:
101
+ logger.error(f"File not found: {audio_url}")
102
+ return ""
103
+
104
+ if not audio_bytes:
105
+ logger.error("Audio bytes are empty.")
106
+ return ""
107
+
108
+ # 2. Check for conversion
109
+ needs_conversion = False
110
+ if (
111
+ audio_url.endswith((".amr", ".silk"))
112
+ or is_tencent
113
+ or b"SILK" in audio_bytes[:8]
114
+ ):
115
+ needs_conversion = True
116
+
117
+ # 3. Perform conversion if needed
118
+ if needs_conversion:
119
+ logger.info("Audio requires conversion, using temporary files...")
120
+ temp_dir = os.path.join(get_astrbot_data_path(), "temp")
121
+ os.makedirs(temp_dir, exist_ok=True)
122
+
123
+ input_path = os.path.join(temp_dir, str(uuid.uuid4()))
124
+ output_path = os.path.join(temp_dir, str(uuid.uuid4()) + ".wav")
125
+ temp_files.extend([input_path, output_path])
126
+
127
+ with open(input_path, "wb") as f:
128
+ f.write(audio_bytes)
129
+
130
+ logger.info("Converting silk/amr file to wav ...")
131
+ await tencent_silk_to_wav(input_path, output_path)
132
+
133
+ with open(output_path, "rb") as f:
134
+ audio_bytes = f.read()
135
+
136
+ # 4. Transcribe
137
+ # 官方asyncCLient的客户端似乎实现有点问题,这里直接用aiohttp实现openai标准兼容请求,提交issue等待官方修复后再改回来
138
+ url = f"{self.base_url}/v1/audio/transcriptions"
139
+ headers = {
140
+ "accept": "application/json",
141
+ }
142
+ if self.client and self.client._headers:
143
+ headers.update(self.client._headers)
144
+
145
+ data = aiohttp.FormData()
146
+ data.add_field("model", self.model_uid)
147
+ data.add_field(
148
+ "file", audio_bytes, filename="audio.wav", content_type="audio/wav"
149
+ )
150
+
151
+ async with self.client.session.post(
152
+ url, data=data, headers=headers, timeout=self.timeout
153
+ ) as resp:
154
+ if resp.status == 200:
155
+ result = await resp.json()
156
+ text = result.get("text", "")
157
+ logger.debug(f"Xinference STT result: {text}")
158
+ return text
159
+ else:
160
+ error_text = await resp.text()
161
+ logger.error(
162
+ f"Xinference STT transcription failed with status {resp.status}: {error_text}"
163
+ )
164
+ return ""
165
+
166
+ except Exception as e:
167
+ logger.error(f"Xinference STT failed: {e}")
168
+ logger.debug(f"Xinference STT failed with exception: {e}", exc_info=True)
169
+ return ""
170
+ finally:
171
+ # 5. Cleanup
172
+ for temp_file in temp_files:
173
+ try:
174
+ if os.path.exists(temp_file):
175
+ os.remove(temp_file)
176
+ logger.debug(f"Removed temporary file: {temp_file}")
177
+ except Exception as e:
178
+ logger.error(f"Failed to remove temporary file {temp_file}: {e}")
179
+
180
+ async def terminate(self) -> None:
181
+ """关闭客户端会话"""
182
+ if self.client:
183
+ logger.info("Closing Xinference STT client...")
184
+ try:
185
+ await self.client.close()
186
+ except Exception as e:
187
+ logger.error(f"Failed to close Xinference client: {e}", exc_info=True)
@@ -1,5 +1,4 @@
1
1
  from asyncio import Queue
2
- from typing import List, Union
3
2
 
4
3
  from astrbot.core.provider.provider import (
5
4
  Provider,
@@ -11,7 +10,7 @@ from astrbot.core.provider.provider import (
11
10
  from astrbot.core.provider.entities import ProviderType
12
11
  from astrbot.core.db import BaseDatabase
13
12
  from astrbot.core.config.astrbot_config import AstrBotConfig
14
- from astrbot.core.provider.func_tool_manager import FunctionToolManager
13
+ from astrbot.core.provider.func_tool_manager import FunctionToolManager, FunctionTool
15
14
  from astrbot.core.platform.astr_message_event import MessageSesion
16
15
  from astrbot.core.message.message_event_result import MessageChain
17
16
  from astrbot.core.provider.manager import ProviderManager
@@ -19,12 +18,14 @@ from astrbot.core.platform import Platform
19
18
  from astrbot.core.platform.manager import PlatformManager
20
19
  from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
21
20
  from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
21
+ from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
22
22
  from astrbot.core.persona_mgr import PersonaManager
23
23
  from .star import star_registry, StarMetadata, star_map
24
24
  from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
25
25
  from .filter.command import CommandFilter
26
26
  from .filter.regex import RegexFilter
27
- from typing import Awaitable, Any, Callable
27
+ from typing import Any
28
+ from collections.abc import Awaitable, Callable
28
29
  from astrbot.core.conversation_mgr import ConversationManager
29
30
  from astrbot.core.star.filter.platform_adapter_type import (
30
31
  PlatformAdapterType,
@@ -41,7 +42,7 @@ class Context:
41
42
  registered_web_apis: list = []
42
43
 
43
44
  # back compatibility
44
- _register_tasks: List[Awaitable] = []
45
+ _register_tasks: list[Awaitable] = []
45
46
  _star_manager = None
46
47
 
47
48
  def __init__(
@@ -55,6 +56,7 @@ class Context:
55
56
  message_history_manager: PlatformMessageHistoryManager,
56
57
  persona_manager: PersonaManager,
57
58
  astrbot_config_mgr: AstrBotConfigManager,
59
+ knowledge_base_manager: KnowledgeBaseManager,
58
60
  ):
59
61
  self._event_queue = event_queue
60
62
  """事件队列。消息平台通过事件队列传递消息事件。"""
@@ -68,6 +70,7 @@ class Context:
68
70
  self.message_history_manager = message_history_manager
69
71
  self.persona_manager = persona_manager
70
72
  self.astrbot_config_mgr = astrbot_config_mgr
73
+ self.kb_manager = knowledge_base_manager
71
74
 
72
75
  def get_registered_star(self, star_name: str) -> StarMetadata | None:
73
76
  """根据插件名获取插件的 Metadata"""
@@ -75,7 +78,7 @@ class Context:
75
78
  if star.name == star_name:
76
79
  return star
77
80
 
78
- def get_all_stars(self) -> List[StarMetadata]:
81
+ def get_all_stars(self) -> list[StarMetadata]:
79
82
  """获取当前载入的所有插件 Metadata 的列表"""
80
83
  return star_registry
81
84
 
@@ -113,19 +116,19 @@ class Context:
113
116
  prov = self.provider_manager.inst_map.get(provider_id)
114
117
  return prov
115
118
 
116
- def get_all_providers(self) -> List[Provider]:
119
+ def get_all_providers(self) -> list[Provider]:
117
120
  """获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。"""
118
121
  return self.provider_manager.provider_insts
119
122
 
120
- def get_all_tts_providers(self) -> List[TTSProvider]:
123
+ def get_all_tts_providers(self) -> list[TTSProvider]:
121
124
  """获取所有用于 TTS 任务的 Provider。"""
122
125
  return self.provider_manager.tts_provider_insts
123
126
 
124
- def get_all_stt_providers(self) -> List[STTProvider]:
127
+ def get_all_stt_providers(self) -> list[STTProvider]:
125
128
  """获取所有用于 STT 任务的 Provider。"""
126
129
  return self.provider_manager.stt_provider_insts
127
130
 
128
- def get_all_embedding_providers(self) -> List[EmbeddingProvider]:
131
+ def get_all_embedding_providers(self) -> list[EmbeddingProvider]:
129
132
  """获取所有用于 Embedding 任务的 Provider。"""
130
133
  return self.provider_manager.embedding_provider_insts
131
134
 
@@ -193,9 +196,7 @@ class Context:
193
196
  return self._event_queue
194
197
 
195
198
  @deprecated(version="4.0.0", reason="Use get_platform_inst instead")
196
- def get_platform(
197
- self, platform_type: Union[PlatformAdapterType, str]
198
- ) -> Platform | None:
199
+ def get_platform(self, platform_type: PlatformAdapterType | str) -> Platform | None:
199
200
  """
200
201
  获取指定类型的平台适配器。
201
202
 
@@ -228,7 +229,7 @@ class Context:
228
229
  return platform
229
230
 
230
231
  async def send_message(
231
- self, session: Union[str, MessageSesion], message_chain: MessageChain
232
+ self, session: str | MessageSesion, message_chain: MessageChain
232
233
  ) -> bool:
233
234
  """
234
235
  根据 session(unified_msg_origin) 主动发送消息。
@@ -255,6 +256,11 @@ class Context:
255
256
  return True
256
257
  return False
257
258
 
259
+ def add_llm_tools(self, *tools: FunctionTool) -> None:
260
+ """添加 LLM 工具。"""
261
+ for tool in tools:
262
+ self.provider_manager.llm_tools.func_list.append(tool)
263
+
258
264
  """
259
265
  以下的方法已经不推荐使用。请从 AstrBot 文档查看更好的注册方式。
260
266
  """
astrbot/core/star/star.py CHANGED
@@ -56,6 +56,12 @@ class StarMetadata:
56
56
  star_handler_full_names: list[str] = field(default_factory=list)
57
57
  """注册的 Handler 的全名列表"""
58
58
 
59
+ display_name: str | None = None
60
+ """用于展示的插件名称"""
61
+
62
+ logo_path: str | None = None
63
+ """插件 Logo 的路径"""
64
+
59
65
  def __str__(self) -> str:
60
66
  return f"Plugin {self.name} ({self.version}) by {self.author}: {self.desc}"
61
67