AstrBot 4.3.5__py3-none-any.whl → 4.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/core/agent/runners/tool_loop_agent_runner.py +31 -2
- astrbot/core/astrbot_config_mgr.py +23 -51
- astrbot/core/config/default.py +132 -12
- astrbot/core/conversation_mgr.py +36 -1
- astrbot/core/core_lifecycle.py +24 -5
- astrbot/core/db/migration/helper.py +6 -3
- astrbot/core/db/migration/migra_45_to_46.py +44 -0
- astrbot/core/db/vec_db/base.py +33 -2
- astrbot/core/db/vec_db/faiss_impl/document_storage.py +310 -52
- astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +31 -3
- astrbot/core/db/vec_db/faiss_impl/vec_db.py +81 -23
- astrbot/core/file_token_service.py +6 -1
- astrbot/core/initial_loader.py +6 -3
- astrbot/core/knowledge_base/chunking/__init__.py +11 -0
- astrbot/core/knowledge_base/chunking/base.py +24 -0
- astrbot/core/knowledge_base/chunking/fixed_size.py +57 -0
- astrbot/core/knowledge_base/chunking/recursive.py +155 -0
- astrbot/core/knowledge_base/kb_db_sqlite.py +299 -0
- astrbot/core/knowledge_base/kb_helper.py +348 -0
- astrbot/core/knowledge_base/kb_mgr.py +287 -0
- astrbot/core/knowledge_base/models.py +114 -0
- astrbot/core/knowledge_base/parsers/__init__.py +15 -0
- astrbot/core/knowledge_base/parsers/base.py +50 -0
- astrbot/core/knowledge_base/parsers/markitdown_parser.py +25 -0
- astrbot/core/knowledge_base/parsers/pdf_parser.py +100 -0
- astrbot/core/knowledge_base/parsers/text_parser.py +41 -0
- astrbot/core/knowledge_base/parsers/util.py +13 -0
- astrbot/core/knowledge_base/retrieval/__init__.py +16 -0
- astrbot/core/knowledge_base/retrieval/hit_stopwords.txt +767 -0
- astrbot/core/knowledge_base/retrieval/manager.py +273 -0
- astrbot/core/knowledge_base/retrieval/rank_fusion.py +138 -0
- astrbot/core/knowledge_base/retrieval/sparse_retriever.py +130 -0
- astrbot/core/pipeline/process_stage/method/llm_request.py +29 -7
- astrbot/core/pipeline/process_stage/utils.py +80 -0
- astrbot/core/platform/astr_message_event.py +8 -7
- astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +5 -2
- astrbot/core/platform/sources/misskey/misskey_adapter.py +380 -44
- astrbot/core/platform/sources/misskey/misskey_api.py +581 -45
- astrbot/core/platform/sources/misskey/misskey_event.py +76 -41
- astrbot/core/platform/sources/misskey/misskey_utils.py +254 -43
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +2 -1
- astrbot/core/platform/sources/satori/satori_adapter.py +27 -1
- astrbot/core/platform/sources/satori/satori_event.py +270 -99
- astrbot/core/provider/manager.py +22 -9
- astrbot/core/provider/provider.py +67 -0
- astrbot/core/provider/sources/anthropic_source.py +4 -4
- astrbot/core/provider/sources/dashscope_source.py +10 -9
- astrbot/core/provider/sources/dify_source.py +6 -8
- astrbot/core/provider/sources/gemini_embedding_source.py +1 -2
- astrbot/core/provider/sources/openai_embedding_source.py +1 -2
- astrbot/core/provider/sources/openai_source.py +43 -15
- astrbot/core/provider/sources/openai_tts_api_source.py +1 -1
- astrbot/core/provider/sources/xinference_rerank_source.py +108 -0
- astrbot/core/provider/sources/xinference_stt_provider.py +187 -0
- astrbot/core/star/context.py +19 -13
- astrbot/core/star/star.py +6 -0
- astrbot/core/star/star_manager.py +13 -7
- astrbot/core/umop_config_router.py +81 -0
- astrbot/core/updator.py +1 -1
- astrbot/core/utils/io.py +23 -12
- astrbot/dashboard/routes/__init__.py +2 -0
- astrbot/dashboard/routes/config.py +137 -9
- astrbot/dashboard/routes/knowledge_base.py +1065 -0
- astrbot/dashboard/routes/plugin.py +24 -5
- astrbot/dashboard/routes/update.py +1 -1
- astrbot/dashboard/server.py +6 -0
- astrbot/dashboard/utils.py +161 -0
- {astrbot-4.3.5.dist-info → astrbot-4.5.1.dist-info}/METADATA +30 -13
- {astrbot-4.3.5.dist-info → astrbot-4.5.1.dist-info}/RECORD +72 -46
- {astrbot-4.3.5.dist-info → astrbot-4.5.1.dist-info}/WHEEL +0 -0
- {astrbot-4.3.5.dist-info → astrbot-4.5.1.dist-info}/entry_points.txt +0 -0
- {astrbot-4.3.5.dist-info → astrbot-4.5.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,9 +1,7 @@
|
|
|
1
1
|
import astrbot.core.message.components as Comp
|
|
2
2
|
import os
|
|
3
|
-
from typing import List
|
|
4
3
|
from .. import Provider
|
|
5
4
|
from ..entities import LLMResponse
|
|
6
|
-
from ..func_tool_manager import FuncCall
|
|
7
5
|
from ..register import register_provider_adapter
|
|
8
6
|
from astrbot.core.utils.dify_api_client import DifyAPIClient
|
|
9
7
|
from astrbot.core.utils.io import download_image_by_url, download_file
|
|
@@ -55,11 +53,11 @@ class ProviderDify(Provider):
|
|
|
55
53
|
async def text_chat(
|
|
56
54
|
self,
|
|
57
55
|
prompt: str,
|
|
58
|
-
session_id
|
|
59
|
-
image_urls
|
|
60
|
-
func_tool
|
|
61
|
-
contexts
|
|
62
|
-
system_prompt
|
|
56
|
+
session_id=None,
|
|
57
|
+
image_urls=None,
|
|
58
|
+
func_tool=None,
|
|
59
|
+
contexts=None,
|
|
60
|
+
system_prompt=None,
|
|
63
61
|
tool_calls_result=None,
|
|
64
62
|
model=None,
|
|
65
63
|
**kwargs,
|
|
@@ -223,7 +221,7 @@ class ProviderDify(Provider):
|
|
|
223
221
|
# Chat
|
|
224
222
|
return MessageChain(chain=[Comp.Plain(chunk)])
|
|
225
223
|
|
|
226
|
-
async def parse_file(item: dict)
|
|
224
|
+
async def parse_file(item: dict):
|
|
227
225
|
match item["type"]:
|
|
228
226
|
case "image":
|
|
229
227
|
return Comp.Image(file=item["url"], url=item["url"])
|
|
@@ -32,7 +32,6 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
|
|
|
32
32
|
self.model = provider_config.get(
|
|
33
33
|
"embedding_model", "gemini-embedding-exp-03-07"
|
|
34
34
|
)
|
|
35
|
-
self.dimension = provider_config.get("embedding_dimensions", 768)
|
|
36
35
|
|
|
37
36
|
async def get_embedding(self, text: str) -> list[float]:
|
|
38
37
|
"""
|
|
@@ -60,4 +59,4 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
|
|
|
60
59
|
|
|
61
60
|
def get_dim(self) -> int:
|
|
62
61
|
"""获取向量的维度"""
|
|
63
|
-
return self.
|
|
62
|
+
return self.provider_config.get("embedding_dimensions", 768)
|
|
@@ -22,7 +22,6 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
|
|
22
22
|
timeout=int(provider_config.get("timeout", 20)),
|
|
23
23
|
)
|
|
24
24
|
self.model = provider_config.get("embedding_model", "text-embedding-3-small")
|
|
25
|
-
self.dimension = provider_config.get("embedding_dimensions", 1024)
|
|
26
25
|
|
|
27
26
|
async def get_embedding(self, text: str) -> list[float]:
|
|
28
27
|
"""
|
|
@@ -40,4 +39,4 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
|
|
40
39
|
|
|
41
40
|
def get_dim(self) -> int:
|
|
42
41
|
"""获取向量的维度"""
|
|
43
|
-
return self.
|
|
42
|
+
return self.provider_config.get("embedding_dimensions", 1024)
|
|
@@ -16,7 +16,7 @@ from astrbot.core.message.message_event_result import MessageChain
|
|
|
16
16
|
|
|
17
17
|
from astrbot.api.provider import Provider
|
|
18
18
|
from astrbot import logger
|
|
19
|
-
from astrbot.core.provider.func_tool_manager import
|
|
19
|
+
from astrbot.core.provider.func_tool_manager import ToolSet
|
|
20
20
|
from typing import List, AsyncGenerator
|
|
21
21
|
from ..register import register_provider_adapter
|
|
22
22
|
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
|
@@ -49,7 +49,7 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
49
49
|
self.client = AsyncAzureOpenAI(
|
|
50
50
|
api_key=self.chosen_api_key,
|
|
51
51
|
api_version=provider_config.get("api_version", None),
|
|
52
|
-
base_url=provider_config.get("api_base",
|
|
52
|
+
base_url=provider_config.get("api_base", ""),
|
|
53
53
|
timeout=self.timeout,
|
|
54
54
|
)
|
|
55
55
|
else:
|
|
@@ -68,6 +68,28 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
68
68
|
model = model_config.get("model", "unknown")
|
|
69
69
|
self.set_model(model)
|
|
70
70
|
|
|
71
|
+
def _maybe_inject_xai_search(self, payloads: dict, **kwargs):
|
|
72
|
+
"""当开启 xAI 原生搜索时,向请求体注入 Live Search 参数。
|
|
73
|
+
|
|
74
|
+
- 仅在 provider_config.xai_native_search 为 True 时生效
|
|
75
|
+
- 默认注入 {"mode": "auto"}
|
|
76
|
+
- 允许通过 kwargs 使用 xai_search_mode 覆盖(on/auto/off)
|
|
77
|
+
"""
|
|
78
|
+
if not bool(self.provider_config.get("xai_native_search", False)):
|
|
79
|
+
return
|
|
80
|
+
|
|
81
|
+
mode = kwargs.get("xai_search_mode", "auto")
|
|
82
|
+
mode = str(mode).lower()
|
|
83
|
+
if mode not in ("auto", "on", "off"):
|
|
84
|
+
mode = "auto"
|
|
85
|
+
|
|
86
|
+
# off 时不注入,保持与未开启一致
|
|
87
|
+
if mode == "off":
|
|
88
|
+
return
|
|
89
|
+
|
|
90
|
+
# OpenAI SDK 不识别的字段会在 _query/_query_stream 中放入 extra_body
|
|
91
|
+
payloads["search_parameters"] = {"mode": mode}
|
|
92
|
+
|
|
71
93
|
async def get_models(self):
|
|
72
94
|
try:
|
|
73
95
|
models_str = []
|
|
@@ -79,7 +101,7 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
79
101
|
except NotFoundError as e:
|
|
80
102
|
raise Exception(f"获取模型列表失败:{e}")
|
|
81
103
|
|
|
82
|
-
async def _query(self, payloads: dict, tools:
|
|
104
|
+
async def _query(self, payloads: dict, tools: ToolSet) -> LLMResponse:
|
|
83
105
|
if tools:
|
|
84
106
|
model = payloads.get("model", "").lower()
|
|
85
107
|
omit_empty_param_field = "gemini" in model
|
|
@@ -126,7 +148,7 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
126
148
|
return llm_response
|
|
127
149
|
|
|
128
150
|
async def _query_stream(
|
|
129
|
-
self, payloads: dict, tools:
|
|
151
|
+
self, payloads: dict, tools: ToolSet
|
|
130
152
|
) -> AsyncGenerator[LLMResponse, None]:
|
|
131
153
|
"""流式查询API,逐步返回结果"""
|
|
132
154
|
if tools:
|
|
@@ -183,9 +205,7 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
183
205
|
|
|
184
206
|
yield llm_response
|
|
185
207
|
|
|
186
|
-
async def parse_openai_completion(
|
|
187
|
-
self, completion: ChatCompletion, tools: FuncCall
|
|
188
|
-
):
|
|
208
|
+
async def parse_openai_completion(self, completion: ChatCompletion, tools: ToolSet):
|
|
189
209
|
"""解析 OpenAI 的 ChatCompletion 响应"""
|
|
190
210
|
llm_response = LLMResponse("assistant")
|
|
191
211
|
|
|
@@ -208,7 +228,10 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
208
228
|
# workaround for #1359
|
|
209
229
|
tool_call = json.loads(tool_call)
|
|
210
230
|
for tool in tools.func_list:
|
|
211
|
-
if
|
|
231
|
+
if (
|
|
232
|
+
tool_call.type == "function"
|
|
233
|
+
and tool.name == tool_call.function.name
|
|
234
|
+
):
|
|
212
235
|
# workaround for #1454
|
|
213
236
|
if isinstance(tool_call.function.arguments, str):
|
|
214
237
|
args = json.loads(tool_call.function.arguments)
|
|
@@ -270,6 +293,9 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
270
293
|
|
|
271
294
|
payloads = {"messages": context_query, **model_config}
|
|
272
295
|
|
|
296
|
+
# xAI 原生搜索参数(最小侵入地在此处注入)
|
|
297
|
+
self._maybe_inject_xai_search(payloads, **kwargs)
|
|
298
|
+
|
|
273
299
|
return payloads, context_query
|
|
274
300
|
|
|
275
301
|
async def _handle_api_error(
|
|
@@ -277,7 +303,7 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
277
303
|
e: Exception,
|
|
278
304
|
payloads: dict,
|
|
279
305
|
context_query: list,
|
|
280
|
-
func_tool:
|
|
306
|
+
func_tool: ToolSet,
|
|
281
307
|
chosen_key: str,
|
|
282
308
|
available_api_keys: List[str],
|
|
283
309
|
retry_cnt: int,
|
|
@@ -420,7 +446,7 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
420
446
|
if success:
|
|
421
447
|
break
|
|
422
448
|
|
|
423
|
-
if retry_cnt == max_retries - 1:
|
|
449
|
+
if retry_cnt == max_retries - 1 or llm_response is None:
|
|
424
450
|
logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。")
|
|
425
451
|
if last_exception is None:
|
|
426
452
|
raise Exception("未知错误")
|
|
@@ -430,10 +456,10 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
430
456
|
async def text_chat_stream(
|
|
431
457
|
self,
|
|
432
458
|
prompt: str,
|
|
433
|
-
session_id
|
|
434
|
-
image_urls
|
|
435
|
-
func_tool
|
|
436
|
-
contexts=
|
|
459
|
+
session_id=None,
|
|
460
|
+
image_urls=None,
|
|
461
|
+
func_tool=None,
|
|
462
|
+
contexts=None,
|
|
437
463
|
system_prompt=None,
|
|
438
464
|
tool_calls_result=None,
|
|
439
465
|
model=None,
|
|
@@ -526,7 +552,9 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
526
552
|
def set_key(self, key):
|
|
527
553
|
self.client.api_key = key
|
|
528
554
|
|
|
529
|
-
async def assemble_context(
|
|
555
|
+
async def assemble_context(
|
|
556
|
+
self, text: str, image_urls: List[str] | None = None
|
|
557
|
+
) -> dict:
|
|
530
558
|
"""组装成符合 OpenAI 格式的 role 为 user 的消息段"""
|
|
531
559
|
if image_urls:
|
|
532
560
|
user_content = {
|
|
@@ -30,7 +30,7 @@ class ProviderOpenAITTSAPI(TTSProvider):
|
|
|
30
30
|
timeout=timeout,
|
|
31
31
|
)
|
|
32
32
|
|
|
33
|
-
self.set_model(provider_config.get("model",
|
|
33
|
+
self.set_model(provider_config.get("model", ""))
|
|
34
34
|
|
|
35
35
|
async def get_audio(self, text: str) -> str:
|
|
36
36
|
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
from xinference_client.client.restful.async_restful_client import (
|
|
2
|
+
AsyncClient as Client,
|
|
3
|
+
)
|
|
4
|
+
from astrbot import logger
|
|
5
|
+
from ..provider import RerankProvider
|
|
6
|
+
from ..register import register_provider_adapter
|
|
7
|
+
from ..entities import ProviderType, RerankResult
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@register_provider_adapter(
|
|
11
|
+
"xinference_rerank",
|
|
12
|
+
"Xinference Rerank 适配器",
|
|
13
|
+
provider_type=ProviderType.RERANK,
|
|
14
|
+
)
|
|
15
|
+
class XinferenceRerankProvider(RerankProvider):
|
|
16
|
+
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
|
17
|
+
super().__init__(provider_config, provider_settings)
|
|
18
|
+
self.provider_config = provider_config
|
|
19
|
+
self.provider_settings = provider_settings
|
|
20
|
+
self.base_url = provider_config.get("rerank_api_base", "http://127.0.0.1:8000")
|
|
21
|
+
self.base_url = self.base_url.rstrip("/")
|
|
22
|
+
self.timeout = provider_config.get("timeout", 20)
|
|
23
|
+
self.model_name = provider_config.get("rerank_model", "BAAI/bge-reranker-base")
|
|
24
|
+
self.api_key = provider_config.get("rerank_api_key")
|
|
25
|
+
self.launch_model_if_not_running = provider_config.get(
|
|
26
|
+
"launch_model_if_not_running", False
|
|
27
|
+
)
|
|
28
|
+
self.client = None
|
|
29
|
+
self.model = None
|
|
30
|
+
self.model_uid = None
|
|
31
|
+
|
|
32
|
+
async def initialize(self):
|
|
33
|
+
if self.api_key:
|
|
34
|
+
logger.info("Xinference Rerank: Using API key for authentication.")
|
|
35
|
+
self.client = Client(self.base_url, api_key=self.api_key)
|
|
36
|
+
else:
|
|
37
|
+
logger.info("Xinference Rerank: No API key provided.")
|
|
38
|
+
self.client = Client(self.base_url)
|
|
39
|
+
|
|
40
|
+
try:
|
|
41
|
+
running_models = await self.client.list_models()
|
|
42
|
+
for uid, model_spec in running_models.items():
|
|
43
|
+
if model_spec.get("model_name") == self.model_name:
|
|
44
|
+
logger.info(
|
|
45
|
+
f"Model '{self.model_name}' is already running with UID: {uid}"
|
|
46
|
+
)
|
|
47
|
+
self.model_uid = uid
|
|
48
|
+
break
|
|
49
|
+
|
|
50
|
+
if self.model_uid is None:
|
|
51
|
+
if self.launch_model_if_not_running:
|
|
52
|
+
logger.info(f"Launching {self.model_name} model...")
|
|
53
|
+
self.model_uid = await self.client.launch_model(
|
|
54
|
+
model_name=self.model_name, model_type="rerank"
|
|
55
|
+
)
|
|
56
|
+
logger.info("Model launched.")
|
|
57
|
+
else:
|
|
58
|
+
logger.warning(
|
|
59
|
+
f"Model '{self.model_name}' is not running and auto-launch is disabled. Provider will not be available."
|
|
60
|
+
)
|
|
61
|
+
return
|
|
62
|
+
|
|
63
|
+
if self.model_uid:
|
|
64
|
+
self.model = await self.client.get_model(self.model_uid)
|
|
65
|
+
|
|
66
|
+
except Exception as e:
|
|
67
|
+
logger.error(f"Failed to initialize Xinference model: {e}")
|
|
68
|
+
logger.debug(
|
|
69
|
+
f"Xinference initialization failed with exception: {e}", exc_info=True
|
|
70
|
+
)
|
|
71
|
+
self.model = None
|
|
72
|
+
|
|
73
|
+
async def rerank(
|
|
74
|
+
self, query: str, documents: list[str], top_n: int | None = None
|
|
75
|
+
) -> list[RerankResult]:
|
|
76
|
+
if not self.model:
|
|
77
|
+
logger.error("Xinference rerank model is not initialized.")
|
|
78
|
+
return []
|
|
79
|
+
try:
|
|
80
|
+
response = await self.model.rerank(documents, query, top_n)
|
|
81
|
+
results = response.get("results", [])
|
|
82
|
+
logger.debug(f"Rerank API response: {response}")
|
|
83
|
+
|
|
84
|
+
if not results:
|
|
85
|
+
logger.warning(
|
|
86
|
+
f"Rerank API returned an empty list. Original response: {response}"
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
return [
|
|
90
|
+
RerankResult(
|
|
91
|
+
index=result["index"],
|
|
92
|
+
relevance_score=result["relevance_score"],
|
|
93
|
+
)
|
|
94
|
+
for result in results
|
|
95
|
+
]
|
|
96
|
+
except Exception as e:
|
|
97
|
+
logger.error(f"Xinference rerank failed: {e}")
|
|
98
|
+
logger.debug(f"Xinference rerank failed with exception: {e}", exc_info=True)
|
|
99
|
+
return []
|
|
100
|
+
|
|
101
|
+
async def terminate(self) -> None:
|
|
102
|
+
"""关闭客户端会话"""
|
|
103
|
+
if self.client:
|
|
104
|
+
logger.info("Closing Xinference rerank client...")
|
|
105
|
+
try:
|
|
106
|
+
await self.client.close()
|
|
107
|
+
except Exception as e:
|
|
108
|
+
logger.error(f"Failed to close Xinference client: {e}", exc_info=True)
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
import uuid
|
|
2
|
+
import os
|
|
3
|
+
import aiohttp
|
|
4
|
+
from xinference_client.client.restful.async_restful_client import (
|
|
5
|
+
AsyncClient as Client,
|
|
6
|
+
)
|
|
7
|
+
from ..provider import STTProvider
|
|
8
|
+
from ..entities import ProviderType
|
|
9
|
+
from ..register import register_provider_adapter
|
|
10
|
+
from astrbot.core import logger
|
|
11
|
+
from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
|
|
12
|
+
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@register_provider_adapter(
|
|
16
|
+
"xinference_stt",
|
|
17
|
+
"Xinference STT",
|
|
18
|
+
provider_type=ProviderType.SPEECH_TO_TEXT,
|
|
19
|
+
)
|
|
20
|
+
class ProviderXinferenceSTT(STTProvider):
|
|
21
|
+
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
|
22
|
+
super().__init__(provider_config, provider_settings)
|
|
23
|
+
self.provider_config = provider_config
|
|
24
|
+
self.provider_settings = provider_settings
|
|
25
|
+
self.base_url = provider_config.get("api_base", "http://127.0.0.1:9997")
|
|
26
|
+
self.base_url = self.base_url.rstrip("/")
|
|
27
|
+
self.timeout = provider_config.get("timeout", 180)
|
|
28
|
+
self.model_name = provider_config.get("model", "whisper-large-v3")
|
|
29
|
+
self.api_key = provider_config.get("api_key")
|
|
30
|
+
self.launch_model_if_not_running = provider_config.get(
|
|
31
|
+
"launch_model_if_not_running", False
|
|
32
|
+
)
|
|
33
|
+
self.client = None
|
|
34
|
+
self.model_uid = None
|
|
35
|
+
|
|
36
|
+
async def initialize(self):
|
|
37
|
+
if self.api_key:
|
|
38
|
+
logger.info("Xinference STT: Using API key for authentication.")
|
|
39
|
+
self.client = Client(self.base_url, api_key=self.api_key)
|
|
40
|
+
else:
|
|
41
|
+
logger.info("Xinference STT: No API key provided.")
|
|
42
|
+
self.client = Client(self.base_url)
|
|
43
|
+
|
|
44
|
+
try:
|
|
45
|
+
running_models = await self.client.list_models()
|
|
46
|
+
for uid, model_spec in running_models.items():
|
|
47
|
+
if model_spec.get("model_name") == self.model_name:
|
|
48
|
+
logger.info(
|
|
49
|
+
f"Model '{self.model_name}' is already running with UID: {uid}"
|
|
50
|
+
)
|
|
51
|
+
self.model_uid = uid
|
|
52
|
+
break
|
|
53
|
+
|
|
54
|
+
if self.model_uid is None:
|
|
55
|
+
if self.launch_model_if_not_running:
|
|
56
|
+
logger.info(f"Launching {self.model_name} model...")
|
|
57
|
+
self.model_uid = await self.client.launch_model(
|
|
58
|
+
model_name=self.model_name, model_type="audio"
|
|
59
|
+
)
|
|
60
|
+
logger.info("Model launched.")
|
|
61
|
+
else:
|
|
62
|
+
logger.warning(
|
|
63
|
+
f"Model '{self.model_name}' is not running and auto-launch is disabled. Provider will not be available."
|
|
64
|
+
)
|
|
65
|
+
return
|
|
66
|
+
|
|
67
|
+
except Exception as e:
|
|
68
|
+
logger.error(f"Failed to initialize Xinference model: {e}")
|
|
69
|
+
logger.debug(
|
|
70
|
+
f"Xinference initialization failed with exception: {e}", exc_info=True
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
async def get_text(self, audio_url: str) -> str:
|
|
74
|
+
if not self.model_uid or self.client is None or self.client.session is None:
|
|
75
|
+
logger.error("Xinference STT model is not initialized.")
|
|
76
|
+
return ""
|
|
77
|
+
|
|
78
|
+
audio_bytes = None
|
|
79
|
+
temp_files = []
|
|
80
|
+
is_tencent = False
|
|
81
|
+
|
|
82
|
+
try:
|
|
83
|
+
# 1. Get audio bytes
|
|
84
|
+
if audio_url.startswith("http"):
|
|
85
|
+
if "multimedia.nt.qq.com.cn" in audio_url:
|
|
86
|
+
is_tencent = True
|
|
87
|
+
async with aiohttp.ClientSession() as session:
|
|
88
|
+
async with session.get(audio_url, timeout=self.timeout) as resp:
|
|
89
|
+
if resp.status == 200:
|
|
90
|
+
audio_bytes = await resp.read()
|
|
91
|
+
else:
|
|
92
|
+
logger.error(
|
|
93
|
+
f"Failed to download audio from {audio_url}, status: {resp.status}"
|
|
94
|
+
)
|
|
95
|
+
return ""
|
|
96
|
+
else:
|
|
97
|
+
if os.path.exists(audio_url):
|
|
98
|
+
with open(audio_url, "rb") as f:
|
|
99
|
+
audio_bytes = f.read()
|
|
100
|
+
else:
|
|
101
|
+
logger.error(f"File not found: {audio_url}")
|
|
102
|
+
return ""
|
|
103
|
+
|
|
104
|
+
if not audio_bytes:
|
|
105
|
+
logger.error("Audio bytes are empty.")
|
|
106
|
+
return ""
|
|
107
|
+
|
|
108
|
+
# 2. Check for conversion
|
|
109
|
+
needs_conversion = False
|
|
110
|
+
if (
|
|
111
|
+
audio_url.endswith((".amr", ".silk"))
|
|
112
|
+
or is_tencent
|
|
113
|
+
or b"SILK" in audio_bytes[:8]
|
|
114
|
+
):
|
|
115
|
+
needs_conversion = True
|
|
116
|
+
|
|
117
|
+
# 3. Perform conversion if needed
|
|
118
|
+
if needs_conversion:
|
|
119
|
+
logger.info("Audio requires conversion, using temporary files...")
|
|
120
|
+
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
|
121
|
+
os.makedirs(temp_dir, exist_ok=True)
|
|
122
|
+
|
|
123
|
+
input_path = os.path.join(temp_dir, str(uuid.uuid4()))
|
|
124
|
+
output_path = os.path.join(temp_dir, str(uuid.uuid4()) + ".wav")
|
|
125
|
+
temp_files.extend([input_path, output_path])
|
|
126
|
+
|
|
127
|
+
with open(input_path, "wb") as f:
|
|
128
|
+
f.write(audio_bytes)
|
|
129
|
+
|
|
130
|
+
logger.info("Converting silk/amr file to wav ...")
|
|
131
|
+
await tencent_silk_to_wav(input_path, output_path)
|
|
132
|
+
|
|
133
|
+
with open(output_path, "rb") as f:
|
|
134
|
+
audio_bytes = f.read()
|
|
135
|
+
|
|
136
|
+
# 4. Transcribe
|
|
137
|
+
# 官方asyncCLient的客户端似乎实现有点问题,这里直接用aiohttp实现openai标准兼容请求,提交issue等待官方修复后再改回来
|
|
138
|
+
url = f"{self.base_url}/v1/audio/transcriptions"
|
|
139
|
+
headers = {
|
|
140
|
+
"accept": "application/json",
|
|
141
|
+
}
|
|
142
|
+
if self.client and self.client._headers:
|
|
143
|
+
headers.update(self.client._headers)
|
|
144
|
+
|
|
145
|
+
data = aiohttp.FormData()
|
|
146
|
+
data.add_field("model", self.model_uid)
|
|
147
|
+
data.add_field(
|
|
148
|
+
"file", audio_bytes, filename="audio.wav", content_type="audio/wav"
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
async with self.client.session.post(
|
|
152
|
+
url, data=data, headers=headers, timeout=self.timeout
|
|
153
|
+
) as resp:
|
|
154
|
+
if resp.status == 200:
|
|
155
|
+
result = await resp.json()
|
|
156
|
+
text = result.get("text", "")
|
|
157
|
+
logger.debug(f"Xinference STT result: {text}")
|
|
158
|
+
return text
|
|
159
|
+
else:
|
|
160
|
+
error_text = await resp.text()
|
|
161
|
+
logger.error(
|
|
162
|
+
f"Xinference STT transcription failed with status {resp.status}: {error_text}"
|
|
163
|
+
)
|
|
164
|
+
return ""
|
|
165
|
+
|
|
166
|
+
except Exception as e:
|
|
167
|
+
logger.error(f"Xinference STT failed: {e}")
|
|
168
|
+
logger.debug(f"Xinference STT failed with exception: {e}", exc_info=True)
|
|
169
|
+
return ""
|
|
170
|
+
finally:
|
|
171
|
+
# 5. Cleanup
|
|
172
|
+
for temp_file in temp_files:
|
|
173
|
+
try:
|
|
174
|
+
if os.path.exists(temp_file):
|
|
175
|
+
os.remove(temp_file)
|
|
176
|
+
logger.debug(f"Removed temporary file: {temp_file}")
|
|
177
|
+
except Exception as e:
|
|
178
|
+
logger.error(f"Failed to remove temporary file {temp_file}: {e}")
|
|
179
|
+
|
|
180
|
+
async def terminate(self) -> None:
|
|
181
|
+
"""关闭客户端会话"""
|
|
182
|
+
if self.client:
|
|
183
|
+
logger.info("Closing Xinference STT client...")
|
|
184
|
+
try:
|
|
185
|
+
await self.client.close()
|
|
186
|
+
except Exception as e:
|
|
187
|
+
logger.error(f"Failed to close Xinference client: {e}", exc_info=True)
|
astrbot/core/star/context.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
from asyncio import Queue
|
|
2
|
-
from typing import List, Union
|
|
3
2
|
|
|
4
3
|
from astrbot.core.provider.provider import (
|
|
5
4
|
Provider,
|
|
@@ -11,7 +10,7 @@ from astrbot.core.provider.provider import (
|
|
|
11
10
|
from astrbot.core.provider.entities import ProviderType
|
|
12
11
|
from astrbot.core.db import BaseDatabase
|
|
13
12
|
from astrbot.core.config.astrbot_config import AstrBotConfig
|
|
14
|
-
from astrbot.core.provider.func_tool_manager import FunctionToolManager
|
|
13
|
+
from astrbot.core.provider.func_tool_manager import FunctionToolManager, FunctionTool
|
|
15
14
|
from astrbot.core.platform.astr_message_event import MessageSesion
|
|
16
15
|
from astrbot.core.message.message_event_result import MessageChain
|
|
17
16
|
from astrbot.core.provider.manager import ProviderManager
|
|
@@ -19,12 +18,14 @@ from astrbot.core.platform import Platform
|
|
|
19
18
|
from astrbot.core.platform.manager import PlatformManager
|
|
20
19
|
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
|
|
21
20
|
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
|
21
|
+
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
|
22
22
|
from astrbot.core.persona_mgr import PersonaManager
|
|
23
23
|
from .star import star_registry, StarMetadata, star_map
|
|
24
24
|
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
|
|
25
25
|
from .filter.command import CommandFilter
|
|
26
26
|
from .filter.regex import RegexFilter
|
|
27
|
-
from typing import
|
|
27
|
+
from typing import Any
|
|
28
|
+
from collections.abc import Awaitable, Callable
|
|
28
29
|
from astrbot.core.conversation_mgr import ConversationManager
|
|
29
30
|
from astrbot.core.star.filter.platform_adapter_type import (
|
|
30
31
|
PlatformAdapterType,
|
|
@@ -41,7 +42,7 @@ class Context:
|
|
|
41
42
|
registered_web_apis: list = []
|
|
42
43
|
|
|
43
44
|
# back compatibility
|
|
44
|
-
_register_tasks:
|
|
45
|
+
_register_tasks: list[Awaitable] = []
|
|
45
46
|
_star_manager = None
|
|
46
47
|
|
|
47
48
|
def __init__(
|
|
@@ -55,6 +56,7 @@ class Context:
|
|
|
55
56
|
message_history_manager: PlatformMessageHistoryManager,
|
|
56
57
|
persona_manager: PersonaManager,
|
|
57
58
|
astrbot_config_mgr: AstrBotConfigManager,
|
|
59
|
+
knowledge_base_manager: KnowledgeBaseManager,
|
|
58
60
|
):
|
|
59
61
|
self._event_queue = event_queue
|
|
60
62
|
"""事件队列。消息平台通过事件队列传递消息事件。"""
|
|
@@ -68,6 +70,7 @@ class Context:
|
|
|
68
70
|
self.message_history_manager = message_history_manager
|
|
69
71
|
self.persona_manager = persona_manager
|
|
70
72
|
self.astrbot_config_mgr = astrbot_config_mgr
|
|
73
|
+
self.kb_manager = knowledge_base_manager
|
|
71
74
|
|
|
72
75
|
def get_registered_star(self, star_name: str) -> StarMetadata | None:
|
|
73
76
|
"""根据插件名获取插件的 Metadata"""
|
|
@@ -75,7 +78,7 @@ class Context:
|
|
|
75
78
|
if star.name == star_name:
|
|
76
79
|
return star
|
|
77
80
|
|
|
78
|
-
def get_all_stars(self) ->
|
|
81
|
+
def get_all_stars(self) -> list[StarMetadata]:
|
|
79
82
|
"""获取当前载入的所有插件 Metadata 的列表"""
|
|
80
83
|
return star_registry
|
|
81
84
|
|
|
@@ -113,19 +116,19 @@ class Context:
|
|
|
113
116
|
prov = self.provider_manager.inst_map.get(provider_id)
|
|
114
117
|
return prov
|
|
115
118
|
|
|
116
|
-
def get_all_providers(self) ->
|
|
119
|
+
def get_all_providers(self) -> list[Provider]:
|
|
117
120
|
"""获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。"""
|
|
118
121
|
return self.provider_manager.provider_insts
|
|
119
122
|
|
|
120
|
-
def get_all_tts_providers(self) ->
|
|
123
|
+
def get_all_tts_providers(self) -> list[TTSProvider]:
|
|
121
124
|
"""获取所有用于 TTS 任务的 Provider。"""
|
|
122
125
|
return self.provider_manager.tts_provider_insts
|
|
123
126
|
|
|
124
|
-
def get_all_stt_providers(self) ->
|
|
127
|
+
def get_all_stt_providers(self) -> list[STTProvider]:
|
|
125
128
|
"""获取所有用于 STT 任务的 Provider。"""
|
|
126
129
|
return self.provider_manager.stt_provider_insts
|
|
127
130
|
|
|
128
|
-
def get_all_embedding_providers(self) ->
|
|
131
|
+
def get_all_embedding_providers(self) -> list[EmbeddingProvider]:
|
|
129
132
|
"""获取所有用于 Embedding 任务的 Provider。"""
|
|
130
133
|
return self.provider_manager.embedding_provider_insts
|
|
131
134
|
|
|
@@ -193,9 +196,7 @@ class Context:
|
|
|
193
196
|
return self._event_queue
|
|
194
197
|
|
|
195
198
|
@deprecated(version="4.0.0", reason="Use get_platform_inst instead")
|
|
196
|
-
def get_platform(
|
|
197
|
-
self, platform_type: Union[PlatformAdapterType, str]
|
|
198
|
-
) -> Platform | None:
|
|
199
|
+
def get_platform(self, platform_type: PlatformAdapterType | str) -> Platform | None:
|
|
199
200
|
"""
|
|
200
201
|
获取指定类型的平台适配器。
|
|
201
202
|
|
|
@@ -228,7 +229,7 @@ class Context:
|
|
|
228
229
|
return platform
|
|
229
230
|
|
|
230
231
|
async def send_message(
|
|
231
|
-
self, session:
|
|
232
|
+
self, session: str | MessageSesion, message_chain: MessageChain
|
|
232
233
|
) -> bool:
|
|
233
234
|
"""
|
|
234
235
|
根据 session(unified_msg_origin) 主动发送消息。
|
|
@@ -255,6 +256,11 @@ class Context:
|
|
|
255
256
|
return True
|
|
256
257
|
return False
|
|
257
258
|
|
|
259
|
+
def add_llm_tools(self, *tools: FunctionTool) -> None:
|
|
260
|
+
"""添加 LLM 工具。"""
|
|
261
|
+
for tool in tools:
|
|
262
|
+
self.provider_manager.llm_tools.func_list.append(tool)
|
|
263
|
+
|
|
258
264
|
"""
|
|
259
265
|
以下的方法已经不推荐使用。请从 AstrBot 文档查看更好的注册方式。
|
|
260
266
|
"""
|
astrbot/core/star/star.py
CHANGED
|
@@ -56,6 +56,12 @@ class StarMetadata:
|
|
|
56
56
|
star_handler_full_names: list[str] = field(default_factory=list)
|
|
57
57
|
"""注册的 Handler 的全名列表"""
|
|
58
58
|
|
|
59
|
+
display_name: str | None = None
|
|
60
|
+
"""用于展示的插件名称"""
|
|
61
|
+
|
|
62
|
+
logo_path: str | None = None
|
|
63
|
+
"""插件 Logo 的路径"""
|
|
64
|
+
|
|
59
65
|
def __str__(self) -> str:
|
|
60
66
|
return f"Plugin {self.name} ({self.version}) by {self.author}: {self.desc}"
|
|
61
67
|
|