AstrBot 4.3.2__py3-none-any.whl → 4.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. astrbot/core/agent/mcp_client.py +18 -4
  2. astrbot/core/astr_agent_context.py +1 -0
  3. astrbot/core/config/default.py +66 -9
  4. astrbot/core/db/sqlite.py +7 -0
  5. astrbot/core/pipeline/context_utils.py +1 -0
  6. astrbot/core/pipeline/process_stage/method/llm_request.py +32 -14
  7. astrbot/core/pipeline/result_decorate/stage.py +44 -45
  8. astrbot/core/pipeline/scheduler.py +1 -1
  9. astrbot/core/platform/manager.py +4 -0
  10. astrbot/core/platform/sources/satori/satori_event.py +23 -1
  11. astrbot/core/platform/sources/webchat/webchat_adapter.py +0 -1
  12. astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +289 -0
  13. astrbot/core/platform/sources/wecom_ai_bot/__init__.py +17 -0
  14. astrbot/core/platform/sources/wecom_ai_bot/ierror.py +20 -0
  15. astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +445 -0
  16. astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +378 -0
  17. astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +149 -0
  18. astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +148 -0
  19. astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +166 -0
  20. astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +199 -0
  21. astrbot/core/provider/provider.py +2 -1
  22. astrbot/core/provider/sources/anthropic_source.py +13 -7
  23. astrbot/core/provider/sources/dashscope_tts.py +120 -12
  24. astrbot/core/provider/sources/gemini_source.py +21 -17
  25. astrbot/core/provider/sources/openai_source.py +1 -1
  26. astrbot/dashboard/routes/session_management.py +6 -6
  27. astrbot/dashboard/routes/tools.py +14 -0
  28. astrbot/dashboard/routes/update.py +8 -5
  29. {astrbot-4.3.2.dist-info → astrbot-4.3.5.dist-info}/METADATA +64 -44
  30. {astrbot-4.3.2.dist-info → astrbot-4.3.5.dist-info}/RECORD +33 -24
  31. {astrbot-4.3.2.dist-info → astrbot-4.3.5.dist-info}/WHEEL +0 -0
  32. {astrbot-4.3.2.dist-info → astrbot-4.3.5.dist-info}/entry_points.txt +0 -0
  33. {astrbot-4.3.2.dist-info → astrbot-4.3.5.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,166 @@
1
+ """
2
+ 企业微信智能机器人 HTTP 服务器
3
+ 处理企业微信智能机器人的 HTTP 回调请求
4
+ """
5
+
6
+ import asyncio
7
+ from typing import Dict, Any, Optional, Callable
8
+
9
+ import quart
10
+ from astrbot.api import logger
11
+
12
+ from .wecomai_api import WecomAIBotAPIClient
13
+ from .wecomai_utils import WecomAIBotConstants
14
+
15
+
16
+ class WecomAIBotServer:
17
+ """企业微信智能机器人 HTTP 服务器"""
18
+
19
+ def __init__(
20
+ self,
21
+ host: str,
22
+ port: int,
23
+ api_client: WecomAIBotAPIClient,
24
+ message_handler: Optional[
25
+ Callable[[Dict[str, Any], Dict[str, str]], Any]
26
+ ] = None,
27
+ ):
28
+ """初始化服务器
29
+
30
+ Args:
31
+ host: 监听地址
32
+ port: 监听端口
33
+ api_client: API客户端实例
34
+ message_handler: 消息处理回调函数
35
+ """
36
+ self.host = host
37
+ self.port = port
38
+ self.api_client = api_client
39
+ self.message_handler = message_handler
40
+
41
+ self.app = quart.Quart(__name__)
42
+ self._setup_routes()
43
+
44
+ self.shutdown_event = asyncio.Event()
45
+
46
+ def _setup_routes(self):
47
+ """设置 Quart 路由"""
48
+
49
+ # 使用 Quart 的 add_url_rule 方法添加路由
50
+ self.app.add_url_rule(
51
+ "/webhook/wecom-ai-bot",
52
+ view_func=self.verify_url,
53
+ methods=["GET"],
54
+ )
55
+
56
+ self.app.add_url_rule(
57
+ "/webhook/wecom-ai-bot",
58
+ view_func=self.handle_message,
59
+ methods=["POST"],
60
+ )
61
+
62
+ async def verify_url(self):
63
+ """验证回调 URL"""
64
+ args = quart.request.args
65
+ msg_signature = args.get("msg_signature")
66
+ timestamp = args.get("timestamp")
67
+ nonce = args.get("nonce")
68
+ echostr = args.get("echostr")
69
+
70
+ if not all([msg_signature, timestamp, nonce, echostr]):
71
+ logger.error("URL 验证参数缺失")
72
+ return "verify fail", 400
73
+
74
+ # 类型检查确保不为 None
75
+ assert msg_signature is not None
76
+ assert timestamp is not None
77
+ assert nonce is not None
78
+ assert echostr is not None
79
+
80
+ logger.info("收到企业微信智能机器人 WebHook URL 验证请求。")
81
+ result = self.api_client.verify_url(msg_signature, timestamp, nonce, echostr)
82
+ return result, 200, {"Content-Type": "text/plain"}
83
+
84
+ async def handle_message(self):
85
+ """处理消息回调"""
86
+ args = quart.request.args
87
+ msg_signature = args.get("msg_signature")
88
+ timestamp = args.get("timestamp")
89
+ nonce = args.get("nonce")
90
+
91
+ if not all([msg_signature, timestamp, nonce]):
92
+ logger.error("消息回调参数缺失")
93
+ return "缺少必要参数", 400
94
+
95
+ # 类型检查确保不为 None
96
+ assert msg_signature is not None
97
+ assert timestamp is not None
98
+ assert nonce is not None
99
+
100
+ logger.debug(
101
+ f"收到消息回调,msg_signature={msg_signature}, timestamp={timestamp}, nonce={nonce}"
102
+ )
103
+
104
+ try:
105
+ # 获取请求体
106
+ post_data = await quart.request.get_data()
107
+
108
+ # 确保 post_data 是 bytes 类型
109
+ if isinstance(post_data, str):
110
+ post_data = post_data.encode("utf-8")
111
+
112
+ # 解密消息
113
+ ret_code, message_data = await self.api_client.decrypt_message(
114
+ post_data, msg_signature, timestamp, nonce
115
+ )
116
+
117
+ if ret_code != WecomAIBotConstants.SUCCESS or not message_data:
118
+ logger.error("消息解密失败,错误码: %d", ret_code)
119
+ return "消息解密失败", 400
120
+
121
+ # 调用消息处理器
122
+ response = None
123
+ if self.message_handler:
124
+ try:
125
+ response = await self.message_handler(
126
+ message_data, {"nonce": nonce, "timestamp": timestamp}
127
+ )
128
+ except Exception as e:
129
+ logger.error("消息处理器执行异常: %s", e)
130
+ return "消息处理异常", 500
131
+
132
+ if response:
133
+ return response, 200, {"Content-Type": "text/plain"}
134
+ else:
135
+ return "success", 200, {"Content-Type": "text/plain"}
136
+
137
+ except Exception as e:
138
+ logger.error("处理消息时发生异常: %s", e)
139
+ return "内部服务器错误", 500
140
+
141
+ async def start_server(self):
142
+ """启动服务器"""
143
+ logger.info("启动企业微信智能机器人服务器,监听 %s:%d", self.host, self.port)
144
+
145
+ try:
146
+ await self.app.run_task(
147
+ host=self.host,
148
+ port=self.port,
149
+ shutdown_trigger=self.shutdown_trigger,
150
+ )
151
+ except Exception as e:
152
+ logger.error("服务器运行异常: %s", e)
153
+ raise
154
+
155
+ async def shutdown_trigger(self):
156
+ """关闭触发器"""
157
+ await self.shutdown_event.wait()
158
+
159
+ async def shutdown(self):
160
+ """关闭服务器"""
161
+ logger.info("企业微信智能机器人服务器正在关闭...")
162
+ self.shutdown_event.set()
163
+
164
+ def get_app(self):
165
+ """获取 Quart 应用实例"""
166
+ return self.app
@@ -0,0 +1,199 @@
1
+ """
2
+ 企业微信智能机器人工具模块
3
+ 提供常量定义、工具函数和辅助方法
4
+ """
5
+
6
+ import string
7
+ import random
8
+ import hashlib
9
+ import base64
10
+ import aiohttp
11
+ import asyncio
12
+ from Crypto.Cipher import AES
13
+ from typing import Any, Tuple
14
+ from astrbot.api import logger
15
+
16
+
17
+ # 常量定义
18
+ class WecomAIBotConstants:
19
+ """企业微信智能机器人常量"""
20
+
21
+ # 消息类型
22
+ MSG_TYPE_TEXT = "text"
23
+ MSG_TYPE_IMAGE = "image"
24
+ MSG_TYPE_MIXED = "mixed"
25
+ MSG_TYPE_STREAM = "stream"
26
+ MSG_TYPE_EVENT = "event"
27
+
28
+ # 流消息状态
29
+ STREAM_CONTINUE = False
30
+ STREAM_FINISH = True
31
+
32
+ # 错误码
33
+ SUCCESS = 0
34
+ DECRYPT_ERROR = -40001
35
+ VALIDATE_SIGNATURE_ERROR = -40002
36
+ PARSE_XML_ERROR = -40003
37
+ COMPUTE_SIGNATURE_ERROR = -40004
38
+ ILLEGAL_AES_KEY = -40005
39
+ VALIDATE_APPID_ERROR = -40006
40
+ ENCRYPT_AES_ERROR = -40007
41
+ ILLEGAL_BUFFER = -40008
42
+
43
+
44
+ def generate_random_string(length: int = 10) -> str:
45
+ """生成随机字符串
46
+
47
+ Args:
48
+ length: 字符串长度,默认为 10
49
+
50
+ Returns:
51
+ 随机字符串
52
+ """
53
+ letters = string.ascii_letters + string.digits
54
+ return "".join(random.choice(letters) for _ in range(length))
55
+
56
+
57
+ def calculate_image_md5(image_data: bytes) -> str:
58
+ """计算图片数据的 MD5 值
59
+
60
+ Args:
61
+ image_data: 图片二进制数据
62
+
63
+ Returns:
64
+ MD5 哈希值(十六进制字符串)
65
+ """
66
+ return hashlib.md5(image_data).hexdigest()
67
+
68
+
69
+ def encode_image_base64(image_data: bytes) -> str:
70
+ """将图片数据编码为 Base64
71
+
72
+ Args:
73
+ image_data: 图片二进制数据
74
+
75
+ Returns:
76
+ Base64 编码的字符串
77
+ """
78
+ return base64.b64encode(image_data).decode("utf-8")
79
+
80
+
81
+ def format_session_id(session_type: str, session_id: str) -> str:
82
+ """格式化会话 ID
83
+
84
+ Args:
85
+ session_type: 会话类型 ("user", "group")
86
+ session_id: 原始会话 ID
87
+
88
+ Returns:
89
+ 格式化后的会话 ID
90
+ """
91
+ return f"wecom_ai_bot_{session_type}_{session_id}"
92
+
93
+
94
+ def parse_session_id(formatted_session_id: str) -> Tuple[str, str]:
95
+ """解析格式化的会话 ID
96
+
97
+ Args:
98
+ formatted_session_id: 格式化的会话 ID
99
+
100
+ Returns:
101
+ (会话类型, 原始会话ID)
102
+ """
103
+ parts = formatted_session_id.split("_", 3)
104
+ if (
105
+ len(parts) >= 4
106
+ and parts[0] == "wecom"
107
+ and parts[1] == "ai"
108
+ and parts[2] == "bot"
109
+ ):
110
+ return parts[3], "_".join(parts[4:]) if len(parts) > 4 else ""
111
+ return "user", formatted_session_id
112
+
113
+
114
+ def safe_json_loads(json_str: str, default: Any = None) -> Any:
115
+ """安全地解析 JSON 字符串
116
+
117
+ Args:
118
+ json_str: JSON 字符串
119
+ default: 解析失败时的默认值
120
+
121
+ Returns:
122
+ 解析结果或默认值
123
+ """
124
+ import json
125
+
126
+ try:
127
+ return json.loads(json_str)
128
+ except (json.JSONDecodeError, TypeError) as e:
129
+ logger.warning(f"JSON 解析失败: {e}, 原始字符串: {json_str}")
130
+ return default
131
+
132
+
133
+ def format_error_response(error_code: int, error_msg: str) -> str:
134
+ """格式化错误响应
135
+
136
+ Args:
137
+ error_code: 错误码
138
+ error_msg: 错误信息
139
+
140
+ Returns:
141
+ 格式化的错误响应字符串
142
+ """
143
+ return f"Error {error_code}: {error_msg}"
144
+
145
+
146
+ async def process_encrypted_image(
147
+ image_url: str, aes_key_base64: str
148
+ ) -> Tuple[bool, str]:
149
+ """下载并解密加密图片
150
+
151
+ Args:
152
+ image_url: 加密图片的URL
153
+ aes_key_base64: Base64编码的AES密钥(与回调加解密相同)
154
+
155
+ Returns:
156
+ Tuple[bool, str]: status 为 True 时 data 是解密后的图片数据的 base64 编码,
157
+ status 为 False 时 data 是错误信息
158
+ """
159
+ # 1. 下载加密图片
160
+ logger.info("开始下载加密图片: %s", image_url)
161
+ try:
162
+ async with aiohttp.ClientSession() as session:
163
+ async with session.get(image_url, timeout=15) as response:
164
+ response.raise_for_status()
165
+ encrypted_data = await response.read()
166
+ logger.info("图片下载成功,大小: %d 字节", len(encrypted_data))
167
+ except (aiohttp.ClientError, asyncio.TimeoutError) as e:
168
+ error_msg = f"下载图片失败: {str(e)}"
169
+ logger.error(error_msg)
170
+ return False, error_msg
171
+
172
+ # 2. 准备AES密钥和IV
173
+ if not aes_key_base64:
174
+ raise ValueError("AES密钥不能为空")
175
+
176
+ # Base64解码密钥 (自动处理填充)
177
+ aes_key = base64.b64decode(aes_key_base64 + "=" * (-len(aes_key_base64) % 4))
178
+ if len(aes_key) != 32:
179
+ raise ValueError("无效的AES密钥长度: 应为32字节")
180
+
181
+ iv = aes_key[:16] # 初始向量为密钥前16字节
182
+
183
+ # 3. 解密图片数据
184
+ cipher = AES.new(aes_key, AES.MODE_CBC, iv)
185
+ decrypted_data = cipher.decrypt(encrypted_data)
186
+
187
+ # 4. 去除PKCS#7填充 (Python 3兼容写法)
188
+ pad_len = decrypted_data[-1] # 直接获取最后一个字节的整数值
189
+ if pad_len > 32: # AES-256块大小为32字节
190
+ raise ValueError("无效的填充长度 (大于32字节)")
191
+
192
+ decrypted_data = decrypted_data[:-pad_len]
193
+ logger.info("图片解密成功,解密后大小: %d 字节", len(decrypted_data))
194
+
195
+ # 5. 转换为base64编码
196
+ base64_data = base64.b64encode(decrypted_data).decode("utf-8")
197
+ logger.info("图片已转换为base64编码,编码后长度: %d", len(base64_data))
198
+
199
+ return True, base64_data
@@ -68,7 +68,8 @@ class Provider(AbstractProvider):
68
68
 
69
69
  def get_keys(self) -> List[str]:
70
70
  """获得提供商 Key"""
71
- return self.provider_config.get("key", [])
71
+ keys = self.provider_config.get("key", [""])
72
+ return keys or [""]
72
73
 
73
74
  @abc.abstractmethod
74
75
  def set_key(self, key: str):
@@ -33,7 +33,7 @@ class ProviderAnthropic(Provider):
33
33
  )
34
34
 
35
35
  self.chosen_api_key: str = ""
36
- self.api_keys: List = provider_config.get("key", [])
36
+ self.api_keys: List = super().get_keys()
37
37
  self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else ""
38
38
  self.base_url = provider_config.get("api_base", "https://api.anthropic.com")
39
39
  self.timeout = provider_config.get("timeout", 120)
@@ -70,9 +70,13 @@ class ProviderAnthropic(Provider):
70
70
  {
71
71
  "type": "tool_use",
72
72
  "name": tool_call["function"]["name"],
73
- "input": json.loads(tool_call["function"]["arguments"])
74
- if isinstance(tool_call["function"]["arguments"], str)
75
- else tool_call["function"]["arguments"],
73
+ "input": (
74
+ json.loads(tool_call["function"]["arguments"])
75
+ if isinstance(
76
+ tool_call["function"]["arguments"], str
77
+ )
78
+ else tool_call["function"]["arguments"]
79
+ ),
76
80
  "id": tool_call["id"],
77
81
  }
78
82
  )
@@ -355,9 +359,11 @@ class ProviderAnthropic(Provider):
355
359
  "source": {
356
360
  "type": "base64",
357
361
  "media_type": mime_type,
358
- "data": image_data.split("base64,")[1]
359
- if "base64," in image_data
360
- else image_data,
362
+ "data": (
363
+ image_data.split("base64,")[1]
364
+ if "base64," in image_data
365
+ else image_data
366
+ ),
361
367
  },
362
368
  }
363
369
  )
@@ -1,10 +1,22 @@
1
+ import asyncio
2
+ import base64
3
+ import logging
1
4
  import os
2
- import dashscope
3
5
  import uuid
4
- import asyncio
5
- from dashscope.audio.tts_v2 import *
6
- from ..provider import TTSProvider
6
+ from typing import Optional, Tuple
7
+ import aiohttp
8
+ import dashscope
9
+ from dashscope.audio.tts_v2 import AudioFormat, SpeechSynthesizer
10
+
11
+ try:
12
+ from dashscope.aigc.multimodal_conversation import MultiModalConversation
13
+ except (
14
+ ImportError
15
+ ): # pragma: no cover - older dashscope versions without Qwen TTS support
16
+ MultiModalConversation = None
17
+
7
18
  from ..entities import ProviderType
19
+ from ..provider import TTSProvider
8
20
  from ..register import register_provider_adapter
9
21
  from astrbot.core.utils.astrbot_path import get_astrbot_data_path
10
22
 
@@ -26,16 +38,112 @@ class ProviderDashscopeTTSAPI(TTSProvider):
26
38
  dashscope.api_key = self.chosen_api_key
27
39
 
28
40
  async def get_audio(self, text: str) -> str:
41
+ model = self.get_model()
42
+ if not model:
43
+ raise RuntimeError("Dashscope TTS model is not configured.")
44
+
29
45
  temp_dir = os.path.join(get_astrbot_data_path(), "temp")
30
- path = os.path.join(temp_dir, f"dashscope_tts_{uuid.uuid4()}.wav")
31
- self.synthesizer = SpeechSynthesizer(
32
- model=self.get_model(),
46
+ os.makedirs(temp_dir, exist_ok=True)
47
+
48
+ if self._is_qwen_tts_model(model):
49
+ audio_bytes, ext = await self._synthesize_with_qwen_tts(model, text)
50
+ else:
51
+ audio_bytes, ext = await self._synthesize_with_cosyvoice(model, text)
52
+
53
+ if not audio_bytes:
54
+ raise RuntimeError(
55
+ "Audio synthesis failed, returned empty content. The model may not be supported or the service is unavailable."
56
+ )
57
+
58
+ path = os.path.join(temp_dir, f"dashscope_tts_{uuid.uuid4()}{ext}")
59
+ with open(path, "wb") as f:
60
+ f.write(audio_bytes)
61
+ return path
62
+
63
+ def _call_qwen_tts(self, model: str, text: str):
64
+ if MultiModalConversation is None:
65
+ raise RuntimeError(
66
+ "dashscope SDK missing MultiModalConversation. Please upgrade the dashscope package to use Qwen TTS models."
67
+ )
68
+
69
+ kwargs = {
70
+ "model": model,
71
+ "text": text,
72
+ "api_key": self.chosen_api_key,
73
+ "voice": self.voice or "Cherry",
74
+ }
75
+ if not self.voice:
76
+ logging.warning(
77
+ "No voice specified for Qwen TTS model, using default 'Cherry'."
78
+ )
79
+ return MultiModalConversation.call(**kwargs)
80
+
81
+ async def _synthesize_with_qwen_tts(
82
+ self, model: str, text: str
83
+ ) -> Tuple[Optional[bytes], str]:
84
+ loop = asyncio.get_event_loop()
85
+ response = await loop.run_in_executor(None, self._call_qwen_tts, model, text)
86
+ audio_bytes = await self._extract_audio_from_response(response)
87
+ if not audio_bytes:
88
+ raise RuntimeError(
89
+ f"Audio synthesis failed for model '{model}'. {response}"
90
+ )
91
+ ext = ".wav"
92
+ return audio_bytes, ext
93
+
94
+ async def _extract_audio_from_response(self, response) -> Optional[bytes]:
95
+ output = getattr(response, "output", None)
96
+ audio_obj = getattr(output, "audio", None) if output is not None else None
97
+ if not audio_obj:
98
+ return None
99
+
100
+ data_b64 = getattr(audio_obj, "data", None)
101
+ if data_b64:
102
+ try:
103
+ return base64.b64decode(data_b64)
104
+ except (ValueError, TypeError):
105
+ logging.error("Failed to decode base64 audio data.")
106
+ return None
107
+
108
+ url = getattr(audio_obj, "url", None)
109
+ if url:
110
+ return await self._download_audio_from_url(url)
111
+ return None
112
+
113
+ async def _download_audio_from_url(self, url: str) -> Optional[bytes]:
114
+ if not url:
115
+ return None
116
+ timeout = max(self.timeout_ms / 1000, 1) if self.timeout_ms else 20
117
+ try:
118
+ async with aiohttp.ClientSession() as session:
119
+ async with session.get(
120
+ url, timeout=aiohttp.ClientTimeout(total=timeout)
121
+ ) as response:
122
+ return await response.read()
123
+ except (aiohttp.ClientError, asyncio.TimeoutError, OSError) as e:
124
+ logging.error(f"Failed to download audio from URL {url}: {e}")
125
+ return None
126
+
127
+ async def _synthesize_with_cosyvoice(
128
+ self, model: str, text: str
129
+ ) -> Tuple[Optional[bytes], str]:
130
+ synthesizer = SpeechSynthesizer(
131
+ model=model,
33
132
  voice=self.voice,
34
133
  format=AudioFormat.WAV_24000HZ_MONO_16BIT,
35
134
  )
36
- audio = await asyncio.get_event_loop().run_in_executor(
37
- None, self.synthesizer.call, text, self.timeout_ms
135
+ loop = asyncio.get_event_loop()
136
+ audio_bytes = await loop.run_in_executor(
137
+ None, synthesizer.call, text, self.timeout_ms
38
138
  )
39
- with open(path, "wb") as f:
40
- f.write(audio)
41
- return path
139
+ if not audio_bytes:
140
+ resp = synthesizer.get_response()
141
+ if resp and isinstance(resp, dict):
142
+ raise RuntimeError(
143
+ f"Audio synthesis failed for model '{model}'. {resp}".strip()
144
+ )
145
+ return audio_bytes, ".wav"
146
+
147
+ def _is_qwen_tts_model(self, model: str) -> bool:
148
+ model_lower = model.lower()
149
+ return "tts" in model_lower and model_lower.startswith("qwen")
@@ -3,7 +3,7 @@ import base64
3
3
  import json
4
4
  import logging
5
5
  import random
6
- from typing import Optional
6
+ from typing import Optional, List
7
7
  from collections.abc import AsyncGenerator
8
8
 
9
9
  from google import genai
@@ -60,7 +60,7 @@ class ProviderGoogleGenAI(Provider):
60
60
  provider_settings,
61
61
  default_persona,
62
62
  )
63
- self.api_keys: list = provider_config.get("key", [])
63
+ self.api_keys: List = super().get_keys()
64
64
  self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else ""
65
65
  self.timeout: int = int(provider_config.get("timeout", 180))
66
66
 
@@ -218,19 +218,21 @@ class ProviderGoogleGenAI(Provider):
218
218
  response_modalities=modalities,
219
219
  tools=tool_list,
220
220
  safety_settings=self.safety_settings if self.safety_settings else None,
221
- thinking_config=types.ThinkingConfig(
222
- thinking_budget=min(
223
- int(
224
- self.provider_config.get("gm_thinking_config", {}).get(
225
- "budget", 0
226
- )
221
+ thinking_config=(
222
+ types.ThinkingConfig(
223
+ thinking_budget=min(
224
+ int(
225
+ self.provider_config.get("gm_thinking_config", {}).get(
226
+ "budget", 0
227
+ )
228
+ ),
229
+ 24576,
227
230
  ),
228
- 24576,
229
- ),
230
- )
231
- if "gemini-2.5-flash" in self.get_model()
232
- and hasattr(types.ThinkingConfig, "thinking_budget")
233
- else None,
231
+ )
232
+ if "gemini-2.5-flash" in self.get_model()
233
+ and hasattr(types.ThinkingConfig, "thinking_budget")
234
+ else None
235
+ ),
234
236
  automatic_function_calling=types.AutomaticFunctionCallingConfig(
235
237
  disable=True
236
238
  ),
@@ -274,9 +276,11 @@ class ProviderGoogleGenAI(Provider):
274
276
  if role == "user":
275
277
  if isinstance(content, list):
276
278
  parts = [
277
- types.Part.from_text(text=item["text"] or " ")
278
- if item["type"] == "text"
279
- else process_image_url(item["image_url"])
279
+ (
280
+ types.Part.from_text(text=item["text"] or " ")
281
+ if item["type"] == "text"
282
+ else process_image_url(item["image_url"])
283
+ )
280
284
  for item in content
281
285
  ]
282
286
  else:
@@ -38,7 +38,7 @@ class ProviderOpenAIOfficial(Provider):
38
38
  default_persona,
39
39
  )
40
40
  self.chosen_api_key = None
41
- self.api_keys: List = provider_config.get("key", [])
41
+ self.api_keys: List = super().get_keys()
42
42
  self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
43
43
  self.timeout = provider_config.get("timeout", 120)
44
44
  if isinstance(self.timeout, str):
@@ -65,12 +65,12 @@ class SessionManagementRoute(Route):
65
65
  persona_name = data["persona_name"]
66
66
 
67
67
  # 处理 persona 显示
68
- if conv_persona_id == "[%None]":
69
- persona_name = "无人格"
70
- else:
71
- default_persona = persona_mgr.selected_default_persona_v3
72
- if default_persona:
73
- persona_name = default_persona["name"]
68
+ if persona_name is None:
69
+ if conv_persona_id is None:
70
+ if default_persona := persona_mgr.selected_default_persona_v3:
71
+ persona_name = default_persona["name"]
72
+ else:
73
+ persona_name = "[%None]"
74
74
 
75
75
  session_info = {
76
76
  "session_id": session_id,
@@ -273,6 +273,20 @@ class ToolsRoute(Route):
273
273
  server_data = await request.json
274
274
  config = server_data.get("mcp_server_config", None)
275
275
 
276
+ if not isinstance(config, dict) or not config:
277
+ return Response().error("无效的 MCP 服务器配置").__dict__
278
+
279
+ if "mcpServers" in config:
280
+ keys = list(config["mcpServers"].keys())
281
+ if not keys:
282
+ return Response().error("MCP 服务器配置不能为空").__dict__
283
+ if len(keys) > 1:
284
+ return Response().error("一次只能配置一个 MCP 服务器配置").__dict__
285
+ config = config["mcpServers"][keys[0]]
286
+ else:
287
+ if not config:
288
+ return Response().error("MCP 服务器配置不能为空").__dict__
289
+
276
290
  tools_name = await self.tool_mgr.test_mcp_server_connection(config)
277
291
  return (
278
292
  Response().ok(data=tools_name, message="🎉 MCP 服务器可用!").__dict__