AstrBot 3.5.6__py3-none-any.whl → 4.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/api/__init__.py +16 -4
- astrbot/api/all.py +2 -1
- astrbot/api/event/__init__.py +5 -6
- astrbot/api/event/filter/__init__.py +37 -34
- astrbot/api/platform/__init__.py +7 -8
- astrbot/api/provider/__init__.py +8 -7
- astrbot/api/star/__init__.py +3 -4
- astrbot/api/util/__init__.py +2 -2
- astrbot/cli/__init__.py +1 -0
- astrbot/cli/__main__.py +18 -197
- astrbot/cli/commands/__init__.py +6 -0
- astrbot/cli/commands/cmd_conf.py +209 -0
- astrbot/cli/commands/cmd_init.py +56 -0
- astrbot/cli/commands/cmd_plug.py +245 -0
- astrbot/cli/commands/cmd_run.py +62 -0
- astrbot/cli/utils/__init__.py +18 -0
- astrbot/cli/utils/basic.py +76 -0
- astrbot/cli/utils/plugin.py +246 -0
- astrbot/cli/utils/version_comparator.py +90 -0
- astrbot/core/__init__.py +17 -19
- astrbot/core/agent/agent.py +14 -0
- astrbot/core/agent/handoff.py +38 -0
- astrbot/core/agent/hooks.py +30 -0
- astrbot/core/agent/mcp_client.py +385 -0
- astrbot/core/agent/message.py +175 -0
- astrbot/core/agent/response.py +14 -0
- astrbot/core/agent/run_context.py +22 -0
- astrbot/core/agent/runners/__init__.py +3 -0
- astrbot/core/agent/runners/base.py +65 -0
- astrbot/core/agent/runners/coze/coze_agent_runner.py +367 -0
- astrbot/core/agent/runners/coze/coze_api_client.py +324 -0
- astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py +403 -0
- astrbot/core/agent/runners/dify/dify_agent_runner.py +336 -0
- astrbot/core/agent/runners/dify/dify_api_client.py +195 -0
- astrbot/core/agent/runners/tool_loop_agent_runner.py +400 -0
- astrbot/core/agent/tool.py +285 -0
- astrbot/core/agent/tool_executor.py +17 -0
- astrbot/core/astr_agent_context.py +19 -0
- astrbot/core/astr_agent_hooks.py +36 -0
- astrbot/core/astr_agent_run_util.py +80 -0
- astrbot/core/astr_agent_tool_exec.py +246 -0
- astrbot/core/astrbot_config_mgr.py +275 -0
- astrbot/core/config/__init__.py +2 -2
- astrbot/core/config/astrbot_config.py +60 -20
- astrbot/core/config/default.py +1972 -453
- astrbot/core/config/i18n_utils.py +110 -0
- astrbot/core/conversation_mgr.py +285 -75
- astrbot/core/core_lifecycle.py +167 -62
- astrbot/core/db/__init__.py +305 -102
- astrbot/core/db/migration/helper.py +69 -0
- astrbot/core/db/migration/migra_3_to_4.py +357 -0
- astrbot/core/db/migration/migra_45_to_46.py +44 -0
- astrbot/core/db/migration/migra_webchat_session.py +131 -0
- astrbot/core/db/migration/shared_preferences_v3.py +48 -0
- astrbot/core/db/migration/sqlite_v3.py +497 -0
- astrbot/core/db/po.py +259 -55
- astrbot/core/db/sqlite.py +773 -528
- astrbot/core/db/vec_db/base.py +73 -0
- astrbot/core/db/vec_db/faiss_impl/__init__.py +3 -0
- astrbot/core/db/vec_db/faiss_impl/document_storage.py +392 -0
- astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +93 -0
- astrbot/core/db/vec_db/faiss_impl/sqlite_init.sql +17 -0
- astrbot/core/db/vec_db/faiss_impl/vec_db.py +204 -0
- astrbot/core/event_bus.py +26 -22
- astrbot/core/exceptions.py +9 -0
- astrbot/core/file_token_service.py +98 -0
- astrbot/core/initial_loader.py +19 -10
- astrbot/core/knowledge_base/chunking/__init__.py +9 -0
- astrbot/core/knowledge_base/chunking/base.py +25 -0
- astrbot/core/knowledge_base/chunking/fixed_size.py +59 -0
- astrbot/core/knowledge_base/chunking/recursive.py +161 -0
- astrbot/core/knowledge_base/kb_db_sqlite.py +301 -0
- astrbot/core/knowledge_base/kb_helper.py +642 -0
- astrbot/core/knowledge_base/kb_mgr.py +330 -0
- astrbot/core/knowledge_base/models.py +120 -0
- astrbot/core/knowledge_base/parsers/__init__.py +13 -0
- astrbot/core/knowledge_base/parsers/base.py +51 -0
- astrbot/core/knowledge_base/parsers/markitdown_parser.py +26 -0
- astrbot/core/knowledge_base/parsers/pdf_parser.py +101 -0
- astrbot/core/knowledge_base/parsers/text_parser.py +42 -0
- astrbot/core/knowledge_base/parsers/url_parser.py +103 -0
- astrbot/core/knowledge_base/parsers/util.py +13 -0
- astrbot/core/knowledge_base/prompts.py +65 -0
- astrbot/core/knowledge_base/retrieval/__init__.py +14 -0
- astrbot/core/knowledge_base/retrieval/hit_stopwords.txt +767 -0
- astrbot/core/knowledge_base/retrieval/manager.py +276 -0
- astrbot/core/knowledge_base/retrieval/rank_fusion.py +142 -0
- astrbot/core/knowledge_base/retrieval/sparse_retriever.py +136 -0
- astrbot/core/log.py +21 -15
- astrbot/core/message/components.py +413 -287
- astrbot/core/message/message_event_result.py +35 -24
- astrbot/core/persona_mgr.py +192 -0
- astrbot/core/pipeline/__init__.py +14 -14
- astrbot/core/pipeline/content_safety_check/stage.py +13 -9
- astrbot/core/pipeline/content_safety_check/strategies/__init__.py +1 -2
- astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +13 -14
- astrbot/core/pipeline/content_safety_check/strategies/keywords.py +2 -1
- astrbot/core/pipeline/content_safety_check/strategies/strategy.py +6 -6
- astrbot/core/pipeline/context.py +7 -1
- astrbot/core/pipeline/context_utils.py +107 -0
- astrbot/core/pipeline/preprocess_stage/stage.py +63 -36
- astrbot/core/pipeline/process_stage/method/agent_request.py +48 -0
- astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +464 -0
- astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +202 -0
- astrbot/core/pipeline/process_stage/method/star_request.py +26 -32
- astrbot/core/pipeline/process_stage/stage.py +21 -15
- astrbot/core/pipeline/process_stage/utils.py +125 -0
- astrbot/core/pipeline/rate_limit_check/stage.py +34 -36
- astrbot/core/pipeline/respond/stage.py +142 -101
- astrbot/core/pipeline/result_decorate/stage.py +124 -57
- astrbot/core/pipeline/scheduler.py +21 -16
- astrbot/core/pipeline/session_status_check/stage.py +37 -0
- astrbot/core/pipeline/stage.py +11 -76
- astrbot/core/pipeline/waking_check/stage.py +69 -33
- astrbot/core/pipeline/whitelist_check/stage.py +10 -7
- astrbot/core/platform/__init__.py +6 -6
- astrbot/core/platform/astr_message_event.py +107 -129
- astrbot/core/platform/astrbot_message.py +32 -12
- astrbot/core/platform/manager.py +62 -18
- astrbot/core/platform/message_session.py +30 -0
- astrbot/core/platform/platform.py +16 -24
- astrbot/core/platform/platform_metadata.py +9 -4
- astrbot/core/platform/register.py +12 -7
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +136 -60
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +126 -46
- astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +63 -31
- astrbot/core/platform/sources/dingtalk/dingtalk_event.py +30 -26
- astrbot/core/platform/sources/discord/client.py +129 -0
- astrbot/core/platform/sources/discord/components.py +139 -0
- astrbot/core/platform/sources/discord/discord_platform_adapter.py +473 -0
- astrbot/core/platform/sources/discord/discord_platform_event.py +313 -0
- astrbot/core/platform/sources/lark/lark_adapter.py +27 -18
- astrbot/core/platform/sources/lark/lark_event.py +39 -13
- astrbot/core/platform/sources/misskey/misskey_adapter.py +770 -0
- astrbot/core/platform/sources/misskey/misskey_api.py +964 -0
- astrbot/core/platform/sources/misskey/misskey_event.py +163 -0
- astrbot/core/platform/sources/misskey/misskey_utils.py +550 -0
- astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +149 -33
- astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +41 -26
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +36 -17
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py +3 -1
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +14 -8
- astrbot/core/platform/sources/satori/satori_adapter.py +792 -0
- astrbot/core/platform/sources/satori/satori_event.py +432 -0
- astrbot/core/platform/sources/slack/client.py +164 -0
- astrbot/core/platform/sources/slack/slack_adapter.py +416 -0
- astrbot/core/platform/sources/slack/slack_event.py +253 -0
- astrbot/core/platform/sources/telegram/tg_adapter.py +100 -43
- astrbot/core/platform/sources/telegram/tg_event.py +136 -36
- astrbot/core/platform/sources/webchat/webchat_adapter.py +72 -22
- astrbot/core/platform/sources/webchat/webchat_event.py +46 -22
- astrbot/core/platform/sources/webchat/webchat_queue_mgr.py +35 -0
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +926 -0
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +178 -0
- astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py +159 -0
- astrbot/core/platform/sources/wecom/wecom_adapter.py +169 -27
- astrbot/core/platform/sources/wecom/wecom_event.py +162 -77
- astrbot/core/platform/sources/wecom/wecom_kf.py +279 -0
- astrbot/core/platform/sources/wecom/wecom_kf_message.py +196 -0
- astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +297 -0
- astrbot/core/platform/sources/wecom_ai_bot/__init__.py +15 -0
- astrbot/core/platform/sources/wecom_ai_bot/ierror.py +19 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +472 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +417 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +152 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +153 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +168 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +209 -0
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +306 -0
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +186 -0
- astrbot/core/platform_message_history_mgr.py +49 -0
- astrbot/core/provider/__init__.py +2 -3
- astrbot/core/provider/entites.py +8 -8
- astrbot/core/provider/entities.py +154 -98
- astrbot/core/provider/func_tool_manager.py +446 -458
- astrbot/core/provider/manager.py +345 -207
- astrbot/core/provider/provider.py +188 -73
- astrbot/core/provider/register.py +9 -7
- astrbot/core/provider/sources/anthropic_source.py +295 -115
- astrbot/core/provider/sources/azure_tts_source.py +224 -0
- astrbot/core/provider/sources/bailian_rerank_source.py +236 -0
- astrbot/core/provider/sources/dashscope_tts.py +138 -14
- astrbot/core/provider/sources/edge_tts_source.py +24 -19
- astrbot/core/provider/sources/fishaudio_tts_api_source.py +58 -13
- astrbot/core/provider/sources/gemini_embedding_source.py +61 -0
- astrbot/core/provider/sources/gemini_source.py +310 -132
- astrbot/core/provider/sources/gemini_tts_source.py +81 -0
- astrbot/core/provider/sources/groq_source.py +15 -0
- astrbot/core/provider/sources/gsv_selfhosted_source.py +151 -0
- astrbot/core/provider/sources/gsvi_tts_source.py +14 -7
- astrbot/core/provider/sources/minimax_tts_api_source.py +159 -0
- astrbot/core/provider/sources/openai_embedding_source.py +40 -0
- astrbot/core/provider/sources/openai_source.py +241 -145
- astrbot/core/provider/sources/openai_tts_api_source.py +18 -7
- astrbot/core/provider/sources/sensevoice_selfhosted_source.py +13 -11
- astrbot/core/provider/sources/vllm_rerank_source.py +71 -0
- astrbot/core/provider/sources/volcengine_tts.py +115 -0
- astrbot/core/provider/sources/whisper_api_source.py +18 -13
- astrbot/core/provider/sources/whisper_selfhosted_source.py +19 -12
- astrbot/core/provider/sources/xinference_rerank_source.py +116 -0
- astrbot/core/provider/sources/xinference_stt_provider.py +197 -0
- astrbot/core/provider/sources/zhipu_source.py +6 -73
- astrbot/core/star/__init__.py +43 -11
- astrbot/core/star/config.py +17 -18
- astrbot/core/star/context.py +362 -138
- astrbot/core/star/filter/__init__.py +4 -3
- astrbot/core/star/filter/command.py +111 -35
- astrbot/core/star/filter/command_group.py +46 -34
- astrbot/core/star/filter/custom_filter.py +6 -5
- astrbot/core/star/filter/event_message_type.py +4 -2
- astrbot/core/star/filter/permission.py +4 -2
- astrbot/core/star/filter/platform_adapter_type.py +45 -12
- astrbot/core/star/filter/regex.py +4 -2
- astrbot/core/star/register/__init__.py +19 -15
- astrbot/core/star/register/star.py +41 -13
- astrbot/core/star/register/star_handler.py +236 -86
- astrbot/core/star/session_llm_manager.py +280 -0
- astrbot/core/star/session_plugin_manager.py +170 -0
- astrbot/core/star/star.py +36 -43
- astrbot/core/star/star_handler.py +47 -85
- astrbot/core/star/star_manager.py +442 -260
- astrbot/core/star/star_tools.py +167 -45
- astrbot/core/star/updator.py +17 -20
- astrbot/core/umop_config_router.py +106 -0
- astrbot/core/updator.py +38 -13
- astrbot/core/utils/astrbot_path.py +39 -0
- astrbot/core/utils/command_parser.py +1 -1
- astrbot/core/utils/io.py +119 -60
- astrbot/core/utils/log_pipe.py +1 -1
- astrbot/core/utils/metrics.py +11 -10
- astrbot/core/utils/migra_helper.py +73 -0
- astrbot/core/utils/path_util.py +63 -62
- astrbot/core/utils/pip_installer.py +37 -15
- astrbot/core/utils/session_lock.py +29 -0
- astrbot/core/utils/session_waiter.py +19 -20
- astrbot/core/utils/shared_preferences.py +174 -34
- astrbot/core/utils/t2i/__init__.py +4 -1
- astrbot/core/utils/t2i/local_strategy.py +386 -238
- astrbot/core/utils/t2i/network_strategy.py +109 -49
- astrbot/core/utils/t2i/renderer.py +29 -14
- astrbot/core/utils/t2i/template/astrbot_powershell.html +184 -0
- astrbot/core/utils/t2i/template_manager.py +111 -0
- astrbot/core/utils/tencent_record_helper.py +115 -1
- astrbot/core/utils/version_comparator.py +10 -13
- astrbot/core/zip_updator.py +112 -65
- astrbot/dashboard/routes/__init__.py +20 -13
- astrbot/dashboard/routes/auth.py +20 -9
- astrbot/dashboard/routes/chat.py +297 -141
- astrbot/dashboard/routes/config.py +652 -55
- astrbot/dashboard/routes/conversation.py +107 -37
- astrbot/dashboard/routes/file.py +26 -0
- astrbot/dashboard/routes/knowledge_base.py +1244 -0
- astrbot/dashboard/routes/log.py +27 -2
- astrbot/dashboard/routes/persona.py +202 -0
- astrbot/dashboard/routes/plugin.py +197 -139
- astrbot/dashboard/routes/route.py +27 -7
- astrbot/dashboard/routes/session_management.py +354 -0
- astrbot/dashboard/routes/stat.py +85 -18
- astrbot/dashboard/routes/static_file.py +5 -2
- astrbot/dashboard/routes/t2i.py +233 -0
- astrbot/dashboard/routes/tools.py +184 -120
- astrbot/dashboard/routes/update.py +59 -36
- astrbot/dashboard/server.py +96 -36
- astrbot/dashboard/utils.py +165 -0
- astrbot-4.7.0.dist-info/METADATA +294 -0
- astrbot-4.7.0.dist-info/RECORD +274 -0
- {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/WHEEL +1 -1
- astrbot/core/db/plugin/sqlite_impl.py +0 -112
- astrbot/core/db/sqlite_init.sql +0 -50
- astrbot/core/pipeline/platform_compatibility/stage.py +0 -56
- astrbot/core/pipeline/process_stage/method/llm_request.py +0 -606
- astrbot/core/platform/sources/gewechat/client.py +0 -806
- astrbot/core/platform/sources/gewechat/downloader.py +0 -55
- astrbot/core/platform/sources/gewechat/gewechat_event.py +0 -255
- astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py +0 -103
- astrbot/core/platform/sources/gewechat/xml_data_parser.py +0 -110
- astrbot/core/provider/sources/dashscope_source.py +0 -203
- astrbot/core/provider/sources/dify_source.py +0 -281
- astrbot/core/provider/sources/llmtuner_source.py +0 -132
- astrbot/core/rag/embedding/openai_source.py +0 -20
- astrbot/core/rag/knowledge_db_mgr.py +0 -94
- astrbot/core/rag/store/__init__.py +0 -9
- astrbot/core/rag/store/chroma_db.py +0 -42
- astrbot/core/utils/dify_api_client.py +0 -152
- astrbot-3.5.6.dist-info/METADATA +0 -249
- astrbot-3.5.6.dist-info/RECORD +0 -158
- {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/entry_points.txt +0 -0
- {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,203 +0,0 @@
|
|
|
1
|
-
import re
|
|
2
|
-
import asyncio
|
|
3
|
-
import functools
|
|
4
|
-
from typing import List
|
|
5
|
-
from .. import Provider, Personality
|
|
6
|
-
from ..entities import LLMResponse
|
|
7
|
-
from ..func_tool_manager import FuncCall
|
|
8
|
-
from astrbot.core.db import BaseDatabase
|
|
9
|
-
from ..register import register_provider_adapter
|
|
10
|
-
from astrbot.core.message.message_event_result import MessageChain
|
|
11
|
-
from .openai_source import ProviderOpenAIOfficial
|
|
12
|
-
from astrbot.core import logger, sp
|
|
13
|
-
from dashscope import Application
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
@register_provider_adapter("dashscope", "Dashscope APP 适配器。")
|
|
17
|
-
class ProviderDashscope(ProviderOpenAIOfficial):
|
|
18
|
-
def __init__(
|
|
19
|
-
self,
|
|
20
|
-
provider_config: dict,
|
|
21
|
-
provider_settings: dict,
|
|
22
|
-
db_helper: BaseDatabase,
|
|
23
|
-
persistant_history=False,
|
|
24
|
-
default_persona: Personality = None,
|
|
25
|
-
) -> None:
|
|
26
|
-
Provider.__init__(
|
|
27
|
-
self,
|
|
28
|
-
provider_config,
|
|
29
|
-
provider_settings,
|
|
30
|
-
persistant_history,
|
|
31
|
-
db_helper,
|
|
32
|
-
default_persona,
|
|
33
|
-
)
|
|
34
|
-
self.api_key = provider_config.get("dashscope_api_key", "")
|
|
35
|
-
if not self.api_key:
|
|
36
|
-
raise Exception("阿里云百炼 API Key 不能为空。")
|
|
37
|
-
self.app_id = provider_config.get("dashscope_app_id", "")
|
|
38
|
-
if not self.app_id:
|
|
39
|
-
raise Exception("阿里云百炼 APP ID 不能为空。")
|
|
40
|
-
self.dashscope_app_type = provider_config.get("dashscope_app_type", "")
|
|
41
|
-
if not self.dashscope_app_type:
|
|
42
|
-
raise Exception("阿里云百炼 APP 类型不能为空。")
|
|
43
|
-
self.model_name = "dashscope"
|
|
44
|
-
self.variables: dict = provider_config.get("variables", {})
|
|
45
|
-
self.rag_options: dict = provider_config.get("rag_options", {})
|
|
46
|
-
self.output_reference = self.rag_options.get("output_reference", False)
|
|
47
|
-
self.rag_options = self.rag_options.copy()
|
|
48
|
-
self.rag_options.pop("output_reference", None)
|
|
49
|
-
|
|
50
|
-
self.timeout = provider_config.get("timeout", 120)
|
|
51
|
-
if isinstance(self.timeout, str):
|
|
52
|
-
self.timeout = int(self.timeout)
|
|
53
|
-
|
|
54
|
-
def has_rag_options(self):
|
|
55
|
-
"""判断是否有 RAG 选项
|
|
56
|
-
|
|
57
|
-
Returns:
|
|
58
|
-
bool: 是否有 RAG 选项
|
|
59
|
-
"""
|
|
60
|
-
if self.rag_options and (
|
|
61
|
-
len(self.rag_options.get("pipeline_ids", [])) > 0
|
|
62
|
-
or len(self.rag_options.get("file_ids", [])) > 0
|
|
63
|
-
):
|
|
64
|
-
return True
|
|
65
|
-
return False
|
|
66
|
-
|
|
67
|
-
async def text_chat(
|
|
68
|
-
self,
|
|
69
|
-
prompt: str,
|
|
70
|
-
session_id: str = None,
|
|
71
|
-
image_urls: List[str] = [],
|
|
72
|
-
func_tool: FuncCall = None,
|
|
73
|
-
contexts: List = None,
|
|
74
|
-
system_prompt: str = None,
|
|
75
|
-
**kwargs,
|
|
76
|
-
) -> LLMResponse:
|
|
77
|
-
# 获得会话变量
|
|
78
|
-
payload_vars = self.variables.copy()
|
|
79
|
-
# 动态变量
|
|
80
|
-
session_vars = sp.get("session_variables", {})
|
|
81
|
-
session_var = session_vars.get(session_id, {})
|
|
82
|
-
payload_vars.update(session_var)
|
|
83
|
-
|
|
84
|
-
if (
|
|
85
|
-
self.dashscope_app_type in ["agent", "dialog-workflow"]
|
|
86
|
-
and not self.has_rag_options()
|
|
87
|
-
):
|
|
88
|
-
# 支持多轮对话的
|
|
89
|
-
new_record = {"role": "user", "content": prompt}
|
|
90
|
-
if image_urls:
|
|
91
|
-
logger.warning("阿里云百炼暂不支持图片输入,将自动忽略图片内容。")
|
|
92
|
-
contexts_no_img = await self._remove_image_from_context(contexts)
|
|
93
|
-
context_query = [*contexts_no_img, new_record]
|
|
94
|
-
if system_prompt:
|
|
95
|
-
context_query.insert(0, {"role": "system", "content": system_prompt})
|
|
96
|
-
for part in context_query:
|
|
97
|
-
if "_no_save" in part:
|
|
98
|
-
del part["_no_save"]
|
|
99
|
-
# 调用阿里云百炼 API
|
|
100
|
-
payload = {
|
|
101
|
-
"app_id": self.app_id,
|
|
102
|
-
"api_key": self.api_key,
|
|
103
|
-
"messages": context_query,
|
|
104
|
-
"biz_params": payload_vars or None,
|
|
105
|
-
}
|
|
106
|
-
partial = functools.partial(
|
|
107
|
-
Application.call,
|
|
108
|
-
**payload,
|
|
109
|
-
)
|
|
110
|
-
response = await asyncio.get_event_loop().run_in_executor(None, partial)
|
|
111
|
-
else:
|
|
112
|
-
# 不支持多轮对话的
|
|
113
|
-
# 调用阿里云百炼 API
|
|
114
|
-
payload = {
|
|
115
|
-
"app_id": self.app_id,
|
|
116
|
-
"prompt": prompt,
|
|
117
|
-
"api_key": self.api_key,
|
|
118
|
-
"biz_params": payload_vars or None,
|
|
119
|
-
}
|
|
120
|
-
if self.rag_options:
|
|
121
|
-
payload["rag_options"] = self.rag_options
|
|
122
|
-
partial = functools.partial(
|
|
123
|
-
Application.call,
|
|
124
|
-
**payload,
|
|
125
|
-
)
|
|
126
|
-
response = await asyncio.get_event_loop().run_in_executor(None, partial)
|
|
127
|
-
|
|
128
|
-
logger.debug(f"dashscope resp: {response}")
|
|
129
|
-
|
|
130
|
-
if response.status_code != 200:
|
|
131
|
-
logger.error(
|
|
132
|
-
f"阿里云百炼请求失败: request_id={response.request_id}, code={response.status_code}, message={response.message}, 请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code"
|
|
133
|
-
)
|
|
134
|
-
return LLMResponse(
|
|
135
|
-
role="err",
|
|
136
|
-
result_chain=MessageChain().message(
|
|
137
|
-
f"阿里云百炼请求失败: message={response.message} code={response.status_code}"
|
|
138
|
-
),
|
|
139
|
-
)
|
|
140
|
-
|
|
141
|
-
output_text = response.output.get("text", "")
|
|
142
|
-
# RAG 引用脚标格式化
|
|
143
|
-
output_text = re.sub(r"<ref>\[(\d+)\]</ref>", r"[\1]", output_text)
|
|
144
|
-
if self.output_reference and response.output.get("doc_references", None):
|
|
145
|
-
ref_str = ""
|
|
146
|
-
for ref in response.output.get("doc_references", []):
|
|
147
|
-
ref_title = (
|
|
148
|
-
ref.get("title", "")
|
|
149
|
-
if ref.get("title")
|
|
150
|
-
else ref.get("doc_name", "")
|
|
151
|
-
)
|
|
152
|
-
ref_str += f"{ref['index_id']}. {ref_title}\n"
|
|
153
|
-
output_text += f"\n\n回答来源:\n{ref_str}"
|
|
154
|
-
|
|
155
|
-
llm_response = LLMResponse("assistant")
|
|
156
|
-
llm_response.result_chain = MessageChain().message(output_text)
|
|
157
|
-
|
|
158
|
-
return llm_response
|
|
159
|
-
|
|
160
|
-
async def text_chat_stream(
|
|
161
|
-
self,
|
|
162
|
-
prompt,
|
|
163
|
-
session_id=None,
|
|
164
|
-
image_urls=...,
|
|
165
|
-
func_tool=None,
|
|
166
|
-
contexts=...,
|
|
167
|
-
system_prompt=None,
|
|
168
|
-
tool_calls_result=None,
|
|
169
|
-
**kwargs,
|
|
170
|
-
):
|
|
171
|
-
# raise NotImplementedError("This method is not implemented yet.")
|
|
172
|
-
# 调用 text_chat 模拟流式
|
|
173
|
-
llm_response = await self.text_chat(
|
|
174
|
-
prompt=prompt,
|
|
175
|
-
session_id=session_id,
|
|
176
|
-
image_urls=image_urls,
|
|
177
|
-
func_tool=func_tool,
|
|
178
|
-
contexts=contexts,
|
|
179
|
-
system_prompt=system_prompt,
|
|
180
|
-
tool_calls_result=tool_calls_result,
|
|
181
|
-
)
|
|
182
|
-
llm_response.is_chunk = True
|
|
183
|
-
yield llm_response
|
|
184
|
-
llm_response.is_chunk = False
|
|
185
|
-
yield llm_response
|
|
186
|
-
|
|
187
|
-
async def forget(self, session_id):
|
|
188
|
-
return True
|
|
189
|
-
|
|
190
|
-
async def get_current_key(self):
|
|
191
|
-
return self.api_key
|
|
192
|
-
|
|
193
|
-
async def set_key(self, key):
|
|
194
|
-
raise Exception("阿里云百炼 适配器不支持设置 API Key。")
|
|
195
|
-
|
|
196
|
-
async def get_models(self):
|
|
197
|
-
return [self.get_model()]
|
|
198
|
-
|
|
199
|
-
async def get_human_readable_context(self, session_id, page, page_size):
|
|
200
|
-
raise Exception("暂不支持获得 阿里云百炼 的历史消息记录。")
|
|
201
|
-
|
|
202
|
-
async def terminate(self):
|
|
203
|
-
pass
|
|
@@ -1,281 +0,0 @@
|
|
|
1
|
-
import astrbot.core.message.components as Comp
|
|
2
|
-
|
|
3
|
-
from typing import List
|
|
4
|
-
from .. import Provider, Personality
|
|
5
|
-
from ..entities import LLMResponse
|
|
6
|
-
from ..func_tool_manager import FuncCall
|
|
7
|
-
from astrbot.core.db import BaseDatabase
|
|
8
|
-
from ..register import register_provider_adapter
|
|
9
|
-
from astrbot.core.utils.dify_api_client import DifyAPIClient
|
|
10
|
-
from astrbot.core.utils.io import download_image_by_url, download_file
|
|
11
|
-
from astrbot.core import logger, sp
|
|
12
|
-
from astrbot.core.message.message_event_result import MessageChain
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
@register_provider_adapter("dify", "Dify APP 适配器。")
|
|
16
|
-
class ProviderDify(Provider):
|
|
17
|
-
def __init__(
|
|
18
|
-
self,
|
|
19
|
-
provider_config: dict,
|
|
20
|
-
provider_settings: dict,
|
|
21
|
-
db_helper: BaseDatabase,
|
|
22
|
-
persistant_history=False,
|
|
23
|
-
default_persona: Personality = None,
|
|
24
|
-
) -> None:
|
|
25
|
-
super().__init__(
|
|
26
|
-
provider_config,
|
|
27
|
-
provider_settings,
|
|
28
|
-
persistant_history,
|
|
29
|
-
db_helper,
|
|
30
|
-
default_persona,
|
|
31
|
-
)
|
|
32
|
-
self.api_key = provider_config.get("dify_api_key", "")
|
|
33
|
-
if not self.api_key:
|
|
34
|
-
raise Exception("Dify API Key 不能为空。")
|
|
35
|
-
api_base = provider_config.get("dify_api_base", "https://api.dify.ai/v1")
|
|
36
|
-
self.api_type = provider_config.get("dify_api_type", "")
|
|
37
|
-
if not self.api_type:
|
|
38
|
-
raise Exception("Dify API 类型不能为空。")
|
|
39
|
-
self.model_name = "dify"
|
|
40
|
-
self.workflow_output_key = provider_config.get(
|
|
41
|
-
"dify_workflow_output_key", "astrbot_wf_output"
|
|
42
|
-
)
|
|
43
|
-
self.dify_query_input_key = provider_config.get(
|
|
44
|
-
"dify_query_input_key", "astrbot_text_query"
|
|
45
|
-
)
|
|
46
|
-
if not self.dify_query_input_key:
|
|
47
|
-
self.dify_query_input_key = "astrbot_text_query"
|
|
48
|
-
if not self.workflow_output_key:
|
|
49
|
-
self.workflow_output_key = "astrbot_wf_output"
|
|
50
|
-
self.variables: dict = provider_config.get("variables", {})
|
|
51
|
-
self.timeout = provider_config.get("timeout", 120)
|
|
52
|
-
if isinstance(self.timeout, str):
|
|
53
|
-
self.timeout = int(self.timeout)
|
|
54
|
-
self.conversation_ids = {}
|
|
55
|
-
"""记录当前 session id 的对话 ID"""
|
|
56
|
-
|
|
57
|
-
self.api_client = DifyAPIClient(self.api_key, api_base)
|
|
58
|
-
|
|
59
|
-
async def text_chat(
|
|
60
|
-
self,
|
|
61
|
-
prompt: str,
|
|
62
|
-
session_id: str = None,
|
|
63
|
-
image_urls: List[str] = [],
|
|
64
|
-
func_tool: FuncCall = None,
|
|
65
|
-
contexts: List = None,
|
|
66
|
-
system_prompt: str = None,
|
|
67
|
-
**kwargs,
|
|
68
|
-
) -> LLMResponse:
|
|
69
|
-
result = ""
|
|
70
|
-
conversation_id = self.conversation_ids.get(session_id, "")
|
|
71
|
-
|
|
72
|
-
files_payload = []
|
|
73
|
-
for image_url in image_urls:
|
|
74
|
-
image_path = (
|
|
75
|
-
await download_image_by_url(image_url)
|
|
76
|
-
if image_url.startswith("http")
|
|
77
|
-
else image_url
|
|
78
|
-
)
|
|
79
|
-
file_response = await self.api_client.file_upload(
|
|
80
|
-
image_path, user=session_id
|
|
81
|
-
)
|
|
82
|
-
logger.debug(f"Dify 上传图片响应:{file_response}")
|
|
83
|
-
if "id" not in file_response:
|
|
84
|
-
logger.warning(
|
|
85
|
-
f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。"
|
|
86
|
-
)
|
|
87
|
-
continue
|
|
88
|
-
files_payload.append(
|
|
89
|
-
{
|
|
90
|
-
"type": "image",
|
|
91
|
-
"transfer_method": "local_file",
|
|
92
|
-
"upload_file_id": file_response["id"],
|
|
93
|
-
}
|
|
94
|
-
)
|
|
95
|
-
|
|
96
|
-
# 获得会话变量
|
|
97
|
-
payload_vars = self.variables.copy()
|
|
98
|
-
# 动态变量
|
|
99
|
-
session_vars = sp.get("session_variables", {})
|
|
100
|
-
session_var = session_vars.get(session_id, {})
|
|
101
|
-
payload_vars.update(session_var)
|
|
102
|
-
|
|
103
|
-
try:
|
|
104
|
-
match self.api_type:
|
|
105
|
-
case "chat" | "agent" | "chatflow":
|
|
106
|
-
if not prompt:
|
|
107
|
-
prompt = "请描述这张图片。"
|
|
108
|
-
|
|
109
|
-
async for chunk in self.api_client.chat_messages(
|
|
110
|
-
inputs={
|
|
111
|
-
**payload_vars,
|
|
112
|
-
},
|
|
113
|
-
query=prompt,
|
|
114
|
-
user=session_id,
|
|
115
|
-
conversation_id=conversation_id,
|
|
116
|
-
files=files_payload,
|
|
117
|
-
timeout=self.timeout,
|
|
118
|
-
):
|
|
119
|
-
logger.debug(f"dify resp chunk: {chunk}")
|
|
120
|
-
if (
|
|
121
|
-
chunk["event"] == "message"
|
|
122
|
-
or chunk["event"] == "agent_message"
|
|
123
|
-
):
|
|
124
|
-
result += chunk["answer"]
|
|
125
|
-
if not conversation_id:
|
|
126
|
-
self.conversation_ids[session_id] = chunk[
|
|
127
|
-
"conversation_id"
|
|
128
|
-
]
|
|
129
|
-
conversation_id = chunk["conversation_id"]
|
|
130
|
-
elif chunk["event"] == "message_end":
|
|
131
|
-
logger.debug("Dify message end")
|
|
132
|
-
break
|
|
133
|
-
elif chunk["event"] == "error":
|
|
134
|
-
logger.error(f"Dify 出现错误:{chunk}")
|
|
135
|
-
raise Exception(
|
|
136
|
-
f"Dify 出现错误 status: {chunk['status']} message: {chunk['message']}"
|
|
137
|
-
)
|
|
138
|
-
|
|
139
|
-
case "workflow":
|
|
140
|
-
async for chunk in self.api_client.workflow_run(
|
|
141
|
-
inputs={
|
|
142
|
-
self.dify_query_input_key: prompt,
|
|
143
|
-
"astrbot_session_id": session_id,
|
|
144
|
-
**payload_vars,
|
|
145
|
-
},
|
|
146
|
-
user=session_id,
|
|
147
|
-
files=files_payload,
|
|
148
|
-
timeout=self.timeout,
|
|
149
|
-
):
|
|
150
|
-
match chunk["event"]:
|
|
151
|
-
case "workflow_started":
|
|
152
|
-
logger.info(
|
|
153
|
-
f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行。"
|
|
154
|
-
)
|
|
155
|
-
case "node_finished":
|
|
156
|
-
logger.debug(
|
|
157
|
-
f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束。"
|
|
158
|
-
)
|
|
159
|
-
case "workflow_finished":
|
|
160
|
-
logger.info(
|
|
161
|
-
f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束"
|
|
162
|
-
)
|
|
163
|
-
logger.debug(f"Dify 工作流结果:{chunk}")
|
|
164
|
-
if chunk["data"]["error"]:
|
|
165
|
-
logger.error(
|
|
166
|
-
f"Dify 工作流出现错误:{chunk['data']['error']}"
|
|
167
|
-
)
|
|
168
|
-
raise Exception(
|
|
169
|
-
f"Dify 工作流出现错误:{chunk['data']['error']}"
|
|
170
|
-
)
|
|
171
|
-
if (
|
|
172
|
-
self.workflow_output_key
|
|
173
|
-
not in chunk["data"]["outputs"]
|
|
174
|
-
):
|
|
175
|
-
raise Exception(
|
|
176
|
-
f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}"
|
|
177
|
-
)
|
|
178
|
-
result = chunk
|
|
179
|
-
case _:
|
|
180
|
-
raise Exception(f"未知的 Dify API 类型:{self.api_type}")
|
|
181
|
-
except Exception as e:
|
|
182
|
-
logger.error(f"Dify 请求失败:{str(e)}")
|
|
183
|
-
return LLMResponse(role="err", completion_text=f"Dify 请求失败:{str(e)}")
|
|
184
|
-
|
|
185
|
-
if not result:
|
|
186
|
-
logger.warning("Dify 请求结果为空,请查看 Debug 日志。")
|
|
187
|
-
|
|
188
|
-
chain = await self.parse_dify_result(result)
|
|
189
|
-
|
|
190
|
-
return LLMResponse(role="assistant", result_chain=chain)
|
|
191
|
-
|
|
192
|
-
async def text_chat_stream(
|
|
193
|
-
self,
|
|
194
|
-
prompt,
|
|
195
|
-
session_id=None,
|
|
196
|
-
image_urls=...,
|
|
197
|
-
func_tool=None,
|
|
198
|
-
contexts=...,
|
|
199
|
-
system_prompt=None,
|
|
200
|
-
tool_calls_result=None,
|
|
201
|
-
**kwargs,
|
|
202
|
-
):
|
|
203
|
-
# raise NotImplementedError("This method is not implemented yet.")
|
|
204
|
-
# 调用 text_chat 模拟流式
|
|
205
|
-
llm_response = await self.text_chat(
|
|
206
|
-
prompt=prompt,
|
|
207
|
-
session_id=session_id,
|
|
208
|
-
image_urls=image_urls,
|
|
209
|
-
func_tool=func_tool,
|
|
210
|
-
contexts=contexts,
|
|
211
|
-
system_prompt=system_prompt,
|
|
212
|
-
tool_calls_result=tool_calls_result,
|
|
213
|
-
)
|
|
214
|
-
llm_response.is_chunk = True
|
|
215
|
-
yield llm_response
|
|
216
|
-
llm_response.is_chunk = False
|
|
217
|
-
yield llm_response
|
|
218
|
-
|
|
219
|
-
async def parse_dify_result(self, chunk: dict | str) -> MessageChain:
|
|
220
|
-
if isinstance(chunk, str):
|
|
221
|
-
# Chat
|
|
222
|
-
return MessageChain(chain=[Comp.Plain(chunk)])
|
|
223
|
-
|
|
224
|
-
async def parse_file(item: dict) -> Comp:
|
|
225
|
-
match item["type"]:
|
|
226
|
-
case "image":
|
|
227
|
-
return Comp.Image(file=item["url"], url=item["url"])
|
|
228
|
-
case "audio":
|
|
229
|
-
# 仅支持 wav
|
|
230
|
-
path = f"data/temp/{item['filename']}.wav"
|
|
231
|
-
await download_file(item["url"], path)
|
|
232
|
-
return Comp.Image(file=item["url"], url=item["url"])
|
|
233
|
-
case "video":
|
|
234
|
-
return Comp.Video(file=item["url"])
|
|
235
|
-
case _:
|
|
236
|
-
return Comp.File(name=item["filename"], file=item["url"])
|
|
237
|
-
|
|
238
|
-
output = chunk["data"]["outputs"][self.workflow_output_key]
|
|
239
|
-
chains = []
|
|
240
|
-
if isinstance(output, str):
|
|
241
|
-
# 纯文本输出
|
|
242
|
-
chains.append(Comp.Plain(output))
|
|
243
|
-
elif isinstance(output, list):
|
|
244
|
-
# 主要适配 Dify 的 HTTP 请求结点的多模态输出
|
|
245
|
-
for item in output:
|
|
246
|
-
# handle Array[File]
|
|
247
|
-
if (
|
|
248
|
-
not isinstance(item, dict)
|
|
249
|
-
or item.get("dify_model_identity", "") != "__dify__file__"
|
|
250
|
-
):
|
|
251
|
-
chains.append(Comp.Plain(str(output)))
|
|
252
|
-
break
|
|
253
|
-
else:
|
|
254
|
-
chains.append(Comp.Plain(str(output)))
|
|
255
|
-
|
|
256
|
-
# scan file
|
|
257
|
-
files = chunk["data"].get("files", [])
|
|
258
|
-
for item in files:
|
|
259
|
-
comp = await parse_file(item)
|
|
260
|
-
chains.append(comp)
|
|
261
|
-
|
|
262
|
-
return MessageChain(chain=chains)
|
|
263
|
-
|
|
264
|
-
async def forget(self, session_id):
|
|
265
|
-
self.conversation_ids[session_id] = ""
|
|
266
|
-
return True
|
|
267
|
-
|
|
268
|
-
async def get_current_key(self):
|
|
269
|
-
return self.api_key
|
|
270
|
-
|
|
271
|
-
async def set_key(self, key):
|
|
272
|
-
raise Exception("Dify 适配器不支持设置 API Key。")
|
|
273
|
-
|
|
274
|
-
async def get_models(self):
|
|
275
|
-
return [self.get_model()]
|
|
276
|
-
|
|
277
|
-
async def get_human_readable_context(self, session_id, page, page_size):
|
|
278
|
-
raise Exception("暂不支持获得 Dify 的历史消息记录。")
|
|
279
|
-
|
|
280
|
-
async def terminate(self):
|
|
281
|
-
await self.api_client.close()
|
|
@@ -1,132 +0,0 @@
|
|
|
1
|
-
import os
|
|
2
|
-
from llmtuner.chat import ChatModel
|
|
3
|
-
from typing import List
|
|
4
|
-
from .. import Provider
|
|
5
|
-
from ..entities import LLMResponse
|
|
6
|
-
from ..func_tool_manager import FuncCall
|
|
7
|
-
from astrbot.core.db import BaseDatabase
|
|
8
|
-
from ..register import register_provider_adapter
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
@register_provider_adapter(
|
|
12
|
-
"llm_tuner", "LLMTuner 适配器, 用于装载使用 LlamaFactory 微调后的模型"
|
|
13
|
-
)
|
|
14
|
-
class LLMTunerModelLoader(Provider):
|
|
15
|
-
def __init__(
|
|
16
|
-
self,
|
|
17
|
-
provider_config: dict,
|
|
18
|
-
provider_settings: dict,
|
|
19
|
-
db_helper: BaseDatabase,
|
|
20
|
-
persistant_history=True,
|
|
21
|
-
default_persona=None,
|
|
22
|
-
) -> None:
|
|
23
|
-
super().__init__(
|
|
24
|
-
provider_config,
|
|
25
|
-
provider_settings,
|
|
26
|
-
persistant_history,
|
|
27
|
-
db_helper,
|
|
28
|
-
default_persona,
|
|
29
|
-
)
|
|
30
|
-
if not os.path.exists(provider_config["base_model_path"]) or not os.path.exists(
|
|
31
|
-
provider_config["adapter_model_path"]
|
|
32
|
-
):
|
|
33
|
-
raise FileNotFoundError("模型文件路径不存在。")
|
|
34
|
-
self.base_model_path = provider_config["base_model_path"]
|
|
35
|
-
self.adapter_model_path = provider_config["adapter_model_path"]
|
|
36
|
-
self.model = ChatModel(
|
|
37
|
-
{
|
|
38
|
-
"model_name_or_path": self.base_model_path,
|
|
39
|
-
"adapter_name_or_path": self.adapter_model_path,
|
|
40
|
-
"template": provider_config["llmtuner_template"],
|
|
41
|
-
"finetuning_type": provider_config["finetuning_type"],
|
|
42
|
-
"quantization_bit": provider_config["quantization_bit"],
|
|
43
|
-
}
|
|
44
|
-
)
|
|
45
|
-
self.set_model(
|
|
46
|
-
os.path.basename(self.base_model_path)
|
|
47
|
-
+ "_"
|
|
48
|
-
+ os.path.basename(self.adapter_model_path)
|
|
49
|
-
)
|
|
50
|
-
|
|
51
|
-
async def assemble_context(self, text: str, image_urls: List[str] = None):
|
|
52
|
-
"""
|
|
53
|
-
组装上下文。
|
|
54
|
-
"""
|
|
55
|
-
return {"role": "user", "content": text}
|
|
56
|
-
|
|
57
|
-
async def text_chat(
|
|
58
|
-
self,
|
|
59
|
-
prompt: str,
|
|
60
|
-
session_id: str = None,
|
|
61
|
-
image_urls: List[str] = None,
|
|
62
|
-
func_tool: FuncCall = None,
|
|
63
|
-
contexts: List = [],
|
|
64
|
-
system_prompt: str = None,
|
|
65
|
-
**kwargs,
|
|
66
|
-
) -> LLMResponse:
|
|
67
|
-
system_prompt = ""
|
|
68
|
-
new_record = {"role": "user", "content": prompt}
|
|
69
|
-
query_context = [*contexts, new_record]
|
|
70
|
-
|
|
71
|
-
# 提取出系统提示
|
|
72
|
-
system_idxs = []
|
|
73
|
-
for idx, context in enumerate(query_context):
|
|
74
|
-
if context["role"] == "system":
|
|
75
|
-
system_idxs.append(idx)
|
|
76
|
-
|
|
77
|
-
if "_no_save" in context:
|
|
78
|
-
del context["_no_save"]
|
|
79
|
-
|
|
80
|
-
for idx in reversed(system_idxs):
|
|
81
|
-
system_prompt += " " + query_context.pop(idx)["content"]
|
|
82
|
-
|
|
83
|
-
conf = {
|
|
84
|
-
"messages": query_context,
|
|
85
|
-
"system": system_prompt,
|
|
86
|
-
}
|
|
87
|
-
if func_tool:
|
|
88
|
-
tool_list = func_tool.get_func_desc_openai_style()
|
|
89
|
-
if tool_list:
|
|
90
|
-
conf["tools"] = tool_list
|
|
91
|
-
|
|
92
|
-
responses = await self.model.achat(**conf)
|
|
93
|
-
|
|
94
|
-
llm_response = LLMResponse("assistant", responses[-1].response_text)
|
|
95
|
-
|
|
96
|
-
return llm_response
|
|
97
|
-
|
|
98
|
-
async def text_chat_stream(
|
|
99
|
-
self,
|
|
100
|
-
prompt,
|
|
101
|
-
session_id=None,
|
|
102
|
-
image_urls=...,
|
|
103
|
-
func_tool=None,
|
|
104
|
-
contexts=...,
|
|
105
|
-
system_prompt=None,
|
|
106
|
-
tool_calls_result=None,
|
|
107
|
-
**kwargs,
|
|
108
|
-
):
|
|
109
|
-
# raise NotImplementedError("This method is not implemented yet.")
|
|
110
|
-
# 调用 text_chat 模拟流式
|
|
111
|
-
llm_response = await self.text_chat(
|
|
112
|
-
prompt=prompt,
|
|
113
|
-
session_id=session_id,
|
|
114
|
-
image_urls=image_urls,
|
|
115
|
-
func_tool=func_tool,
|
|
116
|
-
contexts=contexts,
|
|
117
|
-
system_prompt=system_prompt,
|
|
118
|
-
tool_calls_result=tool_calls_result,
|
|
119
|
-
)
|
|
120
|
-
llm_response.is_chunk = True
|
|
121
|
-
yield llm_response
|
|
122
|
-
llm_response.is_chunk = False
|
|
123
|
-
yield llm_response
|
|
124
|
-
|
|
125
|
-
async def get_current_key(self):
|
|
126
|
-
return "none"
|
|
127
|
-
|
|
128
|
-
async def set_key(self, key):
|
|
129
|
-
pass
|
|
130
|
-
|
|
131
|
-
async def get_models(self):
|
|
132
|
-
return [self.get_model()]
|
|
@@ -1,20 +0,0 @@
|
|
|
1
|
-
from typing import List
|
|
2
|
-
from openai import AsyncOpenAI
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
class SimpleOpenAIEmbedding:
|
|
6
|
-
def __init__(
|
|
7
|
-
self,
|
|
8
|
-
model,
|
|
9
|
-
api_key,
|
|
10
|
-
api_base=None,
|
|
11
|
-
) -> None:
|
|
12
|
-
self.client = AsyncOpenAI(api_key=api_key, base_url=api_base)
|
|
13
|
-
self.model = model
|
|
14
|
-
|
|
15
|
-
async def get_embedding(self, text) -> List[float]:
|
|
16
|
-
"""
|
|
17
|
-
获取文本的嵌入
|
|
18
|
-
"""
|
|
19
|
-
embedding = await self.client.embeddings.create(input=text, model=self.model)
|
|
20
|
-
return embedding.data[0].embedding
|