AstrBot 4.3.5__py3-none-any.whl → 4.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/core/agent/runners/tool_loop_agent_runner.py +31 -2
- astrbot/core/astrbot_config_mgr.py +23 -51
- astrbot/core/config/default.py +92 -12
- astrbot/core/conversation_mgr.py +36 -1
- astrbot/core/core_lifecycle.py +24 -5
- astrbot/core/db/migration/migra_45_to_46.py +44 -0
- astrbot/core/db/vec_db/base.py +33 -2
- astrbot/core/db/vec_db/faiss_impl/document_storage.py +310 -52
- astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +31 -3
- astrbot/core/db/vec_db/faiss_impl/vec_db.py +81 -23
- astrbot/core/file_token_service.py +6 -1
- astrbot/core/initial_loader.py +6 -3
- astrbot/core/knowledge_base/chunking/__init__.py +11 -0
- astrbot/core/knowledge_base/chunking/base.py +24 -0
- astrbot/core/knowledge_base/chunking/fixed_size.py +57 -0
- astrbot/core/knowledge_base/chunking/recursive.py +155 -0
- astrbot/core/knowledge_base/kb_db_sqlite.py +299 -0
- astrbot/core/knowledge_base/kb_helper.py +348 -0
- astrbot/core/knowledge_base/kb_mgr.py +287 -0
- astrbot/core/knowledge_base/models.py +114 -0
- astrbot/core/knowledge_base/parsers/__init__.py +15 -0
- astrbot/core/knowledge_base/parsers/base.py +50 -0
- astrbot/core/knowledge_base/parsers/markitdown_parser.py +25 -0
- astrbot/core/knowledge_base/parsers/pdf_parser.py +100 -0
- astrbot/core/knowledge_base/parsers/text_parser.py +41 -0
- astrbot/core/knowledge_base/parsers/util.py +13 -0
- astrbot/core/knowledge_base/retrieval/__init__.py +16 -0
- astrbot/core/knowledge_base/retrieval/hit_stopwords.txt +767 -0
- astrbot/core/knowledge_base/retrieval/manager.py +273 -0
- astrbot/core/knowledge_base/retrieval/rank_fusion.py +138 -0
- astrbot/core/knowledge_base/retrieval/sparse_retriever.py +130 -0
- astrbot/core/pipeline/process_stage/method/llm_request.py +29 -7
- astrbot/core/pipeline/process_stage/utils.py +80 -0
- astrbot/core/platform/astr_message_event.py +8 -7
- astrbot/core/platform/sources/misskey/misskey_adapter.py +380 -44
- astrbot/core/platform/sources/misskey/misskey_api.py +581 -45
- astrbot/core/platform/sources/misskey/misskey_event.py +76 -41
- astrbot/core/platform/sources/misskey/misskey_utils.py +254 -43
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +2 -1
- astrbot/core/platform/sources/satori/satori_adapter.py +27 -1
- astrbot/core/platform/sources/satori/satori_event.py +270 -99
- astrbot/core/provider/manager.py +14 -9
- astrbot/core/provider/provider.py +67 -0
- astrbot/core/provider/sources/anthropic_source.py +4 -4
- astrbot/core/provider/sources/dashscope_source.py +10 -9
- astrbot/core/provider/sources/dify_source.py +6 -8
- astrbot/core/provider/sources/gemini_embedding_source.py +1 -2
- astrbot/core/provider/sources/openai_embedding_source.py +1 -2
- astrbot/core/provider/sources/openai_source.py +18 -15
- astrbot/core/provider/sources/openai_tts_api_source.py +1 -1
- astrbot/core/star/context.py +3 -0
- astrbot/core/star/star.py +6 -0
- astrbot/core/star/star_manager.py +13 -7
- astrbot/core/umop_config_router.py +81 -0
- astrbot/core/updator.py +1 -1
- astrbot/core/utils/io.py +23 -12
- astrbot/dashboard/routes/__init__.py +2 -0
- astrbot/dashboard/routes/config.py +137 -9
- astrbot/dashboard/routes/knowledge_base.py +1065 -0
- astrbot/dashboard/routes/plugin.py +24 -5
- astrbot/dashboard/routes/update.py +1 -1
- astrbot/dashboard/server.py +6 -0
- astrbot/dashboard/utils.py +161 -0
- {astrbot-4.3.5.dist-info → astrbot-4.5.0.dist-info}/METADATA +29 -13
- {astrbot-4.3.5.dist-info → astrbot-4.5.0.dist-info}/RECORD +68 -44
- {astrbot-4.3.5.dist-info → astrbot-4.5.0.dist-info}/WHEEL +0 -0
- {astrbot-4.3.5.dist-info → astrbot-4.5.0.dist-info}/entry_points.txt +0 -0
- {astrbot-4.3.5.dist-info → astrbot-4.5.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -22,7 +22,6 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
|
|
22
22
|
timeout=int(provider_config.get("timeout", 20)),
|
|
23
23
|
)
|
|
24
24
|
self.model = provider_config.get("embedding_model", "text-embedding-3-small")
|
|
25
|
-
self.dimension = provider_config.get("embedding_dimensions", 1024)
|
|
26
25
|
|
|
27
26
|
async def get_embedding(self, text: str) -> list[float]:
|
|
28
27
|
"""
|
|
@@ -40,4 +39,4 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
|
|
40
39
|
|
|
41
40
|
def get_dim(self) -> int:
|
|
42
41
|
"""获取向量的维度"""
|
|
43
|
-
return self.
|
|
42
|
+
return self.provider_config.get("embedding_dimensions", 1024)
|
|
@@ -16,7 +16,7 @@ from astrbot.core.message.message_event_result import MessageChain
|
|
|
16
16
|
|
|
17
17
|
from astrbot.api.provider import Provider
|
|
18
18
|
from astrbot import logger
|
|
19
|
-
from astrbot.core.provider.func_tool_manager import
|
|
19
|
+
from astrbot.core.provider.func_tool_manager import ToolSet
|
|
20
20
|
from typing import List, AsyncGenerator
|
|
21
21
|
from ..register import register_provider_adapter
|
|
22
22
|
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
|
@@ -49,7 +49,7 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
49
49
|
self.client = AsyncAzureOpenAI(
|
|
50
50
|
api_key=self.chosen_api_key,
|
|
51
51
|
api_version=provider_config.get("api_version", None),
|
|
52
|
-
base_url=provider_config.get("api_base",
|
|
52
|
+
base_url=provider_config.get("api_base", ""),
|
|
53
53
|
timeout=self.timeout,
|
|
54
54
|
)
|
|
55
55
|
else:
|
|
@@ -79,7 +79,7 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
79
79
|
except NotFoundError as e:
|
|
80
80
|
raise Exception(f"获取模型列表失败:{e}")
|
|
81
81
|
|
|
82
|
-
async def _query(self, payloads: dict, tools:
|
|
82
|
+
async def _query(self, payloads: dict, tools: ToolSet) -> LLMResponse:
|
|
83
83
|
if tools:
|
|
84
84
|
model = payloads.get("model", "").lower()
|
|
85
85
|
omit_empty_param_field = "gemini" in model
|
|
@@ -126,7 +126,7 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
126
126
|
return llm_response
|
|
127
127
|
|
|
128
128
|
async def _query_stream(
|
|
129
|
-
self, payloads: dict, tools:
|
|
129
|
+
self, payloads: dict, tools: ToolSet
|
|
130
130
|
) -> AsyncGenerator[LLMResponse, None]:
|
|
131
131
|
"""流式查询API,逐步返回结果"""
|
|
132
132
|
if tools:
|
|
@@ -183,9 +183,7 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
183
183
|
|
|
184
184
|
yield llm_response
|
|
185
185
|
|
|
186
|
-
async def parse_openai_completion(
|
|
187
|
-
self, completion: ChatCompletion, tools: FuncCall
|
|
188
|
-
):
|
|
186
|
+
async def parse_openai_completion(self, completion: ChatCompletion, tools: ToolSet):
|
|
189
187
|
"""解析 OpenAI 的 ChatCompletion 响应"""
|
|
190
188
|
llm_response = LLMResponse("assistant")
|
|
191
189
|
|
|
@@ -208,7 +206,10 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
208
206
|
# workaround for #1359
|
|
209
207
|
tool_call = json.loads(tool_call)
|
|
210
208
|
for tool in tools.func_list:
|
|
211
|
-
if
|
|
209
|
+
if (
|
|
210
|
+
tool_call.type == "function"
|
|
211
|
+
and tool.name == tool_call.function.name
|
|
212
|
+
):
|
|
212
213
|
# workaround for #1454
|
|
213
214
|
if isinstance(tool_call.function.arguments, str):
|
|
214
215
|
args = json.loads(tool_call.function.arguments)
|
|
@@ -277,7 +278,7 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
277
278
|
e: Exception,
|
|
278
279
|
payloads: dict,
|
|
279
280
|
context_query: list,
|
|
280
|
-
func_tool:
|
|
281
|
+
func_tool: ToolSet,
|
|
281
282
|
chosen_key: str,
|
|
282
283
|
available_api_keys: List[str],
|
|
283
284
|
retry_cnt: int,
|
|
@@ -420,7 +421,7 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
420
421
|
if success:
|
|
421
422
|
break
|
|
422
423
|
|
|
423
|
-
if retry_cnt == max_retries - 1:
|
|
424
|
+
if retry_cnt == max_retries - 1 or llm_response is None:
|
|
424
425
|
logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。")
|
|
425
426
|
if last_exception is None:
|
|
426
427
|
raise Exception("未知错误")
|
|
@@ -430,10 +431,10 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
430
431
|
async def text_chat_stream(
|
|
431
432
|
self,
|
|
432
433
|
prompt: str,
|
|
433
|
-
session_id
|
|
434
|
-
image_urls
|
|
435
|
-
func_tool
|
|
436
|
-
contexts=
|
|
434
|
+
session_id=None,
|
|
435
|
+
image_urls=None,
|
|
436
|
+
func_tool=None,
|
|
437
|
+
contexts=None,
|
|
437
438
|
system_prompt=None,
|
|
438
439
|
tool_calls_result=None,
|
|
439
440
|
model=None,
|
|
@@ -526,7 +527,9 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
526
527
|
def set_key(self, key):
|
|
527
528
|
self.client.api_key = key
|
|
528
529
|
|
|
529
|
-
async def assemble_context(
|
|
530
|
+
async def assemble_context(
|
|
531
|
+
self, text: str, image_urls: List[str] | None = None
|
|
532
|
+
) -> dict:
|
|
530
533
|
"""组装成符合 OpenAI 格式的 role 为 user 的消息段"""
|
|
531
534
|
if image_urls:
|
|
532
535
|
user_content = {
|
|
@@ -30,7 +30,7 @@ class ProviderOpenAITTSAPI(TTSProvider):
|
|
|
30
30
|
timeout=timeout,
|
|
31
31
|
)
|
|
32
32
|
|
|
33
|
-
self.set_model(provider_config.get("model",
|
|
33
|
+
self.set_model(provider_config.get("model", ""))
|
|
34
34
|
|
|
35
35
|
async def get_audio(self, text: str) -> str:
|
|
36
36
|
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
astrbot/core/star/context.py
CHANGED
|
@@ -19,6 +19,7 @@ from astrbot.core.platform import Platform
|
|
|
19
19
|
from astrbot.core.platform.manager import PlatformManager
|
|
20
20
|
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
|
|
21
21
|
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
|
22
|
+
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
|
22
23
|
from astrbot.core.persona_mgr import PersonaManager
|
|
23
24
|
from .star import star_registry, StarMetadata, star_map
|
|
24
25
|
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
|
|
@@ -55,6 +56,7 @@ class Context:
|
|
|
55
56
|
message_history_manager: PlatformMessageHistoryManager,
|
|
56
57
|
persona_manager: PersonaManager,
|
|
57
58
|
astrbot_config_mgr: AstrBotConfigManager,
|
|
59
|
+
knowledge_base_manager: KnowledgeBaseManager,
|
|
58
60
|
):
|
|
59
61
|
self._event_queue = event_queue
|
|
60
62
|
"""事件队列。消息平台通过事件队列传递消息事件。"""
|
|
@@ -68,6 +70,7 @@ class Context:
|
|
|
68
70
|
self.message_history_manager = message_history_manager
|
|
69
71
|
self.persona_manager = persona_manager
|
|
70
72
|
self.astrbot_config_mgr = astrbot_config_mgr
|
|
73
|
+
self.kb_manager = knowledge_base_manager
|
|
71
74
|
|
|
72
75
|
def get_registered_star(self, star_name: str) -> StarMetadata | None:
|
|
73
76
|
"""根据插件名获取插件的 Metadata"""
|
astrbot/core/star/star.py
CHANGED
|
@@ -56,6 +56,12 @@ class StarMetadata:
|
|
|
56
56
|
star_handler_full_names: list[str] = field(default_factory=list)
|
|
57
57
|
"""注册的 Handler 的全名列表"""
|
|
58
58
|
|
|
59
|
+
display_name: str | None = None
|
|
60
|
+
"""用于展示的插件名称"""
|
|
61
|
+
|
|
62
|
+
logo_path: str | None = None
|
|
63
|
+
"""插件 Logo 的路径"""
|
|
64
|
+
|
|
59
65
|
def __str__(self) -> str:
|
|
60
66
|
return f"Plugin {self.name} ({self.version}) by {self.author}: {self.desc}"
|
|
61
67
|
|
|
@@ -57,6 +57,7 @@ class PluginManager:
|
|
|
57
57
|
)
|
|
58
58
|
"""保留插件的路径。在 packages 目录下"""
|
|
59
59
|
self.conf_schema_fname = "_conf_schema.json"
|
|
60
|
+
self.logo_fname = "logo.png"
|
|
60
61
|
"""插件配置 Schema 文件名"""
|
|
61
62
|
self._pm_lock = asyncio.Lock()
|
|
62
63
|
"""StarManager操作互斥锁"""
|
|
@@ -200,7 +201,7 @@ class PluginManager:
|
|
|
200
201
|
|
|
201
202
|
if os.path.exists(os.path.join(plugin_path, "metadata.yaml")):
|
|
202
203
|
with open(
|
|
203
|
-
os.path.join(plugin_path, "metadata.yaml"),
|
|
204
|
+
os.path.join(plugin_path, "metadata.yaml"), encoding="utf-8"
|
|
204
205
|
) as f:
|
|
205
206
|
metadata = yaml.safe_load(f)
|
|
206
207
|
elif plugin_obj and hasattr(plugin_obj, "info"):
|
|
@@ -226,6 +227,7 @@ class PluginManager:
|
|
|
226
227
|
desc=metadata["desc"],
|
|
227
228
|
version=metadata["version"],
|
|
228
229
|
repo=metadata["repo"] if "repo" in metadata else None,
|
|
230
|
+
display_name=metadata.get("display_name", None),
|
|
229
231
|
)
|
|
230
232
|
|
|
231
233
|
return metadata
|
|
@@ -407,13 +409,14 @@ class PluginManager:
|
|
|
407
409
|
)
|
|
408
410
|
if os.path.exists(plugin_schema_path):
|
|
409
411
|
# 加载插件配置
|
|
410
|
-
with open(plugin_schema_path,
|
|
412
|
+
with open(plugin_schema_path, encoding="utf-8") as f:
|
|
411
413
|
plugin_config = AstrBotConfig(
|
|
412
414
|
config_path=os.path.join(
|
|
413
415
|
self.plugin_config_path, f"{root_dir_name}_config.json"
|
|
414
416
|
),
|
|
415
417
|
schema=json.loads(f.read()),
|
|
416
418
|
)
|
|
419
|
+
logo_path = os.path.join(plugin_dir_path, self.logo_fname)
|
|
417
420
|
|
|
418
421
|
if path in star_map:
|
|
419
422
|
# 通过 __init__subclass__ 注册插件
|
|
@@ -430,6 +433,7 @@ class PluginManager:
|
|
|
430
433
|
metadata.desc = metadata_yaml.desc
|
|
431
434
|
metadata.version = metadata_yaml.version
|
|
432
435
|
metadata.repo = metadata_yaml.repo
|
|
436
|
+
metadata.display_name = metadata_yaml.display_name
|
|
433
437
|
except Exception as e:
|
|
434
438
|
logger.warning(
|
|
435
439
|
f"插件 {root_dir_name} 元数据载入失败: {str(e)}。使用默认元数据。"
|
|
@@ -540,9 +544,11 @@ class PluginManager:
|
|
|
540
544
|
if metadata.module_path in inactivated_plugins:
|
|
541
545
|
metadata.activated = False
|
|
542
546
|
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
547
|
+
# Plugin logo path
|
|
548
|
+
if os.path.exists(logo_path):
|
|
549
|
+
metadata.logo_path = logo_path
|
|
550
|
+
|
|
551
|
+
assert metadata.module_path, f"插件 {metadata.name} 模块路径为空"
|
|
546
552
|
|
|
547
553
|
full_names = []
|
|
548
554
|
for handler in star_handlers_registry.get_handlers_by_module_name(
|
|
@@ -642,7 +648,7 @@ class PluginManager:
|
|
|
642
648
|
|
|
643
649
|
if os.path.exists(readme_path):
|
|
644
650
|
try:
|
|
645
|
-
with open(readme_path,
|
|
651
|
+
with open(readme_path, encoding="utf-8") as f:
|
|
646
652
|
readme_content = f.read()
|
|
647
653
|
except Exception as e:
|
|
648
654
|
logger.warning(
|
|
@@ -857,7 +863,7 @@ class PluginManager:
|
|
|
857
863
|
|
|
858
864
|
if os.path.exists(readme_path):
|
|
859
865
|
try:
|
|
860
|
-
with open(readme_path,
|
|
866
|
+
with open(readme_path, encoding="utf-8") as f:
|
|
861
867
|
readme_content = f.read()
|
|
862
868
|
except Exception as e:
|
|
863
869
|
logger.warning(f"读取插件 {dir_name} 的 README.md 文件失败: {str(e)}")
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
from astrbot.core.utils.shared_preferences import SharedPreferences
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class UmopConfigRouter:
|
|
5
|
+
"""UMOP 配置路由器"""
|
|
6
|
+
|
|
7
|
+
def __init__(self, sp: SharedPreferences):
|
|
8
|
+
self.umop_to_conf_id: dict[str, str] = {}
|
|
9
|
+
"""UMOP 到配置文件 ID 的映射"""
|
|
10
|
+
self.sp = sp
|
|
11
|
+
|
|
12
|
+
self._load_routing_table()
|
|
13
|
+
|
|
14
|
+
def _load_routing_table(self):
|
|
15
|
+
"""加载路由表"""
|
|
16
|
+
# 从 SharedPreferences 中加载 umop_to_conf_id 映射
|
|
17
|
+
sp_data = self.sp.get(
|
|
18
|
+
"umop_config_routing", {}, scope="global", scope_id="global"
|
|
19
|
+
)
|
|
20
|
+
self.umop_to_conf_id = sp_data
|
|
21
|
+
|
|
22
|
+
def _is_umo_match(self, p1: str, p2: str) -> bool:
|
|
23
|
+
"""判断 p2 umo 是否逻辑包含于 p1 umo"""
|
|
24
|
+
p1_ls = p1.split(":")
|
|
25
|
+
p2_ls = p2.split(":")
|
|
26
|
+
|
|
27
|
+
if len(p1_ls) != 3 or len(p2_ls) != 3:
|
|
28
|
+
return False # 非法格式
|
|
29
|
+
|
|
30
|
+
return all(p == "" or p == "*" or p == t for p, t in zip(p1_ls, p2_ls))
|
|
31
|
+
|
|
32
|
+
def get_conf_id_for_umop(self, umo: str) -> str | None:
|
|
33
|
+
"""根据 UMO 获取对应的配置文件 ID
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
umo (str): UMO 字符串
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
str | None: 配置文件 ID,如果没有找到则返回 None
|
|
40
|
+
"""
|
|
41
|
+
for pattern, conf_id in self.umop_to_conf_id.items():
|
|
42
|
+
if self._is_umo_match(pattern, umo):
|
|
43
|
+
return conf_id
|
|
44
|
+
return None
|
|
45
|
+
|
|
46
|
+
async def update_routing_data(self, new_routing: dict[str, str]):
|
|
47
|
+
"""更新路由表
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
new_routing (dict[str, str]): 新的 UMOP 到配置文件 ID 的映射。umo 由三个部分组成 [platform_id]:[message_type]:[session_id]。
|
|
51
|
+
umop 可以是 "::" (代表所有), 可以是 "[platform_id]::" (代表指定平台下的所有类型消息和会话)。
|
|
52
|
+
|
|
53
|
+
Raises:
|
|
54
|
+
ValueError: 如果 new_routing 中的 key 格式不正确
|
|
55
|
+
"""
|
|
56
|
+
for part in new_routing.keys():
|
|
57
|
+
if not isinstance(part, str) or len(part.split(":")) != 3:
|
|
58
|
+
raise ValueError(
|
|
59
|
+
"umop keys must be strings in the format [platform_id]:[message_type]:[session_id], with optional wildcards * or empty for all"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
self.umop_to_conf_id = new_routing
|
|
63
|
+
await self.sp.global_put("umop_config_routing", self.umop_to_conf_id)
|
|
64
|
+
|
|
65
|
+
async def update_route(self, umo: str, conf_id: str):
|
|
66
|
+
"""更新一条路由
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
umo (str): UMO 字符串
|
|
70
|
+
conf_id (str): 配置文件 ID
|
|
71
|
+
|
|
72
|
+
Raises:
|
|
73
|
+
ValueError: 如果 umo 格式不正确
|
|
74
|
+
"""
|
|
75
|
+
if not isinstance(umo, str) or len(umo.split(":")) != 3:
|
|
76
|
+
raise ValueError(
|
|
77
|
+
"umop must be a string in the format [platform_id]:[message_type]:[session_id], with optional wildcards * or empty for all"
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
self.umop_to_conf_id[umo] = conf_id
|
|
81
|
+
await self.sp.global_put("umop_config_routing", self.umop_to_conf_id)
|
astrbot/core/updator.py
CHANGED
|
@@ -99,7 +99,7 @@ class AstrBotUpdator(RepoZipUpdator):
|
|
|
99
99
|
else:
|
|
100
100
|
if len(str(version)) != 40:
|
|
101
101
|
raise Exception("commit hash 长度不正确,应为 40")
|
|
102
|
-
file_url = f"https://github.com/
|
|
102
|
+
file_url = f"https://github.com/AstrBotDevs/AstrBot/archive/{version}.zip"
|
|
103
103
|
logger.info(f"准备更新至指定版本的 AstrBot Core: {version}")
|
|
104
104
|
|
|
105
105
|
if proxy:
|
astrbot/core/utils/io.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import os
|
|
2
|
+
from pathlib import Path
|
|
2
3
|
import ssl
|
|
3
4
|
import shutil
|
|
4
5
|
import socket
|
|
@@ -12,7 +13,6 @@ import logging
|
|
|
12
13
|
|
|
13
14
|
import certifi
|
|
14
15
|
|
|
15
|
-
from typing import Union
|
|
16
16
|
|
|
17
17
|
from PIL import Image
|
|
18
18
|
from .astrbot_path import get_astrbot_data_path
|
|
@@ -52,7 +52,7 @@ def port_checker(port: int, host: str = "localhost"):
|
|
|
52
52
|
return False
|
|
53
53
|
|
|
54
54
|
|
|
55
|
-
def save_temp_img(img:
|
|
55
|
+
def save_temp_img(img: Image.Image | str) -> str:
|
|
56
56
|
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
|
57
57
|
# 获得文件创建时间,清除超过 12 小时的
|
|
58
58
|
try:
|
|
@@ -150,7 +150,11 @@ async def download_file(url: str, path: str, show_progress: bool = False):
|
|
|
150
150
|
f.write(chunk)
|
|
151
151
|
downloaded_size += len(chunk)
|
|
152
152
|
if show_progress:
|
|
153
|
-
elapsed_time =
|
|
153
|
+
elapsed_time = (
|
|
154
|
+
time.time() - start_time
|
|
155
|
+
if time.time() - start_time > 0
|
|
156
|
+
else 1
|
|
157
|
+
)
|
|
154
158
|
speed = downloaded_size / 1024 / elapsed_time # KB/s
|
|
155
159
|
print(
|
|
156
160
|
f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s",
|
|
@@ -209,7 +213,7 @@ async def get_dashboard_version():
|
|
|
209
213
|
if os.path.exists(dist_dir):
|
|
210
214
|
version_file = os.path.join(dist_dir, "assets", "version")
|
|
211
215
|
if os.path.exists(version_file):
|
|
212
|
-
with open(version_file, "
|
|
216
|
+
with open(version_file, encoding="utf-8") as f:
|
|
213
217
|
v = f.read().strip()
|
|
214
218
|
return v
|
|
215
219
|
return None
|
|
@@ -221,10 +225,13 @@ async def download_dashboard(
|
|
|
221
225
|
latest: bool = True,
|
|
222
226
|
version: str | None = None,
|
|
223
227
|
proxy: str | None = None,
|
|
224
|
-
):
|
|
228
|
+
) -> None:
|
|
225
229
|
"""下载管理面板文件"""
|
|
230
|
+
|
|
226
231
|
if path is None:
|
|
227
|
-
|
|
232
|
+
zip_path = Path(get_astrbot_data_path()).absolute() / "dashboard.zip"
|
|
233
|
+
else:
|
|
234
|
+
zip_path = Path(path).absolute()
|
|
228
235
|
|
|
229
236
|
if latest or len(str(version)) != 40:
|
|
230
237
|
ver_name = "latest" if latest else version
|
|
@@ -233,20 +240,24 @@ async def download_dashboard(
|
|
|
233
240
|
f"准备下载指定发行版本的 AstrBot WebUI 文件: {dashboard_release_url}"
|
|
234
241
|
)
|
|
235
242
|
try:
|
|
236
|
-
await download_file(
|
|
243
|
+
await download_file(
|
|
244
|
+
dashboard_release_url, str(zip_path), show_progress=True
|
|
245
|
+
)
|
|
237
246
|
except BaseException as _:
|
|
238
247
|
if latest:
|
|
239
|
-
dashboard_release_url = "https://github.com/
|
|
248
|
+
dashboard_release_url = "https://github.com/AstrBotDevs/AstrBot/releases/latest/download/dist.zip"
|
|
240
249
|
else:
|
|
241
|
-
dashboard_release_url = f"https://github.com/
|
|
250
|
+
dashboard_release_url = f"https://github.com/AstrBotDevs/AstrBot/releases/download/{version}/dist.zip"
|
|
242
251
|
if proxy:
|
|
243
252
|
dashboard_release_url = f"{proxy}/{dashboard_release_url}"
|
|
244
|
-
await download_file(
|
|
253
|
+
await download_file(
|
|
254
|
+
dashboard_release_url, str(zip_path), show_progress=True
|
|
255
|
+
)
|
|
245
256
|
else:
|
|
246
257
|
url = f"https://github.com/AstrBotDevs/astrbot-release-harbour/releases/download/release-{version}/dist.zip"
|
|
247
258
|
logger.info(f"准备下载指定版本的 AstrBot WebUI: {url}")
|
|
248
259
|
if proxy:
|
|
249
260
|
url = f"{proxy}/{url}"
|
|
250
|
-
await download_file(url,
|
|
251
|
-
with zipfile.ZipFile(
|
|
261
|
+
await download_file(url, str(zip_path), show_progress=True)
|
|
262
|
+
with zipfile.ZipFile(zip_path, "r") as z:
|
|
252
263
|
z.extractall(extract_path)
|
|
@@ -11,6 +11,7 @@ from .conversation import ConversationRoute
|
|
|
11
11
|
from .file import FileRoute
|
|
12
12
|
from .session_management import SessionManagementRoute
|
|
13
13
|
from .persona import PersonaRoute
|
|
14
|
+
from .knowledge_base import KnowledgeBaseRoute
|
|
14
15
|
|
|
15
16
|
__all__ = [
|
|
16
17
|
"AuthRoute",
|
|
@@ -26,4 +27,5 @@ __all__ = [
|
|
|
26
27
|
"FileRoute",
|
|
27
28
|
"SessionManagementRoute",
|
|
28
29
|
"PersonaRoute",
|
|
30
|
+
"KnowledgeBaseRoute",
|
|
29
31
|
]
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import typing
|
|
2
1
|
import traceback
|
|
3
2
|
import os
|
|
4
3
|
import inspect
|
|
@@ -6,6 +5,7 @@ from .route import Route, Response, RouteContext
|
|
|
6
5
|
from astrbot.core.provider.entities import ProviderType
|
|
7
6
|
from quart import request
|
|
8
7
|
from astrbot.core.config.default import (
|
|
8
|
+
DEFAULT_CONFIG,
|
|
9
9
|
CONFIG_METADATA_2,
|
|
10
10
|
DEFAULT_VALUE_MAP,
|
|
11
11
|
CONFIG_METADATA_3,
|
|
@@ -44,9 +44,7 @@ def try_cast(value: str, type_: str):
|
|
|
44
44
|
return None
|
|
45
45
|
|
|
46
46
|
|
|
47
|
-
def validate_config(
|
|
48
|
-
data, schema: dict, is_core: bool
|
|
49
|
-
) -> typing.Tuple[typing.List[str], typing.Dict]:
|
|
47
|
+
def validate_config(data, schema: dict, is_core: bool) -> tuple[list[str], dict]:
|
|
50
48
|
errors = []
|
|
51
49
|
|
|
52
50
|
def validate(data: dict, metadata: dict = schema, path=""):
|
|
@@ -152,13 +150,19 @@ class ConfigRoute(Route):
|
|
|
152
150
|
self.config: AstrBotConfig = core_lifecycle.astrbot_config
|
|
153
151
|
self._logo_token_cache = {} # 缓存logo token,避免重复注册
|
|
154
152
|
self.acm = core_lifecycle.astrbot_config_mgr
|
|
153
|
+
self.ucr = core_lifecycle.umop_config_router
|
|
155
154
|
self.routes = {
|
|
156
155
|
"/config/abconf/new": ("POST", self.create_abconf),
|
|
157
156
|
"/config/abconf": ("GET", self.get_abconf),
|
|
158
157
|
"/config/abconfs": ("GET", self.get_abconf_list),
|
|
159
158
|
"/config/abconf/delete": ("POST", self.delete_abconf),
|
|
160
159
|
"/config/abconf/update": ("POST", self.update_abconf),
|
|
160
|
+
"/config/umo_abconf_routes": ("GET", self.get_uc_table),
|
|
161
|
+
"/config/umo_abconf_route/update_all": ("POST", self.update_ucr_all),
|
|
162
|
+
"/config/umo_abconf_route/update": ("POST", self.update_ucr),
|
|
163
|
+
"/config/umo_abconf_route/delete": ("POST", self.delete_ucr),
|
|
161
164
|
"/config/get": ("GET", self.get_configs),
|
|
165
|
+
"/config/default": ("GET", self.get_default_config),
|
|
162
166
|
"/config/astrbot/update": ("POST", self.post_astrbot_configs),
|
|
163
167
|
"/config/plugin/update": ("POST", self.post_plugin_configs),
|
|
164
168
|
"/config/platform/new": ("POST", self.post_new_platform),
|
|
@@ -171,9 +175,79 @@ class ConfigRoute(Route):
|
|
|
171
175
|
"/config/provider/check_one": ("GET", self.check_one_provider_status),
|
|
172
176
|
"/config/provider/list": ("GET", self.get_provider_config_list),
|
|
173
177
|
"/config/provider/model_list": ("GET", self.get_provider_model_list),
|
|
178
|
+
"/config/provider/get_embedding_dim": ("POST", self.get_embedding_dim),
|
|
174
179
|
}
|
|
175
180
|
self.register_routes()
|
|
176
181
|
|
|
182
|
+
async def get_uc_table(self):
|
|
183
|
+
"""获取 UMOP 配置路由表"""
|
|
184
|
+
return Response().ok({"routing": self.ucr.umop_to_conf_id}).__dict__
|
|
185
|
+
|
|
186
|
+
async def update_ucr_all(self):
|
|
187
|
+
"""更新 UMOP 配置路由表的全部内容"""
|
|
188
|
+
post_data = await request.json
|
|
189
|
+
if not post_data:
|
|
190
|
+
return Response().error("缺少配置数据").__dict__
|
|
191
|
+
|
|
192
|
+
new_routing = post_data.get("routing", None)
|
|
193
|
+
|
|
194
|
+
if not new_routing or not isinstance(new_routing, dict):
|
|
195
|
+
return Response().error("缺少或错误的路由表数据").__dict__
|
|
196
|
+
|
|
197
|
+
try:
|
|
198
|
+
await self.ucr.update_routing_data(new_routing)
|
|
199
|
+
return Response().ok(message="更新成功").__dict__
|
|
200
|
+
except Exception as e:
|
|
201
|
+
logger.error(traceback.format_exc())
|
|
202
|
+
return Response().error(f"更新路由表失败: {str(e)}").__dict__
|
|
203
|
+
|
|
204
|
+
async def update_ucr(self):
|
|
205
|
+
"""更新 UMOP 配置路由表"""
|
|
206
|
+
post_data = await request.json
|
|
207
|
+
if not post_data:
|
|
208
|
+
return Response().error("缺少配置数据").__dict__
|
|
209
|
+
|
|
210
|
+
umo = post_data.get("umo", None)
|
|
211
|
+
conf_id = post_data.get("conf_id", None)
|
|
212
|
+
|
|
213
|
+
if not umo or not conf_id:
|
|
214
|
+
return Response().error("缺少 UMO 或配置文件 ID").__dict__
|
|
215
|
+
|
|
216
|
+
try:
|
|
217
|
+
await self.ucr.update_route(umo, conf_id)
|
|
218
|
+
return Response().ok(message="更新成功").__dict__
|
|
219
|
+
except Exception as e:
|
|
220
|
+
logger.error(traceback.format_exc())
|
|
221
|
+
return Response().error(f"更新路由表失败: {str(e)}").__dict__
|
|
222
|
+
|
|
223
|
+
async def delete_ucr(self):
|
|
224
|
+
"""删除 UMOP 配置路由表中的一项"""
|
|
225
|
+
post_data = await request.json
|
|
226
|
+
if not post_data:
|
|
227
|
+
return Response().error("缺少配置数据").__dict__
|
|
228
|
+
|
|
229
|
+
umo = post_data.get("umo", None)
|
|
230
|
+
|
|
231
|
+
if not umo:
|
|
232
|
+
return Response().error("缺少 UMO").__dict__
|
|
233
|
+
|
|
234
|
+
try:
|
|
235
|
+
if umo in self.ucr.umop_to_conf_id:
|
|
236
|
+
del self.ucr.umop_to_conf_id[umo]
|
|
237
|
+
await self.ucr.update_routing_data(self.ucr.umop_to_conf_id)
|
|
238
|
+
return Response().ok(message="删除成功").__dict__
|
|
239
|
+
except Exception as e:
|
|
240
|
+
logger.error(traceback.format_exc())
|
|
241
|
+
return Response().error(f"删除路由表项失败: {str(e)}").__dict__
|
|
242
|
+
|
|
243
|
+
async def get_default_config(self):
|
|
244
|
+
"""获取默认配置文件"""
|
|
245
|
+
return (
|
|
246
|
+
Response()
|
|
247
|
+
.ok({"config": DEFAULT_CONFIG, "metadata": CONFIG_METADATA_3})
|
|
248
|
+
.__dict__
|
|
249
|
+
)
|
|
250
|
+
|
|
177
251
|
async def get_abconf_list(self):
|
|
178
252
|
"""获取所有 AstrBot 配置文件的列表"""
|
|
179
253
|
abconf_list = self.acm.get_conf_list()
|
|
@@ -184,11 +258,11 @@ class ConfigRoute(Route):
|
|
|
184
258
|
post_data = await request.json
|
|
185
259
|
if not post_data:
|
|
186
260
|
return Response().error("缺少配置数据").__dict__
|
|
187
|
-
umo_parts = post_data["umo_parts"]
|
|
188
261
|
name = post_data.get("name", None)
|
|
262
|
+
config = post_data.get("config", DEFAULT_CONFIG)
|
|
189
263
|
|
|
190
264
|
try:
|
|
191
|
-
conf_id = self.acm.create_conf(
|
|
265
|
+
conf_id = self.acm.create_conf(name=name, config=config)
|
|
192
266
|
return Response().ok(message="创建成功", data={"conf_id": conf_id}).__dict__
|
|
193
267
|
except ValueError as e:
|
|
194
268
|
return Response().error(str(e)).__dict__
|
|
@@ -250,10 +324,9 @@ class ConfigRoute(Route):
|
|
|
250
324
|
return Response().error("缺少配置文件 ID").__dict__
|
|
251
325
|
|
|
252
326
|
name = post_data.get("name")
|
|
253
|
-
umo_parts = post_data.get("umo_parts")
|
|
254
327
|
|
|
255
328
|
try:
|
|
256
|
-
success = self.acm.update_conf_info(conf_id, name=name
|
|
329
|
+
success = self.acm.update_conf_info(conf_id, name=name)
|
|
257
330
|
if success:
|
|
258
331
|
return Response().ok(message="更新成功").__dict__
|
|
259
332
|
else:
|
|
@@ -526,6 +599,61 @@ class ConfigRoute(Route):
|
|
|
526
599
|
logger.error(traceback.format_exc())
|
|
527
600
|
return Response().error(str(e)).__dict__
|
|
528
601
|
|
|
602
|
+
async def get_embedding_dim(self):
|
|
603
|
+
"""获取嵌入模型的维度"""
|
|
604
|
+
post_data = await request.json
|
|
605
|
+
provider_config = post_data.get("provider_config", None)
|
|
606
|
+
if not provider_config:
|
|
607
|
+
return Response().error("缺少参数 provider_config").__dict__
|
|
608
|
+
|
|
609
|
+
try:
|
|
610
|
+
# 动态导入 EmbeddingProvider
|
|
611
|
+
from astrbot.core.provider.provider import EmbeddingProvider
|
|
612
|
+
from astrbot.core.provider.register import provider_cls_map
|
|
613
|
+
|
|
614
|
+
# 获取 provider 类型
|
|
615
|
+
provider_type = provider_config.get("type", None)
|
|
616
|
+
if not provider_type:
|
|
617
|
+
return Response().error("provider_config 缺少 type 字段").__dict__
|
|
618
|
+
|
|
619
|
+
# 获取对应的 provider 类
|
|
620
|
+
if provider_type not in provider_cls_map:
|
|
621
|
+
return (
|
|
622
|
+
Response()
|
|
623
|
+
.error(f"未找到适用于 {provider_type} 的提供商适配器")
|
|
624
|
+
.__dict__
|
|
625
|
+
)
|
|
626
|
+
|
|
627
|
+
provider_metadata = provider_cls_map[provider_type]
|
|
628
|
+
cls_type = provider_metadata.cls_type
|
|
629
|
+
|
|
630
|
+
if not cls_type:
|
|
631
|
+
return Response().error(f"无法找到 {provider_type} 的类").__dict__
|
|
632
|
+
|
|
633
|
+
# 实例化 provider
|
|
634
|
+
inst = cls_type(provider_config, {})
|
|
635
|
+
|
|
636
|
+
# 检查是否是 EmbeddingProvider
|
|
637
|
+
if not isinstance(inst, EmbeddingProvider):
|
|
638
|
+
return Response().error("提供商不是 EmbeddingProvider 类型").__dict__
|
|
639
|
+
|
|
640
|
+
# 初始化
|
|
641
|
+
if getattr(inst, "initialize", None):
|
|
642
|
+
await inst.initialize()
|
|
643
|
+
|
|
644
|
+
# 获取嵌入向量维度
|
|
645
|
+
vec = await inst.get_embedding("echo")
|
|
646
|
+
dim = len(vec)
|
|
647
|
+
|
|
648
|
+
logger.info(
|
|
649
|
+
f"检测到 {provider_config.get('id', 'unknown')} 的嵌入向量维度为 {dim}"
|
|
650
|
+
)
|
|
651
|
+
|
|
652
|
+
return Response().ok({"embedding_dimensions": dim}).__dict__
|
|
653
|
+
except Exception as e:
|
|
654
|
+
logger.error(traceback.format_exc())
|
|
655
|
+
return Response().error(f"获取嵌入维度失败: {str(e)}").__dict__
|
|
656
|
+
|
|
529
657
|
async def get_platform_list(self):
|
|
530
658
|
"""获取所有平台的列表"""
|
|
531
659
|
platform_list = []
|
|
@@ -722,7 +850,7 @@ class ConfigRoute(Route):
|
|
|
722
850
|
logger.warning(
|
|
723
851
|
f"Failed to import required modules for platform {platform.name}: {e}"
|
|
724
852
|
)
|
|
725
|
-
except
|
|
853
|
+
except OSError as e:
|
|
726
854
|
logger.warning(f"File system error for platform {platform.name} logo: {e}")
|
|
727
855
|
except Exception as e:
|
|
728
856
|
logger.warning(
|