nonebot-plugin-skills 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
|
@@ -0,0 +1,2193 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import base64
|
|
5
|
+
import json
|
|
6
|
+
import re
|
|
7
|
+
import time
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from typing import List, Optional, Tuple, cast
|
|
10
|
+
|
|
11
|
+
import httpx
|
|
12
|
+
from google import genai
|
|
13
|
+
from google.genai import types
|
|
14
|
+
from nonebot import get_driver, logger, on_command, on_message
|
|
15
|
+
from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent, Message, MessageEvent, MessageSegment
|
|
16
|
+
from nonebot.params import CommandArg
|
|
17
|
+
from nonebot.plugin import PluginMetadata
|
|
18
|
+
|
|
19
|
+
from .config import config
|
|
20
|
+
|
|
21
|
+
__plugin_meta__ = PluginMetadata(
|
|
22
|
+
name="nonebot-plugin-skills",
|
|
23
|
+
description="基于 Gemini 的头像/图片处理与聊天插件,支持上下文缓存与群/私聊隔离",
|
|
24
|
+
usage="指令:处理头像 <指令> / 技能|聊天 <内容> / 天气 <城市>",
|
|
25
|
+
type="application",
|
|
26
|
+
homepage="https://github.com/yourname/nonebot-plugin-skills",
|
|
27
|
+
supported_adapters={"~onebot.v11"},
|
|
28
|
+
)
|
|
29
|
+
def _mask_api_key(text: str) -> str:
|
|
30
|
+
if not config.google_api_key:
|
|
31
|
+
return text
|
|
32
|
+
return text.replace(config.google_api_key, "***")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _truncate(text: str, limit: int = 800) -> str:
|
|
36
|
+
if len(text) <= limit:
|
|
37
|
+
return text
|
|
38
|
+
return text[:limit].rstrip() + "..."
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _safe_error_message(exc: Exception) -> str:
|
|
42
|
+
detail = str(exc)
|
|
43
|
+
if isinstance(exc, httpx.HTTPStatusError):
|
|
44
|
+
response_text = _truncate(exc.response.text)
|
|
45
|
+
detail = f"{detail} | response: {response_text}"
|
|
46
|
+
detail = _mask_api_key(detail)
|
|
47
|
+
detail = detail.replace("\r\n", " ").replace("\n", " ").replace("\r", " ")
|
|
48
|
+
detail = _collapse_spaces(detail)
|
|
49
|
+
if detail:
|
|
50
|
+
return detail
|
|
51
|
+
return f"{type(exc).__name__}: 未知错误"
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
_MD_CODE_FENCE_START = re.compile(r"```[^\n]*\n")
|
|
55
|
+
_MD_INLINE_CODE = re.compile(r"`([^`]*)`")
|
|
56
|
+
_MD_IMAGE = re.compile(r"!\[([^\]]*)\]\([^)]+\)")
|
|
57
|
+
_MD_LINK = re.compile(r"\[([^\]]+)\]\([^)]+\)")
|
|
58
|
+
_MD_BOLD = re.compile(r"\*\*(.*?)\*\*")
|
|
59
|
+
_MD_UNDERLINE = re.compile(r"__(.*?)__")
|
|
60
|
+
_MD_STRIKE = re.compile(r"~~(.*?)~~")
|
|
61
|
+
_MD_HRULE = re.compile(r"^\s*([-*_]\s*){3,}$", re.M)
|
|
62
|
+
_MD_LIST_MARKER = re.compile(r"^\s*(?:[-*+•·*]|\d+[).、])\s*")
|
|
63
|
+
_LINE_SPLIT_RE = re.compile(r"\r?\n+")
|
|
64
|
+
|
|
65
|
+
_MODEL_REPLY_MAX_CHARS = 0
|
|
66
|
+
_WAIT_NEXT_IMAGE_SEC = 60.0
|
|
67
|
+
|
|
68
|
+
_CHAT_SYSTEM_PROMPT = (
|
|
69
|
+
"Role\n"
|
|
70
|
+
"你是asoul成员嘉然,会尽量满足提问者的帮助。\n"
|
|
71
|
+
"你是一个普通 QQ 用户,正在和朋友私聊或群聊。\n\n"
|
|
72
|
+
"Goal\n"
|
|
73
|
+
"根据对方消息给出自然、简短、真实的回复。\n\n"
|
|
74
|
+
"Rules\n"
|
|
75
|
+
"1. 只当普通人,不要自称 AI/助手,也不要提系统或规则。\n"
|
|
76
|
+
"2. 输出纯文本,不使用 Markdown 或代码块。\n"
|
|
77
|
+
"3. 适合 QQ 消息,精炼、不啰嗦;口语、简短,可自然换行;少用句号,问号和感叹号克制。\n"
|
|
78
|
+
"4. 不要夸张表演,不堆流行语,不写长段解释。\n"
|
|
79
|
+
"5. 不清楚就简短追问,不编造细节。\n"
|
|
80
|
+
"6. 只回复当前消息,不要夹带已回复过的旧话题。\n"
|
|
81
|
+
"Output\n"
|
|
82
|
+
"只输出最终回复内容。\n"
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
_IMAGE_CHAT_SYSTEM_PROMPT = (
|
|
86
|
+
"你现在是asoul成员嘉然,会尽量满足提问者的帮助。\n"
|
|
87
|
+
"你在进行图片内容对话,只需回答当前指令或问题。\n"
|
|
88
|
+
"不要补充已回复过的历史话题,不要输出 Markdown 或代码块。\n"
|
|
89
|
+
"回答适合 QQ 消息,精炼、不啰嗦,简短、口语化,可自然换行。\n"
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
_TRAVEL_SYSTEM_PROMPT = (
|
|
93
|
+
"你是旅行规划助手,给出清晰、实用、可执行的旅行建议。\n"
|
|
94
|
+
"输出纯文本,不使用 Markdown 或代码块。\n"
|
|
95
|
+
"适合 QQ 消息,精炼、不啰嗦。\n"
|
|
96
|
+
"结构清晰,可自然换行,尽量不要空行,包含景点/活动/用餐/交通/住宿要点。\n"
|
|
97
|
+
"请自动生成该城市最常见的规划天数。\n"
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
_INTENT_SYSTEM_PROMPT = (
|
|
101
|
+
"你是消息意图解析器,只输出 JSON,不要解释或补充说明。"
|
|
102
|
+
"不要输出拒绝/免责声明/权限说明(例如“我无法访问账号”)。"
|
|
103
|
+
"严格输出如下 JSON:"
|
|
104
|
+
"{"
|
|
105
|
+
"\"action\": \"chat|image_chat|image_generate|image_create|weather|avatar_get|travel_plan|history_clear|ignore\","
|
|
106
|
+
"\"target\": \"message_image|reply_image|at_user|last_image|sender_avatar|group_avatar|qq_avatar|message_id|wait_next|city|trip|none\","
|
|
107
|
+
"\"instruction\": \"string\","
|
|
108
|
+
"\"params\": {\"qq\": \"string\", \"message_id\": \"int\", \"city\": \"string\","
|
|
109
|
+
" \"destination\": \"string\", \"days\": \"int\", \"nights\": \"int\", \"reply\": \"string\"}"
|
|
110
|
+
"}"
|
|
111
|
+
"说明:"
|
|
112
|
+
"- action=chat 表示普通聊天;instruction 为要回复的文本。"
|
|
113
|
+
"- action=image_chat 表示聊这张图(不生成图);instruction 为想问/想说的内容。"
|
|
114
|
+
"- action=image_generate 表示基于参考图生成/编辑;instruction 为帮忙生成关于xx的图片的处理指令。"
|
|
115
|
+
"- action=image_create 表示无参考图生成;instruction 为生成指令。"
|
|
116
|
+
"- action=weather 表示查询天气;instruction 为地点,target=city,params.city 填地点。"
|
|
117
|
+
"- action=avatar_get 表示获取头像;instruction 可为空,target 可为 sender_avatar 或 group_avatar 等。"
|
|
118
|
+
"- action=travel_plan 表示旅行规划;instruction 为完整需求;target=trip;"
|
|
119
|
+
"params.destination 为目的地,params.days 为天数,params.nights 为晚数。"
|
|
120
|
+
"- action=history_clear 表示清除当前会话(当前聊天或群)历史记录;instruction 为空或简短确认。"
|
|
121
|
+
"- action=ignore 表示不处理;instruction 为空字符串。"
|
|
122
|
+
"- target 仅在 image_chat/image_generate 时使用:"
|
|
123
|
+
" message_image=本消息里的图;reply_image=回复消息里的图;"
|
|
124
|
+
" at_user=@用户头像;last_image=最近图片;sender_avatar=发送者头像;group_avatar=群头像;"
|
|
125
|
+
" qq_avatar=指定 QQ 头像(params.qq);"
|
|
126
|
+
" message_id=指定消息ID图片(params.message_id);"
|
|
127
|
+
" wait_next=等待下一张图;none=无参考图。"
|
|
128
|
+
"params 里只在对应 target 时填写:"
|
|
129
|
+
"- target=qq_avatar 时填写 params.qq。"
|
|
130
|
+
"- target=message_id 时填写 params.message_id。"
|
|
131
|
+
"其他情况 params 为空对象。"
|
|
132
|
+
"若旅行或天气缺关键信息,仍输出对应 action,缺失字段留空。"
|
|
133
|
+
"上下文可能包含“昵称: 内容”的格式,需识别说话人。"
|
|
134
|
+
"如需发送等待/过渡语,可在 params.reply 中填写一句短句。"
|
|
135
|
+
" 如果文本包含多行,默认第一行是当前消息;只有当前消息无法判断时才参考后续上下文/回复内容。"
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
_DUPLICATE_TEXT_TTL_SEC = 60.0
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class UnsupportedImageError(RuntimeError):
|
|
142
|
+
pass
|
|
143
|
+
|
|
144
|
+
_SELF_ID_PATTERNS = [
|
|
145
|
+
re.compile(r"^(作为|我作为)(一名|一个)?(人工智能|AI|语言模型|模型).*?[,,。]\s*", re.I),
|
|
146
|
+
re.compile(r"^我是(一名|一个)?(人工智能|AI|语言模型|模型).*?[,,。]\s*", re.I),
|
|
147
|
+
]
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def _strip_markdown(text: str) -> str:
|
|
151
|
+
if not text:
|
|
152
|
+
return text
|
|
153
|
+
text = _MD_CODE_FENCE_START.sub("", text)
|
|
154
|
+
text = text.replace("```", "")
|
|
155
|
+
text = _MD_INLINE_CODE.sub(r"\1", text)
|
|
156
|
+
text = _MD_IMAGE.sub(r"\1", text)
|
|
157
|
+
text = _MD_LINK.sub(r"\1", text)
|
|
158
|
+
text = _MD_BOLD.sub(r"\1", text)
|
|
159
|
+
text = _MD_UNDERLINE.sub(r"\1", text)
|
|
160
|
+
text = _MD_STRIKE.sub(r"\1", text)
|
|
161
|
+
lines: List[str] = []
|
|
162
|
+
for raw_line in text.splitlines():
|
|
163
|
+
line = raw_line.rstrip()
|
|
164
|
+
line = re.sub(r"^\s{0,3}#{1,6}\s+", "", line)
|
|
165
|
+
line = re.sub(r"^\s{0,3}>\s?", "", line)
|
|
166
|
+
line = _MD_LIST_MARKER.sub("", line)
|
|
167
|
+
lines.append(line)
|
|
168
|
+
text = "\n".join(lines)
|
|
169
|
+
text = _MD_HRULE.sub("", text)
|
|
170
|
+
text = re.sub(r"\n{3,}", "\n\n", text)
|
|
171
|
+
return text.strip()
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def _remove_self_identification(text: str) -> str:
|
|
175
|
+
if not text:
|
|
176
|
+
return text
|
|
177
|
+
cleaned_lines: List[str] = []
|
|
178
|
+
for raw_line in text.splitlines():
|
|
179
|
+
line = raw_line.strip()
|
|
180
|
+
for pattern in _SELF_ID_PATTERNS:
|
|
181
|
+
line = pattern.sub("", line)
|
|
182
|
+
cleaned_lines.append(line)
|
|
183
|
+
return "\n".join(cleaned_lines).strip()
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _remove_prompt_leakage(text: str) -> str:
|
|
187
|
+
if not text:
|
|
188
|
+
return text
|
|
189
|
+
cleaned_lines: List[str] = []
|
|
190
|
+
for raw_line in text.splitlines():
|
|
191
|
+
line = raw_line.strip()
|
|
192
|
+
lower = line.lower()
|
|
193
|
+
if lower.startswith("system prompt") or lower.startswith("system instruction"):
|
|
194
|
+
continue
|
|
195
|
+
if line.startswith(("系统提示", "系统指令", "提示词", "系统消息")):
|
|
196
|
+
continue
|
|
197
|
+
cleaned_lines.append(raw_line.strip())
|
|
198
|
+
return "\n".join(cleaned_lines).strip()
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def _ensure_plain_text(text: str) -> str:
|
|
202
|
+
if not text:
|
|
203
|
+
return text
|
|
204
|
+
text = _strip_markdown(text)
|
|
205
|
+
text = _remove_prompt_leakage(text)
|
|
206
|
+
text = _remove_self_identification(text)
|
|
207
|
+
return text.strip()
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def _collapse_spaces(text: str) -> str:
|
|
211
|
+
return re.sub(r"\s+", " ", text).strip()
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def _normalize_user_name(value: Optional[object]) -> str:
|
|
215
|
+
if value is None:
|
|
216
|
+
return ""
|
|
217
|
+
name = str(value).strip()
|
|
218
|
+
if not name:
|
|
219
|
+
return ""
|
|
220
|
+
name = name.replace("\r", " ").replace("\n", " ")
|
|
221
|
+
name = _collapse_spaces(name)
|
|
222
|
+
return name.strip("::")
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def _event_user_name(event: MessageEvent) -> str:
|
|
226
|
+
sender = getattr(event, "sender", None)
|
|
227
|
+
name = None
|
|
228
|
+
if sender is not None:
|
|
229
|
+
name = getattr(sender, "card", None) or getattr(sender, "nickname", None)
|
|
230
|
+
if not name:
|
|
231
|
+
name = getattr(event, "user_id", None)
|
|
232
|
+
return _normalize_user_name(name)
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def _sender_user_name(sender: object) -> str:
|
|
236
|
+
if sender is None:
|
|
237
|
+
return ""
|
|
238
|
+
name = getattr(sender, "card", None) or getattr(sender, "nickname", None)
|
|
239
|
+
if not name:
|
|
240
|
+
name = getattr(sender, "user_id", None)
|
|
241
|
+
return _normalize_user_name(name)
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def _format_context_line(text: str, user_name: Optional[str]) -> str:
|
|
245
|
+
name = _normalize_user_name(user_name)
|
|
246
|
+
if name:
|
|
247
|
+
return f"{name}: {text}"
|
|
248
|
+
return text
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def _compact_reply_lines(text: str) -> str:
|
|
252
|
+
if not text:
|
|
253
|
+
return text
|
|
254
|
+
lines = [line.strip() for line in text.split("\n")]
|
|
255
|
+
lines = [line for line in lines if line]
|
|
256
|
+
return "\n".join(lines).strip()
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def _transition_text(action: str) -> Optional[str]:
|
|
260
|
+
if action in {"image_create"}:
|
|
261
|
+
return "正在生成图片,请稍候..."
|
|
262
|
+
if action in {"image_generate"}:
|
|
263
|
+
return "正在处理图片,请稍候..."
|
|
264
|
+
if action in {"weather", "travel_plan", "avatar_get", "image_chat"}:
|
|
265
|
+
return "我看看喵"
|
|
266
|
+
return None
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def _intent_transition_text(intent: dict) -> str:
|
|
270
|
+
params = _intent_params(intent)
|
|
271
|
+
reply = params.get("reply")
|
|
272
|
+
if isinstance(reply, str):
|
|
273
|
+
return reply.strip()
|
|
274
|
+
return ""
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
async def _send_transition(action: str, send_func) -> None:
|
|
278
|
+
text = _transition_text(action)
|
|
279
|
+
if text:
|
|
280
|
+
await send_func(text)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def _format_reply_text(text: str) -> str:
|
|
284
|
+
if not text:
|
|
285
|
+
return text
|
|
286
|
+
cleaned = _ensure_plain_text(text)
|
|
287
|
+
if not cleaned:
|
|
288
|
+
return ""
|
|
289
|
+
normalized = cleaned.replace("\r\n", "\n").replace("\r", "\n")
|
|
290
|
+
lines = [line.strip() for line in normalized.split("\n")]
|
|
291
|
+
normalized = "\n".join(lines)
|
|
292
|
+
normalized = re.sub(r"\n{3,}", "\n\n", normalized)
|
|
293
|
+
return normalized.strip()
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def _limit_reply_text(text: str, limit: int = _MODEL_REPLY_MAX_CHARS) -> str:
|
|
297
|
+
if not text:
|
|
298
|
+
return text
|
|
299
|
+
try:
|
|
300
|
+
limit_value = int(limit)
|
|
301
|
+
except Exception:
|
|
302
|
+
return text
|
|
303
|
+
if limit_value <= 0:
|
|
304
|
+
return text
|
|
305
|
+
if len(text) <= limit_value:
|
|
306
|
+
return text
|
|
307
|
+
return text[:limit_value]
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def _redact_large_data(value: object, depth: int = 0) -> object:
|
|
311
|
+
if depth > 4:
|
|
312
|
+
return "..."
|
|
313
|
+
if isinstance(value, bytes):
|
|
314
|
+
return f"<{len(value)} bytes>"
|
|
315
|
+
if isinstance(value, dict):
|
|
316
|
+
result: dict[str, object] = {}
|
|
317
|
+
for key, val in value.items():
|
|
318
|
+
if key == "data" and isinstance(val, (bytes, str)):
|
|
319
|
+
size = len(val)
|
|
320
|
+
unit = "bytes" if isinstance(val, bytes) else "chars"
|
|
321
|
+
result[key] = f"<{size} {unit}>"
|
|
322
|
+
else:
|
|
323
|
+
result[key] = _redact_large_data(val, depth + 1)
|
|
324
|
+
return result
|
|
325
|
+
if isinstance(value, list):
|
|
326
|
+
trimmed = value[:20]
|
|
327
|
+
result_list = [_redact_large_data(item, depth + 1) for item in trimmed]
|
|
328
|
+
if len(value) > 20:
|
|
329
|
+
result_list.append("...")
|
|
330
|
+
return result_list
|
|
331
|
+
return value
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def _dump_response(response: object) -> str:
|
|
335
|
+
for attr in ("model_dump", "to_dict"):
|
|
336
|
+
method = getattr(response, attr, None)
|
|
337
|
+
if callable(method):
|
|
338
|
+
try:
|
|
339
|
+
data = method()
|
|
340
|
+
redacted = _redact_large_data(data)
|
|
341
|
+
return json.dumps(redacted, ensure_ascii=True)
|
|
342
|
+
except Exception:
|
|
343
|
+
pass
|
|
344
|
+
try:
|
|
345
|
+
text = str(response)
|
|
346
|
+
except Exception:
|
|
347
|
+
text = repr(response)
|
|
348
|
+
return _truncate(_mask_api_key(text), 1200)
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
def _log_response_text(prefix: str, response: object) -> None:
|
|
352
|
+
text = getattr(response, "text", None)
|
|
353
|
+
if isinstance(text, str) and text.strip():
|
|
354
|
+
logger.info("{}: {}", prefix, _truncate(_mask_api_key(text), 1200))
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
@dataclass
|
|
360
|
+
class HistoryItem:
|
|
361
|
+
role: str
|
|
362
|
+
text: str
|
|
363
|
+
ts: float
|
|
364
|
+
user_id: Optional[str] = None
|
|
365
|
+
user_name: Optional[str] = None
|
|
366
|
+
to_bot: bool = False
|
|
367
|
+
message_id: Optional[int] = None
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
@dataclass
|
|
371
|
+
class SessionState:
|
|
372
|
+
history: List[HistoryItem]
|
|
373
|
+
last_image_url: Optional[str]
|
|
374
|
+
image_cache: dict[int, tuple[str, float]]
|
|
375
|
+
pending_image_waiters: dict[str, asyncio.Future[str]]
|
|
376
|
+
handled_message_ids: dict[int, float]
|
|
377
|
+
handled_texts: dict[str, float]
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
_SESSIONS: dict[str, SessionState] = {}
|
|
381
|
+
_CLIENT: Optional[genai.Client] = None
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def _session_id(event: MessageEvent) -> str:
|
|
385
|
+
if isinstance(event, GroupMessageEvent):
|
|
386
|
+
return f"group:{event.group_id}"
|
|
387
|
+
return f"private:{event.get_user_id()}"
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
def _now() -> float:
|
|
391
|
+
return time.time()
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
def _event_ts(event: MessageEvent) -> float:
|
|
395
|
+
value = getattr(event, "time", None)
|
|
396
|
+
if isinstance(value, (int, float)) and value > 0:
|
|
397
|
+
return float(value)
|
|
398
|
+
return _now()
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def _get_state(session_id: str) -> SessionState:
|
|
402
|
+
state = _SESSIONS.get(session_id)
|
|
403
|
+
if state is None:
|
|
404
|
+
state = SessionState(
|
|
405
|
+
history=[],
|
|
406
|
+
last_image_url=None,
|
|
407
|
+
image_cache={},
|
|
408
|
+
pending_image_waiters={},
|
|
409
|
+
handled_message_ids={},
|
|
410
|
+
handled_texts={},
|
|
411
|
+
)
|
|
412
|
+
_SESSIONS[session_id] = state
|
|
413
|
+
return state
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
def _get_client() -> genai.Client:
|
|
417
|
+
global _CLIENT
|
|
418
|
+
if _CLIENT is None:
|
|
419
|
+
if not config.google_api_key:
|
|
420
|
+
raise RuntimeError("未配置 GOOGLE_API_KEY")
|
|
421
|
+
_CLIENT = genai.Client(api_key=config.google_api_key)
|
|
422
|
+
return _CLIENT
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
def _prune_state(state: SessionState) -> None:
|
|
426
|
+
ttl = max(30, int(config.history_ttl_sec))
|
|
427
|
+
cutoff = _now() - ttl
|
|
428
|
+
state.history = [item for item in state.history if item.ts >= cutoff]
|
|
429
|
+
if len(state.history) > config.history_max_messages:
|
|
430
|
+
state.history = state.history[-config.history_max_messages :]
|
|
431
|
+
if state.image_cache:
|
|
432
|
+
state.image_cache = {
|
|
433
|
+
msg_id: (url, ts)
|
|
434
|
+
for msg_id, (url, ts) in state.image_cache.items()
|
|
435
|
+
if ts >= cutoff
|
|
436
|
+
}
|
|
437
|
+
if state.handled_message_ids:
|
|
438
|
+
state.handled_message_ids = {
|
|
439
|
+
msg_id: ts for msg_id, ts in state.handled_message_ids.items() if ts >= cutoff
|
|
440
|
+
}
|
|
441
|
+
if state.handled_texts:
|
|
442
|
+
text_cutoff = _now() - max(ttl, int(_DUPLICATE_TEXT_TTL_SEC))
|
|
443
|
+
state.handled_texts = {
|
|
444
|
+
key: ts for key, ts in state.handled_texts.items() if ts >= text_cutoff
|
|
445
|
+
}
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
def _clear_session_state(state: SessionState) -> None:
|
|
449
|
+
state.history = []
|
|
450
|
+
state.last_image_url = None
|
|
451
|
+
state.image_cache = {}
|
|
452
|
+
if state.pending_image_waiters:
|
|
453
|
+
for waiter in state.pending_image_waiters.values():
|
|
454
|
+
if not waiter.done():
|
|
455
|
+
waiter.cancel()
|
|
456
|
+
state.pending_image_waiters = {}
|
|
457
|
+
state.handled_message_ids = {}
|
|
458
|
+
state.handled_texts = {}
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
_UNSUPPORTED_IMAGE_EXTS = (".gif", ".apng")
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
def _handled_text_key(user_id: str, text: str) -> str:
|
|
465
|
+
return f"{user_id}:{text}"
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
def _is_duplicate_request(state: SessionState, event: MessageEvent, text: str) -> bool:
|
|
469
|
+
msg_id = getattr(event, "message_id", None)
|
|
470
|
+
if isinstance(msg_id, int) and msg_id in state.handled_message_ids:
|
|
471
|
+
return True
|
|
472
|
+
stripped = text.strip()
|
|
473
|
+
if not stripped:
|
|
474
|
+
return False
|
|
475
|
+
key = _handled_text_key(str(event.get_user_id()), stripped)
|
|
476
|
+
ts = state.handled_texts.get(key)
|
|
477
|
+
if ts is None:
|
|
478
|
+
return False
|
|
479
|
+
return (_now() - ts) <= _DUPLICATE_TEXT_TTL_SEC
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
def _mark_handled_request(state: SessionState, event: MessageEvent, text: str) -> None:
|
|
483
|
+
ts = _event_ts(event)
|
|
484
|
+
msg_id = getattr(event, "message_id", None)
|
|
485
|
+
if isinstance(msg_id, int):
|
|
486
|
+
state.handled_message_ids[msg_id] = ts
|
|
487
|
+
stripped = text.strip()
|
|
488
|
+
if stripped:
|
|
489
|
+
key = _handled_text_key(str(event.get_user_id()), stripped)
|
|
490
|
+
state.handled_texts[key] = ts
|
|
491
|
+
_prune_state(state)
|
|
492
|
+
|
|
493
|
+
|
|
494
|
+
def _is_supported_image_url(url: str) -> bool:
|
|
495
|
+
if not url:
|
|
496
|
+
return False
|
|
497
|
+
lower = url.lower()
|
|
498
|
+
if lower.startswith("data:image/gif"):
|
|
499
|
+
return False
|
|
500
|
+
cleaned = lower.split("?", 1)[0].split("#", 1)[0]
|
|
501
|
+
for ext in _UNSUPPORTED_IMAGE_EXTS:
|
|
502
|
+
if cleaned.endswith(ext):
|
|
503
|
+
return False
|
|
504
|
+
return True
|
|
505
|
+
|
|
506
|
+
|
|
507
|
+
def _extract_first_image_url(message: Message) -> Optional[str]:
|
|
508
|
+
for seg in message:
|
|
509
|
+
if seg.type == "image":
|
|
510
|
+
url = seg.data.get("url") or seg.data.get("file")
|
|
511
|
+
if url:
|
|
512
|
+
if _is_supported_image_url(url):
|
|
513
|
+
return url
|
|
514
|
+
return None
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
def _extract_at_user(message: Message) -> Optional[str]:
|
|
518
|
+
for seg in message:
|
|
519
|
+
if seg.type == "at":
|
|
520
|
+
qq = seg.data.get("qq")
|
|
521
|
+
if qq and qq != "all":
|
|
522
|
+
return str(qq)
|
|
523
|
+
return None
|
|
524
|
+
|
|
525
|
+
|
|
526
|
+
def _avatar_url(qq: str) -> str:
|
|
527
|
+
return f"http://q.qlogo.cn/headimg_dl?dst_uin={qq}&spec=640"
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
def _group_avatar_url(group_id: int) -> str:
|
|
531
|
+
return f"http://p.qlogo.cn/gh/{group_id}/{group_id}/640"
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
WEATHER_CODE_MAP = {
|
|
535
|
+
0: "晴",
|
|
536
|
+
1: "大部晴朗",
|
|
537
|
+
2: "局部多云",
|
|
538
|
+
3: "多云",
|
|
539
|
+
45: "有雾",
|
|
540
|
+
48: "雾凇",
|
|
541
|
+
51: "毛毛雨",
|
|
542
|
+
53: "毛毛雨",
|
|
543
|
+
55: "毛毛雨",
|
|
544
|
+
56: "冻毛毛雨",
|
|
545
|
+
57: "冻毛毛雨",
|
|
546
|
+
61: "小雨",
|
|
547
|
+
63: "中雨",
|
|
548
|
+
65: "大雨",
|
|
549
|
+
66: "冻雨",
|
|
550
|
+
67: "冻雨",
|
|
551
|
+
71: "小雪",
|
|
552
|
+
73: "中雪",
|
|
553
|
+
75: "大雪",
|
|
554
|
+
77: "雪粒",
|
|
555
|
+
80: "阵雨",
|
|
556
|
+
81: "较强阵雨",
|
|
557
|
+
82: "强阵雨",
|
|
558
|
+
85: "阵雪",
|
|
559
|
+
86: "大阵雪",
|
|
560
|
+
95: "雷暴",
|
|
561
|
+
96: "雷暴伴冰雹",
|
|
562
|
+
99: "强雷暴伴冰雹",
|
|
563
|
+
}
|
|
564
|
+
|
|
565
|
+
|
|
566
|
+
def _format_number(value: Optional[float], digits: int = 1) -> str:
|
|
567
|
+
if value is None:
|
|
568
|
+
return "未知"
|
|
569
|
+
try:
|
|
570
|
+
number = float(value)
|
|
571
|
+
except (TypeError, ValueError):
|
|
572
|
+
return str(value)
|
|
573
|
+
text = f"{number:.{digits}f}"
|
|
574
|
+
return text.rstrip("0").rstrip(".")
|
|
575
|
+
|
|
576
|
+
|
|
577
|
+
def _format_measure(value: Optional[float], unit: str, digits: int = 1) -> str:
|
|
578
|
+
if value is None:
|
|
579
|
+
return "未知"
|
|
580
|
+
return f"{_format_number(value, digits)}{unit}"
|
|
581
|
+
|
|
582
|
+
|
|
583
|
+
def _wind_level_from_speed(value: Optional[float], unit: str) -> str:
|
|
584
|
+
if value is None:
|
|
585
|
+
return "风力未知"
|
|
586
|
+
try:
|
|
587
|
+
speed = float(value)
|
|
588
|
+
except (TypeError, ValueError):
|
|
589
|
+
return "风力未知"
|
|
590
|
+
unit_text = (unit or "").lower()
|
|
591
|
+
speed_mps = speed
|
|
592
|
+
if "km" in unit_text:
|
|
593
|
+
speed_mps = speed / 3.6
|
|
594
|
+
elif "m/s" in unit_text or "mps" in unit_text:
|
|
595
|
+
speed_mps = speed
|
|
596
|
+
elif "mph" in unit_text:
|
|
597
|
+
speed_mps = speed * 0.44704
|
|
598
|
+
# Beaufort scale (m/s)
|
|
599
|
+
thresholds = [0.3, 1.6, 3.4, 5.5, 8.0, 10.8, 13.9, 17.2, 20.8, 24.5, 28.5, 32.7]
|
|
600
|
+
level = 0
|
|
601
|
+
for idx, limit in enumerate(thresholds):
|
|
602
|
+
if speed_mps < limit:
|
|
603
|
+
level = idx
|
|
604
|
+
break
|
|
605
|
+
else:
|
|
606
|
+
level = 12
|
|
607
|
+
return f"风力{level}级"
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
def _weather_code_desc(code: Optional[float]) -> str:
|
|
611
|
+
if code is None:
|
|
612
|
+
return "未知天气"
|
|
613
|
+
try:
|
|
614
|
+
code_int = int(code)
|
|
615
|
+
except (TypeError, ValueError):
|
|
616
|
+
return "未知天气"
|
|
617
|
+
return WEATHER_CODE_MAP.get(code_int, f"未知天气({code_int})")
|
|
618
|
+
|
|
619
|
+
|
|
620
|
+
def _is_rain_code(code: Optional[float]) -> bool:
|
|
621
|
+
if code is None:
|
|
622
|
+
return False
|
|
623
|
+
try:
|
|
624
|
+
code_int = int(code)
|
|
625
|
+
except (TypeError, ValueError):
|
|
626
|
+
return False
|
|
627
|
+
return code_int in {
|
|
628
|
+
51,
|
|
629
|
+
53,
|
|
630
|
+
55,
|
|
631
|
+
56,
|
|
632
|
+
57,
|
|
633
|
+
61,
|
|
634
|
+
63,
|
|
635
|
+
65,
|
|
636
|
+
66,
|
|
637
|
+
67,
|
|
638
|
+
80,
|
|
639
|
+
81,
|
|
640
|
+
82,
|
|
641
|
+
95,
|
|
642
|
+
96,
|
|
643
|
+
99,
|
|
644
|
+
}
|
|
645
|
+
|
|
646
|
+
|
|
647
|
+
def _is_snow_code(code: Optional[float]) -> bool:
|
|
648
|
+
if code is None:
|
|
649
|
+
return False
|
|
650
|
+
try:
|
|
651
|
+
code_int = int(code)
|
|
652
|
+
except (TypeError, ValueError):
|
|
653
|
+
return False
|
|
654
|
+
return code_int in {71, 73, 75, 77, 85, 86}
|
|
655
|
+
|
|
656
|
+
|
|
657
|
+
def _weather_clothing_advice(
|
|
658
|
+
temperature: Optional[float],
|
|
659
|
+
weather_code: Optional[float],
|
|
660
|
+
) -> str:
|
|
661
|
+
if temperature is None:
|
|
662
|
+
base = "注意增减衣物"
|
|
663
|
+
else:
|
|
664
|
+
try:
|
|
665
|
+
temp = float(temperature)
|
|
666
|
+
except (TypeError, ValueError):
|
|
667
|
+
base = "注意增减衣物"
|
|
668
|
+
else:
|
|
669
|
+
if temp >= 30:
|
|
670
|
+
base = "有点热 注意防晒"
|
|
671
|
+
elif temp >= 26:
|
|
672
|
+
base = "偏热 注意防晒"
|
|
673
|
+
elif temp >= 20:
|
|
674
|
+
base = "比较舒服 注意早晚温差"
|
|
675
|
+
elif temp >= 12:
|
|
676
|
+
base = "有点凉 注意保暖"
|
|
677
|
+
elif temp >= 5:
|
|
678
|
+
base = "偏冷 注意保暖"
|
|
679
|
+
else:
|
|
680
|
+
base = "很冷 注意保暖"
|
|
681
|
+
|
|
682
|
+
extras: List[str] = []
|
|
683
|
+
if _is_rain_code(weather_code):
|
|
684
|
+
extras.append("带伞")
|
|
685
|
+
if _is_snow_code(weather_code):
|
|
686
|
+
extras.append("注意防滑")
|
|
687
|
+
if extras:
|
|
688
|
+
return f"{base} {','.join(extras)}"
|
|
689
|
+
return base
|
|
690
|
+
|
|
691
|
+
|
|
692
|
+
def _normalize_weather_query(query: str) -> str:
|
|
693
|
+
cleaned = re.sub(r"(天气|气温|温度|湿度|风速|风力)", "", query or "")
|
|
694
|
+
cleaned = cleaned.strip(" ,,")
|
|
695
|
+
return cleaned or query
|
|
696
|
+
|
|
697
|
+
|
|
698
|
+
async def _build_weather_messages(query: str) -> List[str]:
|
|
699
|
+
normalized_query = _normalize_weather_query(query)
|
|
700
|
+
location = await _geocode_location(normalized_query)
|
|
701
|
+
if not location:
|
|
702
|
+
return [f"未找到地点:{query}"]
|
|
703
|
+
name = location.get("name") or normalized_query
|
|
704
|
+
admin1 = location.get("admin1")
|
|
705
|
+
country = location.get("country")
|
|
706
|
+
country_code = location.get("country_code")
|
|
707
|
+
is_domestic = str(country_code or "").upper() == "CN" or str(country or "") in {
|
|
708
|
+
"中国",
|
|
709
|
+
"中华人民共和国",
|
|
710
|
+
"China",
|
|
711
|
+
}
|
|
712
|
+
if is_domestic:
|
|
713
|
+
display_name = str(name)
|
|
714
|
+
else:
|
|
715
|
+
display_parts: List[str] = []
|
|
716
|
+
if country:
|
|
717
|
+
display_parts.append(str(country))
|
|
718
|
+
if admin1 and admin1 not in display_parts:
|
|
719
|
+
display_parts.append(str(admin1))
|
|
720
|
+
if name and name not in display_parts:
|
|
721
|
+
display_parts.append(str(name))
|
|
722
|
+
display_name = " ".join(display_parts) if display_parts else str(name)
|
|
723
|
+
lat = float(location["latitude"])
|
|
724
|
+
lon = float(location["longitude"])
|
|
725
|
+
data = await _fetch_current_weather(lat, lon)
|
|
726
|
+
if not data:
|
|
727
|
+
return ["天气服务返回异常,请稍后再试。"]
|
|
728
|
+
current = data.get("current", {}) if isinstance(data, dict) else {}
|
|
729
|
+
units = data.get("current_units", {}) if isinstance(data, dict) else {}
|
|
730
|
+
temp_unit = units.get("temperature_2m") or "°C"
|
|
731
|
+
wind_unit = units.get("wind_speed_10m") or "m/s"
|
|
732
|
+
temp_value = current.get("temperature_2m")
|
|
733
|
+
wind_value = current.get("wind_speed_10m")
|
|
734
|
+
weather_code = current.get("weather_code")
|
|
735
|
+
temp = _format_measure(temp_value, temp_unit)
|
|
736
|
+
wind_level = _wind_level_from_speed(wind_value, str(wind_unit))
|
|
737
|
+
code_desc = _weather_code_desc(weather_code)
|
|
738
|
+
advice = _weather_clothing_advice(temp_value, weather_code)
|
|
739
|
+
line2 = f"{display_name} 现在{temp} {code_desc} {wind_level}"
|
|
740
|
+
line3 = f"{advice}"
|
|
741
|
+
reply = _format_reply_text(f"{line2} {line3}")
|
|
742
|
+
return [reply] if reply else []
|
|
743
|
+
|
|
744
|
+
|
|
745
|
+
async def _geocode_location(query: str) -> Optional[dict]:
|
|
746
|
+
params = {"name": query, "count": 1, "language": "zh", "format": "json"}
|
|
747
|
+
async with httpx.AsyncClient(timeout=config.request_timeout) as client:
|
|
748
|
+
resp = await client.get("https://geocoding-api.open-meteo.com/v1/search", params=params)
|
|
749
|
+
resp.raise_for_status()
|
|
750
|
+
data = resp.json()
|
|
751
|
+
results = data.get("results") if isinstance(data, dict) else None
|
|
752
|
+
if not results:
|
|
753
|
+
return None
|
|
754
|
+
return results[0]
|
|
755
|
+
|
|
756
|
+
|
|
757
|
+
async def _fetch_current_weather(lat: float, lon: float) -> Optional[dict]:
|
|
758
|
+
params = {
|
|
759
|
+
"latitude": lat,
|
|
760
|
+
"longitude": lon,
|
|
761
|
+
"current": "temperature_2m,apparent_temperature,relative_humidity_2m,weather_code,wind_speed_10m",
|
|
762
|
+
"timezone": "auto",
|
|
763
|
+
}
|
|
764
|
+
async with httpx.AsyncClient(timeout=config.request_timeout) as client:
|
|
765
|
+
resp = await client.get("https://api.open-meteo.com/v1/forecast", params=params)
|
|
766
|
+
resp.raise_for_status()
|
|
767
|
+
data = resp.json()
|
|
768
|
+
if not isinstance(data, dict):
|
|
769
|
+
return None
|
|
770
|
+
return data
|
|
771
|
+
|
|
772
|
+
|
|
773
|
+
def _history_to_gemini(state: SessionState) -> List[types.Content]:
|
|
774
|
+
contents: List[types.Content] = []
|
|
775
|
+
for item in state.history:
|
|
776
|
+
text = item.text
|
|
777
|
+
if item.role == "user":
|
|
778
|
+
name = _normalize_user_name(item.user_name) or _normalize_user_name(item.user_id)
|
|
779
|
+
if name:
|
|
780
|
+
text = f"{name}: {text}"
|
|
781
|
+
contents.append(
|
|
782
|
+
types.Content(
|
|
783
|
+
role=item.role,
|
|
784
|
+
parts=[types.Part.from_text(text=text)],
|
|
785
|
+
)
|
|
786
|
+
)
|
|
787
|
+
return contents
|
|
788
|
+
|
|
789
|
+
|
|
790
|
+
def _generate_config_fields() -> Optional[set[str]]:
|
|
791
|
+
fields = getattr(types.GenerateContentConfig, "model_fields", None)
|
|
792
|
+
if isinstance(fields, dict):
|
|
793
|
+
return set(fields.keys())
|
|
794
|
+
fields = getattr(types.GenerateContentConfig, "__fields__", None)
|
|
795
|
+
if isinstance(fields, dict):
|
|
796
|
+
return set(fields.keys())
|
|
797
|
+
return None
|
|
798
|
+
|
|
799
|
+
|
|
800
|
+
def _build_generate_config(
|
|
801
|
+
*,
|
|
802
|
+
system_instruction: Optional[str] = None,
|
|
803
|
+
response_mime_type: Optional[str] = None,
|
|
804
|
+
response_modalities: Optional[List[str]] = None,
|
|
805
|
+
) -> Tuple[Optional[types.GenerateContentConfigOrDict], bool]:
|
|
806
|
+
fields = _generate_config_fields()
|
|
807
|
+
allow_system = bool(system_instruction) and (
|
|
808
|
+
fields is None or "system_instruction" in fields
|
|
809
|
+
)
|
|
810
|
+
allow_mime = bool(response_mime_type) and (
|
|
811
|
+
fields is None or "response_mime_type" in fields
|
|
812
|
+
)
|
|
813
|
+
allow_modalities = bool(response_modalities) and (
|
|
814
|
+
fields is None or "response_modalities" in fields
|
|
815
|
+
)
|
|
816
|
+
if not allow_system and not allow_mime and not allow_modalities:
|
|
817
|
+
return None, False
|
|
818
|
+
config_obj: dict[str, object] = {}
|
|
819
|
+
system_used = False
|
|
820
|
+
if allow_system:
|
|
821
|
+
config_obj["system_instruction"] = system_instruction
|
|
822
|
+
system_used = True
|
|
823
|
+
if allow_mime:
|
|
824
|
+
config_obj["response_mime_type"] = response_mime_type
|
|
825
|
+
if allow_modalities:
|
|
826
|
+
config_obj["response_modalities"] = response_modalities
|
|
827
|
+
if not config_obj:
|
|
828
|
+
return None, False
|
|
829
|
+
return cast(types.GenerateContentConfigOrDict, config_obj), system_used
|
|
830
|
+
|
|
831
|
+
|
|
832
|
+
def _iter_response_parts(response: object) -> List[object]:
|
|
833
|
+
parts: List[object] = []
|
|
834
|
+
candidates = getattr(response, "candidates", None)
|
|
835
|
+
if candidates:
|
|
836
|
+
for cand in candidates:
|
|
837
|
+
content = getattr(cand, "content", None)
|
|
838
|
+
cand_parts = getattr(content, "parts", None) if content else None
|
|
839
|
+
if cand_parts:
|
|
840
|
+
parts.extend(cand_parts)
|
|
841
|
+
if not parts:
|
|
842
|
+
direct_parts = getattr(response, "parts", None)
|
|
843
|
+
if direct_parts:
|
|
844
|
+
parts.extend(direct_parts)
|
|
845
|
+
return parts
|
|
846
|
+
|
|
847
|
+
|
|
848
|
+
def _extract_inline_data(part: object) -> Optional[object]:
|
|
849
|
+
if isinstance(part, dict):
|
|
850
|
+
return part.get("inline_data") or part.get("inlineData")
|
|
851
|
+
return getattr(part, "inline_data", None)
|
|
852
|
+
|
|
853
|
+
|
|
854
|
+
def _extract_text_value(part: object) -> Optional[str]:
|
|
855
|
+
if isinstance(part, dict):
|
|
856
|
+
value = part.get("text")
|
|
857
|
+
return value if isinstance(value, str) else None
|
|
858
|
+
value = getattr(part, "text", None)
|
|
859
|
+
return value if isinstance(value, str) else None
|
|
860
|
+
|
|
861
|
+
|
|
862
|
+
async def _call_gemini_text(prompt: str, state: SessionState) -> str:
|
|
863
|
+
client = _get_client()
|
|
864
|
+
contents = _history_to_gemini(state)
|
|
865
|
+
config_obj, system_used = _build_generate_config(system_instruction=_CHAT_SYSTEM_PROMPT)
|
|
866
|
+
if _CHAT_SYSTEM_PROMPT and not system_used:
|
|
867
|
+
contents.insert(
|
|
868
|
+
0,
|
|
869
|
+
types.Content(
|
|
870
|
+
role="user",
|
|
871
|
+
parts=[types.Part.from_text(text=_CHAT_SYSTEM_PROMPT)],
|
|
872
|
+
),
|
|
873
|
+
)
|
|
874
|
+
contents.append(types.Content(role="user", parts=[types.Part.from_text(text=prompt)]))
|
|
875
|
+
response = await asyncio.wait_for(
|
|
876
|
+
client.aio.models.generate_content(
|
|
877
|
+
model=config.gemini_text_model,
|
|
878
|
+
contents=contents,
|
|
879
|
+
config=config_obj,
|
|
880
|
+
),
|
|
881
|
+
timeout=config.request_timeout,
|
|
882
|
+
)
|
|
883
|
+
if config.gemini_log_response:
|
|
884
|
+
logger.info("Gemini text response: {}", _dump_response(response))
|
|
885
|
+
_log_response_text("Gemini text content", response)
|
|
886
|
+
if response.text:
|
|
887
|
+
cleaned = _format_reply_text(response.text.strip())
|
|
888
|
+
cleaned = _compact_reply_lines(cleaned)
|
|
889
|
+
cleaned = _limit_reply_text(cleaned)
|
|
890
|
+
return cleaned
|
|
891
|
+
text_parts: List[str] = []
|
|
892
|
+
for part in _iter_response_parts(response):
|
|
893
|
+
if getattr(part, "text", None):
|
|
894
|
+
text_parts.append(getattr(part, "text"))
|
|
895
|
+
cleaned = _format_reply_text("\n".join(text_parts).strip())
|
|
896
|
+
cleaned = _compact_reply_lines(cleaned)
|
|
897
|
+
cleaned = _limit_reply_text(cleaned)
|
|
898
|
+
return cleaned
|
|
899
|
+
|
|
900
|
+
|
|
901
|
+
def _build_travel_prompt(intent: dict) -> str:
|
|
902
|
+
params = _intent_params(intent)
|
|
903
|
+
destination = params.get("destination") or ""
|
|
904
|
+
instruction = str(intent.get("instruction") or "").strip()
|
|
905
|
+
dest_text = str(destination).strip()
|
|
906
|
+
cleaned_instruction = _strip_travel_duration(instruction)
|
|
907
|
+
parts = [_TRAVEL_SYSTEM_PROMPT.strip()]
|
|
908
|
+
if dest_text:
|
|
909
|
+
parts.append(f"请规划{dest_text}旅行行程。")
|
|
910
|
+
else:
|
|
911
|
+
parts.append("请规划旅行行程。")
|
|
912
|
+
if cleaned_instruction:
|
|
913
|
+
parts.append(f"需求补充:{cleaned_instruction}")
|
|
914
|
+
parts.append(
|
|
915
|
+
"输出要求:纯文本,结构清晰,可自然换行,包含景点/活动/用餐/交通/住宿要点。"
|
|
916
|
+
)
|
|
917
|
+
return "\n".join(parts)
|
|
918
|
+
|
|
919
|
+
|
|
920
|
+
async def _call_gemini_travel_plan(intent: dict, state: SessionState) -> str:
|
|
921
|
+
client = _get_client()
|
|
922
|
+
contents = _history_to_gemini(state)
|
|
923
|
+
prompt = _build_travel_prompt(intent)
|
|
924
|
+
contents.append(types.Content(role="user", parts=[types.Part.from_text(text=prompt)]))
|
|
925
|
+
config_obj, _ = _build_generate_config()
|
|
926
|
+
response = await asyncio.wait_for(
|
|
927
|
+
client.aio.models.generate_content(
|
|
928
|
+
model=config.gemini_text_model,
|
|
929
|
+
contents=contents,
|
|
930
|
+
config=config_obj,
|
|
931
|
+
),
|
|
932
|
+
timeout=config.request_timeout,
|
|
933
|
+
)
|
|
934
|
+
if config.gemini_log_response:
|
|
935
|
+
logger.info("Gemini travel response: {}", _dump_response(response))
|
|
936
|
+
_log_response_text("Gemini travel content", response)
|
|
937
|
+
if response.text:
|
|
938
|
+
cleaned = _format_reply_text(response.text.strip())
|
|
939
|
+
cleaned = _limit_reply_text(cleaned)
|
|
940
|
+
return cleaned
|
|
941
|
+
text_parts: List[str] = []
|
|
942
|
+
for part in _iter_response_parts(response):
|
|
943
|
+
if getattr(part, "text", None):
|
|
944
|
+
text_parts.append(getattr(part, "text"))
|
|
945
|
+
cleaned = _format_reply_text("\n".join(text_parts).strip())
|
|
946
|
+
cleaned = _limit_reply_text(cleaned)
|
|
947
|
+
return cleaned
|
|
948
|
+
|
|
949
|
+
|
|
950
|
+
async def _download_image_bytes(url: str) -> Tuple[str, bytes]:
|
|
951
|
+
async with httpx.AsyncClient(timeout=config.request_timeout) as client:
|
|
952
|
+
resp = await client.get(url)
|
|
953
|
+
resp.raise_for_status()
|
|
954
|
+
content_type = resp.headers.get("content-type", "image/jpeg")
|
|
955
|
+
data = resp.content
|
|
956
|
+
if isinstance(content_type, str) and content_type.lower().startswith("image/gif"):
|
|
957
|
+
raise UnsupportedImageError("不支持动图")
|
|
958
|
+
return content_type, data
|
|
959
|
+
|
|
960
|
+
|
|
961
|
+
async def _call_gemini_image(prompt: str, image_url: str, state: SessionState) -> Tuple[bool, str]:
|
|
962
|
+
client = _get_client()
|
|
963
|
+
content_type, image_bytes = await _download_image_bytes(image_url)
|
|
964
|
+
contents = _history_to_gemini(state)
|
|
965
|
+
contents.append(
|
|
966
|
+
types.Content(
|
|
967
|
+
role="user",
|
|
968
|
+
parts=[
|
|
969
|
+
types.Part.from_text(text=prompt),
|
|
970
|
+
types.Part.from_bytes(data=image_bytes, mime_type=content_type),
|
|
971
|
+
],
|
|
972
|
+
)
|
|
973
|
+
)
|
|
974
|
+
|
|
975
|
+
config_obj, _ = _build_generate_config(response_modalities=["TEXT", "IMAGE"])
|
|
976
|
+
response = await asyncio.wait_for(
|
|
977
|
+
client.aio.models.generate_content(
|
|
978
|
+
model=config.gemini_image_model,
|
|
979
|
+
contents=contents,
|
|
980
|
+
config=config_obj,
|
|
981
|
+
),
|
|
982
|
+
timeout=config.image_timeout,
|
|
983
|
+
)
|
|
984
|
+
if config.gemini_log_response:
|
|
985
|
+
logger.info("Gemini image response: {}", _dump_response(response))
|
|
986
|
+
_log_response_text("Gemini image content", response)
|
|
987
|
+
|
|
988
|
+
for part in _iter_response_parts(response):
|
|
989
|
+
inline_data = _extract_inline_data(part)
|
|
990
|
+
text_value = _extract_text_value(part)
|
|
991
|
+
if inline_data:
|
|
992
|
+
if isinstance(inline_data, dict):
|
|
993
|
+
data = inline_data.get("data")
|
|
994
|
+
else:
|
|
995
|
+
data = getattr(inline_data, "data", None)
|
|
996
|
+
if isinstance(data, bytes):
|
|
997
|
+
return True, base64.b64encode(data).decode("ascii")
|
|
998
|
+
if isinstance(data, str):
|
|
999
|
+
return True, data
|
|
1000
|
+
if text_value:
|
|
1001
|
+
cleaned = _format_reply_text(text_value)
|
|
1002
|
+
cleaned = _limit_reply_text(cleaned)
|
|
1003
|
+
return False, cleaned or "(没有生成到有效文本)"
|
|
1004
|
+
if getattr(response, "text", None):
|
|
1005
|
+
cleaned = _format_reply_text(getattr(response, "text"))
|
|
1006
|
+
cleaned = _limit_reply_text(cleaned)
|
|
1007
|
+
return False, cleaned or "(没有生成到有效文本)"
|
|
1008
|
+
raise RuntimeError("未获取到有效图片结果")
|
|
1009
|
+
|
|
1010
|
+
|
|
1011
|
+
async def _call_gemini_image_chat(prompt: str, image_url: str, state: SessionState) -> str:
|
|
1012
|
+
client = _get_client()
|
|
1013
|
+
content_type, image_bytes = await _download_image_bytes(image_url)
|
|
1014
|
+
contents = _history_to_gemini(state)
|
|
1015
|
+
config_obj, system_used = _build_generate_config(
|
|
1016
|
+
system_instruction=_IMAGE_CHAT_SYSTEM_PROMPT,
|
|
1017
|
+
response_modalities=["TEXT"],
|
|
1018
|
+
)
|
|
1019
|
+
if _IMAGE_CHAT_SYSTEM_PROMPT and not system_used:
|
|
1020
|
+
contents.insert(
|
|
1021
|
+
0,
|
|
1022
|
+
types.Content(
|
|
1023
|
+
role="user",
|
|
1024
|
+
parts=[types.Part.from_text(text=_IMAGE_CHAT_SYSTEM_PROMPT)],
|
|
1025
|
+
),
|
|
1026
|
+
)
|
|
1027
|
+
contents.append(
|
|
1028
|
+
types.Content(
|
|
1029
|
+
role="user",
|
|
1030
|
+
parts=[
|
|
1031
|
+
types.Part.from_text(text=prompt),
|
|
1032
|
+
types.Part.from_bytes(data=image_bytes, mime_type=content_type),
|
|
1033
|
+
],
|
|
1034
|
+
)
|
|
1035
|
+
)
|
|
1036
|
+
response = await asyncio.wait_for(
|
|
1037
|
+
client.aio.models.generate_content(
|
|
1038
|
+
model=config.gemini_image_model,
|
|
1039
|
+
contents=contents,
|
|
1040
|
+
config=config_obj,
|
|
1041
|
+
),
|
|
1042
|
+
timeout=config.image_timeout,
|
|
1043
|
+
)
|
|
1044
|
+
if config.gemini_log_response:
|
|
1045
|
+
logger.info("Gemini image chat response: {}", _dump_response(response))
|
|
1046
|
+
_log_response_text("Gemini image chat content", response)
|
|
1047
|
+
if response.text:
|
|
1048
|
+
cleaned = _format_reply_text(response.text.strip())
|
|
1049
|
+
cleaned = _limit_reply_text(cleaned)
|
|
1050
|
+
return cleaned
|
|
1051
|
+
text_parts: List[str] = []
|
|
1052
|
+
for part in _iter_response_parts(response):
|
|
1053
|
+
text_value = _extract_text_value(part)
|
|
1054
|
+
if text_value:
|
|
1055
|
+
text_parts.append(text_value)
|
|
1056
|
+
cleaned = _format_reply_text("\n".join(text_parts).strip())
|
|
1057
|
+
cleaned = _limit_reply_text(cleaned)
|
|
1058
|
+
return cleaned
|
|
1059
|
+
|
|
1060
|
+
|
|
1061
|
+
async def _call_gemini_text_to_image(prompt: str, state: SessionState) -> Tuple[bool, str]:
|
|
1062
|
+
client = _get_client()
|
|
1063
|
+
contents = _history_to_gemini(state)
|
|
1064
|
+
contents.append(types.Content(role="user", parts=[types.Part.from_text(text=prompt)]))
|
|
1065
|
+
config_obj, _ = _build_generate_config(response_modalities=["IMAGE"])
|
|
1066
|
+
response = await asyncio.wait_for(
|
|
1067
|
+
client.aio.models.generate_content(
|
|
1068
|
+
model=config.gemini_image_model,
|
|
1069
|
+
contents=contents,
|
|
1070
|
+
config=config_obj,
|
|
1071
|
+
),
|
|
1072
|
+
timeout=config.image_timeout,
|
|
1073
|
+
)
|
|
1074
|
+
if config.gemini_log_response:
|
|
1075
|
+
logger.info("Gemini text-to-image response: {}", _dump_response(response))
|
|
1076
|
+
_log_response_text("Gemini text-to-image content", response)
|
|
1077
|
+
for part in _iter_response_parts(response):
|
|
1078
|
+
inline_data = _extract_inline_data(part)
|
|
1079
|
+
text_value = _extract_text_value(part)
|
|
1080
|
+
if inline_data:
|
|
1081
|
+
if isinstance(inline_data, dict):
|
|
1082
|
+
data = inline_data.get("data")
|
|
1083
|
+
else:
|
|
1084
|
+
data = getattr(inline_data, "data", None)
|
|
1085
|
+
if isinstance(data, bytes):
|
|
1086
|
+
return True, base64.b64encode(data).decode("ascii")
|
|
1087
|
+
if isinstance(data, str):
|
|
1088
|
+
return True, data
|
|
1089
|
+
if text_value:
|
|
1090
|
+
cleaned = _format_reply_text(text_value)
|
|
1091
|
+
cleaned = _limit_reply_text(cleaned)
|
|
1092
|
+
return False, cleaned or "(没有生成到有效文本)"
|
|
1093
|
+
if getattr(response, "text", None):
|
|
1094
|
+
cleaned = _format_reply_text(getattr(response, "text"))
|
|
1095
|
+
cleaned = _limit_reply_text(cleaned)
|
|
1096
|
+
return False, cleaned or "(没有生成到有效文本)"
|
|
1097
|
+
raise RuntimeError("未获取到有效图片结果")
|
|
1098
|
+
|
|
1099
|
+
|
|
1100
|
+
def _image_segment_from_result(result: str) -> MessageSegment:
|
|
1101
|
+
if not result:
|
|
1102
|
+
raise RuntimeError("图片结果为空")
|
|
1103
|
+
if result.startswith("http://") or result.startswith("https://"):
|
|
1104
|
+
return MessageSegment.image(result)
|
|
1105
|
+
if result.startswith("base64://"):
|
|
1106
|
+
return MessageSegment.image(result)
|
|
1107
|
+
if result.startswith("data:image"):
|
|
1108
|
+
return MessageSegment.image(result)
|
|
1109
|
+
return MessageSegment.image(f"base64://{result}")
|
|
1110
|
+
|
|
1111
|
+
|
|
1112
|
+
def _append_history(
|
|
1113
|
+
state: SessionState,
|
|
1114
|
+
role: str,
|
|
1115
|
+
text: str,
|
|
1116
|
+
*,
|
|
1117
|
+
user_id: Optional[str] = None,
|
|
1118
|
+
user_name: Optional[str] = None,
|
|
1119
|
+
to_bot: bool = False,
|
|
1120
|
+
ts: Optional[float] = None,
|
|
1121
|
+
message_id: Optional[int] = None,
|
|
1122
|
+
) -> None:
|
|
1123
|
+
state.history.append(
|
|
1124
|
+
HistoryItem(
|
|
1125
|
+
role=role,
|
|
1126
|
+
text=text,
|
|
1127
|
+
ts=_now() if ts is None else ts,
|
|
1128
|
+
user_id=user_id,
|
|
1129
|
+
user_name=user_name,
|
|
1130
|
+
to_bot=to_bot,
|
|
1131
|
+
message_id=message_id,
|
|
1132
|
+
)
|
|
1133
|
+
)
|
|
1134
|
+
_prune_state(state)
|
|
1135
|
+
|
|
1136
|
+
|
|
1137
|
+
history_collector = on_message(priority=99, block=False)
|
|
1138
|
+
nlp_handler = on_message(priority=15, block=False)
|
|
1139
|
+
avatar_handler = on_command("处理头像", priority=5)
|
|
1140
|
+
chat_handler = on_command("技能", aliases={"聊天", "对话"}, priority=5)
|
|
1141
|
+
weather_handler = on_command("天气", aliases={"查询天气", "查天气"}, priority=5)
|
|
1142
|
+
travel_handler = on_command("旅行规划", aliases={"旅行计划", "行程规划", "旅行", "行程"}, priority=5)
|
|
1143
|
+
|
|
1144
|
+
|
|
1145
|
+
@history_collector.handle()
|
|
1146
|
+
async def _collect_history(event: MessageEvent):
|
|
1147
|
+
session_id = _session_id(event)
|
|
1148
|
+
state = _get_state(session_id)
|
|
1149
|
+
|
|
1150
|
+
text = event.get_plaintext().strip()
|
|
1151
|
+
image_url = _extract_first_image_url(event.get_message())
|
|
1152
|
+
if image_url:
|
|
1153
|
+
state.last_image_url = image_url
|
|
1154
|
+
msg_id = getattr(event, "message_id", None)
|
|
1155
|
+
if isinstance(msg_id, int):
|
|
1156
|
+
state.image_cache[msg_id] = (image_url, _event_ts(event))
|
|
1157
|
+
_notify_pending_image(state, str(event.get_user_id()), image_url)
|
|
1158
|
+
|
|
1159
|
+
if text:
|
|
1160
|
+
user_name = _event_user_name(event)
|
|
1161
|
+
_append_history(
|
|
1162
|
+
state,
|
|
1163
|
+
"user",
|
|
1164
|
+
text,
|
|
1165
|
+
user_id=str(event.get_user_id()),
|
|
1166
|
+
user_name=user_name,
|
|
1167
|
+
to_bot=_should_trigger_nlp(event, text),
|
|
1168
|
+
ts=_event_ts(event),
|
|
1169
|
+
message_id=getattr(event, "message_id", None),
|
|
1170
|
+
)
|
|
1171
|
+
|
|
1172
|
+
|
|
1173
|
+
def _is_command_message(text: str) -> bool:
|
|
1174
|
+
text = text.strip()
|
|
1175
|
+
if not text:
|
|
1176
|
+
return False
|
|
1177
|
+
try:
|
|
1178
|
+
starts = list(get_driver().config.command_start or [])
|
|
1179
|
+
except Exception:
|
|
1180
|
+
starts = ["/"]
|
|
1181
|
+
if not starts:
|
|
1182
|
+
return False
|
|
1183
|
+
command_words = [
|
|
1184
|
+
"处理头像",
|
|
1185
|
+
"技能",
|
|
1186
|
+
"聊天",
|
|
1187
|
+
"对话",
|
|
1188
|
+
"天气",
|
|
1189
|
+
"查询天气",
|
|
1190
|
+
"查天气",
|
|
1191
|
+
"旅行规划",
|
|
1192
|
+
"旅行计划",
|
|
1193
|
+
"行程规划",
|
|
1194
|
+
"旅行",
|
|
1195
|
+
"行程",
|
|
1196
|
+
]
|
|
1197
|
+
for prefix in starts:
|
|
1198
|
+
if not prefix:
|
|
1199
|
+
continue
|
|
1200
|
+
for word in command_words:
|
|
1201
|
+
if text.startswith(prefix + word):
|
|
1202
|
+
return True
|
|
1203
|
+
return False
|
|
1204
|
+
|
|
1205
|
+
|
|
1206
|
+
def _match_keyword(text: str) -> Optional[str]:
|
|
1207
|
+
for kw in config.bot_keywords:
|
|
1208
|
+
if kw and kw in text:
|
|
1209
|
+
return kw
|
|
1210
|
+
return None
|
|
1211
|
+
|
|
1212
|
+
|
|
1213
|
+
def _is_at_bot(event: MessageEvent) -> bool:
|
|
1214
|
+
message = event.get_message()
|
|
1215
|
+
for seg in message:
|
|
1216
|
+
if seg.type == "at":
|
|
1217
|
+
qq = seg.data.get("qq")
|
|
1218
|
+
if qq and str(qq) == str(event.self_id):
|
|
1219
|
+
return True
|
|
1220
|
+
return False
|
|
1221
|
+
|
|
1222
|
+
|
|
1223
|
+
def _is_reply_to_bot(event: MessageEvent) -> bool:
|
|
1224
|
+
reply = getattr(event, "reply", None)
|
|
1225
|
+
if not reply:
|
|
1226
|
+
return False
|
|
1227
|
+
sender = getattr(reply, "sender", None)
|
|
1228
|
+
sender_id = getattr(sender, "user_id", None)
|
|
1229
|
+
if sender_id is None:
|
|
1230
|
+
return False
|
|
1231
|
+
return str(sender_id) == str(event.self_id)
|
|
1232
|
+
|
|
1233
|
+
|
|
1234
|
+
def _should_trigger_nlp(event: MessageEvent, text: str) -> bool:
|
|
1235
|
+
if isinstance(event, GroupMessageEvent):
|
|
1236
|
+
try:
|
|
1237
|
+
if event.is_tome():
|
|
1238
|
+
return True
|
|
1239
|
+
except Exception:
|
|
1240
|
+
if _is_at_bot(event):
|
|
1241
|
+
return True
|
|
1242
|
+
if _is_reply_to_bot(event):
|
|
1243
|
+
return True
|
|
1244
|
+
return _match_keyword(text) is not None
|
|
1245
|
+
return True
|
|
1246
|
+
|
|
1247
|
+
|
|
1248
|
+
def _extract_reply_context(
|
|
1249
|
+
event: MessageEvent,
|
|
1250
|
+
state: SessionState,
|
|
1251
|
+
) -> Tuple[Optional[str], Optional[str]]:
|
|
1252
|
+
reply = getattr(event, "reply", None)
|
|
1253
|
+
if not reply:
|
|
1254
|
+
return None, None
|
|
1255
|
+
reply_id = getattr(reply, "message_id", None)
|
|
1256
|
+
if reply_id is not None:
|
|
1257
|
+
for item in reversed(state.history):
|
|
1258
|
+
if item.message_id == reply_id:
|
|
1259
|
+
return item.text, (item.user_name or item.user_id)
|
|
1260
|
+
reply_message = getattr(reply, "message", None)
|
|
1261
|
+
if reply_message:
|
|
1262
|
+
try:
|
|
1263
|
+
text = reply_message.extract_plain_text().strip()
|
|
1264
|
+
except Exception:
|
|
1265
|
+
text = None
|
|
1266
|
+
if text:
|
|
1267
|
+
sender_name = _sender_user_name(getattr(reply, "sender", None))
|
|
1268
|
+
return text, sender_name or None
|
|
1269
|
+
sender_name = _sender_user_name(getattr(reply, "sender", None))
|
|
1270
|
+
return None, sender_name or None
|
|
1271
|
+
|
|
1272
|
+
|
|
1273
|
+
def _extract_reply_image_url(event: MessageEvent, state: SessionState) -> Optional[str]:
|
|
1274
|
+
reply = getattr(event, "reply", None)
|
|
1275
|
+
if not reply:
|
|
1276
|
+
return None
|
|
1277
|
+
reply_message = getattr(reply, "message", None)
|
|
1278
|
+
if reply_message:
|
|
1279
|
+
url = _extract_first_image_url(reply_message)
|
|
1280
|
+
if url:
|
|
1281
|
+
return url
|
|
1282
|
+
reply_id = getattr(reply, "message_id", None)
|
|
1283
|
+
if reply_id is not None:
|
|
1284
|
+
cached = state.image_cache.get(int(reply_id))
|
|
1285
|
+
if cached:
|
|
1286
|
+
return cached[0]
|
|
1287
|
+
return None
|
|
1288
|
+
|
|
1289
|
+
|
|
1290
|
+
def _coerce_int(value: object) -> Optional[int]:
|
|
1291
|
+
try:
|
|
1292
|
+
if isinstance(value, bool):
|
|
1293
|
+
return None
|
|
1294
|
+
if isinstance(value, int):
|
|
1295
|
+
return value
|
|
1296
|
+
if isinstance(value, str) and value.strip().isdigit():
|
|
1297
|
+
return int(value.strip())
|
|
1298
|
+
except Exception:
|
|
1299
|
+
return None
|
|
1300
|
+
return None
|
|
1301
|
+
|
|
1302
|
+
|
|
1303
|
+
async def _resolve_image_url(
|
|
1304
|
+
intent: dict,
|
|
1305
|
+
*,
|
|
1306
|
+
event: MessageEvent,
|
|
1307
|
+
state: SessionState,
|
|
1308
|
+
current_image_url: Optional[str],
|
|
1309
|
+
reply_image_url: Optional[str],
|
|
1310
|
+
at_user: Optional[str],
|
|
1311
|
+
) -> Optional[str]:
|
|
1312
|
+
target = str(intent.get("target") or "").lower()
|
|
1313
|
+
params = _intent_params(intent)
|
|
1314
|
+
user_id = str(event.get_user_id())
|
|
1315
|
+
|
|
1316
|
+
if target == "message_image":
|
|
1317
|
+
return current_image_url
|
|
1318
|
+
if target == "reply_image":
|
|
1319
|
+
return reply_image_url
|
|
1320
|
+
if target == "at_user":
|
|
1321
|
+
return _avatar_url(at_user) if at_user else None
|
|
1322
|
+
if target == "last_image":
|
|
1323
|
+
return state.last_image_url
|
|
1324
|
+
if target == "sender_avatar":
|
|
1325
|
+
return _avatar_url(user_id)
|
|
1326
|
+
if target == "group_avatar":
|
|
1327
|
+
if isinstance(event, GroupMessageEvent):
|
|
1328
|
+
return _group_avatar_url(int(event.group_id))
|
|
1329
|
+
return None
|
|
1330
|
+
if target == "qq_avatar":
|
|
1331
|
+
qq = params.get("qq")
|
|
1332
|
+
if qq:
|
|
1333
|
+
return _avatar_url(str(qq))
|
|
1334
|
+
return None
|
|
1335
|
+
if target == "message_id":
|
|
1336
|
+
msg_id = _coerce_int(params.get("message_id"))
|
|
1337
|
+
if msg_id is None:
|
|
1338
|
+
return None
|
|
1339
|
+
cached = state.image_cache.get(msg_id)
|
|
1340
|
+
return cached[0] if cached else None
|
|
1341
|
+
if target == "wait_next":
|
|
1342
|
+
return await _wait_next_image(state, user_id, _WAIT_NEXT_IMAGE_SEC)
|
|
1343
|
+
return None
|
|
1344
|
+
|
|
1345
|
+
|
|
1346
|
+
def _collect_context_messages(
|
|
1347
|
+
state: SessionState,
|
|
1348
|
+
current_user_id: str,
|
|
1349
|
+
*,
|
|
1350
|
+
ts: float,
|
|
1351
|
+
limit: int,
|
|
1352
|
+
future: bool,
|
|
1353
|
+
current_text: str,
|
|
1354
|
+
) -> List[str]:
|
|
1355
|
+
if limit <= 0:
|
|
1356
|
+
return []
|
|
1357
|
+
texts: List[str] = []
|
|
1358
|
+
items = state.history if future else reversed(state.history)
|
|
1359
|
+
for item in items:
|
|
1360
|
+
if item.role != "user":
|
|
1361
|
+
continue
|
|
1362
|
+
if not item.to_bot:
|
|
1363
|
+
continue
|
|
1364
|
+
if future and item.ts <= ts:
|
|
1365
|
+
continue
|
|
1366
|
+
if not future and item.ts > ts:
|
|
1367
|
+
continue
|
|
1368
|
+
if item.text == current_text and item.user_id == current_user_id:
|
|
1369
|
+
continue
|
|
1370
|
+
line = _format_context_line(item.text, item.user_name or item.user_id)
|
|
1371
|
+
texts.append(line)
|
|
1372
|
+
if len(texts) >= limit:
|
|
1373
|
+
break
|
|
1374
|
+
if future:
|
|
1375
|
+
return texts
|
|
1376
|
+
return list(reversed(texts))
|
|
1377
|
+
|
|
1378
|
+
|
|
1379
|
+
def _notify_pending_image(state: SessionState, user_id: str, image_url: str) -> None:
|
|
1380
|
+
waiter = state.pending_image_waiters.pop(user_id, None)
|
|
1381
|
+
if waiter and not waiter.done():
|
|
1382
|
+
waiter.set_result(image_url)
|
|
1383
|
+
|
|
1384
|
+
|
|
1385
|
+
async def _wait_next_image(
|
|
1386
|
+
state: SessionState,
|
|
1387
|
+
user_id: str,
|
|
1388
|
+
timeout_sec: float,
|
|
1389
|
+
) -> Optional[str]:
|
|
1390
|
+
waiter = state.pending_image_waiters.get(user_id)
|
|
1391
|
+
if waiter and not waiter.done():
|
|
1392
|
+
waiter.cancel()
|
|
1393
|
+
loop = asyncio.get_running_loop()
|
|
1394
|
+
future: asyncio.Future[str] = loop.create_future()
|
|
1395
|
+
state.pending_image_waiters[user_id] = future
|
|
1396
|
+
try:
|
|
1397
|
+
return await asyncio.wait_for(future, timeout=timeout_sec)
|
|
1398
|
+
except Exception:
|
|
1399
|
+
return None
|
|
1400
|
+
finally:
|
|
1401
|
+
current = state.pending_image_waiters.get(user_id)
|
|
1402
|
+
if current is future:
|
|
1403
|
+
state.pending_image_waiters.pop(user_id, None)
|
|
1404
|
+
|
|
1405
|
+
|
|
1406
|
+
async def _build_intent_text(
|
|
1407
|
+
event: MessageEvent,
|
|
1408
|
+
state: SessionState,
|
|
1409
|
+
text: str,
|
|
1410
|
+
) -> str:
|
|
1411
|
+
try:
|
|
1412
|
+
max_prev = max(0, int(getattr(config, "nlp_context_history_messages", 2)))
|
|
1413
|
+
except Exception:
|
|
1414
|
+
max_prev = 2
|
|
1415
|
+
try:
|
|
1416
|
+
max_future = max(0, int(getattr(config, "nlp_context_future_messages", 2)))
|
|
1417
|
+
except Exception:
|
|
1418
|
+
max_future = 2
|
|
1419
|
+
try:
|
|
1420
|
+
wait_sec = max(0.0, float(getattr(config, "nlp_context_future_wait_sec", 1.0)))
|
|
1421
|
+
except Exception:
|
|
1422
|
+
wait_sec = 1.0
|
|
1423
|
+
|
|
1424
|
+
ts = _event_ts(event)
|
|
1425
|
+
user_id = str(event.get_user_id())
|
|
1426
|
+
reply_text, reply_name = _extract_reply_context(event, state)
|
|
1427
|
+
|
|
1428
|
+
prev_texts = _collect_context_messages(
|
|
1429
|
+
state,
|
|
1430
|
+
user_id,
|
|
1431
|
+
ts=ts,
|
|
1432
|
+
limit=max_prev,
|
|
1433
|
+
future=False,
|
|
1434
|
+
current_text=text,
|
|
1435
|
+
)
|
|
1436
|
+
future_texts: List[str] = []
|
|
1437
|
+
if max_future > 0:
|
|
1438
|
+
if wait_sec > 0:
|
|
1439
|
+
await asyncio.sleep(wait_sec)
|
|
1440
|
+
future_texts = _collect_context_messages(
|
|
1441
|
+
state,
|
|
1442
|
+
user_id,
|
|
1443
|
+
ts=ts,
|
|
1444
|
+
limit=max_future,
|
|
1445
|
+
future=True,
|
|
1446
|
+
current_text=text,
|
|
1447
|
+
)
|
|
1448
|
+
|
|
1449
|
+
reply_line = ""
|
|
1450
|
+
if reply_text:
|
|
1451
|
+
reply_line = (
|
|
1452
|
+
_format_context_line(reply_text, reply_name)
|
|
1453
|
+
if reply_name
|
|
1454
|
+
else f"回复内容: {reply_text}"
|
|
1455
|
+
)
|
|
1456
|
+
combined = [
|
|
1457
|
+
part
|
|
1458
|
+
for part in [text, reply_line, *prev_texts, *future_texts]
|
|
1459
|
+
if part
|
|
1460
|
+
]
|
|
1461
|
+
if not combined:
|
|
1462
|
+
return text
|
|
1463
|
+
return "\n".join(combined)
|
|
1464
|
+
|
|
1465
|
+
|
|
1466
|
+
def _build_primary_intent_text(
|
|
1467
|
+
event: MessageEvent,
|
|
1468
|
+
state: SessionState,
|
|
1469
|
+
text: str,
|
|
1470
|
+
) -> str:
|
|
1471
|
+
reply_text, reply_name = _extract_reply_context(event, state)
|
|
1472
|
+
if not reply_text:
|
|
1473
|
+
return text
|
|
1474
|
+
if reply_text.strip() == text.strip():
|
|
1475
|
+
return text
|
|
1476
|
+
reply_line = (
|
|
1477
|
+
_format_context_line(reply_text, reply_name)
|
|
1478
|
+
if reply_name
|
|
1479
|
+
else f"回复内容: {reply_text}"
|
|
1480
|
+
)
|
|
1481
|
+
return "\n".join([text, reply_line])
|
|
1482
|
+
|
|
1483
|
+
|
|
1484
|
+
_ALLOWED_ACTIONS = {
|
|
1485
|
+
"chat",
|
|
1486
|
+
"image_chat",
|
|
1487
|
+
"image_generate",
|
|
1488
|
+
"image_create",
|
|
1489
|
+
"weather",
|
|
1490
|
+
"avatar_get",
|
|
1491
|
+
"travel_plan",
|
|
1492
|
+
"history_clear",
|
|
1493
|
+
"ignore",
|
|
1494
|
+
}
|
|
1495
|
+
_ALLOWED_TARGETS = {
|
|
1496
|
+
"message_image",
|
|
1497
|
+
"reply_image",
|
|
1498
|
+
"at_user",
|
|
1499
|
+
"last_image",
|
|
1500
|
+
"sender_avatar",
|
|
1501
|
+
"group_avatar",
|
|
1502
|
+
"qq_avatar",
|
|
1503
|
+
"message_id",
|
|
1504
|
+
"wait_next",
|
|
1505
|
+
"trip",
|
|
1506
|
+
"none",
|
|
1507
|
+
}
|
|
1508
|
+
|
|
1509
|
+
|
|
1510
|
+
def _intent_params(intent: Optional[dict]) -> dict[str, object]:
|
|
1511
|
+
if not isinstance(intent, dict):
|
|
1512
|
+
return {}
|
|
1513
|
+
raw_params = intent.get("params")
|
|
1514
|
+
return raw_params if isinstance(raw_params, dict) else {}
|
|
1515
|
+
|
|
1516
|
+
|
|
1517
|
+
def _normalize_intent(
|
|
1518
|
+
intent: Optional[dict],
|
|
1519
|
+
has_image: bool,
|
|
1520
|
+
has_reply_image: bool,
|
|
1521
|
+
at_user: Optional[str],
|
|
1522
|
+
state: SessionState,
|
|
1523
|
+
) -> Optional[dict]:
|
|
1524
|
+
if not isinstance(intent, dict):
|
|
1525
|
+
return None
|
|
1526
|
+
action = str(intent.get("action", "")).strip().lower()
|
|
1527
|
+
if action not in _ALLOWED_ACTIONS:
|
|
1528
|
+
return None
|
|
1529
|
+
if action == "ignore":
|
|
1530
|
+
return {"action": "ignore"}
|
|
1531
|
+
instruction = intent.get("instruction")
|
|
1532
|
+
if not isinstance(instruction, str) or not instruction.strip():
|
|
1533
|
+
return None
|
|
1534
|
+
params = _intent_params(intent)
|
|
1535
|
+
target = str(intent.get("target", "")).strip().lower()
|
|
1536
|
+
|
|
1537
|
+
if action == "image_create":
|
|
1538
|
+
return {
|
|
1539
|
+
"action": action,
|
|
1540
|
+
"instruction": instruction.strip(),
|
|
1541
|
+
"target": "none",
|
|
1542
|
+
"params": params,
|
|
1543
|
+
}
|
|
1544
|
+
|
|
1545
|
+
if action in {"image_chat", "image_generate"}:
|
|
1546
|
+
if target not in _ALLOWED_TARGETS:
|
|
1547
|
+
target = ""
|
|
1548
|
+
if not target or target == "none":
|
|
1549
|
+
if has_image:
|
|
1550
|
+
target = "message_image"
|
|
1551
|
+
elif has_reply_image:
|
|
1552
|
+
target = "reply_image"
|
|
1553
|
+
elif at_user:
|
|
1554
|
+
target = "at_user"
|
|
1555
|
+
elif state.last_image_url:
|
|
1556
|
+
target = "last_image"
|
|
1557
|
+
else:
|
|
1558
|
+
target = "wait_next"
|
|
1559
|
+
return {
|
|
1560
|
+
"action": action,
|
|
1561
|
+
"instruction": instruction.strip(),
|
|
1562
|
+
"target": target,
|
|
1563
|
+
"params": params,
|
|
1564
|
+
}
|
|
1565
|
+
|
|
1566
|
+
if action == "avatar_get":
|
|
1567
|
+
if target not in _ALLOWED_TARGETS:
|
|
1568
|
+
target = ""
|
|
1569
|
+
if not target or target == "none":
|
|
1570
|
+
target = "sender_avatar"
|
|
1571
|
+
return {
|
|
1572
|
+
"action": action,
|
|
1573
|
+
"instruction": instruction.strip(),
|
|
1574
|
+
"target": target,
|
|
1575
|
+
"params": params,
|
|
1576
|
+
}
|
|
1577
|
+
|
|
1578
|
+
if action == "weather":
|
|
1579
|
+
city = ""
|
|
1580
|
+
raw_city = params.get("city")
|
|
1581
|
+
if isinstance(raw_city, str):
|
|
1582
|
+
city = raw_city.strip()
|
|
1583
|
+
if not city and isinstance(instruction, str):
|
|
1584
|
+
city = instruction.strip()
|
|
1585
|
+
return {
|
|
1586
|
+
"action": action,
|
|
1587
|
+
"instruction": city,
|
|
1588
|
+
"target": "city",
|
|
1589
|
+
"params": {"city": city} if city else {},
|
|
1590
|
+
}
|
|
1591
|
+
|
|
1592
|
+
if action == "travel_plan":
|
|
1593
|
+
days = _coerce_int(params.get("days"))
|
|
1594
|
+
nights = _coerce_int(params.get("nights"))
|
|
1595
|
+
destination = ""
|
|
1596
|
+
raw_destination = params.get("destination") or params.get("city")
|
|
1597
|
+
if isinstance(raw_destination, str):
|
|
1598
|
+
destination = raw_destination.strip()
|
|
1599
|
+
if (days is None or nights is None) and isinstance(instruction, str):
|
|
1600
|
+
parsed_days, parsed_nights = _extract_travel_duration(instruction)
|
|
1601
|
+
if days is None:
|
|
1602
|
+
days = parsed_days
|
|
1603
|
+
if nights is None:
|
|
1604
|
+
nights = parsed_nights
|
|
1605
|
+
if not destination and isinstance(instruction, str):
|
|
1606
|
+
destination = _extract_travel_destination(instruction) or ""
|
|
1607
|
+
normalized_params: dict[str, object] = {}
|
|
1608
|
+
if days is not None:
|
|
1609
|
+
normalized_params["days"] = days
|
|
1610
|
+
if nights is not None:
|
|
1611
|
+
normalized_params["nights"] = nights
|
|
1612
|
+
if destination:
|
|
1613
|
+
normalized_params["destination"] = destination
|
|
1614
|
+
return {
|
|
1615
|
+
"action": action,
|
|
1616
|
+
"instruction": instruction.strip(),
|
|
1617
|
+
"target": "trip",
|
|
1618
|
+
"params": normalized_params,
|
|
1619
|
+
}
|
|
1620
|
+
|
|
1621
|
+
return {"action": action, "instruction": instruction.strip(), "params": params}
|
|
1622
|
+
|
|
1623
|
+
|
|
1624
|
+
async def _build_travel_plan_reply(
|
|
1625
|
+
intent: dict,
|
|
1626
|
+
state: SessionState,
|
|
1627
|
+
event: MessageEvent,
|
|
1628
|
+
) -> Optional[str]:
|
|
1629
|
+
params = _intent_params(intent)
|
|
1630
|
+
destination = params.get("destination")
|
|
1631
|
+
destination_text = destination.strip() if isinstance(destination, str) else ""
|
|
1632
|
+
if not destination_text:
|
|
1633
|
+
return "请告诉我目的地,例如:北京"
|
|
1634
|
+
normalized_params = dict(params)
|
|
1635
|
+
normalized_params["destination"] = destination_text
|
|
1636
|
+
intent = dict(intent)
|
|
1637
|
+
intent["params"] = normalized_params
|
|
1638
|
+
reply = await _call_gemini_travel_plan(intent, state)
|
|
1639
|
+
if not reply:
|
|
1640
|
+
return None
|
|
1641
|
+
instruction = str(intent.get("instruction") or "").strip()
|
|
1642
|
+
cleaned_instruction = _strip_travel_duration(instruction)
|
|
1643
|
+
summary = f"{destination_text}"
|
|
1644
|
+
if cleaned_instruction and cleaned_instruction not in summary:
|
|
1645
|
+
summary = f"{summary} 需求:{cleaned_instruction}"
|
|
1646
|
+
user_name = _event_user_name(event)
|
|
1647
|
+
_append_history(
|
|
1648
|
+
state,
|
|
1649
|
+
"user",
|
|
1650
|
+
f"旅行规划:{summary}",
|
|
1651
|
+
user_id=str(event.get_user_id()),
|
|
1652
|
+
user_name=user_name,
|
|
1653
|
+
to_bot=True,
|
|
1654
|
+
)
|
|
1655
|
+
_append_history(state, "model", reply)
|
|
1656
|
+
return reply
|
|
1657
|
+
|
|
1658
|
+
|
|
1659
|
+
async def _dispatch_intent(
|
|
1660
|
+
intent: dict,
|
|
1661
|
+
state: SessionState,
|
|
1662
|
+
event: MessageEvent,
|
|
1663
|
+
text: str,
|
|
1664
|
+
*,
|
|
1665
|
+
image_url: Optional[str],
|
|
1666
|
+
reply_image_url: Optional[str],
|
|
1667
|
+
at_user: Optional[str],
|
|
1668
|
+
send_func,
|
|
1669
|
+
) -> None:
|
|
1670
|
+
action = str(intent.get("action", "ignore")).lower()
|
|
1671
|
+
if action == "ignore":
|
|
1672
|
+
return
|
|
1673
|
+
user_name = _event_user_name(event)
|
|
1674
|
+
|
|
1675
|
+
if action == "chat":
|
|
1676
|
+
prompt = intent.get("instruction")
|
|
1677
|
+
try:
|
|
1678
|
+
reply = await _call_gemini_text(str(prompt), state)
|
|
1679
|
+
if not reply:
|
|
1680
|
+
return
|
|
1681
|
+
_append_history(
|
|
1682
|
+
state,
|
|
1683
|
+
"user",
|
|
1684
|
+
str(prompt),
|
|
1685
|
+
user_id=str(event.get_user_id()),
|
|
1686
|
+
user_name=user_name,
|
|
1687
|
+
to_bot=True,
|
|
1688
|
+
)
|
|
1689
|
+
_append_history(state, "model", reply)
|
|
1690
|
+
await send_func(reply)
|
|
1691
|
+
_mark_handled_request(state, event, text)
|
|
1692
|
+
except Exception as exc:
|
|
1693
|
+
logger.error("NLP chat failed: {}", _safe_error_message(exc))
|
|
1694
|
+
return
|
|
1695
|
+
|
|
1696
|
+
if action == "weather":
|
|
1697
|
+
query = str(intent.get("instruction") or "").strip()
|
|
1698
|
+
if not query:
|
|
1699
|
+
await send_func("请告诉我城市或地区,例如:天气 北京")
|
|
1700
|
+
return
|
|
1701
|
+
await _send_transition(action, send_func)
|
|
1702
|
+
try:
|
|
1703
|
+
messages = await _build_weather_messages(query)
|
|
1704
|
+
if not messages:
|
|
1705
|
+
return
|
|
1706
|
+
reply_text = "\n".join(messages)
|
|
1707
|
+
_append_history(
|
|
1708
|
+
state,
|
|
1709
|
+
"user",
|
|
1710
|
+
f"天气:{query}",
|
|
1711
|
+
user_id=str(event.get_user_id()),
|
|
1712
|
+
user_name=user_name,
|
|
1713
|
+
to_bot=True,
|
|
1714
|
+
)
|
|
1715
|
+
_append_history(state, "model", reply_text)
|
|
1716
|
+
for msg in messages:
|
|
1717
|
+
await send_func(msg)
|
|
1718
|
+
_mark_handled_request(state, event, text)
|
|
1719
|
+
except Exception as exc:
|
|
1720
|
+
logger.error("NLP weather failed: {}", _safe_error_message(exc))
|
|
1721
|
+
await send_func(f"出错了:{_safe_error_message(exc)}")
|
|
1722
|
+
return
|
|
1723
|
+
|
|
1724
|
+
if action == "travel_plan":
|
|
1725
|
+
params = _intent_params(intent)
|
|
1726
|
+
destination = params.get("destination")
|
|
1727
|
+
if not isinstance(destination, str) or not destination.strip():
|
|
1728
|
+
await send_func("请告诉我目的地,例如:北京")
|
|
1729
|
+
return
|
|
1730
|
+
await _send_transition(action, send_func)
|
|
1731
|
+
try:
|
|
1732
|
+
reply = await _build_travel_plan_reply(intent, state, event)
|
|
1733
|
+
if not reply:
|
|
1734
|
+
return
|
|
1735
|
+
await send_func(reply)
|
|
1736
|
+
_mark_handled_request(state, event, text)
|
|
1737
|
+
except Exception as exc:
|
|
1738
|
+
logger.error("NLP travel failed: {}", _safe_error_message(exc))
|
|
1739
|
+
await send_func(f"出错了:{_safe_error_message(exc)}")
|
|
1740
|
+
return
|
|
1741
|
+
|
|
1742
|
+
if action == "history_clear":
|
|
1743
|
+
_clear_session_state(state)
|
|
1744
|
+
await send_func("已清除当前会话记录,可以继续聊啦。")
|
|
1745
|
+
return
|
|
1746
|
+
|
|
1747
|
+
if action == "avatar_get":
|
|
1748
|
+
target = str(intent.get("target") or "").lower()
|
|
1749
|
+
params = _intent_params(intent)
|
|
1750
|
+
if target == "qq_avatar" and not params.get("qq"):
|
|
1751
|
+
await send_func("请提供 QQ 号。")
|
|
1752
|
+
return
|
|
1753
|
+
await _send_transition(action, send_func)
|
|
1754
|
+
image_url = await _resolve_image_url(
|
|
1755
|
+
intent,
|
|
1756
|
+
event=event,
|
|
1757
|
+
state=state,
|
|
1758
|
+
current_image_url=None,
|
|
1759
|
+
reply_image_url=None,
|
|
1760
|
+
at_user=at_user,
|
|
1761
|
+
)
|
|
1762
|
+
if not image_url:
|
|
1763
|
+
await send_func("未找到可用的头像。")
|
|
1764
|
+
return
|
|
1765
|
+
await send_func(_image_segment_from_result(image_url))
|
|
1766
|
+
_mark_handled_request(state, event, text)
|
|
1767
|
+
return
|
|
1768
|
+
|
|
1769
|
+
prompt = str(intent.get("instruction"))
|
|
1770
|
+
target = str(intent.get("target") or "").lower()
|
|
1771
|
+
params = _intent_params(intent)
|
|
1772
|
+
|
|
1773
|
+
if action == "image_create":
|
|
1774
|
+
transition_text = _intent_transition_text(intent)
|
|
1775
|
+
if transition_text:
|
|
1776
|
+
await send_func(transition_text)
|
|
1777
|
+
try:
|
|
1778
|
+
is_image, result = await _call_gemini_text_to_image(prompt, state)
|
|
1779
|
+
_append_history(
|
|
1780
|
+
state,
|
|
1781
|
+
"user",
|
|
1782
|
+
f"生成图片:{prompt}",
|
|
1783
|
+
user_id=str(event.get_user_id()),
|
|
1784
|
+
user_name=user_name,
|
|
1785
|
+
to_bot=True,
|
|
1786
|
+
)
|
|
1787
|
+
if is_image:
|
|
1788
|
+
_append_history(state, "model", "[已生成图片]")
|
|
1789
|
+
await send_func("已生成图片。")
|
|
1790
|
+
await send_func(_image_segment_from_result(result))
|
|
1791
|
+
_mark_handled_request(state, event, text)
|
|
1792
|
+
else:
|
|
1793
|
+
_append_history(state, "model", result)
|
|
1794
|
+
await send_func(f"生成结果:{result}")
|
|
1795
|
+
_mark_handled_request(state, event, text)
|
|
1796
|
+
except Exception as exc:
|
|
1797
|
+
logger.error("NLP image create failed: {}", _safe_error_message(exc))
|
|
1798
|
+
await send_func(f"出错了:{_safe_error_message(exc)}")
|
|
1799
|
+
return
|
|
1800
|
+
|
|
1801
|
+
if action not in {"image_chat", "image_generate"}:
|
|
1802
|
+
return
|
|
1803
|
+
|
|
1804
|
+
if target == "qq_avatar" and not params.get("qq"):
|
|
1805
|
+
await send_func("请提供 QQ 号。")
|
|
1806
|
+
return
|
|
1807
|
+
if target == "message_id" and not params.get("message_id"):
|
|
1808
|
+
await send_func("请提供消息 ID。")
|
|
1809
|
+
return
|
|
1810
|
+
if target == "wait_next":
|
|
1811
|
+
await send_func("请在60秒内发送图片。")
|
|
1812
|
+
|
|
1813
|
+
image_url = await _resolve_image_url(
|
|
1814
|
+
intent,
|
|
1815
|
+
event=event,
|
|
1816
|
+
state=state,
|
|
1817
|
+
current_image_url=image_url,
|
|
1818
|
+
reply_image_url=reply_image_url,
|
|
1819
|
+
at_user=at_user,
|
|
1820
|
+
)
|
|
1821
|
+
if not image_url:
|
|
1822
|
+
await send_func("未找到可处理的图片或头像。")
|
|
1823
|
+
return
|
|
1824
|
+
|
|
1825
|
+
if action == "image_chat":
|
|
1826
|
+
try:
|
|
1827
|
+
await _send_transition(action, send_func)
|
|
1828
|
+
reply = await _call_gemini_image_chat(prompt, image_url, state)
|
|
1829
|
+
if not reply:
|
|
1830
|
+
return
|
|
1831
|
+
_append_history(
|
|
1832
|
+
state,
|
|
1833
|
+
"user",
|
|
1834
|
+
f"聊图:{prompt}",
|
|
1835
|
+
user_id=str(event.get_user_id()),
|
|
1836
|
+
user_name=user_name,
|
|
1837
|
+
to_bot=True,
|
|
1838
|
+
)
|
|
1839
|
+
_append_history(state, "model", reply)
|
|
1840
|
+
await send_func(reply)
|
|
1841
|
+
_mark_handled_request(state, event, text)
|
|
1842
|
+
except UnsupportedImageError:
|
|
1843
|
+
await send_func("这个格式我处理不了,发张静态图吧。")
|
|
1844
|
+
except Exception as exc:
|
|
1845
|
+
logger.error("NLP image chat failed: {}", _safe_error_message(exc))
|
|
1846
|
+
await send_func(f"出错了:{_safe_error_message(exc)}")
|
|
1847
|
+
return
|
|
1848
|
+
|
|
1849
|
+
try:
|
|
1850
|
+
transition_text = _intent_transition_text(intent)
|
|
1851
|
+
if transition_text:
|
|
1852
|
+
await send_func(transition_text)
|
|
1853
|
+
is_image, result = await _call_gemini_image(prompt, image_url, state)
|
|
1854
|
+
_append_history(
|
|
1855
|
+
state,
|
|
1856
|
+
"user",
|
|
1857
|
+
f"处理头像:{prompt}",
|
|
1858
|
+
user_id=str(event.get_user_id()),
|
|
1859
|
+
user_name=user_name,
|
|
1860
|
+
to_bot=True,
|
|
1861
|
+
)
|
|
1862
|
+
if is_image:
|
|
1863
|
+
_append_history(state, "model", "[已生成图片]")
|
|
1864
|
+
await send_func("已完成修改。")
|
|
1865
|
+
await send_func(_image_segment_from_result(result))
|
|
1866
|
+
_mark_handled_request(state, event, text)
|
|
1867
|
+
else:
|
|
1868
|
+
_append_history(state, "model", result)
|
|
1869
|
+
await send_func(f"修改结果:{result}")
|
|
1870
|
+
_mark_handled_request(state, event, text)
|
|
1871
|
+
except UnsupportedImageError:
|
|
1872
|
+
await send_func("这个格式我处理不了,发张静态图吧。")
|
|
1873
|
+
except Exception as exc:
|
|
1874
|
+
logger.error("NLP image failed: {}", _safe_error_message(exc))
|
|
1875
|
+
await send_func(f"出错了:{_safe_error_message(exc)}")
|
|
1876
|
+
|
|
1877
|
+
|
|
1878
|
+
def _clarify_intent_text(has_image: bool) -> str:
|
|
1879
|
+
if has_image:
|
|
1880
|
+
return "我没太听懂,你是想聊这张图、处理图片、查天气还是旅行规划?"
|
|
1881
|
+
return "我没太听懂,你是想聊天、处理图片、无图生成、查天气、旅行规划还是清除历史?"
|
|
1882
|
+
|
|
1883
|
+
|
|
1884
|
+
_TRAVEL_KEYWORDS = ("旅行", "旅游", "行程", "出行", "游玩")
|
|
1885
|
+
_TRAVEL_WEAK_KEYWORDS = ("规划", "计划")
|
|
1886
|
+
_TRAVEL_DAYS_RE = re.compile(r"([0-9]{1,2}|[零一二三四五六七八九十两]{1,3})\s*天")
|
|
1887
|
+
_TRAVEL_NIGHTS_RE = re.compile(r"([0-9]{1,2}|[零一二三四五六七八九十两]{1,3})\s*(?:晚|夜)")
|
|
1888
|
+
_TRAVEL_DEST_RE = re.compile(r"(?:去|到|在)\s*([\u4e00-\u9fffA-Za-z0-9]{1,20})")
|
|
1889
|
+
|
|
1890
|
+
|
|
1891
|
+
def _chinese_number_to_int(value: str) -> Optional[int]:
|
|
1892
|
+
if not value:
|
|
1893
|
+
return None
|
|
1894
|
+
digits = {
|
|
1895
|
+
"零": 0,
|
|
1896
|
+
"一": 1,
|
|
1897
|
+
"二": 2,
|
|
1898
|
+
"两": 2,
|
|
1899
|
+
"三": 3,
|
|
1900
|
+
"四": 4,
|
|
1901
|
+
"五": 5,
|
|
1902
|
+
"六": 6,
|
|
1903
|
+
"七": 7,
|
|
1904
|
+
"八": 8,
|
|
1905
|
+
"九": 9,
|
|
1906
|
+
}
|
|
1907
|
+
if value.isdigit():
|
|
1908
|
+
return int(value)
|
|
1909
|
+
if value in digits:
|
|
1910
|
+
return digits[value]
|
|
1911
|
+
if value == "十":
|
|
1912
|
+
return 10
|
|
1913
|
+
if len(value) == 2 and value[0] == "十":
|
|
1914
|
+
tail = digits.get(value[1])
|
|
1915
|
+
return 10 + tail if tail is not None else None
|
|
1916
|
+
if len(value) == 2 and value[1] == "十":
|
|
1917
|
+
head = digits.get(value[0])
|
|
1918
|
+
return head * 10 if head is not None else None
|
|
1919
|
+
if len(value) == 3 and value[1] == "十":
|
|
1920
|
+
head = digits.get(value[0])
|
|
1921
|
+
tail = digits.get(value[2])
|
|
1922
|
+
if head is None or tail is None:
|
|
1923
|
+
return None
|
|
1924
|
+
return head * 10 + tail
|
|
1925
|
+
return None
|
|
1926
|
+
|
|
1927
|
+
|
|
1928
|
+
def _parse_travel_number(token: str) -> Optional[int]:
|
|
1929
|
+
if not token:
|
|
1930
|
+
return None
|
|
1931
|
+
number = _coerce_int(token)
|
|
1932
|
+
if number is not None:
|
|
1933
|
+
return number
|
|
1934
|
+
return _chinese_number_to_int(token)
|
|
1935
|
+
|
|
1936
|
+
|
|
1937
|
+
def _extract_travel_duration(text: str) -> Tuple[Optional[int], Optional[int]]:
|
|
1938
|
+
days = None
|
|
1939
|
+
nights = None
|
|
1940
|
+
if not text:
|
|
1941
|
+
return days, nights
|
|
1942
|
+
day_match = _TRAVEL_DAYS_RE.search(text)
|
|
1943
|
+
if day_match:
|
|
1944
|
+
days = _parse_travel_number(day_match.group(1))
|
|
1945
|
+
night_match = _TRAVEL_NIGHTS_RE.search(text)
|
|
1946
|
+
if night_match:
|
|
1947
|
+
nights = _parse_travel_number(night_match.group(1))
|
|
1948
|
+
return days, nights
|
|
1949
|
+
|
|
1950
|
+
|
|
1951
|
+
def _extract_travel_destination(text: str) -> Optional[str]:
|
|
1952
|
+
if not text:
|
|
1953
|
+
return None
|
|
1954
|
+
match = _TRAVEL_DEST_RE.search(text)
|
|
1955
|
+
if match:
|
|
1956
|
+
return match.group(1).strip()
|
|
1957
|
+
cleaned = _TRAVEL_DAYS_RE.sub("", text)
|
|
1958
|
+
cleaned = _TRAVEL_NIGHTS_RE.sub("", cleaned)
|
|
1959
|
+
cleaned = re.sub(r"[,,。.!!??/]", " ", cleaned)
|
|
1960
|
+
for kw in _TRAVEL_KEYWORDS:
|
|
1961
|
+
cleaned = cleaned.replace(kw, " ")
|
|
1962
|
+
for kw in _TRAVEL_WEAK_KEYWORDS:
|
|
1963
|
+
cleaned = cleaned.replace(kw, " ")
|
|
1964
|
+
cleaned = cleaned.replace("去", " ").replace("到", " ").replace("在", " ")
|
|
1965
|
+
cleaned = _collapse_spaces(cleaned)
|
|
1966
|
+
return cleaned or None
|
|
1967
|
+
|
|
1968
|
+
|
|
1969
|
+
def _strip_travel_duration(text: str) -> str:
|
|
1970
|
+
if not text:
|
|
1971
|
+
return ""
|
|
1972
|
+
cleaned = _TRAVEL_DAYS_RE.sub("", text)
|
|
1973
|
+
cleaned = _TRAVEL_NIGHTS_RE.sub("", cleaned)
|
|
1974
|
+
return _collapse_spaces(cleaned)
|
|
1975
|
+
|
|
1976
|
+
|
|
1977
|
+
def _extract_json(text: str) -> Optional[dict]:
|
|
1978
|
+
text = text.strip()
|
|
1979
|
+
if not text:
|
|
1980
|
+
return None
|
|
1981
|
+
try:
|
|
1982
|
+
return json.loads(text)
|
|
1983
|
+
except Exception:
|
|
1984
|
+
pass
|
|
1985
|
+
start = text.find("{")
|
|
1986
|
+
end = text.rfind("}")
|
|
1987
|
+
if start != -1 and end != -1 and end > start:
|
|
1988
|
+
snippet = text[start : end + 1]
|
|
1989
|
+
try:
|
|
1990
|
+
return json.loads(snippet)
|
|
1991
|
+
except Exception:
|
|
1992
|
+
return None
|
|
1993
|
+
return None
|
|
1994
|
+
|
|
1995
|
+
|
|
1996
|
+
async def _classify_intent(
|
|
1997
|
+
text: str,
|
|
1998
|
+
state: SessionState,
|
|
1999
|
+
has_image: bool,
|
|
2000
|
+
has_reply_image: bool,
|
|
2001
|
+
at_user: Optional[str],
|
|
2002
|
+
) -> Optional[dict]:
|
|
2003
|
+
if not config.google_api_key:
|
|
2004
|
+
return None
|
|
2005
|
+
client = _get_client()
|
|
2006
|
+
system = _INTENT_SYSTEM_PROMPT
|
|
2007
|
+
user_prompt = (
|
|
2008
|
+
f"文本: {text}\n"
|
|
2009
|
+
f"消息包含图片: {has_image}\n"
|
|
2010
|
+
f"回复里有图片: {has_reply_image}\n"
|
|
2011
|
+
f"是否@用户: {bool(at_user)}\n"
|
|
2012
|
+
f"是否有最近图片: {bool(state.last_image_url)}\n"
|
|
2013
|
+
)
|
|
2014
|
+
config_obj, system_used = _build_generate_config(
|
|
2015
|
+
system_instruction=system,
|
|
2016
|
+
response_mime_type="application/json",
|
|
2017
|
+
)
|
|
2018
|
+
if system and not system_used:
|
|
2019
|
+
user_prompt = f"{system}\n\n{user_prompt}"
|
|
2020
|
+
response = await asyncio.wait_for(
|
|
2021
|
+
client.aio.models.generate_content(
|
|
2022
|
+
model=config.gemini_text_model,
|
|
2023
|
+
contents=[types.Content(role="user", parts=[types.Part.from_text(text=user_prompt)])],
|
|
2024
|
+
config=config_obj,
|
|
2025
|
+
),
|
|
2026
|
+
timeout=config.request_timeout,
|
|
2027
|
+
)
|
|
2028
|
+
if config.gemini_log_response:
|
|
2029
|
+
logger.info("Gemini intent response: {}", _dump_response(response))
|
|
2030
|
+
_log_response_text("Gemini intent content", response)
|
|
2031
|
+
payload = _extract_json(response.text or "")
|
|
2032
|
+
return payload
|
|
2033
|
+
|
|
2034
|
+
|
|
2035
|
+
@nlp_handler.handle()
|
|
2036
|
+
async def _handle_natural_language(bot: Bot, event: MessageEvent):
|
|
2037
|
+
if not config.nlp_enable:
|
|
2038
|
+
return
|
|
2039
|
+
text = event.get_plaintext().strip()
|
|
2040
|
+
if not text:
|
|
2041
|
+
return
|
|
2042
|
+
if _is_command_message(text):
|
|
2043
|
+
return
|
|
2044
|
+
if str(event.get_user_id()) == str(event.self_id):
|
|
2045
|
+
return
|
|
2046
|
+
if not _should_trigger_nlp(event, text):
|
|
2047
|
+
return
|
|
2048
|
+
if not config.google_api_key:
|
|
2049
|
+
return
|
|
2050
|
+
|
|
2051
|
+
session_id = _session_id(event)
|
|
2052
|
+
state = _get_state(session_id)
|
|
2053
|
+
if _is_duplicate_request(state, event, text):
|
|
2054
|
+
return
|
|
2055
|
+
image_url = _extract_first_image_url(event.get_message())
|
|
2056
|
+
at_user = _extract_at_user(event.get_message())
|
|
2057
|
+
reply_image_url = _extract_reply_image_url(event, state)
|
|
2058
|
+
has_image = image_url is not None
|
|
2059
|
+
has_reply_image = reply_image_url is not None
|
|
2060
|
+
|
|
2061
|
+
try:
|
|
2062
|
+
primary_text = _build_primary_intent_text(event, state, text)
|
|
2063
|
+
intent_raw = await _classify_intent(
|
|
2064
|
+
primary_text, state, has_image, has_reply_image, at_user
|
|
2065
|
+
)
|
|
2066
|
+
except Exception as exc:
|
|
2067
|
+
logger.error("Intent classify failed: {}", _safe_error_message(exc))
|
|
2068
|
+
return
|
|
2069
|
+
|
|
2070
|
+
intent = _normalize_intent(intent_raw, has_image, has_reply_image, at_user, state)
|
|
2071
|
+
if not intent:
|
|
2072
|
+
try:
|
|
2073
|
+
intent_text = await _build_intent_text(event, state, text)
|
|
2074
|
+
if intent_text and intent_text != primary_text:
|
|
2075
|
+
intent_raw = await _classify_intent(
|
|
2076
|
+
intent_text, state, has_image, has_reply_image, at_user
|
|
2077
|
+
)
|
|
2078
|
+
intent = _normalize_intent(
|
|
2079
|
+
intent_raw, has_image, has_reply_image, at_user, state
|
|
2080
|
+
)
|
|
2081
|
+
except Exception as exc:
|
|
2082
|
+
logger.error("Intent classify failed: {}", _safe_error_message(exc))
|
|
2083
|
+
return
|
|
2084
|
+
if not intent:
|
|
2085
|
+
await nlp_handler.send(_clarify_intent_text(has_image))
|
|
2086
|
+
return
|
|
2087
|
+
|
|
2088
|
+
reply = getattr(event, "reply", None)
|
|
2089
|
+
reply_id = getattr(reply, "message_id", None) if reply else None
|
|
2090
|
+
if reply_id is not None and isinstance(intent.get("params"), dict):
|
|
2091
|
+
intent["params"].setdefault("message_id", reply_id)
|
|
2092
|
+
await _dispatch_intent(
|
|
2093
|
+
intent,
|
|
2094
|
+
state,
|
|
2095
|
+
event,
|
|
2096
|
+
text,
|
|
2097
|
+
image_url=image_url,
|
|
2098
|
+
reply_image_url=reply_image_url,
|
|
2099
|
+
at_user=at_user,
|
|
2100
|
+
send_func=nlp_handler.send,
|
|
2101
|
+
)
|
|
2102
|
+
|
|
2103
|
+
|
|
2104
|
+
async def _handle_command_via_intent(
|
|
2105
|
+
event: MessageEvent,
|
|
2106
|
+
*,
|
|
2107
|
+
text: str,
|
|
2108
|
+
send_func,
|
|
2109
|
+
) -> None:
|
|
2110
|
+
if not config.google_api_key:
|
|
2111
|
+
await send_func("未配置 GOOGLE_API_KEY")
|
|
2112
|
+
return
|
|
2113
|
+
session_id = _session_id(event)
|
|
2114
|
+
state = _get_state(session_id)
|
|
2115
|
+
image_url = _extract_first_image_url(event.get_message())
|
|
2116
|
+
at_user = _extract_at_user(event.get_message())
|
|
2117
|
+
reply_image_url = _extract_reply_image_url(event, state)
|
|
2118
|
+
has_image = image_url is not None
|
|
2119
|
+
has_reply_image = reply_image_url is not None
|
|
2120
|
+
try:
|
|
2121
|
+
intent_raw = await _classify_intent(
|
|
2122
|
+
text, state, has_image, has_reply_image, at_user
|
|
2123
|
+
)
|
|
2124
|
+
except Exception as exc:
|
|
2125
|
+
logger.error("Intent classify failed: {}", _safe_error_message(exc))
|
|
2126
|
+
await send_func("意图解析失败,请稍后再试。")
|
|
2127
|
+
return
|
|
2128
|
+
intent = _normalize_intent(intent_raw, has_image, has_reply_image, at_user, state)
|
|
2129
|
+
if not intent:
|
|
2130
|
+
await send_func(_clarify_intent_text(has_image))
|
|
2131
|
+
return
|
|
2132
|
+
reply = getattr(event, "reply", None)
|
|
2133
|
+
reply_id = getattr(reply, "message_id", None) if reply else None
|
|
2134
|
+
if reply_id is not None and isinstance(intent.get("params"), dict):
|
|
2135
|
+
intent["params"].setdefault("message_id", reply_id)
|
|
2136
|
+
await _dispatch_intent(
|
|
2137
|
+
intent,
|
|
2138
|
+
state,
|
|
2139
|
+
event,
|
|
2140
|
+
text,
|
|
2141
|
+
image_url=image_url,
|
|
2142
|
+
reply_image_url=reply_image_url,
|
|
2143
|
+
at_user=at_user,
|
|
2144
|
+
send_func=send_func,
|
|
2145
|
+
)
|
|
2146
|
+
|
|
2147
|
+
|
|
2148
|
+
@avatar_handler.handle()
|
|
2149
|
+
async def handle_avatar(bot: Bot, event: MessageEvent, args: Message = CommandArg()):
|
|
2150
|
+
prompt = args.extract_plain_text().strip()
|
|
2151
|
+
if not prompt:
|
|
2152
|
+
await avatar_handler.finish("请告诉我你想怎么处理头像,例如:处理头像 变成赛博朋克风")
|
|
2153
|
+
await _handle_command_via_intent(
|
|
2154
|
+
event,
|
|
2155
|
+
text=f"处理头像 {prompt}",
|
|
2156
|
+
send_func=avatar_handler.send,
|
|
2157
|
+
)
|
|
2158
|
+
|
|
2159
|
+
|
|
2160
|
+
@chat_handler.handle()
|
|
2161
|
+
async def handle_chat(bot: Bot, event: MessageEvent, args: Message = CommandArg()):
|
|
2162
|
+
prompt = args.extract_plain_text().strip()
|
|
2163
|
+
if not prompt:
|
|
2164
|
+
await chat_handler.finish("请发送要聊天的内容,例如:聊天 你好")
|
|
2165
|
+
await _handle_command_via_intent(
|
|
2166
|
+
event,
|
|
2167
|
+
text=f"聊天 {prompt}",
|
|
2168
|
+
send_func=chat_handler.send,
|
|
2169
|
+
)
|
|
2170
|
+
|
|
2171
|
+
|
|
2172
|
+
@weather_handler.handle()
|
|
2173
|
+
async def handle_weather(bot: Bot, event: MessageEvent, args: Message = CommandArg()):
|
|
2174
|
+
query = args.extract_plain_text().strip()
|
|
2175
|
+
if not query:
|
|
2176
|
+
await weather_handler.finish("请提供城市或地区,例如:天气 北京")
|
|
2177
|
+
await _handle_command_via_intent(
|
|
2178
|
+
event,
|
|
2179
|
+
text=f"天气 {query}",
|
|
2180
|
+
send_func=weather_handler.send,
|
|
2181
|
+
)
|
|
2182
|
+
|
|
2183
|
+
|
|
2184
|
+
@travel_handler.handle()
|
|
2185
|
+
async def handle_travel(bot: Bot, event: MessageEvent, args: Message = CommandArg()):
|
|
2186
|
+
text = args.extract_plain_text().strip()
|
|
2187
|
+
if not text:
|
|
2188
|
+
await travel_handler.finish("请提供行程需求,例如:旅行规划 3天2晚 北京")
|
|
2189
|
+
await _handle_command_via_intent(
|
|
2190
|
+
event,
|
|
2191
|
+
text=f"旅行规划 {text}",
|
|
2192
|
+
send_func=travel_handler.send,
|
|
2193
|
+
)
|