AstrBot 4.1.3__py3-none-any.whl → 4.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. astrbot/core/agent/agent.py +1 -1
  2. astrbot/core/agent/mcp_client.py +3 -1
  3. astrbot/core/agent/runners/tool_loop_agent_runner.py +6 -27
  4. astrbot/core/agent/tool.py +28 -17
  5. astrbot/core/config/default.py +50 -14
  6. astrbot/core/db/sqlite.py +15 -1
  7. astrbot/core/pipeline/content_safety_check/stage.py +1 -1
  8. astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +1 -1
  9. astrbot/core/pipeline/content_safety_check/strategies/keywords.py +1 -1
  10. astrbot/core/pipeline/context_utils.py +4 -1
  11. astrbot/core/pipeline/process_stage/method/llm_request.py +23 -4
  12. astrbot/core/pipeline/process_stage/method/star_request.py +8 -6
  13. astrbot/core/platform/manager.py +4 -0
  14. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +2 -1
  15. astrbot/core/platform/sources/misskey/misskey_adapter.py +391 -0
  16. astrbot/core/platform/sources/misskey/misskey_api.py +404 -0
  17. astrbot/core/platform/sources/misskey/misskey_event.py +123 -0
  18. astrbot/core/platform/sources/misskey/misskey_utils.py +327 -0
  19. astrbot/core/platform/sources/satori/satori_adapter.py +290 -24
  20. astrbot/core/platform/sources/satori/satori_event.py +9 -0
  21. astrbot/core/platform/sources/telegram/tg_event.py +0 -1
  22. astrbot/core/provider/entities.py +13 -3
  23. astrbot/core/provider/func_tool_manager.py +4 -4
  24. astrbot/core/provider/manager.py +35 -19
  25. astrbot/core/star/context.py +26 -12
  26. astrbot/core/star/filter/command.py +3 -4
  27. astrbot/core/star/filter/command_group.py +4 -4
  28. astrbot/core/star/filter/platform_adapter_type.py +10 -5
  29. astrbot/core/star/register/star.py +3 -1
  30. astrbot/core/star/register/star_handler.py +65 -36
  31. astrbot/core/star/session_plugin_manager.py +3 -0
  32. astrbot/core/star/star_handler.py +4 -4
  33. astrbot/core/star/star_manager.py +10 -4
  34. astrbot/core/star/star_tools.py +6 -2
  35. astrbot/core/star/updator.py +3 -0
  36. {astrbot-4.1.3.dist-info → astrbot-4.1.5.dist-info}/METADATA +6 -7
  37. {astrbot-4.1.3.dist-info → astrbot-4.1.5.dist-info}/RECORD +40 -36
  38. {astrbot-4.1.3.dist-info → astrbot-4.1.5.dist-info}/WHEEL +0 -0
  39. {astrbot-4.1.3.dist-info → astrbot-4.1.5.dist-info}/entry_points.txt +0 -0
  40. {astrbot-4.1.3.dist-info → astrbot-4.1.5.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,404 @@
1
+ import json
2
+ from typing import Any, Optional, Dict, List, Callable, Awaitable
3
+ import uuid
4
+
5
+ try:
6
+ import aiohttp
7
+ import websockets
8
+ except ImportError as e:
9
+ raise ImportError(
10
+ "aiohttp and websockets are required for Misskey API. Please install them with: pip install aiohttp websockets"
11
+ ) from e
12
+
13
+ from astrbot.api import logger
14
+
15
+ # Constants
16
+ API_MAX_RETRIES = 3
17
+ HTTP_OK = 200
18
+
19
+
20
+ class APIError(Exception):
21
+ """Misskey API 基础异常"""
22
+
23
+ pass
24
+
25
+
26
+ class APIConnectionError(APIError):
27
+ """网络连接异常"""
28
+
29
+ pass
30
+
31
+
32
+ class APIRateLimitError(APIError):
33
+ """API 频率限制异常"""
34
+
35
+ pass
36
+
37
+
38
+ class AuthenticationError(APIError):
39
+ """认证失败异常"""
40
+
41
+ pass
42
+
43
+
44
+ class WebSocketError(APIError):
45
+ """WebSocket 连接异常"""
46
+
47
+ pass
48
+
49
+
50
+ class StreamingClient:
51
+ def __init__(self, instance_url: str, access_token: str):
52
+ self.instance_url = instance_url.rstrip("/")
53
+ self.access_token = access_token
54
+ self.websocket: Optional[Any] = None
55
+ self.is_connected = False
56
+ self.message_handlers: Dict[str, Callable] = {}
57
+ self.channels: Dict[str, str] = {}
58
+ self._running = False
59
+ self._last_pong = None
60
+
61
+ async def connect(self) -> bool:
62
+ try:
63
+ ws_url = self.instance_url.replace("https://", "wss://").replace(
64
+ "http://", "ws://"
65
+ )
66
+ ws_url += f"/streaming?i={self.access_token}"
67
+
68
+ self.websocket = await websockets.connect(
69
+ ws_url, ping_interval=30, ping_timeout=10
70
+ )
71
+ self.is_connected = True
72
+ self._running = True
73
+
74
+ logger.info("[Misskey WebSocket] 已连接")
75
+ return True
76
+
77
+ except Exception as e:
78
+ logger.error(f"[Misskey WebSocket] 连接失败: {e}")
79
+ self.is_connected = False
80
+ return False
81
+
82
+ async def disconnect(self):
83
+ self._running = False
84
+ if self.websocket:
85
+ await self.websocket.close()
86
+ self.websocket = None
87
+ self.is_connected = False
88
+ logger.info("[Misskey WebSocket] 连接已断开")
89
+
90
+ async def subscribe_channel(
91
+ self, channel_type: str, params: Optional[Dict] = None
92
+ ) -> str:
93
+ if not self.is_connected or not self.websocket:
94
+ raise WebSocketError("WebSocket 未连接")
95
+
96
+ channel_id = str(uuid.uuid4())
97
+ message = {
98
+ "type": "connect",
99
+ "body": {"channel": channel_type, "id": channel_id, "params": params or {}},
100
+ }
101
+
102
+ await self.websocket.send(json.dumps(message))
103
+ self.channels[channel_id] = channel_type
104
+ return channel_id
105
+
106
+ async def unsubscribe_channel(self, channel_id: str):
107
+ if (
108
+ not self.is_connected
109
+ or not self.websocket
110
+ or channel_id not in self.channels
111
+ ):
112
+ return
113
+
114
+ message = {"type": "disconnect", "body": {"id": channel_id}}
115
+
116
+ await self.websocket.send(json.dumps(message))
117
+ del self.channels[channel_id]
118
+
119
+ def add_message_handler(
120
+ self, event_type: str, handler: Callable[[Dict], Awaitable[None]]
121
+ ):
122
+ self.message_handlers[event_type] = handler
123
+
124
+ async def listen(self):
125
+ if not self.is_connected or not self.websocket:
126
+ raise WebSocketError("WebSocket 未连接")
127
+
128
+ try:
129
+ async for message in self.websocket:
130
+ if not self._running:
131
+ break
132
+
133
+ try:
134
+ data = json.loads(message)
135
+ await self._handle_message(data)
136
+ except json.JSONDecodeError as e:
137
+ logger.warning(f"[Misskey WebSocket] 无法解析消息: {e}")
138
+ except Exception as e:
139
+ logger.error(f"[Misskey WebSocket] 处理消息失败: {e}")
140
+
141
+ except websockets.exceptions.ConnectionClosedError as e:
142
+ logger.warning(f"[Misskey WebSocket] 连接意外关闭: {e}")
143
+ self.is_connected = False
144
+ except websockets.exceptions.ConnectionClosed as e:
145
+ logger.warning(
146
+ f"[Misskey WebSocket] 连接已关闭 (代码: {e.code}, 原因: {e.reason})"
147
+ )
148
+ self.is_connected = False
149
+ except websockets.exceptions.InvalidHandshake as e:
150
+ logger.error(f"[Misskey WebSocket] 握手失败: {e}")
151
+ self.is_connected = False
152
+ except Exception as e:
153
+ logger.error(f"[Misskey WebSocket] 监听消息失败: {e}")
154
+ self.is_connected = False
155
+
156
+ async def _handle_message(self, data: Dict[str, Any]):
157
+ message_type = data.get("type")
158
+ body = data.get("body", {})
159
+
160
+ logger.debug(
161
+ f"[Misskey WebSocket] 收到消息类型: {message_type}\n数据: {json.dumps(data, indent=2, ensure_ascii=False)}"
162
+ )
163
+
164
+ if message_type == "channel":
165
+ channel_id = body.get("id")
166
+ event_type = body.get("type")
167
+ event_body = body.get("body", {})
168
+
169
+ logger.debug(
170
+ f"[Misskey WebSocket] 频道消息: {channel_id}, 事件类型: {event_type}"
171
+ )
172
+
173
+ if channel_id in self.channels:
174
+ channel_type = self.channels[channel_id]
175
+ handler_key = f"{channel_type}:{event_type}"
176
+
177
+ if handler_key in self.message_handlers:
178
+ logger.debug(f"[Misskey WebSocket] 使用处理器: {handler_key}")
179
+ await self.message_handlers[handler_key](event_body)
180
+ elif event_type in self.message_handlers:
181
+ logger.debug(f"[Misskey WebSocket] 使用事件处理器: {event_type}")
182
+ await self.message_handlers[event_type](event_body)
183
+ else:
184
+ logger.debug(
185
+ f"[Misskey WebSocket] 未找到处理器: {handler_key} 或 {event_type}"
186
+ )
187
+ if "_debug" in self.message_handlers:
188
+ await self.message_handlers["_debug"](
189
+ {
190
+ "type": event_type,
191
+ "body": event_body,
192
+ "channel": channel_type,
193
+ }
194
+ )
195
+
196
+ elif message_type in self.message_handlers:
197
+ logger.debug(f"[Misskey WebSocket] 直接消息处理器: {message_type}")
198
+ await self.message_handlers[message_type](body)
199
+ else:
200
+ logger.debug(f"[Misskey WebSocket] 未处理的消息类型: {message_type}")
201
+ if "_debug" in self.message_handlers:
202
+ await self.message_handlers["_debug"](data)
203
+
204
+
205
+ def retry_async(max_retries: int = 3, retryable_exceptions: tuple = ()):
206
+ def decorator(func):
207
+ async def wrapper(*args, **kwargs):
208
+ last_exc = None
209
+ for _ in range(max_retries):
210
+ try:
211
+ return await func(*args, **kwargs)
212
+ except retryable_exceptions as e:
213
+ last_exc = e
214
+ continue
215
+ if last_exc:
216
+ raise last_exc
217
+
218
+ return wrapper
219
+
220
+ return decorator
221
+
222
+
223
+ class MisskeyAPI:
224
+ def __init__(self, instance_url: str, access_token: str):
225
+ self.instance_url = instance_url.rstrip("/")
226
+ self.access_token = access_token
227
+ self._session: Optional[aiohttp.ClientSession] = None
228
+ self.streaming: Optional[StreamingClient] = None
229
+
230
+ async def __aenter__(self):
231
+ return self
232
+
233
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
234
+ await self.close()
235
+ return False
236
+
237
+ async def close(self) -> None:
238
+ if self.streaming:
239
+ await self.streaming.disconnect()
240
+ self.streaming = None
241
+ if self._session:
242
+ await self._session.close()
243
+ self._session = None
244
+ logger.debug("[Misskey API] 客户端已关闭")
245
+
246
+ def get_streaming_client(self) -> StreamingClient:
247
+ if not self.streaming:
248
+ self.streaming = StreamingClient(self.instance_url, self.access_token)
249
+ return self.streaming
250
+
251
+ @property
252
+ def session(self) -> aiohttp.ClientSession:
253
+ if self._session is None or self._session.closed:
254
+ headers = {"Authorization": f"Bearer {self.access_token}"}
255
+ self._session = aiohttp.ClientSession(headers=headers)
256
+ return self._session
257
+
258
+ def _handle_response_status(self, status: int, endpoint: str):
259
+ """处理 HTTP 响应状态码"""
260
+ if status == 400:
261
+ logger.error(f"API 请求错误: {endpoint} (状态码: {status})")
262
+ raise APIError(f"Bad request for {endpoint}")
263
+ elif status in (401, 403):
264
+ logger.error(f"API 认证失败: {endpoint} (状态码: {status})")
265
+ raise AuthenticationError(f"Authentication failed for {endpoint}")
266
+ elif status == 429:
267
+ logger.warning(f"API 频率限制: {endpoint} (状态码: {status})")
268
+ raise APIRateLimitError(f"Rate limit exceeded for {endpoint}")
269
+ else:
270
+ logger.error(f"API 请求失败: {endpoint} (状态码: {status})")
271
+ raise APIConnectionError(f"HTTP {status} for {endpoint}")
272
+
273
+ async def _process_response(
274
+ self, response: aiohttp.ClientResponse, endpoint: str
275
+ ) -> Any:
276
+ """处理 API 响应"""
277
+ if response.status == HTTP_OK:
278
+ try:
279
+ result = await response.json()
280
+ if endpoint == "i/notifications":
281
+ notifications_data = (
282
+ result
283
+ if isinstance(result, list)
284
+ else result.get("notifications", [])
285
+ if isinstance(result, dict)
286
+ else []
287
+ )
288
+ if notifications_data:
289
+ logger.debug(f"获取到 {len(notifications_data)} 条新通知")
290
+ else:
291
+ logger.debug(f"API 请求成功: {endpoint}")
292
+ return result
293
+ except json.JSONDecodeError as e:
294
+ logger.error(f"响应不是有效的 JSON 格式: {e}")
295
+ raise APIConnectionError("Invalid JSON response") from e
296
+ else:
297
+ try:
298
+ error_text = await response.text()
299
+ logger.error(
300
+ f"API 请求失败: {endpoint} - 状态码: {response.status}, 响应: {error_text}"
301
+ )
302
+ except Exception:
303
+ logger.error(f"API 请求失败: {endpoint} - 状态码: {response.status}")
304
+
305
+ self._handle_response_status(response.status, endpoint)
306
+ raise APIConnectionError(f"Request failed for {endpoint}")
307
+
308
+ @retry_async(
309
+ max_retries=API_MAX_RETRIES,
310
+ retryable_exceptions=(APIConnectionError, APIRateLimitError),
311
+ )
312
+ async def _make_request(
313
+ self, endpoint: str, data: Optional[Dict[str, Any]] = None
314
+ ) -> Any:
315
+ url = f"{self.instance_url}/api/{endpoint}"
316
+ payload = {"i": self.access_token}
317
+ if data:
318
+ payload.update(data)
319
+
320
+ try:
321
+ async with self.session.post(url, json=payload) as response:
322
+ return await self._process_response(response, endpoint)
323
+ except aiohttp.ClientError as e:
324
+ logger.error(f"HTTP 请求错误: {e}")
325
+ raise APIConnectionError(f"HTTP request failed: {e}") from e
326
+
327
+ async def create_note(
328
+ self,
329
+ text: str,
330
+ visibility: str = "public",
331
+ reply_id: Optional[str] = None,
332
+ visible_user_ids: Optional[List[str]] = None,
333
+ local_only: bool = False,
334
+ ) -> Dict[str, Any]:
335
+ """创建新贴文"""
336
+ data: Dict[str, Any] = {
337
+ "text": text,
338
+ "visibility": visibility,
339
+ "localOnly": local_only,
340
+ }
341
+ if reply_id:
342
+ data["replyId"] = reply_id
343
+ if visible_user_ids and visibility == "specified":
344
+ data["visibleUserIds"] = visible_user_ids
345
+
346
+ result = await self._make_request("notes/create", data)
347
+ note_id = result.get("createdNote", {}).get("id", "unknown")
348
+ logger.debug(f"发帖成功,note_id: {note_id}")
349
+ return result
350
+
351
+ async def get_current_user(self) -> Dict[str, Any]:
352
+ """获取当前用户信息"""
353
+ return await self._make_request("i", {})
354
+
355
+ async def send_message(self, user_id: str, text: str) -> Dict[str, Any]:
356
+ """发送聊天消息"""
357
+ result = await self._make_request(
358
+ "chat/messages/create-to-user", {"toUserId": user_id, "text": text}
359
+ )
360
+ message_id = result.get("id", "unknown")
361
+ logger.debug(f"聊天发送成功,message_id: {message_id}")
362
+ return result
363
+
364
+ async def send_room_message(self, room_id: str, text: str) -> Dict[str, Any]:
365
+ """发送房间消息"""
366
+ result = await self._make_request(
367
+ "chat/messages/create-to-room", {"toRoomId": room_id, "text": text}
368
+ )
369
+ message_id = result.get("id", "unknown")
370
+ logger.debug(f"房间消息发送成功,message_id: {message_id}")
371
+ return result
372
+
373
+ async def get_messages(
374
+ self, user_id: str, limit: int = 10, since_id: Optional[str] = None
375
+ ) -> List[Dict[str, Any]]:
376
+ """获取聊天消息历史"""
377
+ data: Dict[str, Any] = {"userId": user_id, "limit": limit}
378
+ if since_id:
379
+ data["sinceId"] = since_id
380
+
381
+ result = await self._make_request("chat/messages/user-timeline", data)
382
+ if isinstance(result, list):
383
+ return result
384
+ else:
385
+ logger.warning(f"获取聊天消息响应格式异常: {type(result)}")
386
+ return []
387
+
388
+ async def get_mentions(
389
+ self, limit: int = 10, since_id: Optional[str] = None
390
+ ) -> List[Dict[str, Any]]:
391
+ """获取提及通知"""
392
+ data: Dict[str, Any] = {"limit": limit}
393
+ if since_id:
394
+ data["sinceId"] = since_id
395
+ data["includeTypes"] = ["mention", "reply", "quote"]
396
+
397
+ result = await self._make_request("i/notifications", data)
398
+ if isinstance(result, list):
399
+ return result
400
+ elif isinstance(result, dict) and "notifications" in result:
401
+ return result["notifications"]
402
+ else:
403
+ logger.warning(f"获取提及通知响应格式异常: {type(result)}")
404
+ return []
@@ -0,0 +1,123 @@
1
+ import asyncio
2
+ import re
3
+ from typing import AsyncGenerator
4
+ from astrbot.api import logger
5
+ from astrbot.api.event import AstrMessageEvent, MessageChain
6
+ from astrbot.api.platform import PlatformMetadata, AstrBotMessage
7
+ from astrbot.api.message_components import Plain
8
+
9
+ from .misskey_utils import (
10
+ serialize_message_chain,
11
+ resolve_visibility_from_raw_message,
12
+ is_valid_user_session_id,
13
+ is_valid_room_session_id,
14
+ add_at_mention_if_needed,
15
+ extract_user_id_from_session_id,
16
+ extract_room_id_from_session_id,
17
+ )
18
+
19
+
20
+ class MisskeyPlatformEvent(AstrMessageEvent):
21
+ def __init__(
22
+ self,
23
+ message_str: str,
24
+ message_obj: AstrBotMessage,
25
+ platform_meta: PlatformMetadata,
26
+ session_id: str,
27
+ client,
28
+ ):
29
+ super().__init__(message_str, message_obj, platform_meta, session_id)
30
+ self.client = client
31
+
32
+ def _is_system_command(self, message_str: str) -> bool:
33
+ """检测是否为系统指令"""
34
+ if not message_str or not message_str.strip():
35
+ return False
36
+
37
+ system_prefixes = ["/", "!", "#", ".", "^"]
38
+ message_trimmed = message_str.strip()
39
+
40
+ return any(message_trimmed.startswith(prefix) for prefix in system_prefixes)
41
+
42
+ async def send(self, message: MessageChain):
43
+ content, has_at = serialize_message_chain(message.chain)
44
+
45
+ if not content:
46
+ logger.debug("[MisskeyEvent] 内容为空,跳过发送")
47
+ return
48
+
49
+ try:
50
+ original_message_id = getattr(self.message_obj, "message_id", None)
51
+ raw_message = getattr(self.message_obj, "raw_message", {})
52
+
53
+ if raw_message and not has_at:
54
+ user_data = raw_message.get("user", {})
55
+ user_info = {
56
+ "username": user_data.get("username", ""),
57
+ "nickname": user_data.get("name", user_data.get("username", "")),
58
+ }
59
+ content = add_at_mention_if_needed(content, user_info, has_at)
60
+
61
+ # 根据会话类型选择发送方式
62
+ if hasattr(self.client, "send_message") and is_valid_user_session_id(
63
+ self.session_id
64
+ ):
65
+ user_id = extract_user_id_from_session_id(self.session_id)
66
+ await self.client.send_message(user_id, content)
67
+ elif hasattr(self.client, "send_room_message") and is_valid_room_session_id(
68
+ self.session_id
69
+ ):
70
+ room_id = extract_room_id_from_session_id(self.session_id)
71
+ await self.client.send_room_message(room_id, content)
72
+ elif original_message_id and hasattr(self.client, "create_note"):
73
+ visibility, visible_user_ids = resolve_visibility_from_raw_message(
74
+ raw_message
75
+ )
76
+ await self.client.create_note(
77
+ content,
78
+ reply_id=original_message_id,
79
+ visibility=visibility,
80
+ visible_user_ids=visible_user_ids,
81
+ )
82
+ elif hasattr(self.client, "create_note"):
83
+ logger.debug("[MisskeyEvent] 创建新帖子")
84
+ await self.client.create_note(content)
85
+
86
+ await super().send(message)
87
+
88
+ except Exception as e:
89
+ logger.error(f"[MisskeyEvent] 发送失败: {e}")
90
+
91
+ async def send_streaming(
92
+ self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False
93
+ ):
94
+ if not use_fallback:
95
+ buffer = None
96
+ async for chain in generator:
97
+ if not buffer:
98
+ buffer = chain
99
+ else:
100
+ buffer.chain.extend(chain.chain)
101
+ if not buffer:
102
+ return
103
+ buffer.squash_plain()
104
+ await self.send(buffer)
105
+ return await super().send_streaming(generator, use_fallback)
106
+
107
+ buffer = ""
108
+ pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
109
+
110
+ async for chain in generator:
111
+ if isinstance(chain, MessageChain):
112
+ for comp in chain.chain:
113
+ if isinstance(comp, Plain):
114
+ buffer += comp.text
115
+ if any(p in buffer for p in "。?!~…"):
116
+ buffer = await self.process_buffer(buffer, pattern)
117
+ else:
118
+ await self.send(MessageChain(chain=[comp]))
119
+ await asyncio.sleep(1.5) # 限速
120
+
121
+ if buffer.strip():
122
+ await self.send(MessageChain([Plain(buffer)]))
123
+ return await super().send_streaming(generator, use_fallback)