wecom-aibot-python-sdk 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aibot/ws.py ADDED
@@ -0,0 +1,574 @@
1
+ """
2
+ WebSocket 长连接管理器
3
+
4
+ 对标 Node.js SDK src/ws.ts
5
+ 负责维护与企业微信的 WebSocket 长连接,包括心跳、重连、认证、串行回复队列等。
6
+ """
7
+
8
+ import asyncio
9
+ import json
10
+ import ssl
11
+ from typing import Any, Callable, Dict, List, Optional, Tuple
12
+
13
+ try:
14
+ import certifi
15
+ _SSL_CONTEXT = ssl.create_default_context(cafile=certifi.where())
16
+ except ImportError:
17
+ # 未安装 certifi 时回退到系统默认证书
18
+ _SSL_CONTEXT = ssl.create_default_context()
19
+
20
+ from .types import WsCmd, WsFrame
21
+ from .utils import generate_req_id
22
+
23
+ # SDK 内置默认 WebSocket 连接地址
24
+ DEFAULT_WS_URL = "wss://openws.work.weixin.qq.com"
25
+
26
+ try:
27
+ import websockets
28
+ from websockets.client import WebSocketClientProtocol
29
+
30
+ def _ws_is_open(ws) -> bool:
31
+ """兼容 websockets 新旧版本的连接状态判断"""
32
+ if hasattr(ws, 'open'):
33
+ # websockets <= 13.x
34
+ return ws.open
35
+ elif hasattr(ws, 'state'):
36
+ # websockets >= 14.x
37
+ try:
38
+ from websockets.protocol import State
39
+ return ws.state is State.OPEN
40
+ except ImportError:
41
+ return ws.state.name == 'OPEN'
42
+ return False
43
+ except ImportError:
44
+ raise ImportError("请安装 websockets: pip install websockets>=12.0")
45
+
46
+
47
+ class _ReplyQueueItem:
48
+ """回复队列中的单个任务项"""
49
+
50
+ __slots__ = ("frame", "future")
51
+
52
+ def __init__(self, frame: WsFrame, future: "asyncio.Future[WsFrame]"):
53
+ self.frame = frame
54
+ self.future = future
55
+
56
+
57
+ class WsConnectionManager:
58
+ """
59
+ WebSocket 长连接管理器
60
+
61
+ 负责维护与企业微信的 WebSocket 长连接,包括心跳、重连、认证等。
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ logger: Any,
67
+ heartbeat_interval: int = 30000,
68
+ reconnect_base_delay: int = 1000,
69
+ max_reconnect_attempts: int = 10,
70
+ ws_url: Optional[str] = None,
71
+ ):
72
+ self._logger = logger
73
+ self._ws_url = ws_url or DEFAULT_WS_URL
74
+ self._heartbeat_interval = heartbeat_interval
75
+ self._reconnect_base_delay = reconnect_base_delay
76
+ self._max_reconnect_attempts = max_reconnect_attempts
77
+
78
+ self._ws: Optional[WebSocketClientProtocol] = None
79
+ self._heartbeat_task: Optional[asyncio.Task[None]] = None
80
+ self._receive_task: Optional[asyncio.Task[None]] = None
81
+ self._reconnect_attempts: int = 0
82
+ self._is_manual_close: bool = False
83
+
84
+ # 认证凭证
85
+ self._bot_id: str = ""
86
+ self._bot_secret: str = ""
87
+
88
+ # 心跳相关
89
+ self._missed_pong_count: int = 0
90
+ self._max_missed_pong: int = 2
91
+ self._reconnect_max_delay: int = 30000
92
+
93
+ # 串行回复队列
94
+ self._reply_queues: Dict[str, List[_ReplyQueueItem]] = {}
95
+ self._pending_acks: Dict[
96
+ str,
97
+ Tuple["asyncio.Future[WsFrame]", Optional[asyncio.TimerHandle]],
98
+ ] = {}
99
+ self._reply_ack_timeout: float = 5.0 # 秒
100
+ self._max_reply_queue_size: int = 100
101
+ self._processing_queues: set = set() # 正在处理的 req_id 集合
102
+
103
+ # 回调
104
+ self.on_connected: Optional[Callable[[], None]] = None
105
+ self.on_authenticated: Optional[Callable[[], None]] = None
106
+ self.on_disconnected: Optional[Callable[[str], None]] = None
107
+ self.on_message: Optional[Callable[[WsFrame], None]] = None
108
+ self.on_reconnecting: Optional[Callable[[int], None]] = None
109
+ self.on_error: Optional[Callable[[Exception], None]] = None
110
+
111
+ def set_credentials(self, bot_id: str, bot_secret: str) -> None:
112
+ """设置认证凭证"""
113
+ self._bot_id = bot_id
114
+ self._bot_secret = bot_secret
115
+
116
+ async def connect(self) -> None:
117
+ """建立 WebSocket 连接"""
118
+ self._is_manual_close = False
119
+
120
+ # 清理旧连接
121
+ await self._cleanup_ws()
122
+
123
+ self._logger.info(f"Connecting to WebSocket: {self._ws_url}...")
124
+
125
+ try:
126
+ self._ws = await websockets.connect(
127
+ self._ws_url,
128
+ ssl=_SSL_CONTEXT,
129
+ ping_interval=None, # 我们自己管理心跳
130
+ ping_timeout=None,
131
+ close_timeout=5,
132
+ )
133
+
134
+ self._reconnect_attempts = 0
135
+ self._missed_pong_count = 0
136
+
137
+ self._logger.info("WebSocket connection established, sending auth...")
138
+
139
+ # 连接建立回调
140
+ if self.on_connected:
141
+ self.on_connected()
142
+
143
+ # 发送认证帧
144
+ await self._send_auth()
145
+
146
+ # 启动消息接收循环
147
+ self._receive_task = asyncio.ensure_future(self._receive_loop())
148
+
149
+ except Exception as e:
150
+ self._logger.error(f"Failed to create WebSocket connection: {e}")
151
+ if self.on_error:
152
+ self.on_error(e)
153
+ await self._schedule_reconnect()
154
+
155
+ async def _cleanup_ws(self) -> None:
156
+ """清理 WebSocket 连接"""
157
+ if self._receive_task and not self._receive_task.done():
158
+ self._receive_task.cancel()
159
+ try:
160
+ await self._receive_task
161
+ except (asyncio.CancelledError, Exception):
162
+ pass
163
+ self._receive_task = None
164
+
165
+ if self._ws:
166
+ try:
167
+ await self._ws.close()
168
+ except Exception:
169
+ pass
170
+ self._ws = None
171
+
172
+ async def _send_auth(self) -> None:
173
+ """发送认证帧"""
174
+ try:
175
+ await self.send(
176
+ {
177
+ "cmd": WsCmd.SUBSCRIBE,
178
+ "headers": {"req_id": generate_req_id(WsCmd.SUBSCRIBE)},
179
+ "body": {
180
+ "bot_id": self._bot_id,
181
+ "secret": self._bot_secret,
182
+ },
183
+ }
184
+ )
185
+ self._logger.info("Auth frame sent")
186
+ except Exception as e:
187
+ self._logger.error(f"Failed to send auth frame: {e}")
188
+
189
+ async def _receive_loop(self) -> None:
190
+ """消息接收循环"""
191
+ try:
192
+ async for raw_message in self._ws: # type: ignore
193
+ try:
194
+ if isinstance(raw_message, bytes):
195
+ raw_message = raw_message.decode("utf-8")
196
+ frame: WsFrame = json.loads(raw_message)
197
+ self._handle_frame(frame)
198
+ except json.JSONDecodeError as e:
199
+ self._logger.error(f"Failed to parse WebSocket message: {e}")
200
+ except websockets.exceptions.ConnectionClosed as e:
201
+ reason_str = str(e) or f"code: {e.code}"
202
+ self._logger.warn(f"WebSocket connection closed: {reason_str}")
203
+ self._stop_heartbeat()
204
+ self._clear_pending_messages(f"WebSocket connection closed ({reason_str})")
205
+ if self.on_disconnected:
206
+ self.on_disconnected(reason_str)
207
+ if not self._is_manual_close:
208
+ await self._schedule_reconnect()
209
+ except asyncio.CancelledError:
210
+ pass
211
+ except Exception as e:
212
+ self._logger.error(f"WebSocket error: {e}")
213
+ if self.on_error:
214
+ self.on_error(e)
215
+
216
+ def _handle_frame(self, frame: WsFrame) -> None:
217
+ """处理收到的帧数据"""
218
+ cmd = frame.get("cmd")
219
+
220
+ # 消息推送
221
+ if cmd == WsCmd.CALLBACK:
222
+ self._logger.debug(f"Received push message: {json.dumps(frame.get('body', {}))}")
223
+ if self.on_message:
224
+ self.on_message(frame)
225
+ return
226
+
227
+ # 事件推送
228
+ if cmd == WsCmd.EVENT_CALLBACK:
229
+ self._logger.debug(f"Received event callback: {json.dumps(frame.get('body', {}))}")
230
+ if self.on_message:
231
+ self.on_message(frame)
232
+ return
233
+
234
+ # 无 cmd 的帧:认证响应、心跳响应或回复消息回执
235
+ headers = frame.get("headers", {})
236
+ req_id = headers.get("req_id", "")
237
+
238
+ # 检查是否是回复消息的回执
239
+ if req_id in self._pending_acks:
240
+ self._handle_reply_ack(req_id, frame)
241
+ return
242
+
243
+ if req_id.startswith(WsCmd.SUBSCRIBE):
244
+ # 认证响应
245
+ errcode = frame.get("errcode")
246
+ if errcode != 0:
247
+ self._logger.error(
248
+ f"Authentication failed: errcode={errcode}, errmsg={frame.get('errmsg')}"
249
+ )
250
+ if self.on_error:
251
+ self.on_error(
252
+ Exception(
253
+ f"Authentication failed: {frame.get('errmsg')} (code: {errcode})"
254
+ )
255
+ )
256
+ return
257
+ self._logger.info("Authentication successful")
258
+ self._start_heartbeat()
259
+ if self.on_authenticated:
260
+ self.on_authenticated()
261
+ return
262
+
263
+ if req_id.startswith(WsCmd.HEARTBEAT):
264
+ # 心跳响应
265
+ errcode = frame.get("errcode")
266
+ if errcode != 0:
267
+ self._logger.warn(
268
+ f"Heartbeat ack error: errcode={errcode}, errmsg={frame.get('errmsg')}"
269
+ )
270
+ return
271
+ self._missed_pong_count = 0
272
+ self._logger.debug("Received heartbeat ack")
273
+ return
274
+
275
+ # 未知帧类型
276
+ self._logger.warn(f"Received unknown frame: {json.dumps(frame)}")
277
+ if self.on_message:
278
+ self.on_message(frame)
279
+
280
+ def _start_heartbeat(self) -> None:
281
+ """启动心跳定时器"""
282
+ self._stop_heartbeat()
283
+ self._heartbeat_task = asyncio.ensure_future(self._heartbeat_loop())
284
+ self._logger.debug(
285
+ f"Heartbeat timer started, interval: {self._heartbeat_interval}ms"
286
+ )
287
+
288
+ def _stop_heartbeat(self) -> None:
289
+ """停止心跳定时器"""
290
+ if self._heartbeat_task and not self._heartbeat_task.done():
291
+ self._heartbeat_task.cancel()
292
+ self._heartbeat_task = None
293
+ self._logger.debug("Heartbeat timer stopped")
294
+
295
+ async def _heartbeat_loop(self) -> None:
296
+ """心跳循环"""
297
+ try:
298
+ while True:
299
+ await asyncio.sleep(self._heartbeat_interval / 1000)
300
+ await self._send_heartbeat()
301
+ except asyncio.CancelledError:
302
+ pass
303
+
304
+ async def _send_heartbeat(self) -> None:
305
+ """发送心跳"""
306
+ # 检查连续未收到 pong 的次数
307
+ if self._missed_pong_count >= self._max_missed_pong:
308
+ self._logger.warn(
309
+ f"No heartbeat ack received for {self._missed_pong_count} consecutive pings, "
310
+ "connection considered dead"
311
+ )
312
+ self._stop_heartbeat()
313
+ # 强制关闭底层连接
314
+ if self._ws:
315
+ try:
316
+ await self._ws.close()
317
+ except Exception:
318
+ pass
319
+ return
320
+
321
+ self._missed_pong_count += 1
322
+ try:
323
+ await self.send(
324
+ {
325
+ "cmd": WsCmd.HEARTBEAT,
326
+ "headers": {"req_id": generate_req_id(WsCmd.HEARTBEAT)},
327
+ }
328
+ )
329
+ extra = (
330
+ f" (awaiting {self._missed_pong_count} pong)"
331
+ if self._missed_pong_count > 1
332
+ else ""
333
+ )
334
+ self._logger.debug(f"Heartbeat sent{extra}")
335
+ except Exception as e:
336
+ self._logger.error(f"Failed to send heartbeat: {e}")
337
+
338
+ async def _schedule_reconnect(self) -> None:
339
+ """安排重连"""
340
+ if (
341
+ self._max_reconnect_attempts != -1
342
+ and self._reconnect_attempts >= self._max_reconnect_attempts
343
+ ):
344
+ self._logger.error(
345
+ f"Max reconnect attempts reached ({self._max_reconnect_attempts}), giving up"
346
+ )
347
+ if self.on_error:
348
+ self.on_error(Exception("Max reconnect attempts exceeded"))
349
+ return
350
+
351
+ self._reconnect_attempts += 1
352
+ # 指数退避:1s, 2s, 4s, 8s … 上限 30s
353
+ delay = min(
354
+ self._reconnect_base_delay * (2 ** (self._reconnect_attempts - 1)),
355
+ self._reconnect_max_delay,
356
+ )
357
+
358
+ self._logger.info(
359
+ f"Reconnecting in {delay}ms (attempt {self._reconnect_attempts})..."
360
+ )
361
+ if self.on_reconnecting:
362
+ self.on_reconnecting(self._reconnect_attempts)
363
+
364
+ await asyncio.sleep(delay / 1000)
365
+ if self._is_manual_close:
366
+ return
367
+
368
+ await self.connect()
369
+
370
+ async def send(self, frame: WsFrame) -> None:
371
+ """
372
+ 发送数据帧
373
+
374
+ :param frame: WebSocket 帧
375
+ :raises RuntimeError: 连接未建立时
376
+ """
377
+ if self._ws and _ws_is_open(self._ws):
378
+ await self._ws.send(json.dumps(frame))
379
+ else:
380
+ raise RuntimeError("WebSocket not connected, unable to send data")
381
+
382
+ async def send_reply(
383
+ self, req_id: str, body: Any, cmd: str = WsCmd.RESPONSE
384
+ ) -> WsFrame:
385
+ """
386
+ 通过 WebSocket 通道发送回复消息(串行队列版本)
387
+
388
+ 同一个 req_id 的消息会被放入队列中串行发送。
389
+
390
+ :param req_id: 透传回调中的 req_id
391
+ :param body: 回复消息体
392
+ :param cmd: 发送的命令类型,默认 WsCmd.RESPONSE
393
+ :return: 回执帧
394
+ """
395
+ loop = asyncio.get_event_loop()
396
+ future: asyncio.Future[WsFrame] = loop.create_future()
397
+
398
+ frame: WsFrame = {
399
+ "cmd": cmd,
400
+ "headers": {"req_id": req_id},
401
+ "body": body,
402
+ }
403
+
404
+ item = _ReplyQueueItem(frame, future)
405
+
406
+ if req_id not in self._reply_queues:
407
+ self._reply_queues[req_id] = []
408
+
409
+ queue = self._reply_queues[req_id]
410
+
411
+ # 防止队列无限增长
412
+ if len(queue) >= self._max_reply_queue_size:
413
+ self._logger.warn(
414
+ f"Reply queue for reqId {req_id} exceeds max size ({self._max_reply_queue_size}), "
415
+ "rejecting new message"
416
+ )
417
+ future.set_exception(
418
+ RuntimeError(
419
+ f"Reply queue for reqId {req_id} exceeds max size ({self._max_reply_queue_size})"
420
+ )
421
+ )
422
+ return await future
423
+
424
+ queue.append(item)
425
+
426
+ # 如果队列中只有这一条,立即开始处理
427
+ if len(queue) == 1 and req_id not in self._processing_queues:
428
+ asyncio.ensure_future(self._process_reply_queue(req_id))
429
+
430
+ return await future
431
+
432
+ async def _process_reply_queue(self, req_id: str) -> None:
433
+ """处理指定 req_id 的回复队列"""
434
+ self._processing_queues.add(req_id)
435
+
436
+ try:
437
+ while True:
438
+ queue = self._reply_queues.get(req_id)
439
+ if not queue:
440
+ self._reply_queues.pop(req_id, None)
441
+ break
442
+
443
+ item = queue[0]
444
+
445
+ try:
446
+ await self.send(item.frame)
447
+ self._logger.debug(
448
+ f"Reply message sent via WebSocket, reqId: {req_id}, queue length: {len(queue)}"
449
+ )
450
+ except Exception as e:
451
+ self._logger.error(f"Failed to send reply for reqId {req_id}: {e}")
452
+ queue.pop(0)
453
+ if not item.future.done():
454
+ item.future.set_exception(e)
455
+ continue
456
+
457
+ # 等待回执
458
+ loop = asyncio.get_event_loop()
459
+ ack_future: asyncio.Future[WsFrame] = loop.create_future()
460
+
461
+ # 设置超时
462
+ timeout_handle = loop.call_later(
463
+ self._reply_ack_timeout,
464
+ self._on_reply_ack_timeout,
465
+ req_id,
466
+ ack_future,
467
+ )
468
+
469
+ self._pending_acks[req_id] = (ack_future, timeout_handle)
470
+
471
+ try:
472
+ ack_frame = await ack_future
473
+ # 成功收到回执
474
+ queue.pop(0)
475
+ if not item.future.done():
476
+ item.future.set_result(ack_frame)
477
+ except Exception as e:
478
+ queue.pop(0)
479
+ if not item.future.done():
480
+ item.future.set_exception(e)
481
+ finally:
482
+ self._processing_queues.discard(req_id)
483
+
484
+ def _on_reply_ack_timeout(
485
+ self, req_id: str, ack_future: "asyncio.Future[WsFrame]"
486
+ ) -> None:
487
+ """回复回执超时回调"""
488
+ self._logger.warn(
489
+ f"Reply ack timeout ({self._reply_ack_timeout}s) for reqId: {req_id}"
490
+ )
491
+ self._pending_acks.pop(req_id, None)
492
+ if not ack_future.done():
493
+ ack_future.set_exception(
494
+ TimeoutError(
495
+ f"Reply ack timeout ({self._reply_ack_timeout}s) for reqId: {req_id}"
496
+ )
497
+ )
498
+
499
+ def _handle_reply_ack(self, req_id: str, frame: WsFrame) -> None:
500
+ """处理回复消息的回执"""
501
+ pending = self._pending_acks.pop(req_id, None)
502
+ if not pending:
503
+ return
504
+
505
+ ack_future, timeout_handle = pending
506
+
507
+ # 取消超时
508
+ if timeout_handle:
509
+ timeout_handle.cancel()
510
+
511
+ errcode = frame.get("errcode")
512
+ if errcode != 0:
513
+ self._logger.warn(
514
+ f"Reply ack error: reqId={req_id}, errcode={errcode}, errmsg={frame.get('errmsg')}"
515
+ )
516
+ if not ack_future.done():
517
+ ack_future.set_exception(
518
+ RuntimeError(
519
+ f"Reply ack error: errcode={errcode}, errmsg={frame.get('errmsg')}"
520
+ )
521
+ )
522
+ else:
523
+ self._logger.debug(f"Reply ack received for reqId: {req_id}")
524
+ if not ack_future.done():
525
+ ack_future.set_result(frame)
526
+
527
+ def _clear_pending_messages(self, reason: str) -> None:
528
+ """清理所有待处理的消息和回执"""
529
+ for req_id, (ack_future, timeout_handle) in self._pending_acks.items():
530
+ if timeout_handle:
531
+ timeout_handle.cancel()
532
+ if not ack_future.done():
533
+ ack_future.set_exception(RuntimeError(reason))
534
+ self._pending_acks.clear()
535
+
536
+ for req_id, queue in self._reply_queues.items():
537
+ for item in queue:
538
+ if not item.future.done():
539
+ item.future.set_exception(
540
+ RuntimeError(f"{reason}, reply for reqId: {req_id} cancelled")
541
+ )
542
+ self._reply_queues.clear()
543
+
544
+ def disconnect(self) -> None:
545
+ """主动断开连接(同步方法,安排异步关闭)"""
546
+ self._is_manual_close = True
547
+ self._stop_heartbeat()
548
+ self._clear_pending_messages("Connection manually closed")
549
+
550
+ if self._ws:
551
+ asyncio.ensure_future(self._async_disconnect())
552
+
553
+ self._logger.info("WebSocket connection manually closed")
554
+
555
+ async def _async_disconnect(self) -> None:
556
+ """异步断开连接"""
557
+ if self._receive_task and not self._receive_task.done():
558
+ self._receive_task.cancel()
559
+ try:
560
+ await self._receive_task
561
+ except (asyncio.CancelledError, Exception):
562
+ pass
563
+
564
+ if self._ws:
565
+ try:
566
+ await self._ws.close(code=1000, reason="Manual disconnect")
567
+ except Exception:
568
+ pass
569
+ self._ws = None
570
+
571
+ @property
572
+ def is_connected(self) -> bool:
573
+ """获取当前连接状态"""
574
+ return self._ws is not None and _ws_is_open(self._ws)