astrbotmcp 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -98,7 +98,16 @@ class AstrBotClient:
98
98
  pwd = hashlib.md5(pwd.encode("utf-8")).hexdigest()
99
99
 
100
100
  url = f"{self.base_url}/api/auth/login"
101
- client_kwargs = {"timeout": self.timeout}
101
+ # SSE endpoints can legitimately stay quiet for a long time while work is happening.
102
+ # Use an infinite read timeout, while keeping connect/write/pool bounded.
103
+ client_kwargs = {
104
+ "timeout": httpx.Timeout(
105
+ connect=self.timeout,
106
+ read=None,
107
+ write=self.timeout,
108
+ pool=self.timeout,
109
+ )
110
+ }
102
111
  if self.settings.disable_proxy:
103
112
  client_kwargs["trust_env"] = False # 禁用代理,忽略环境变量设置
104
113
 
@@ -175,9 +184,12 @@ class AstrBotClient:
175
184
  max_events: Optional[int] = None,
176
185
  ) -> List[Dict[str, Any]]:
177
186
  """
178
- Consume a simple SSE endpoint and return parsed JSON payloads.
187
+ Consume an SSE endpoint and return parsed event payloads.
179
188
 
180
- AstrBot's SSE endpoints use `data: {...}\\n\\n` format per event.
189
+ AstrBot's SSE endpoints typically use `data: {...}\\n\\n` format per event.
190
+ This parser is tolerant:
191
+ - Supports multi-line `data:` frames per SSE spec (joined with `\\n`).
192
+ - If `data:` is not valid JSON, returns it as `{\"type\":\"raw\",\"data\":...}`.
181
193
 
182
194
  `max_seconds` is a soft upper bound for how long we wait:
183
195
  - 如果持续有事件流入,最多等待约 `max_seconds` 秒;
@@ -223,31 +235,63 @@ class AstrBotClient:
223
235
  )
224
236
 
225
237
  async def consume() -> None:
226
- async for line in response.aiter_lines():
227
- if not line:
228
- # Heartbeats / blank lines
229
- continue
238
+ current_event: str | None = None
239
+ data_lines: List[str] = []
230
240
 
231
- if not line.startswith("data:"):
232
- continue
233
-
234
- _, data_str = line.split("data:", 1)
235
- data_str = data_str.strip()
241
+ def flush() -> None:
242
+ nonlocal current_event, data_lines
243
+ if not data_lines:
244
+ current_event = None
245
+ return
246
+ data_str = "\n".join(data_lines).strip()
247
+ data_lines = []
236
248
 
237
249
  if not data_str:
238
- continue
250
+ current_event = None
251
+ return
239
252
 
240
253
  try:
241
254
  payload = json.loads(data_str)
242
255
  except json.JSONDecodeError:
243
- continue
256
+ payload = None
244
257
 
245
258
  if isinstance(payload, dict):
259
+ if current_event and "event" not in payload:
260
+ payload = {**payload, "event": current_event}
246
261
  events.append(payload)
262
+ else:
263
+ events.append(
264
+ {
265
+ "type": "raw",
266
+ "event": current_event,
267
+ "data": data_str,
268
+ }
269
+ )
270
+ current_event = None
271
+
272
+ async for line in response.aiter_lines():
273
+ # Blank line terminates an SSE event.
274
+ if line == "":
275
+ flush()
276
+ if max_events is not None and len(events) >= max_events:
277
+ break
278
+ continue
279
+
280
+ # Comments / heartbeats
281
+ if line.startswith(":"):
282
+ continue
283
+
284
+ if line.startswith("event:"):
285
+ current_event = line.split("event:", 1)[1].strip() or None
286
+ continue
287
+
288
+ if line.startswith("data:"):
289
+ data_lines.append(line.split("data:", 1)[1].lstrip())
290
+ continue
291
+
292
+ continue
247
293
 
248
- if max_events is not None and len(events) >= max_events:
249
- # Enough events collected; stop consuming.
250
- break
294
+ flush()
251
295
 
252
296
  if max_seconds is not None and max_seconds > 0:
253
297
  try:
@@ -18,7 +18,7 @@ All functions are re-exported from this module for convenience.
18
18
  # 导入所有工具函数,保持向后兼容
19
19
  from .control_tools import restart_astrbot
20
20
  from .log_tools import get_astrbot_logs
21
- from .message_tools import (
21
+ from .message import (
22
22
  send_platform_message,
23
23
  send_platform_message_direct,
24
24
  )
@@ -0,0 +1,4 @@
1
+ from .direct import send_platform_message_direct
2
+ from .webchat import send_platform_message
3
+
4
+ __all__ = ["send_platform_message", "send_platform_message_direct"]
@@ -0,0 +1,23 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ from typing import Dict, Tuple
5
+
6
+ from ...astrbot_client import AstrBotClient
7
+
8
+ _SESSION_CACHE_LOCK = asyncio.Lock()
9
+ _SESSION_CACHE: Dict[Tuple[str, str, str], str] = {}
10
+
11
+ _LAST_SAVED_MESSAGE_ID_LOCK = asyncio.Lock()
12
+ _LAST_SAVED_MESSAGE_ID_BY_SESSION: Dict[Tuple[str, str, str], str] = {}
13
+
14
+ _LAST_USER_MESSAGE_ID_LOCK = asyncio.Lock()
15
+ _LAST_USER_MESSAGE_ID_BY_SESSION: Dict[Tuple[str, str, str], str] = {}
16
+
17
+
18
+ def _session_cache_key(client: AstrBotClient, platform_id: str) -> Tuple[str, str, str]:
19
+ return (client.base_url, client.settings.username or "", platform_id)
20
+
21
+
22
+ def _last_saved_key(client: AstrBotClient, session_id: str) -> Tuple[str, str, str]:
23
+ return (client.base_url, client.settings.username or "", session_id)
@@ -0,0 +1,252 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import Any, Dict, List, Literal, Optional
5
+
6
+ from ...astrbot_client import AstrBotClient
7
+ from ..helpers import (
8
+ _as_file_uri,
9
+ _attachment_download_url,
10
+ _direct_media_mode,
11
+ _httpx_error_detail,
12
+ _resolve_local_file_path,
13
+ )
14
+ from ..types import MessagePart
15
+
16
+
17
+ async def send_platform_message_direct(
18
+ platform_id: str,
19
+ target_id: str,
20
+ message_chain: Optional[List[MessagePart]] = None,
21
+ message: Optional[str] = None,
22
+ images: Optional[List[str]] = None,
23
+ files: Optional[List[str]] = None,
24
+ videos: Optional[List[str]] = None,
25
+ records: Optional[List[str]] = None,
26
+ message_type: Literal["GroupMessage", "FriendMessage"] = "GroupMessage",
27
+ ) -> Dict[str, Any]:
28
+ """
29
+ Directly send a message chain to a platform group/user (bypass LLM).
30
+
31
+ This calls AstrBot dashboard endpoint: POST /api/platform/send_message
32
+
33
+ Notes:
34
+ - This is for sending to a real platform target (group/user), not WebChat.
35
+ - Media parts:
36
+ - If `file_path` is a local path, this tool will upload it to AstrBot first, then send it as an AstrBot-hosted URL.
37
+ - If `file_path`/`url` is an http(s) URL, it will be forwarded as-is.
38
+ """
39
+ client = AstrBotClient.from_env()
40
+ onebot_like = platform_id.strip().lower() in {
41
+ "napcat",
42
+ "onebot",
43
+ "cqhttp",
44
+ "gocqhttp",
45
+ "llonebot",
46
+ }
47
+
48
+ if message_chain is None:
49
+ message_chain = []
50
+ if message:
51
+ message_chain.append({"type": "plain", "text": message})
52
+ for src in images or []:
53
+ message_chain.append({"type": "image", "file_path": src})
54
+ for src in files or []:
55
+ message_chain.append({"type": "file", "file_path": src})
56
+ for src in records or []:
57
+ message_chain.append({"type": "record", "file_path": src})
58
+ for src in videos or []:
59
+ message_chain.append({"type": "video", "file_path": src})
60
+
61
+ async def build_chain(mode: str) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
62
+ normalized_chain: List[Dict[str, Any]] = []
63
+ uploaded_attachments: List[Dict[str, Any]] = []
64
+
65
+ for part in message_chain or []:
66
+ p_type = part.get("type")
67
+ if p_type in ("image", "file", "record", "video"):
68
+ file_path = part.get("file_path")
69
+ url = part.get("url")
70
+ file_name = part.get("file_name")
71
+ mime_type = part.get("mime_type")
72
+ src = url or file_path
73
+ if not src:
74
+ continue
75
+
76
+ normalized = dict(part)
77
+ if not isinstance(src, str):
78
+ raise ValueError(f"Invalid media source (expected str): {src!r}")
79
+
80
+ if src.startswith(("http://", "https://")):
81
+ normalized["file_path"] = src
82
+ if onebot_like:
83
+ normalized.setdefault("file", src)
84
+ normalized.pop("url", None)
85
+ normalized_chain.append(normalized)
86
+ continue
87
+
88
+ try:
89
+ local_path = _resolve_local_file_path(client, src)
90
+ except ValueError as e:
91
+ raise ValueError(str(e)) from e
92
+ except FileNotFoundError as e:
93
+ raise FileNotFoundError(f"Local file_path does not exist: {src!r}") from e
94
+
95
+ if mode == "local":
96
+ normalized["file_path"] = local_path
97
+ normalized.pop("url", None)
98
+ if onebot_like:
99
+ uri = _as_file_uri(local_path)
100
+ normalized.setdefault("file", uri or local_path)
101
+ normalized_chain.append(normalized)
102
+ continue
103
+
104
+ if mode != "upload":
105
+ raise ValueError(f"Unknown direct media mode: {mode!r}")
106
+
107
+ if not file_name:
108
+ file_name = os.path.basename(local_path) or None
109
+
110
+ attach_resp = await client.post_attachment_file(
111
+ local_path,
112
+ file_name=file_name,
113
+ mime_type=mime_type,
114
+ )
115
+
116
+ if attach_resp.get("status") != "ok":
117
+ raise RuntimeError(attach_resp.get("message") or "Attachment upload failed")
118
+
119
+ attach_data = attach_resp.get("data") or {}
120
+ attachment_id = attach_data.get("attachment_id")
121
+ if not attachment_id:
122
+ raise RuntimeError(
123
+ "Attachment upload succeeded but attachment_id is missing"
124
+ )
125
+
126
+ download_url = _attachment_download_url(client, str(attachment_id))
127
+ normalized["file_path"] = download_url
128
+ if onebot_like:
129
+ normalized.setdefault("file", download_url)
130
+ normalized.pop("url", None)
131
+ normalized.pop("file_name", None)
132
+ normalized.pop("mime_type", None)
133
+ uploaded_attachments.append(attach_data)
134
+ normalized_chain.append(normalized)
135
+ else:
136
+ normalized_chain.append(dict(part))
137
+
138
+ return normalized_chain, uploaded_attachments
139
+
140
+ # Prefer local paths (more compatible with Napcat / Windows), but keep an upload fallback.
141
+ try:
142
+ mode = _direct_media_mode(client)
143
+ except ValueError as e:
144
+ return {
145
+ "status": "error",
146
+ "message": str(e),
147
+ "platform_id": platform_id,
148
+ "session_id": str(target_id),
149
+ "message_type": message_type,
150
+ }
151
+ modes_to_try = ["local", "upload"] if mode == "auto" else [mode]
152
+ last_error: Dict[str, Any] | None = None
153
+
154
+ for attempt_mode in modes_to_try:
155
+ try:
156
+ normalized_chain, uploaded_attachments = await build_chain(attempt_mode)
157
+ except FileNotFoundError as e:
158
+ return {
159
+ "status": "error",
160
+ "message": str(e),
161
+ "platform_id": platform_id,
162
+ "session_id": str(target_id),
163
+ "message_type": message_type,
164
+ "hint": "If you passed a relative path, set ASTRBOTMCP_FILE_ROOT (or run the server in the correct working directory).",
165
+ }
166
+ except ValueError as e:
167
+ return {
168
+ "status": "error",
169
+ "message": str(e),
170
+ "platform_id": platform_id,
171
+ "session_id": str(target_id),
172
+ "message_type": message_type,
173
+ "hint": "Set ASTRBOTMCP_FILE_ROOT to control how relative paths are resolved.",
174
+ }
175
+ except Exception as e:
176
+ return {
177
+ "status": "error",
178
+ "message": str(e),
179
+ "platform_id": platform_id,
180
+ "session_id": str(target_id),
181
+ "message_type": message_type,
182
+ "attempt_mode": attempt_mode,
183
+ }
184
+
185
+ if not normalized_chain:
186
+ return {
187
+ "status": "error",
188
+ "message": "message_chain did not produce any valid message parts",
189
+ "platform_id": platform_id,
190
+ "session_id": str(target_id),
191
+ "message_type": message_type,
192
+ }
193
+
194
+ try:
195
+ direct_resp = await client.send_platform_message_direct(
196
+ platform_id=platform_id,
197
+ message_type=message_type,
198
+ session_id=str(target_id),
199
+ message_chain=normalized_chain,
200
+ )
201
+ except Exception as e:
202
+ status_code = getattr(getattr(e, "response", None), "status_code", None)
203
+ hint = "Ensure AstrBot includes /api/platform/send_message and you are authenticated."
204
+ if status_code in (404, 405):
205
+ hint = (
206
+ "Your AstrBot may not expose /api/platform/send_message (some versions only provide "
207
+ "/api/platform/stats and /api/platform/webhook). Upgrade AstrBot or add an HTTP route for sending."
208
+ )
209
+ return {
210
+ "status": "error",
211
+ "message": (
212
+ f"AstrBot API error: HTTP {status_code}"
213
+ if status_code is not None
214
+ else f"AstrBot API error: {e}"
215
+ ),
216
+ "platform_id": platform_id,
217
+ "session_id": str(target_id),
218
+ "message_type": message_type,
219
+ "attempt_mode": attempt_mode,
220
+ "detail": _httpx_error_detail(e),
221
+ "hint": hint,
222
+ }
223
+
224
+ status = direct_resp.get("status")
225
+ if status == "ok":
226
+ data = direct_resp.get("data") or {}
227
+ return {
228
+ "status": "ok",
229
+ "platform_id": data.get("platform_id", platform_id),
230
+ "session_id": data.get("session_id", str(target_id)),
231
+ "message_type": data.get("message_type", message_type),
232
+ "attempt_mode": attempt_mode,
233
+ "uploaded_attachments": uploaded_attachments,
234
+ }
235
+
236
+ last_error = {
237
+ "status": status,
238
+ "platform_id": platform_id,
239
+ "session_id": str(target_id),
240
+ "message_type": message_type,
241
+ "attempt_mode": attempt_mode,
242
+ "message": direct_resp.get("message"),
243
+ "raw": direct_resp,
244
+ }
245
+
246
+ return last_error or {
247
+ "status": "error",
248
+ "message": "Failed to send message",
249
+ "platform_id": platform_id,
250
+ "session_id": str(target_id),
251
+ "message_type": message_type,
252
+ }
@@ -0,0 +1,71 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, List, Tuple
4
+
5
+ from ...astrbot_client import AstrBotClient
6
+ from .utils import _extract_plain_text_from_history_item, _format_quote_block
7
+
8
+
9
+ async def _resolve_webchat_quotes(
10
+ client: AstrBotClient, *, session_id: str, reply_ids: List[str]
11
+ ) -> Tuple[str, Dict[str, Any]]:
12
+ """
13
+ Resolve WebChat `message_saved.id` -> quoted text by calling /api/chat/get_session.
14
+ Best-effort: returns a quote prefix text and debug info.
15
+ """
16
+ cleaned: List[str] = []
17
+ for rid in reply_ids:
18
+ s = str(rid).strip()
19
+ if s:
20
+ cleaned.append(s)
21
+ if not cleaned:
22
+ return "", {"resolved": {}, "missing": []}
23
+
24
+ try:
25
+ resp = await client.get_platform_session(session_id=session_id)
26
+ except Exception as e:
27
+ return "", {"error": str(e), "resolved": {}, "missing": cleaned}
28
+
29
+ if resp.get("status") != "ok":
30
+ return "", {"status": resp.get("status"), "message": resp.get("message"), "raw": resp}
31
+
32
+ data = resp.get("data") or {}
33
+ history = data.get("history") or []
34
+ if not isinstance(history, list):
35
+ return "", {"resolved": {}, "missing": cleaned, "raw_history_type": str(type(history))}
36
+
37
+ index: Dict[str, Dict[str, Any]] = {}
38
+ for item in history:
39
+ if not isinstance(item, dict):
40
+ continue
41
+ mid = item.get("id")
42
+ if mid is None:
43
+ continue
44
+ index[str(mid)] = item
45
+
46
+ resolved: Dict[str, str] = {}
47
+ missing: List[str] = []
48
+ blocks: List[str] = []
49
+ for rid in cleaned:
50
+ item = index.get(str(rid))
51
+ if not item:
52
+ missing.append(rid)
53
+ blocks.append(
54
+ _format_quote_block(
55
+ message_id=str(rid),
56
+ sender="missing",
57
+ text="<not found in /api/chat/get_session history>",
58
+ )
59
+ )
60
+ continue
61
+ sender = (
62
+ item.get("sender_name")
63
+ or item.get("sender_id")
64
+ or "unknown"
65
+ )
66
+ txt = _extract_plain_text_from_history_item(item)
67
+ block = _format_quote_block(message_id=str(rid), sender=str(sender), text=txt)
68
+ resolved[str(rid)] = block
69
+ blocks.append(block)
70
+
71
+ return "".join(blocks), {"resolved": resolved, "missing": missing}
@@ -0,0 +1,62 @@
1
+ from __future__ import annotations
2
+
3
+ import textwrap
4
+ from typing import Any, Dict, List
5
+
6
+
7
+ def _extract_plain_text_from_history_item(item: Dict[str, Any]) -> str:
8
+ content = item.get("content") or {}
9
+ if not isinstance(content, dict):
10
+ return str(content)
11
+ message = content.get("message") or []
12
+ if not isinstance(message, list):
13
+ return str(message)
14
+
15
+ chunks: List[str] = []
16
+ for part in message:
17
+ if not isinstance(part, dict):
18
+ continue
19
+ p_type = part.get("type")
20
+ if p_type == "plain":
21
+ txt = part.get("text")
22
+ if isinstance(txt, str) and txt:
23
+ chunks.append(txt)
24
+ elif p_type in ("image", "file", "record", "video"):
25
+ name = part.get("filename") or part.get("attachment_id") or ""
26
+ if name:
27
+ chunks.append(f"[{p_type}:{name}]")
28
+ else:
29
+ chunks.append(f"[{p_type}]")
30
+ else:
31
+ if p_type:
32
+ chunks.append(f"[{p_type}]")
33
+ return "".join(chunks).strip()
34
+
35
+
36
+ def _format_quote_block(*, message_id: str, sender: str, text: str) -> str:
37
+ sender = (sender or "unknown").strip() or "unknown"
38
+ text = (text or "").strip()
39
+ if not text:
40
+ text = "<empty>"
41
+ text = textwrap.shorten(text, width=800, placeholder="…")
42
+ return f"[引用消息 {message_id} | {sender}] {text}\n"
43
+
44
+
45
+ def _normalize_history_message_id(value: Any) -> Any:
46
+ """
47
+ AstrBot WebChat reply expects `message_id` to be the history record primary key (usually int).
48
+ Keep original value if it cannot be safely converted.
49
+ """
50
+ if value is None:
51
+ return None
52
+ if isinstance(value, int):
53
+ return value
54
+ s = str(value).strip()
55
+ if not s:
56
+ return value
57
+ if s.isdigit():
58
+ try:
59
+ return int(s)
60
+ except Exception:
61
+ return value
62
+ return value