wecom-aibot-python-sdk 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aibot/__init__.py +50 -0
- aibot/api.py +74 -0
- aibot/client.py +362 -0
- aibot/crypto_utils.py +73 -0
- aibot/logger.py +47 -0
- aibot/message_handler.py +89 -0
- aibot/types.py +170 -0
- aibot/utils.py +32 -0
- aibot/ws.py +574 -0
- wecom_aibot_python_sdk-1.0.0.dist-info/METADATA +365 -0
- wecom_aibot_python_sdk-1.0.0.dist-info/RECORD +14 -0
- wecom_aibot_python_sdk-1.0.0.dist-info/WHEEL +5 -0
- wecom_aibot_python_sdk-1.0.0.dist-info/licenses/LICENSE +21 -0
- wecom_aibot_python_sdk-1.0.0.dist-info/top_level.txt +1 -0
aibot/ws.py
ADDED
|
@@ -0,0 +1,574 @@
|
|
|
1
|
+
"""
|
|
2
|
+
WebSocket 长连接管理器
|
|
3
|
+
|
|
4
|
+
对标 Node.js SDK src/ws.ts
|
|
5
|
+
负责维护与企业微信的 WebSocket 长连接,包括心跳、重连、认证、串行回复队列等。
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import asyncio
|
|
9
|
+
import json
|
|
10
|
+
import ssl
|
|
11
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
import certifi
|
|
15
|
+
_SSL_CONTEXT = ssl.create_default_context(cafile=certifi.where())
|
|
16
|
+
except ImportError:
|
|
17
|
+
# 未安装 certifi 时回退到系统默认证书
|
|
18
|
+
_SSL_CONTEXT = ssl.create_default_context()
|
|
19
|
+
|
|
20
|
+
from .types import WsCmd, WsFrame
|
|
21
|
+
from .utils import generate_req_id
|
|
22
|
+
|
|
23
|
+
# SDK 内置默认 WebSocket 连接地址
|
|
24
|
+
DEFAULT_WS_URL = "wss://openws.work.weixin.qq.com"
|
|
25
|
+
|
|
26
|
+
try:
|
|
27
|
+
import websockets
|
|
28
|
+
from websockets.client import WebSocketClientProtocol
|
|
29
|
+
|
|
30
|
+
def _ws_is_open(ws) -> bool:
|
|
31
|
+
"""兼容 websockets 新旧版本的连接状态判断"""
|
|
32
|
+
if hasattr(ws, 'open'):
|
|
33
|
+
# websockets <= 13.x
|
|
34
|
+
return ws.open
|
|
35
|
+
elif hasattr(ws, 'state'):
|
|
36
|
+
# websockets >= 14.x
|
|
37
|
+
try:
|
|
38
|
+
from websockets.protocol import State
|
|
39
|
+
return ws.state is State.OPEN
|
|
40
|
+
except ImportError:
|
|
41
|
+
return ws.state.name == 'OPEN'
|
|
42
|
+
return False
|
|
43
|
+
except ImportError:
|
|
44
|
+
raise ImportError("请安装 websockets: pip install websockets>=12.0")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class _ReplyQueueItem:
|
|
48
|
+
"""回复队列中的单个任务项"""
|
|
49
|
+
|
|
50
|
+
__slots__ = ("frame", "future")
|
|
51
|
+
|
|
52
|
+
def __init__(self, frame: WsFrame, future: "asyncio.Future[WsFrame]"):
|
|
53
|
+
self.frame = frame
|
|
54
|
+
self.future = future
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class WsConnectionManager:
|
|
58
|
+
"""
|
|
59
|
+
WebSocket 长连接管理器
|
|
60
|
+
|
|
61
|
+
负责维护与企业微信的 WebSocket 长连接,包括心跳、重连、认证等。
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
def __init__(
|
|
65
|
+
self,
|
|
66
|
+
logger: Any,
|
|
67
|
+
heartbeat_interval: int = 30000,
|
|
68
|
+
reconnect_base_delay: int = 1000,
|
|
69
|
+
max_reconnect_attempts: int = 10,
|
|
70
|
+
ws_url: Optional[str] = None,
|
|
71
|
+
):
|
|
72
|
+
self._logger = logger
|
|
73
|
+
self._ws_url = ws_url or DEFAULT_WS_URL
|
|
74
|
+
self._heartbeat_interval = heartbeat_interval
|
|
75
|
+
self._reconnect_base_delay = reconnect_base_delay
|
|
76
|
+
self._max_reconnect_attempts = max_reconnect_attempts
|
|
77
|
+
|
|
78
|
+
self._ws: Optional[WebSocketClientProtocol] = None
|
|
79
|
+
self._heartbeat_task: Optional[asyncio.Task[None]] = None
|
|
80
|
+
self._receive_task: Optional[asyncio.Task[None]] = None
|
|
81
|
+
self._reconnect_attempts: int = 0
|
|
82
|
+
self._is_manual_close: bool = False
|
|
83
|
+
|
|
84
|
+
# 认证凭证
|
|
85
|
+
self._bot_id: str = ""
|
|
86
|
+
self._bot_secret: str = ""
|
|
87
|
+
|
|
88
|
+
# 心跳相关
|
|
89
|
+
self._missed_pong_count: int = 0
|
|
90
|
+
self._max_missed_pong: int = 2
|
|
91
|
+
self._reconnect_max_delay: int = 30000
|
|
92
|
+
|
|
93
|
+
# 串行回复队列
|
|
94
|
+
self._reply_queues: Dict[str, List[_ReplyQueueItem]] = {}
|
|
95
|
+
self._pending_acks: Dict[
|
|
96
|
+
str,
|
|
97
|
+
Tuple["asyncio.Future[WsFrame]", Optional[asyncio.TimerHandle]],
|
|
98
|
+
] = {}
|
|
99
|
+
self._reply_ack_timeout: float = 5.0 # 秒
|
|
100
|
+
self._max_reply_queue_size: int = 100
|
|
101
|
+
self._processing_queues: set = set() # 正在处理的 req_id 集合
|
|
102
|
+
|
|
103
|
+
# 回调
|
|
104
|
+
self.on_connected: Optional[Callable[[], None]] = None
|
|
105
|
+
self.on_authenticated: Optional[Callable[[], None]] = None
|
|
106
|
+
self.on_disconnected: Optional[Callable[[str], None]] = None
|
|
107
|
+
self.on_message: Optional[Callable[[WsFrame], None]] = None
|
|
108
|
+
self.on_reconnecting: Optional[Callable[[int], None]] = None
|
|
109
|
+
self.on_error: Optional[Callable[[Exception], None]] = None
|
|
110
|
+
|
|
111
|
+
def set_credentials(self, bot_id: str, bot_secret: str) -> None:
|
|
112
|
+
"""设置认证凭证"""
|
|
113
|
+
self._bot_id = bot_id
|
|
114
|
+
self._bot_secret = bot_secret
|
|
115
|
+
|
|
116
|
+
async def connect(self) -> None:
|
|
117
|
+
"""建立 WebSocket 连接"""
|
|
118
|
+
self._is_manual_close = False
|
|
119
|
+
|
|
120
|
+
# 清理旧连接
|
|
121
|
+
await self._cleanup_ws()
|
|
122
|
+
|
|
123
|
+
self._logger.info(f"Connecting to WebSocket: {self._ws_url}...")
|
|
124
|
+
|
|
125
|
+
try:
|
|
126
|
+
self._ws = await websockets.connect(
|
|
127
|
+
self._ws_url,
|
|
128
|
+
ssl=_SSL_CONTEXT,
|
|
129
|
+
ping_interval=None, # 我们自己管理心跳
|
|
130
|
+
ping_timeout=None,
|
|
131
|
+
close_timeout=5,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
self._reconnect_attempts = 0
|
|
135
|
+
self._missed_pong_count = 0
|
|
136
|
+
|
|
137
|
+
self._logger.info("WebSocket connection established, sending auth...")
|
|
138
|
+
|
|
139
|
+
# 连接建立回调
|
|
140
|
+
if self.on_connected:
|
|
141
|
+
self.on_connected()
|
|
142
|
+
|
|
143
|
+
# 发送认证帧
|
|
144
|
+
await self._send_auth()
|
|
145
|
+
|
|
146
|
+
# 启动消息接收循环
|
|
147
|
+
self._receive_task = asyncio.ensure_future(self._receive_loop())
|
|
148
|
+
|
|
149
|
+
except Exception as e:
|
|
150
|
+
self._logger.error(f"Failed to create WebSocket connection: {e}")
|
|
151
|
+
if self.on_error:
|
|
152
|
+
self.on_error(e)
|
|
153
|
+
await self._schedule_reconnect()
|
|
154
|
+
|
|
155
|
+
async def _cleanup_ws(self) -> None:
|
|
156
|
+
"""清理 WebSocket 连接"""
|
|
157
|
+
if self._receive_task and not self._receive_task.done():
|
|
158
|
+
self._receive_task.cancel()
|
|
159
|
+
try:
|
|
160
|
+
await self._receive_task
|
|
161
|
+
except (asyncio.CancelledError, Exception):
|
|
162
|
+
pass
|
|
163
|
+
self._receive_task = None
|
|
164
|
+
|
|
165
|
+
if self._ws:
|
|
166
|
+
try:
|
|
167
|
+
await self._ws.close()
|
|
168
|
+
except Exception:
|
|
169
|
+
pass
|
|
170
|
+
self._ws = None
|
|
171
|
+
|
|
172
|
+
async def _send_auth(self) -> None:
|
|
173
|
+
"""发送认证帧"""
|
|
174
|
+
try:
|
|
175
|
+
await self.send(
|
|
176
|
+
{
|
|
177
|
+
"cmd": WsCmd.SUBSCRIBE,
|
|
178
|
+
"headers": {"req_id": generate_req_id(WsCmd.SUBSCRIBE)},
|
|
179
|
+
"body": {
|
|
180
|
+
"bot_id": self._bot_id,
|
|
181
|
+
"secret": self._bot_secret,
|
|
182
|
+
},
|
|
183
|
+
}
|
|
184
|
+
)
|
|
185
|
+
self._logger.info("Auth frame sent")
|
|
186
|
+
except Exception as e:
|
|
187
|
+
self._logger.error(f"Failed to send auth frame: {e}")
|
|
188
|
+
|
|
189
|
+
async def _receive_loop(self) -> None:
|
|
190
|
+
"""消息接收循环"""
|
|
191
|
+
try:
|
|
192
|
+
async for raw_message in self._ws: # type: ignore
|
|
193
|
+
try:
|
|
194
|
+
if isinstance(raw_message, bytes):
|
|
195
|
+
raw_message = raw_message.decode("utf-8")
|
|
196
|
+
frame: WsFrame = json.loads(raw_message)
|
|
197
|
+
self._handle_frame(frame)
|
|
198
|
+
except json.JSONDecodeError as e:
|
|
199
|
+
self._logger.error(f"Failed to parse WebSocket message: {e}")
|
|
200
|
+
except websockets.exceptions.ConnectionClosed as e:
|
|
201
|
+
reason_str = str(e) or f"code: {e.code}"
|
|
202
|
+
self._logger.warn(f"WebSocket connection closed: {reason_str}")
|
|
203
|
+
self._stop_heartbeat()
|
|
204
|
+
self._clear_pending_messages(f"WebSocket connection closed ({reason_str})")
|
|
205
|
+
if self.on_disconnected:
|
|
206
|
+
self.on_disconnected(reason_str)
|
|
207
|
+
if not self._is_manual_close:
|
|
208
|
+
await self._schedule_reconnect()
|
|
209
|
+
except asyncio.CancelledError:
|
|
210
|
+
pass
|
|
211
|
+
except Exception as e:
|
|
212
|
+
self._logger.error(f"WebSocket error: {e}")
|
|
213
|
+
if self.on_error:
|
|
214
|
+
self.on_error(e)
|
|
215
|
+
|
|
216
|
+
def _handle_frame(self, frame: WsFrame) -> None:
|
|
217
|
+
"""处理收到的帧数据"""
|
|
218
|
+
cmd = frame.get("cmd")
|
|
219
|
+
|
|
220
|
+
# 消息推送
|
|
221
|
+
if cmd == WsCmd.CALLBACK:
|
|
222
|
+
self._logger.debug(f"Received push message: {json.dumps(frame.get('body', {}))}")
|
|
223
|
+
if self.on_message:
|
|
224
|
+
self.on_message(frame)
|
|
225
|
+
return
|
|
226
|
+
|
|
227
|
+
# 事件推送
|
|
228
|
+
if cmd == WsCmd.EVENT_CALLBACK:
|
|
229
|
+
self._logger.debug(f"Received event callback: {json.dumps(frame.get('body', {}))}")
|
|
230
|
+
if self.on_message:
|
|
231
|
+
self.on_message(frame)
|
|
232
|
+
return
|
|
233
|
+
|
|
234
|
+
# 无 cmd 的帧:认证响应、心跳响应或回复消息回执
|
|
235
|
+
headers = frame.get("headers", {})
|
|
236
|
+
req_id = headers.get("req_id", "")
|
|
237
|
+
|
|
238
|
+
# 检查是否是回复消息的回执
|
|
239
|
+
if req_id in self._pending_acks:
|
|
240
|
+
self._handle_reply_ack(req_id, frame)
|
|
241
|
+
return
|
|
242
|
+
|
|
243
|
+
if req_id.startswith(WsCmd.SUBSCRIBE):
|
|
244
|
+
# 认证响应
|
|
245
|
+
errcode = frame.get("errcode")
|
|
246
|
+
if errcode != 0:
|
|
247
|
+
self._logger.error(
|
|
248
|
+
f"Authentication failed: errcode={errcode}, errmsg={frame.get('errmsg')}"
|
|
249
|
+
)
|
|
250
|
+
if self.on_error:
|
|
251
|
+
self.on_error(
|
|
252
|
+
Exception(
|
|
253
|
+
f"Authentication failed: {frame.get('errmsg')} (code: {errcode})"
|
|
254
|
+
)
|
|
255
|
+
)
|
|
256
|
+
return
|
|
257
|
+
self._logger.info("Authentication successful")
|
|
258
|
+
self._start_heartbeat()
|
|
259
|
+
if self.on_authenticated:
|
|
260
|
+
self.on_authenticated()
|
|
261
|
+
return
|
|
262
|
+
|
|
263
|
+
if req_id.startswith(WsCmd.HEARTBEAT):
|
|
264
|
+
# 心跳响应
|
|
265
|
+
errcode = frame.get("errcode")
|
|
266
|
+
if errcode != 0:
|
|
267
|
+
self._logger.warn(
|
|
268
|
+
f"Heartbeat ack error: errcode={errcode}, errmsg={frame.get('errmsg')}"
|
|
269
|
+
)
|
|
270
|
+
return
|
|
271
|
+
self._missed_pong_count = 0
|
|
272
|
+
self._logger.debug("Received heartbeat ack")
|
|
273
|
+
return
|
|
274
|
+
|
|
275
|
+
# 未知帧类型
|
|
276
|
+
self._logger.warn(f"Received unknown frame: {json.dumps(frame)}")
|
|
277
|
+
if self.on_message:
|
|
278
|
+
self.on_message(frame)
|
|
279
|
+
|
|
280
|
+
def _start_heartbeat(self) -> None:
|
|
281
|
+
"""启动心跳定时器"""
|
|
282
|
+
self._stop_heartbeat()
|
|
283
|
+
self._heartbeat_task = asyncio.ensure_future(self._heartbeat_loop())
|
|
284
|
+
self._logger.debug(
|
|
285
|
+
f"Heartbeat timer started, interval: {self._heartbeat_interval}ms"
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
def _stop_heartbeat(self) -> None:
|
|
289
|
+
"""停止心跳定时器"""
|
|
290
|
+
if self._heartbeat_task and not self._heartbeat_task.done():
|
|
291
|
+
self._heartbeat_task.cancel()
|
|
292
|
+
self._heartbeat_task = None
|
|
293
|
+
self._logger.debug("Heartbeat timer stopped")
|
|
294
|
+
|
|
295
|
+
async def _heartbeat_loop(self) -> None:
|
|
296
|
+
"""心跳循环"""
|
|
297
|
+
try:
|
|
298
|
+
while True:
|
|
299
|
+
await asyncio.sleep(self._heartbeat_interval / 1000)
|
|
300
|
+
await self._send_heartbeat()
|
|
301
|
+
except asyncio.CancelledError:
|
|
302
|
+
pass
|
|
303
|
+
|
|
304
|
+
async def _send_heartbeat(self) -> None:
|
|
305
|
+
"""发送心跳"""
|
|
306
|
+
# 检查连续未收到 pong 的次数
|
|
307
|
+
if self._missed_pong_count >= self._max_missed_pong:
|
|
308
|
+
self._logger.warn(
|
|
309
|
+
f"No heartbeat ack received for {self._missed_pong_count} consecutive pings, "
|
|
310
|
+
"connection considered dead"
|
|
311
|
+
)
|
|
312
|
+
self._stop_heartbeat()
|
|
313
|
+
# 强制关闭底层连接
|
|
314
|
+
if self._ws:
|
|
315
|
+
try:
|
|
316
|
+
await self._ws.close()
|
|
317
|
+
except Exception:
|
|
318
|
+
pass
|
|
319
|
+
return
|
|
320
|
+
|
|
321
|
+
self._missed_pong_count += 1
|
|
322
|
+
try:
|
|
323
|
+
await self.send(
|
|
324
|
+
{
|
|
325
|
+
"cmd": WsCmd.HEARTBEAT,
|
|
326
|
+
"headers": {"req_id": generate_req_id(WsCmd.HEARTBEAT)},
|
|
327
|
+
}
|
|
328
|
+
)
|
|
329
|
+
extra = (
|
|
330
|
+
f" (awaiting {self._missed_pong_count} pong)"
|
|
331
|
+
if self._missed_pong_count > 1
|
|
332
|
+
else ""
|
|
333
|
+
)
|
|
334
|
+
self._logger.debug(f"Heartbeat sent{extra}")
|
|
335
|
+
except Exception as e:
|
|
336
|
+
self._logger.error(f"Failed to send heartbeat: {e}")
|
|
337
|
+
|
|
338
|
+
async def _schedule_reconnect(self) -> None:
|
|
339
|
+
"""安排重连"""
|
|
340
|
+
if (
|
|
341
|
+
self._max_reconnect_attempts != -1
|
|
342
|
+
and self._reconnect_attempts >= self._max_reconnect_attempts
|
|
343
|
+
):
|
|
344
|
+
self._logger.error(
|
|
345
|
+
f"Max reconnect attempts reached ({self._max_reconnect_attempts}), giving up"
|
|
346
|
+
)
|
|
347
|
+
if self.on_error:
|
|
348
|
+
self.on_error(Exception("Max reconnect attempts exceeded"))
|
|
349
|
+
return
|
|
350
|
+
|
|
351
|
+
self._reconnect_attempts += 1
|
|
352
|
+
# 指数退避:1s, 2s, 4s, 8s … 上限 30s
|
|
353
|
+
delay = min(
|
|
354
|
+
self._reconnect_base_delay * (2 ** (self._reconnect_attempts - 1)),
|
|
355
|
+
self._reconnect_max_delay,
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
self._logger.info(
|
|
359
|
+
f"Reconnecting in {delay}ms (attempt {self._reconnect_attempts})..."
|
|
360
|
+
)
|
|
361
|
+
if self.on_reconnecting:
|
|
362
|
+
self.on_reconnecting(self._reconnect_attempts)
|
|
363
|
+
|
|
364
|
+
await asyncio.sleep(delay / 1000)
|
|
365
|
+
if self._is_manual_close:
|
|
366
|
+
return
|
|
367
|
+
|
|
368
|
+
await self.connect()
|
|
369
|
+
|
|
370
|
+
async def send(self, frame: WsFrame) -> None:
|
|
371
|
+
"""
|
|
372
|
+
发送数据帧
|
|
373
|
+
|
|
374
|
+
:param frame: WebSocket 帧
|
|
375
|
+
:raises RuntimeError: 连接未建立时
|
|
376
|
+
"""
|
|
377
|
+
if self._ws and _ws_is_open(self._ws):
|
|
378
|
+
await self._ws.send(json.dumps(frame))
|
|
379
|
+
else:
|
|
380
|
+
raise RuntimeError("WebSocket not connected, unable to send data")
|
|
381
|
+
|
|
382
|
+
async def send_reply(
|
|
383
|
+
self, req_id: str, body: Any, cmd: str = WsCmd.RESPONSE
|
|
384
|
+
) -> WsFrame:
|
|
385
|
+
"""
|
|
386
|
+
通过 WebSocket 通道发送回复消息(串行队列版本)
|
|
387
|
+
|
|
388
|
+
同一个 req_id 的消息会被放入队列中串行发送。
|
|
389
|
+
|
|
390
|
+
:param req_id: 透传回调中的 req_id
|
|
391
|
+
:param body: 回复消息体
|
|
392
|
+
:param cmd: 发送的命令类型,默认 WsCmd.RESPONSE
|
|
393
|
+
:return: 回执帧
|
|
394
|
+
"""
|
|
395
|
+
loop = asyncio.get_event_loop()
|
|
396
|
+
future: asyncio.Future[WsFrame] = loop.create_future()
|
|
397
|
+
|
|
398
|
+
frame: WsFrame = {
|
|
399
|
+
"cmd": cmd,
|
|
400
|
+
"headers": {"req_id": req_id},
|
|
401
|
+
"body": body,
|
|
402
|
+
}
|
|
403
|
+
|
|
404
|
+
item = _ReplyQueueItem(frame, future)
|
|
405
|
+
|
|
406
|
+
if req_id not in self._reply_queues:
|
|
407
|
+
self._reply_queues[req_id] = []
|
|
408
|
+
|
|
409
|
+
queue = self._reply_queues[req_id]
|
|
410
|
+
|
|
411
|
+
# 防止队列无限增长
|
|
412
|
+
if len(queue) >= self._max_reply_queue_size:
|
|
413
|
+
self._logger.warn(
|
|
414
|
+
f"Reply queue for reqId {req_id} exceeds max size ({self._max_reply_queue_size}), "
|
|
415
|
+
"rejecting new message"
|
|
416
|
+
)
|
|
417
|
+
future.set_exception(
|
|
418
|
+
RuntimeError(
|
|
419
|
+
f"Reply queue for reqId {req_id} exceeds max size ({self._max_reply_queue_size})"
|
|
420
|
+
)
|
|
421
|
+
)
|
|
422
|
+
return await future
|
|
423
|
+
|
|
424
|
+
queue.append(item)
|
|
425
|
+
|
|
426
|
+
# 如果队列中只有这一条,立即开始处理
|
|
427
|
+
if len(queue) == 1 and req_id not in self._processing_queues:
|
|
428
|
+
asyncio.ensure_future(self._process_reply_queue(req_id))
|
|
429
|
+
|
|
430
|
+
return await future
|
|
431
|
+
|
|
432
|
+
async def _process_reply_queue(self, req_id: str) -> None:
|
|
433
|
+
"""处理指定 req_id 的回复队列"""
|
|
434
|
+
self._processing_queues.add(req_id)
|
|
435
|
+
|
|
436
|
+
try:
|
|
437
|
+
while True:
|
|
438
|
+
queue = self._reply_queues.get(req_id)
|
|
439
|
+
if not queue:
|
|
440
|
+
self._reply_queues.pop(req_id, None)
|
|
441
|
+
break
|
|
442
|
+
|
|
443
|
+
item = queue[0]
|
|
444
|
+
|
|
445
|
+
try:
|
|
446
|
+
await self.send(item.frame)
|
|
447
|
+
self._logger.debug(
|
|
448
|
+
f"Reply message sent via WebSocket, reqId: {req_id}, queue length: {len(queue)}"
|
|
449
|
+
)
|
|
450
|
+
except Exception as e:
|
|
451
|
+
self._logger.error(f"Failed to send reply for reqId {req_id}: {e}")
|
|
452
|
+
queue.pop(0)
|
|
453
|
+
if not item.future.done():
|
|
454
|
+
item.future.set_exception(e)
|
|
455
|
+
continue
|
|
456
|
+
|
|
457
|
+
# 等待回执
|
|
458
|
+
loop = asyncio.get_event_loop()
|
|
459
|
+
ack_future: asyncio.Future[WsFrame] = loop.create_future()
|
|
460
|
+
|
|
461
|
+
# 设置超时
|
|
462
|
+
timeout_handle = loop.call_later(
|
|
463
|
+
self._reply_ack_timeout,
|
|
464
|
+
self._on_reply_ack_timeout,
|
|
465
|
+
req_id,
|
|
466
|
+
ack_future,
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
self._pending_acks[req_id] = (ack_future, timeout_handle)
|
|
470
|
+
|
|
471
|
+
try:
|
|
472
|
+
ack_frame = await ack_future
|
|
473
|
+
# 成功收到回执
|
|
474
|
+
queue.pop(0)
|
|
475
|
+
if not item.future.done():
|
|
476
|
+
item.future.set_result(ack_frame)
|
|
477
|
+
except Exception as e:
|
|
478
|
+
queue.pop(0)
|
|
479
|
+
if not item.future.done():
|
|
480
|
+
item.future.set_exception(e)
|
|
481
|
+
finally:
|
|
482
|
+
self._processing_queues.discard(req_id)
|
|
483
|
+
|
|
484
|
+
def _on_reply_ack_timeout(
|
|
485
|
+
self, req_id: str, ack_future: "asyncio.Future[WsFrame]"
|
|
486
|
+
) -> None:
|
|
487
|
+
"""回复回执超时回调"""
|
|
488
|
+
self._logger.warn(
|
|
489
|
+
f"Reply ack timeout ({self._reply_ack_timeout}s) for reqId: {req_id}"
|
|
490
|
+
)
|
|
491
|
+
self._pending_acks.pop(req_id, None)
|
|
492
|
+
if not ack_future.done():
|
|
493
|
+
ack_future.set_exception(
|
|
494
|
+
TimeoutError(
|
|
495
|
+
f"Reply ack timeout ({self._reply_ack_timeout}s) for reqId: {req_id}"
|
|
496
|
+
)
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
def _handle_reply_ack(self, req_id: str, frame: WsFrame) -> None:
|
|
500
|
+
"""处理回复消息的回执"""
|
|
501
|
+
pending = self._pending_acks.pop(req_id, None)
|
|
502
|
+
if not pending:
|
|
503
|
+
return
|
|
504
|
+
|
|
505
|
+
ack_future, timeout_handle = pending
|
|
506
|
+
|
|
507
|
+
# 取消超时
|
|
508
|
+
if timeout_handle:
|
|
509
|
+
timeout_handle.cancel()
|
|
510
|
+
|
|
511
|
+
errcode = frame.get("errcode")
|
|
512
|
+
if errcode != 0:
|
|
513
|
+
self._logger.warn(
|
|
514
|
+
f"Reply ack error: reqId={req_id}, errcode={errcode}, errmsg={frame.get('errmsg')}"
|
|
515
|
+
)
|
|
516
|
+
if not ack_future.done():
|
|
517
|
+
ack_future.set_exception(
|
|
518
|
+
RuntimeError(
|
|
519
|
+
f"Reply ack error: errcode={errcode}, errmsg={frame.get('errmsg')}"
|
|
520
|
+
)
|
|
521
|
+
)
|
|
522
|
+
else:
|
|
523
|
+
self._logger.debug(f"Reply ack received for reqId: {req_id}")
|
|
524
|
+
if not ack_future.done():
|
|
525
|
+
ack_future.set_result(frame)
|
|
526
|
+
|
|
527
|
+
def _clear_pending_messages(self, reason: str) -> None:
|
|
528
|
+
"""清理所有待处理的消息和回执"""
|
|
529
|
+
for req_id, (ack_future, timeout_handle) in self._pending_acks.items():
|
|
530
|
+
if timeout_handle:
|
|
531
|
+
timeout_handle.cancel()
|
|
532
|
+
if not ack_future.done():
|
|
533
|
+
ack_future.set_exception(RuntimeError(reason))
|
|
534
|
+
self._pending_acks.clear()
|
|
535
|
+
|
|
536
|
+
for req_id, queue in self._reply_queues.items():
|
|
537
|
+
for item in queue:
|
|
538
|
+
if not item.future.done():
|
|
539
|
+
item.future.set_exception(
|
|
540
|
+
RuntimeError(f"{reason}, reply for reqId: {req_id} cancelled")
|
|
541
|
+
)
|
|
542
|
+
self._reply_queues.clear()
|
|
543
|
+
|
|
544
|
+
def disconnect(self) -> None:
|
|
545
|
+
"""主动断开连接(同步方法,安排异步关闭)"""
|
|
546
|
+
self._is_manual_close = True
|
|
547
|
+
self._stop_heartbeat()
|
|
548
|
+
self._clear_pending_messages("Connection manually closed")
|
|
549
|
+
|
|
550
|
+
if self._ws:
|
|
551
|
+
asyncio.ensure_future(self._async_disconnect())
|
|
552
|
+
|
|
553
|
+
self._logger.info("WebSocket connection manually closed")
|
|
554
|
+
|
|
555
|
+
async def _async_disconnect(self) -> None:
|
|
556
|
+
"""异步断开连接"""
|
|
557
|
+
if self._receive_task and not self._receive_task.done():
|
|
558
|
+
self._receive_task.cancel()
|
|
559
|
+
try:
|
|
560
|
+
await self._receive_task
|
|
561
|
+
except (asyncio.CancelledError, Exception):
|
|
562
|
+
pass
|
|
563
|
+
|
|
564
|
+
if self._ws:
|
|
565
|
+
try:
|
|
566
|
+
await self._ws.close(code=1000, reason="Manual disconnect")
|
|
567
|
+
except Exception:
|
|
568
|
+
pass
|
|
569
|
+
self._ws = None
|
|
570
|
+
|
|
571
|
+
@property
|
|
572
|
+
def is_connected(self) -> bool:
|
|
573
|
+
"""获取当前连接状态"""
|
|
574
|
+
return self._ws is not None and _ws_is_open(self._ws)
|