AstrBot 4.0.0b4__py3-none-any.whl → 4.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/api/event/filter/__init__.py +2 -0
- astrbot/cli/utils/basic.py +12 -3
- astrbot/core/astrbot_config_mgr.py +16 -9
- astrbot/core/config/default.py +82 -4
- astrbot/core/initial_loader.py +4 -1
- astrbot/core/message/components.py +59 -50
- astrbot/core/pipeline/process_stage/method/llm_request.py +6 -2
- astrbot/core/pipeline/result_decorate/stage.py +5 -1
- astrbot/core/platform/manager.py +25 -3
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +26 -14
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +11 -4
- astrbot/core/platform/sources/satori/satori_adapter.py +482 -0
- astrbot/core/platform/sources/satori/satori_event.py +221 -0
- astrbot/core/platform/sources/telegram/tg_adapter.py +0 -1
- astrbot/core/provider/entities.py +17 -15
- astrbot/core/provider/sources/gemini_source.py +57 -18
- astrbot/core/provider/sources/openai_source.py +12 -5
- astrbot/core/provider/sources/vllm_rerank_source.py +6 -0
- astrbot/core/star/__init__.py +7 -5
- astrbot/core/star/filter/command.py +9 -3
- astrbot/core/star/filter/platform_adapter_type.py +3 -0
- astrbot/core/star/register/__init__.py +2 -0
- astrbot/core/star/register/star_handler.py +18 -4
- astrbot/core/star/star_handler.py +9 -1
- astrbot/core/star/star_tools.py +116 -21
- astrbot/core/updator.py +7 -5
- astrbot/core/utils/io.py +1 -1
- astrbot/core/utils/t2i/network_strategy.py +11 -18
- astrbot/core/utils/t2i/renderer.py +8 -2
- astrbot/core/utils/t2i/template/astrbot_powershell.html +184 -0
- astrbot/core/utils/t2i/template_manager.py +112 -0
- astrbot/core/zip_updator.py +26 -4
- astrbot/dashboard/routes/chat.py +6 -1
- astrbot/dashboard/routes/config.py +24 -49
- astrbot/dashboard/routes/route.py +19 -2
- astrbot/dashboard/routes/t2i.py +230 -0
- astrbot/dashboard/routes/update.py +3 -5
- astrbot/dashboard/server.py +13 -4
- {astrbot-4.0.0b4.dist-info → astrbot-4.1.0.dist-info}/METADATA +40 -53
- {astrbot-4.0.0b4.dist-info → astrbot-4.1.0.dist-info}/RECORD +43 -38
- {astrbot-4.0.0b4.dist-info → astrbot-4.1.0.dist-info}/WHEEL +0 -0
- {astrbot-4.0.0b4.dist-info → astrbot-4.1.0.dist-info}/entry_points.txt +0 -0
- {astrbot-4.0.0b4.dist-info → astrbot-4.1.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,482 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import time
|
|
4
|
+
import websockets
|
|
5
|
+
from websockets.asyncio.client import connect
|
|
6
|
+
from typing import Optional
|
|
7
|
+
from aiohttp import ClientSession, ClientTimeout
|
|
8
|
+
from websockets.asyncio.client import ClientConnection
|
|
9
|
+
from astrbot.api import logger
|
|
10
|
+
from astrbot.api.event import MessageChain
|
|
11
|
+
from astrbot.api.platform import (
|
|
12
|
+
AstrBotMessage,
|
|
13
|
+
MessageMember,
|
|
14
|
+
MessageType,
|
|
15
|
+
Platform,
|
|
16
|
+
PlatformMetadata,
|
|
17
|
+
register_platform_adapter,
|
|
18
|
+
)
|
|
19
|
+
from astrbot.core.platform.astr_message_event import MessageSession
|
|
20
|
+
from astrbot.api.message_components import Plain, Image, At, File, Record
|
|
21
|
+
from xml.etree import ElementTree as ET
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@register_platform_adapter(
|
|
25
|
+
"satori",
|
|
26
|
+
"Satori 协议适配器",
|
|
27
|
+
)
|
|
28
|
+
class SatoriPlatformAdapter(Platform):
|
|
29
|
+
def __init__(
|
|
30
|
+
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
|
31
|
+
) -> None:
|
|
32
|
+
super().__init__(event_queue)
|
|
33
|
+
self.config = platform_config
|
|
34
|
+
self.settings = platform_settings
|
|
35
|
+
|
|
36
|
+
self.api_base_url = self.config.get(
|
|
37
|
+
"satori_api_base_url", "http://localhost:5140/satori/v1"
|
|
38
|
+
)
|
|
39
|
+
self.token = self.config.get("satori_token", "")
|
|
40
|
+
self.endpoint = self.config.get(
|
|
41
|
+
"satori_endpoint", "ws://127.0.0.1:5140/satori/v1/events"
|
|
42
|
+
)
|
|
43
|
+
self.auto_reconnect = self.config.get("satori_auto_reconnect", True)
|
|
44
|
+
self.heartbeat_interval = self.config.get("satori_heartbeat_interval", 10)
|
|
45
|
+
self.reconnect_delay = self.config.get("satori_reconnect_delay", 5)
|
|
46
|
+
|
|
47
|
+
self.ws: Optional[ClientConnection] = None
|
|
48
|
+
self.session: Optional[ClientSession] = None
|
|
49
|
+
self.sequence = 0
|
|
50
|
+
self.logins = []
|
|
51
|
+
self.running = False
|
|
52
|
+
self.heartbeat_task: Optional[asyncio.Task] = None
|
|
53
|
+
self.ready_received = False
|
|
54
|
+
|
|
55
|
+
async def send_by_session(
|
|
56
|
+
self, session: MessageSession, message_chain: MessageChain
|
|
57
|
+
):
|
|
58
|
+
from .satori_event import SatoriPlatformEvent
|
|
59
|
+
|
|
60
|
+
await SatoriPlatformEvent.send_with_adapter(
|
|
61
|
+
self, message_chain, session.session_id
|
|
62
|
+
)
|
|
63
|
+
await super().send_by_session(session, message_chain)
|
|
64
|
+
|
|
65
|
+
def meta(self) -> PlatformMetadata:
|
|
66
|
+
return PlatformMetadata(name="satori", description="Satori 通用协议适配器")
|
|
67
|
+
|
|
68
|
+
def _is_websocket_closed(self, ws) -> bool:
|
|
69
|
+
"""检查WebSocket连接是否已关闭"""
|
|
70
|
+
if not ws:
|
|
71
|
+
return True
|
|
72
|
+
try:
|
|
73
|
+
if hasattr(ws, "closed"):
|
|
74
|
+
return ws.closed
|
|
75
|
+
elif hasattr(ws, "close_code"):
|
|
76
|
+
return ws.close_code is not None
|
|
77
|
+
else:
|
|
78
|
+
return False
|
|
79
|
+
except AttributeError:
|
|
80
|
+
return False
|
|
81
|
+
|
|
82
|
+
async def run(self):
|
|
83
|
+
self.running = True
|
|
84
|
+
self.session = ClientSession(timeout=ClientTimeout(total=30))
|
|
85
|
+
|
|
86
|
+
retry_count = 0
|
|
87
|
+
max_retries = 10
|
|
88
|
+
|
|
89
|
+
while self.running:
|
|
90
|
+
try:
|
|
91
|
+
await self.connect_websocket()
|
|
92
|
+
retry_count = 0
|
|
93
|
+
except websockets.exceptions.ConnectionClosed as e:
|
|
94
|
+
logger.warning(f"Satori WebSocket 连接关闭: {e}")
|
|
95
|
+
retry_count += 1
|
|
96
|
+
except Exception as e:
|
|
97
|
+
logger.error(f"Satori WebSocket 连接失败: {e}")
|
|
98
|
+
retry_count += 1
|
|
99
|
+
|
|
100
|
+
if not self.running:
|
|
101
|
+
break
|
|
102
|
+
|
|
103
|
+
if retry_count >= max_retries:
|
|
104
|
+
logger.error(f"达到最大重试次数 ({max_retries}),停止重试")
|
|
105
|
+
break
|
|
106
|
+
|
|
107
|
+
if not self.auto_reconnect:
|
|
108
|
+
break
|
|
109
|
+
|
|
110
|
+
delay = min(self.reconnect_delay * (2 ** (retry_count - 1)), 60)
|
|
111
|
+
await asyncio.sleep(delay)
|
|
112
|
+
|
|
113
|
+
if self.session:
|
|
114
|
+
await self.session.close()
|
|
115
|
+
|
|
116
|
+
async def connect_websocket(self):
|
|
117
|
+
logger.info(f"Satori 适配器正在连接到 WebSocket: {self.endpoint}")
|
|
118
|
+
logger.info(f"Satori 适配器 HTTP API 地址: {self.api_base_url}")
|
|
119
|
+
|
|
120
|
+
if not self.endpoint.startswith(("ws://", "wss://")):
|
|
121
|
+
logger.error(f"无效的WebSocket URL: {self.endpoint}")
|
|
122
|
+
raise ValueError(f"WebSocket URL必须以ws://或wss://开头: {self.endpoint}")
|
|
123
|
+
|
|
124
|
+
try:
|
|
125
|
+
websocket = await connect(self.endpoint, additional_headers={})
|
|
126
|
+
self.ws = websocket
|
|
127
|
+
|
|
128
|
+
await asyncio.sleep(0.1)
|
|
129
|
+
|
|
130
|
+
await self.send_identify()
|
|
131
|
+
|
|
132
|
+
self.heartbeat_task = asyncio.create_task(self.heartbeat_loop())
|
|
133
|
+
|
|
134
|
+
async for message in websocket:
|
|
135
|
+
try:
|
|
136
|
+
await self.handle_message(message) # type: ignore
|
|
137
|
+
except Exception as e:
|
|
138
|
+
logger.error(f"Satori 处理消息异常: {e}")
|
|
139
|
+
|
|
140
|
+
except websockets.exceptions.ConnectionClosed as e:
|
|
141
|
+
logger.warning(f"Satori WebSocket 连接关闭: {e}")
|
|
142
|
+
raise
|
|
143
|
+
except Exception as e:
|
|
144
|
+
logger.error(f"Satori WebSocket 连接异常: {e}")
|
|
145
|
+
raise
|
|
146
|
+
finally:
|
|
147
|
+
if self.heartbeat_task:
|
|
148
|
+
self.heartbeat_task.cancel()
|
|
149
|
+
try:
|
|
150
|
+
await self.heartbeat_task
|
|
151
|
+
except asyncio.CancelledError:
|
|
152
|
+
pass
|
|
153
|
+
if self.ws:
|
|
154
|
+
try:
|
|
155
|
+
await self.ws.close()
|
|
156
|
+
except Exception as e:
|
|
157
|
+
logger.error(f"Satori WebSocket 关闭异常: {e}")
|
|
158
|
+
|
|
159
|
+
async def send_identify(self):
|
|
160
|
+
if not self.ws:
|
|
161
|
+
raise Exception("WebSocket连接未建立")
|
|
162
|
+
|
|
163
|
+
if self._is_websocket_closed(self.ws):
|
|
164
|
+
raise Exception("WebSocket连接已关闭")
|
|
165
|
+
|
|
166
|
+
identify_payload = {
|
|
167
|
+
"op": 3, # IDENTIFY
|
|
168
|
+
"body": {
|
|
169
|
+
"token": str(self.token) if self.token else "", # 字符串
|
|
170
|
+
},
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
# 只有在有序列号时才添加sn字段
|
|
174
|
+
if self.sequence > 0:
|
|
175
|
+
identify_payload["body"]["sn"] = self.sequence
|
|
176
|
+
|
|
177
|
+
try:
|
|
178
|
+
message_str = json.dumps(identify_payload, ensure_ascii=False)
|
|
179
|
+
await self.ws.send(message_str)
|
|
180
|
+
except websockets.exceptions.ConnectionClosed as e:
|
|
181
|
+
logger.error(f"发送 IDENTIFY 信令时连接关闭: {e}")
|
|
182
|
+
raise
|
|
183
|
+
except Exception as e:
|
|
184
|
+
logger.error(f"发送 IDENTIFY 信令失败: {e}")
|
|
185
|
+
raise
|
|
186
|
+
|
|
187
|
+
async def heartbeat_loop(self):
|
|
188
|
+
try:
|
|
189
|
+
while self.running and self.ws:
|
|
190
|
+
await asyncio.sleep(self.heartbeat_interval)
|
|
191
|
+
|
|
192
|
+
if self.ws and not self._is_websocket_closed(self.ws):
|
|
193
|
+
try:
|
|
194
|
+
ping_payload = {
|
|
195
|
+
"op": 1, # PING
|
|
196
|
+
"body": {},
|
|
197
|
+
}
|
|
198
|
+
await self.ws.send(json.dumps(ping_payload, ensure_ascii=False))
|
|
199
|
+
except websockets.exceptions.ConnectionClosed as e:
|
|
200
|
+
logger.error(f"Satori WebSocket 连接关闭: {e}")
|
|
201
|
+
break
|
|
202
|
+
except Exception as e:
|
|
203
|
+
logger.error(f"Satori WebSocket 发送心跳失败: {e}")
|
|
204
|
+
break
|
|
205
|
+
else:
|
|
206
|
+
break
|
|
207
|
+
except asyncio.CancelledError:
|
|
208
|
+
pass
|
|
209
|
+
except Exception as e:
|
|
210
|
+
logger.error(f"心跳任务异常: {e}")
|
|
211
|
+
|
|
212
|
+
async def handle_message(self, message: str):
|
|
213
|
+
try:
|
|
214
|
+
data = json.loads(message)
|
|
215
|
+
op = data.get("op")
|
|
216
|
+
body = data.get("body", {})
|
|
217
|
+
|
|
218
|
+
if op == 4: # READY
|
|
219
|
+
self.logins = body.get("logins", [])
|
|
220
|
+
self.ready_received = True
|
|
221
|
+
|
|
222
|
+
# 输出连接成功的bot信息
|
|
223
|
+
if self.logins:
|
|
224
|
+
for i, login in enumerate(self.logins):
|
|
225
|
+
platform = login.get("platform", "")
|
|
226
|
+
user = login.get("user", {})
|
|
227
|
+
user_id = user.get("id", "")
|
|
228
|
+
user_name = user.get("name", "")
|
|
229
|
+
logger.info(
|
|
230
|
+
f"Satori 连接成功 - Bot {i + 1}: platform={platform}, user_id={user_id}, user_name={user_name}"
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
if "sn" in body:
|
|
234
|
+
self.sequence = body["sn"]
|
|
235
|
+
|
|
236
|
+
elif op == 2: # PONG
|
|
237
|
+
pass
|
|
238
|
+
|
|
239
|
+
elif op == 0: # EVENT
|
|
240
|
+
await self.handle_event(body)
|
|
241
|
+
if "sn" in body:
|
|
242
|
+
self.sequence = body["sn"]
|
|
243
|
+
|
|
244
|
+
elif op == 5: # META
|
|
245
|
+
if "sn" in body:
|
|
246
|
+
self.sequence = body["sn"]
|
|
247
|
+
|
|
248
|
+
except json.JSONDecodeError as e:
|
|
249
|
+
logger.error(f"解析 WebSocket 消息失败: {e}, 消息内容: {message}")
|
|
250
|
+
except Exception as e:
|
|
251
|
+
logger.error(f"处理 WebSocket 消息异常: {e}")
|
|
252
|
+
|
|
253
|
+
async def handle_event(self, event_data: dict):
|
|
254
|
+
try:
|
|
255
|
+
event_type = event_data.get("type")
|
|
256
|
+
sn = event_data.get("sn")
|
|
257
|
+
if sn:
|
|
258
|
+
self.sequence = sn
|
|
259
|
+
|
|
260
|
+
if event_type == "message-created":
|
|
261
|
+
message = event_data.get("message", {})
|
|
262
|
+
user = event_data.get("user", {})
|
|
263
|
+
channel = event_data.get("channel", {})
|
|
264
|
+
guild = event_data.get("guild")
|
|
265
|
+
login = event_data.get("login", {})
|
|
266
|
+
timestamp = event_data.get("timestamp")
|
|
267
|
+
|
|
268
|
+
if user.get("id") == login.get("user", {}).get("id"):
|
|
269
|
+
return
|
|
270
|
+
|
|
271
|
+
abm = await self.convert_satori_message(
|
|
272
|
+
message, user, channel, guild, login, timestamp
|
|
273
|
+
)
|
|
274
|
+
if abm:
|
|
275
|
+
await self.handle_msg(abm)
|
|
276
|
+
|
|
277
|
+
except Exception as e:
|
|
278
|
+
logger.error(f"处理事件失败: {e}")
|
|
279
|
+
|
|
280
|
+
async def convert_satori_message(
|
|
281
|
+
self,
|
|
282
|
+
message: dict,
|
|
283
|
+
user: dict,
|
|
284
|
+
channel: dict,
|
|
285
|
+
guild: Optional[dict],
|
|
286
|
+
login: dict,
|
|
287
|
+
timestamp: Optional[int] = None,
|
|
288
|
+
) -> Optional[AstrBotMessage]:
|
|
289
|
+
try:
|
|
290
|
+
abm = AstrBotMessage()
|
|
291
|
+
abm.message_id = message.get("id", "")
|
|
292
|
+
abm.raw_message = {
|
|
293
|
+
"message": message,
|
|
294
|
+
"user": user,
|
|
295
|
+
"channel": channel,
|
|
296
|
+
"guild": guild,
|
|
297
|
+
"login": login,
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
if guild and guild.get("id"):
|
|
301
|
+
abm.type = MessageType.GROUP_MESSAGE
|
|
302
|
+
abm.group_id = guild.get("id", "")
|
|
303
|
+
abm.session_id = channel.get("id", "")
|
|
304
|
+
else:
|
|
305
|
+
abm.type = MessageType.FRIEND_MESSAGE
|
|
306
|
+
abm.session_id = channel.get("id", "")
|
|
307
|
+
|
|
308
|
+
abm.sender = MessageMember(
|
|
309
|
+
user_id=user.get("id", ""),
|
|
310
|
+
nickname=user.get("nick", user.get("name", "")),
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
abm.self_id = login.get("user", {}).get("id", "")
|
|
314
|
+
|
|
315
|
+
content = message.get("content", "")
|
|
316
|
+
abm.message = await self.parse_satori_elements(content)
|
|
317
|
+
|
|
318
|
+
# parse message_str
|
|
319
|
+
abm.message_str = ""
|
|
320
|
+
for comp in abm.message:
|
|
321
|
+
if isinstance(comp, Plain):
|
|
322
|
+
abm.message_str += comp.text
|
|
323
|
+
|
|
324
|
+
# 优先使用Satori事件中的时间戳
|
|
325
|
+
if timestamp is not None:
|
|
326
|
+
abm.timestamp = timestamp
|
|
327
|
+
else:
|
|
328
|
+
abm.timestamp = int(time.time())
|
|
329
|
+
|
|
330
|
+
return abm
|
|
331
|
+
|
|
332
|
+
except Exception as e:
|
|
333
|
+
logger.error(f"转换 Satori 消息失败: {e}")
|
|
334
|
+
return None
|
|
335
|
+
|
|
336
|
+
async def parse_satori_elements(self, content: str) -> list:
|
|
337
|
+
"""解析 Satori 消息元素"""
|
|
338
|
+
elements = []
|
|
339
|
+
|
|
340
|
+
if not content:
|
|
341
|
+
return elements
|
|
342
|
+
|
|
343
|
+
try:
|
|
344
|
+
wrapped_content = f"<root>{content}</root>"
|
|
345
|
+
root = ET.fromstring(wrapped_content)
|
|
346
|
+
await self._parse_xml_node(root, elements)
|
|
347
|
+
except ET.ParseError as e:
|
|
348
|
+
raise ValueError(f"解析 Satori 元素时发生解析错误: {e}")
|
|
349
|
+
except Exception as e:
|
|
350
|
+
raise e
|
|
351
|
+
|
|
352
|
+
# 如果没有解析到任何元素,将整个内容当作纯文本
|
|
353
|
+
if not elements and content.strip():
|
|
354
|
+
elements.append(Plain(text=content))
|
|
355
|
+
|
|
356
|
+
return elements
|
|
357
|
+
|
|
358
|
+
async def _parse_xml_node(self, node: ET.Element, elements: list) -> None:
|
|
359
|
+
"""递归解析 XML 节点"""
|
|
360
|
+
if node.text and node.text.strip():
|
|
361
|
+
elements.append(Plain(text=node.text))
|
|
362
|
+
|
|
363
|
+
for child in node:
|
|
364
|
+
tag_name = child.tag.lower()
|
|
365
|
+
attrs = child.attrib
|
|
366
|
+
|
|
367
|
+
if tag_name == "at":
|
|
368
|
+
user_id = attrs.get("id") or attrs.get("name", "")
|
|
369
|
+
elements.append(At(qq=user_id, name=user_id))
|
|
370
|
+
|
|
371
|
+
elif tag_name in ("img", "image"):
|
|
372
|
+
src = attrs.get("src", "")
|
|
373
|
+
if not src:
|
|
374
|
+
continue
|
|
375
|
+
if src.startswith("data:image/"):
|
|
376
|
+
src = src.split(",")[1]
|
|
377
|
+
elements.append(Image.fromBase64(src))
|
|
378
|
+
elif src.startswith("http"):
|
|
379
|
+
elements.append(Image.fromURL(src))
|
|
380
|
+
else:
|
|
381
|
+
logger.error(f"未知的图片 src 格式: {str(src)[:16]}")
|
|
382
|
+
|
|
383
|
+
elif tag_name == "file":
|
|
384
|
+
src = attrs.get("src", "")
|
|
385
|
+
name = attrs.get("name", "文件")
|
|
386
|
+
if src:
|
|
387
|
+
elements.append(File(file=src, name=name))
|
|
388
|
+
|
|
389
|
+
elif tag_name in ("audio", "record"):
|
|
390
|
+
src = attrs.get("src", "")
|
|
391
|
+
if not src:
|
|
392
|
+
continue
|
|
393
|
+
if src.startswith("data:audio/"):
|
|
394
|
+
src = src.split(",")[1]
|
|
395
|
+
elements.append(Record.fromBase64(src))
|
|
396
|
+
elif src.startswith("http"):
|
|
397
|
+
elements.append(Record.fromURL(src))
|
|
398
|
+
else:
|
|
399
|
+
logger.error(f"未知的音频 src 格式: {str(src)[:16]}")
|
|
400
|
+
|
|
401
|
+
else:
|
|
402
|
+
# 未知标签,递归处理其内容
|
|
403
|
+
if child.text and child.text.strip():
|
|
404
|
+
elements.append(Plain(text=child.text))
|
|
405
|
+
await self._parse_xml_node(child, elements)
|
|
406
|
+
|
|
407
|
+
# 处理标签后的文本
|
|
408
|
+
if child.tail and child.tail.strip():
|
|
409
|
+
elements.append(Plain(text=child.tail))
|
|
410
|
+
|
|
411
|
+
async def handle_msg(self, message: AstrBotMessage):
|
|
412
|
+
from .satori_event import SatoriPlatformEvent
|
|
413
|
+
|
|
414
|
+
message_event = SatoriPlatformEvent(
|
|
415
|
+
message_str=message.message_str,
|
|
416
|
+
message_obj=message,
|
|
417
|
+
platform_meta=self.meta(),
|
|
418
|
+
session_id=message.session_id,
|
|
419
|
+
adapter=self,
|
|
420
|
+
)
|
|
421
|
+
self.commit_event(message_event)
|
|
422
|
+
|
|
423
|
+
async def send_http_request(
|
|
424
|
+
self,
|
|
425
|
+
method: str,
|
|
426
|
+
path: str,
|
|
427
|
+
data: dict | None = None,
|
|
428
|
+
platform: str | None = None,
|
|
429
|
+
user_id: str | None = None,
|
|
430
|
+
) -> dict:
|
|
431
|
+
if not self.session:
|
|
432
|
+
raise Exception("HTTP session 未初始化")
|
|
433
|
+
|
|
434
|
+
headers = {
|
|
435
|
+
"Content-Type": "application/json",
|
|
436
|
+
}
|
|
437
|
+
|
|
438
|
+
if self.token:
|
|
439
|
+
headers["Authorization"] = f"Bearer {self.token}"
|
|
440
|
+
|
|
441
|
+
if platform and user_id:
|
|
442
|
+
headers["satori-platform"] = platform
|
|
443
|
+
headers["satori-user-id"] = user_id
|
|
444
|
+
elif self.logins:
|
|
445
|
+
current_login = self.logins[0]
|
|
446
|
+
headers["satori-platform"] = current_login.get("platform", "")
|
|
447
|
+
user = current_login.get("user", {})
|
|
448
|
+
headers["satori-user-id"] = user.get("id", "") if user else ""
|
|
449
|
+
|
|
450
|
+
if not path.startswith("/"):
|
|
451
|
+
path = "/" + path
|
|
452
|
+
|
|
453
|
+
# 使用新的API地址配置
|
|
454
|
+
url = f"{self.api_base_url.rstrip('/')}{path}"
|
|
455
|
+
|
|
456
|
+
try:
|
|
457
|
+
async with self.session.request(
|
|
458
|
+
method, url, json=data, headers=headers
|
|
459
|
+
) as response:
|
|
460
|
+
if response.status == 200:
|
|
461
|
+
result = await response.json()
|
|
462
|
+
return result
|
|
463
|
+
else:
|
|
464
|
+
return {}
|
|
465
|
+
except Exception as e:
|
|
466
|
+
logger.error(f"Satori HTTP 请求异常: {e}")
|
|
467
|
+
return {}
|
|
468
|
+
|
|
469
|
+
async def terminate(self):
|
|
470
|
+
self.running = False
|
|
471
|
+
|
|
472
|
+
if self.heartbeat_task:
|
|
473
|
+
self.heartbeat_task.cancel()
|
|
474
|
+
|
|
475
|
+
if self.ws:
|
|
476
|
+
try:
|
|
477
|
+
await self.ws.close()
|
|
478
|
+
except Exception as e:
|
|
479
|
+
logger.error(f"Satori WebSocket 关闭异常: {e}")
|
|
480
|
+
|
|
481
|
+
if self.session:
|
|
482
|
+
await self.session.close()
|
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
2
|
+
from astrbot.api import logger
|
|
3
|
+
from astrbot.api.event import AstrMessageEvent, MessageChain
|
|
4
|
+
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
|
5
|
+
from astrbot.api.message_components import Plain, Image, At, File, Record
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from .satori_adapter import SatoriPlatformAdapter
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SatoriPlatformEvent(AstrMessageEvent):
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
message_str: str,
|
|
15
|
+
message_obj: AstrBotMessage,
|
|
16
|
+
platform_meta: PlatformMetadata,
|
|
17
|
+
session_id: str,
|
|
18
|
+
adapter: "SatoriPlatformAdapter",
|
|
19
|
+
):
|
|
20
|
+
super().__init__(message_str, message_obj, platform_meta, session_id)
|
|
21
|
+
self.adapter = adapter
|
|
22
|
+
self.platform = None
|
|
23
|
+
self.user_id = None
|
|
24
|
+
if (
|
|
25
|
+
hasattr(message_obj, "raw_message")
|
|
26
|
+
and message_obj.raw_message
|
|
27
|
+
and isinstance(message_obj.raw_message, dict)
|
|
28
|
+
):
|
|
29
|
+
login = message_obj.raw_message.get("login", {})
|
|
30
|
+
self.platform = login.get("platform")
|
|
31
|
+
user = login.get("user", {})
|
|
32
|
+
self.user_id = user.get("id") if user else None
|
|
33
|
+
|
|
34
|
+
@classmethod
|
|
35
|
+
async def send_with_adapter(
|
|
36
|
+
cls, adapter: "SatoriPlatformAdapter", message: MessageChain, session_id: str
|
|
37
|
+
):
|
|
38
|
+
try:
|
|
39
|
+
content_parts = []
|
|
40
|
+
|
|
41
|
+
for component in message.chain:
|
|
42
|
+
if isinstance(component, Plain):
|
|
43
|
+
text = (
|
|
44
|
+
component.text.replace("&", "&")
|
|
45
|
+
.replace("<", "<")
|
|
46
|
+
.replace(">", ">")
|
|
47
|
+
)
|
|
48
|
+
content_parts.append(text)
|
|
49
|
+
|
|
50
|
+
elif isinstance(component, At):
|
|
51
|
+
if component.qq:
|
|
52
|
+
content_parts.append(f'<at id="{component.qq}"/>')
|
|
53
|
+
elif component.name:
|
|
54
|
+
content_parts.append(f'<at name="{component.name}"/>')
|
|
55
|
+
|
|
56
|
+
elif isinstance(component, Image):
|
|
57
|
+
try:
|
|
58
|
+
image_base64 = await component.convert_to_base64()
|
|
59
|
+
if image_base64:
|
|
60
|
+
content_parts.append(
|
|
61
|
+
f'<img src="data:image/jpeg;base64,{image_base64}"/>'
|
|
62
|
+
)
|
|
63
|
+
except Exception as e:
|
|
64
|
+
logger.error(f"图片转换为base64失败: {e}")
|
|
65
|
+
|
|
66
|
+
elif isinstance(component, File):
|
|
67
|
+
content_parts.append(
|
|
68
|
+
f'<file src="{component.file}" name="{component.name or "文件"}"/>'
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
elif isinstance(component, Record):
|
|
72
|
+
try:
|
|
73
|
+
record_base64 = await component.convert_to_base64()
|
|
74
|
+
if record_base64:
|
|
75
|
+
content_parts.append(
|
|
76
|
+
f'<audio src="data:audio/wav;base64,{record_base64}"/>'
|
|
77
|
+
)
|
|
78
|
+
except Exception as e:
|
|
79
|
+
logger.error(f"语音转换为base64失败: {e}")
|
|
80
|
+
|
|
81
|
+
content = "".join(content_parts)
|
|
82
|
+
channel_id = session_id
|
|
83
|
+
data = {"channel_id": channel_id, "content": content}
|
|
84
|
+
|
|
85
|
+
platform = None
|
|
86
|
+
user_id = None
|
|
87
|
+
|
|
88
|
+
if hasattr(adapter, "logins") and adapter.logins:
|
|
89
|
+
current_login = adapter.logins[0]
|
|
90
|
+
platform = current_login.get("platform", "")
|
|
91
|
+
user = current_login.get("user", {})
|
|
92
|
+
user_id = user.get("id", "") if user else ""
|
|
93
|
+
|
|
94
|
+
result = await adapter.send_http_request(
|
|
95
|
+
"POST", "/message.create", data, platform, user_id
|
|
96
|
+
)
|
|
97
|
+
if result:
|
|
98
|
+
return result
|
|
99
|
+
else:
|
|
100
|
+
return None
|
|
101
|
+
|
|
102
|
+
except Exception as e:
|
|
103
|
+
logger.error(f"Satori 消息发送异常: {e}")
|
|
104
|
+
return None
|
|
105
|
+
|
|
106
|
+
async def send(self, message: MessageChain):
|
|
107
|
+
platform = getattr(self, "platform", None)
|
|
108
|
+
user_id = getattr(self, "user_id", None)
|
|
109
|
+
|
|
110
|
+
if not platform or not user_id:
|
|
111
|
+
if hasattr(self.adapter, "logins") and self.adapter.logins:
|
|
112
|
+
current_login = self.adapter.logins[0]
|
|
113
|
+
platform = current_login.get("platform", "")
|
|
114
|
+
user = current_login.get("user", {})
|
|
115
|
+
user_id = user.get("id", "") if user else ""
|
|
116
|
+
|
|
117
|
+
try:
|
|
118
|
+
content_parts = []
|
|
119
|
+
|
|
120
|
+
for component in message.chain:
|
|
121
|
+
if isinstance(component, Plain):
|
|
122
|
+
text = (
|
|
123
|
+
component.text.replace("&", "&")
|
|
124
|
+
.replace("<", "<")
|
|
125
|
+
.replace(">", ">")
|
|
126
|
+
)
|
|
127
|
+
content_parts.append(text)
|
|
128
|
+
|
|
129
|
+
elif isinstance(component, At):
|
|
130
|
+
if component.qq:
|
|
131
|
+
content_parts.append(f'<at id="{component.qq}"/>')
|
|
132
|
+
elif component.name:
|
|
133
|
+
content_parts.append(f'<at name="{component.name}"/>')
|
|
134
|
+
|
|
135
|
+
elif isinstance(component, Image):
|
|
136
|
+
try:
|
|
137
|
+
image_base64 = await component.convert_to_base64()
|
|
138
|
+
if image_base64:
|
|
139
|
+
content_parts.append(
|
|
140
|
+
f'<img src="data:image/jpeg;base64,{image_base64}"/>'
|
|
141
|
+
)
|
|
142
|
+
except Exception as e:
|
|
143
|
+
logger.error(f"图片转换为base64失败: {e}")
|
|
144
|
+
|
|
145
|
+
elif isinstance(component, File):
|
|
146
|
+
content_parts.append(
|
|
147
|
+
f'<file src="{component.file}" name="{component.name or "文件"}"/>'
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
elif isinstance(component, Record):
|
|
151
|
+
try:
|
|
152
|
+
record_base64 = await component.convert_to_base64()
|
|
153
|
+
if record_base64:
|
|
154
|
+
content_parts.append(
|
|
155
|
+
f'<audio src="data:audio/wav;base64,{record_base64}"/>'
|
|
156
|
+
)
|
|
157
|
+
except Exception as e:
|
|
158
|
+
logger.error(f"语音转换为base64失败: {e}")
|
|
159
|
+
|
|
160
|
+
content = "".join(content_parts)
|
|
161
|
+
channel_id = self.session_id
|
|
162
|
+
data = {"channel_id": channel_id, "content": content}
|
|
163
|
+
|
|
164
|
+
result = await self.adapter.send_http_request(
|
|
165
|
+
"POST", "/message.create", data, platform, user_id
|
|
166
|
+
)
|
|
167
|
+
if not result:
|
|
168
|
+
logger.error("Satori 消息发送失败")
|
|
169
|
+
except Exception as e:
|
|
170
|
+
logger.error(f"Satori 消息发送异常: {e}")
|
|
171
|
+
|
|
172
|
+
await super().send(message)
|
|
173
|
+
|
|
174
|
+
async def send_streaming(self, generator, use_fallback: bool = False):
|
|
175
|
+
try:
|
|
176
|
+
content_parts = []
|
|
177
|
+
|
|
178
|
+
async for chain in generator:
|
|
179
|
+
if isinstance(chain, MessageChain):
|
|
180
|
+
if chain.type == "break":
|
|
181
|
+
if content_parts:
|
|
182
|
+
content = "".join(content_parts)
|
|
183
|
+
temp_chain = MessageChain([Plain(text=content)])
|
|
184
|
+
await self.send(temp_chain)
|
|
185
|
+
content_parts = []
|
|
186
|
+
continue
|
|
187
|
+
|
|
188
|
+
for component in chain.chain:
|
|
189
|
+
if isinstance(component, Plain):
|
|
190
|
+
content_parts.append(component.text)
|
|
191
|
+
elif isinstance(component, Image):
|
|
192
|
+
if content_parts:
|
|
193
|
+
content = "".join(content_parts)
|
|
194
|
+
temp_chain = MessageChain([Plain(text=content)])
|
|
195
|
+
await self.send(temp_chain)
|
|
196
|
+
content_parts = []
|
|
197
|
+
try:
|
|
198
|
+
image_base64 = await component.convert_to_base64()
|
|
199
|
+
if image_base64:
|
|
200
|
+
img_chain = MessageChain(
|
|
201
|
+
[
|
|
202
|
+
Plain(
|
|
203
|
+
text=f'<img src="data:image/jpeg;base64,{image_base64}"/>'
|
|
204
|
+
)
|
|
205
|
+
]
|
|
206
|
+
)
|
|
207
|
+
await self.send(img_chain)
|
|
208
|
+
except Exception as e:
|
|
209
|
+
logger.error(f"图片转换为base64失败: {e}")
|
|
210
|
+
else:
|
|
211
|
+
content_parts.append(str(component))
|
|
212
|
+
|
|
213
|
+
if content_parts:
|
|
214
|
+
content = "".join(content_parts)
|
|
215
|
+
temp_chain = MessageChain([Plain(text=content)])
|
|
216
|
+
await self.send(temp_chain)
|
|
217
|
+
|
|
218
|
+
except Exception as e:
|
|
219
|
+
logger.error(f"Satori 流式消息发送异常: {e}")
|
|
220
|
+
|
|
221
|
+
return await super().send_streaming(generator, use_fallback)
|