disagreement 0.1.0rc3__py3-none-any.whl → 0.3.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
disagreement/gateway.py CHANGED
@@ -1,575 +1,630 @@
1
- # disagreement/gateway.py
2
-
3
- """
4
- Manages the WebSocket connection to the Discord Gateway.
5
- """
6
-
7
- import asyncio
8
- import logging
9
- import traceback
10
- import aiohttp
11
- import json
12
- import zlib
13
- import time
14
- import random
15
- from typing import Optional, TYPE_CHECKING, Any, Dict
16
-
17
- from .enums import GatewayOpcode, GatewayIntent
18
- from .errors import GatewayException, DisagreementException, AuthenticationError
19
- from .interactions import Interaction
20
-
21
- if TYPE_CHECKING:
22
- from .client import Client # For type hinting
23
- from .event_dispatcher import EventDispatcher
24
- from .http import HTTPClient
25
- from .interactions import Interaction # Added for INTERACTION_CREATE
26
-
27
- # ZLIB Decompression constants
28
- ZLIB_SUFFIX = b"\x00\x00\xff\xff"
29
- MAX_DECOMPRESSION_SIZE = 10 * 1024 * 1024 # 10 MiB, adjust as needed
30
-
31
-
32
- logger = logging.getLogger(__name__)
33
-
34
-
35
- class GatewayClient:
36
- """
37
- Handles the Discord Gateway WebSocket connection, heartbeating, and event dispatching.
38
- """
39
-
40
- def __init__(
41
- self,
42
- http_client: "HTTPClient",
43
- event_dispatcher: "EventDispatcher",
44
- token: str,
45
- intents: int,
46
- client_instance: "Client", # Pass the main client instance
47
- verbose: bool = False,
48
- *,
49
- shard_id: Optional[int] = None,
50
- shard_count: Optional[int] = None,
51
- max_retries: int = 5,
52
- max_backoff: float = 60.0,
53
- ):
54
- self._http: "HTTPClient" = http_client
55
- self._dispatcher: "EventDispatcher" = event_dispatcher
56
- self._token: str = token
57
- self._intents: int = intents
58
- self._client_instance: "Client" = client_instance # Store client instance
59
- self.verbose: bool = verbose
60
- self._shard_id: Optional[int] = shard_id
61
- self._shard_count: Optional[int] = shard_count
62
- self._max_retries: int = max_retries
63
- self._max_backoff: float = max_backoff
64
-
65
- self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
66
- self._loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
67
- self._heartbeat_interval: Optional[float] = None
68
- self._last_sequence: Optional[int] = None
69
- self._session_id: Optional[str] = None
70
- self._resume_gateway_url: Optional[str] = None
71
-
72
- self._keep_alive_task: Optional[asyncio.Task] = None
73
- self._receive_task: Optional[asyncio.Task] = None
74
-
75
- self._last_heartbeat_sent: Optional[float] = None
76
- self._last_heartbeat_ack: Optional[float] = None
77
-
78
- # For zlib decompression
79
- self._buffer = bytearray()
80
- self._inflator = zlib.decompressobj()
81
-
82
- async def _reconnect(self) -> None:
83
- """Attempts to reconnect using exponential backoff with jitter."""
84
- delay = 1.0
85
- for attempt in range(self._max_retries):
86
- try:
87
- await self.connect()
88
- return
89
- except Exception as e: # noqa: BLE001
90
- if attempt >= self._max_retries - 1:
91
- logger.error(
92
- "Reconnect failed after %s attempts: %s", attempt + 1, e
93
- )
94
- raise
95
- jitter = random.uniform(0, delay)
96
- wait_time = min(delay + jitter, self._max_backoff)
97
- logger.warning(
98
- "Reconnect attempt %s failed: %s. Retrying in %.2f seconds...",
99
- attempt + 1,
100
- e,
101
- wait_time,
102
- )
103
- await asyncio.sleep(wait_time)
104
- delay = min(delay * 2, self._max_backoff)
105
-
106
- async def _decompress_message(
107
- self, message_bytes: bytes
108
- ) -> Optional[Dict[str, Any]]:
109
- """Decompresses a zlib-compressed message from the Gateway."""
110
- self._buffer.extend(message_bytes)
111
-
112
- if len(message_bytes) < 4 or message_bytes[-4:] != ZLIB_SUFFIX:
113
- # Message is not complete or not zlib compressed in the expected way
114
- return None
115
- # Or handle partial messages if Discord ever sends them fragmented like this,
116
- # but typically each binary message is a complete zlib stream.
117
-
118
- try:
119
- decompressed = self._inflator.decompress(self._buffer)
120
- self._buffer.clear() # Reset buffer after successful decompression
121
- return json.loads(decompressed.decode("utf-8"))
122
- except zlib.error as e:
123
- logger.error("Zlib decompression error: %s", e)
124
- self._buffer.clear() # Clear buffer on error
125
- self._inflator = zlib.decompressobj() # Reset inflator
126
- return None
127
- except json.JSONDecodeError as e:
128
- logger.error("JSON decode error after decompression: %s", e)
129
- return None
130
-
131
- async def _send_json(self, payload: Dict[str, Any]):
132
- if self._ws and not self._ws.closed:
133
- if self.verbose:
134
- logger.debug("GATEWAY SEND: %s", payload)
135
- await self._ws.send_json(payload)
136
- else:
137
- logger.warning(
138
- "Gateway send attempted but WebSocket is closed or not available."
139
- )
140
- # raise GatewayException("WebSocket is not connected.")
141
-
142
- async def _heartbeat(self):
143
- """Sends a heartbeat to the Gateway."""
144
- self._last_heartbeat_sent = time.monotonic()
145
- payload = {"op": GatewayOpcode.HEARTBEAT, "d": self._last_sequence}
146
- await self._send_json(payload)
147
- # print("Sent heartbeat.")
148
-
149
- async def _keep_alive(self):
150
- """Manages the heartbeating loop."""
151
- if self._heartbeat_interval is None:
152
- # This should not happen if HELLO was processed correctly
153
- logger.error("Heartbeat interval not set. Cannot start keep_alive.")
154
- return
155
-
156
- try:
157
- while True:
158
- await self._heartbeat()
159
- await asyncio.sleep(
160
- self._heartbeat_interval / 1000
161
- ) # Interval is in ms
162
- except asyncio.CancelledError:
163
- logger.debug("Keep_alive task cancelled.")
164
- except Exception as e:
165
- logger.error("Error in keep_alive loop: %s", e)
166
- # Potentially trigger a reconnect here or notify client
167
- await self._client_instance.close_gateway(code=1000) # Generic close
168
-
169
- async def _identify(self):
170
- """Sends the IDENTIFY payload to the Gateway."""
171
- payload = {
172
- "op": GatewayOpcode.IDENTIFY,
173
- "d": {
174
- "token": self._token,
175
- "intents": self._intents,
176
- "properties": {
177
- "$os": "python", # Or platform.system()
178
- "$browser": "disagreement", # Library name
179
- "$device": "disagreement", # Library name
180
- },
181
- "compress": True, # Request zlib compression
182
- },
183
- }
184
- if self._shard_id is not None and self._shard_count is not None:
185
- payload["d"]["shard"] = [self._shard_id, self._shard_count]
186
- await self._send_json(payload)
187
- logger.info("Sent IDENTIFY.")
188
-
189
- async def _resume(self):
190
- """Sends the RESUME payload to the Gateway."""
191
- if not self._session_id or self._last_sequence is None:
192
- logger.warning("Cannot RESUME: session_id or last_sequence is missing.")
193
- await self._identify() # Fallback to identify
194
- return
195
-
196
- payload = {
197
- "op": GatewayOpcode.RESUME,
198
- "d": {
199
- "token": self._token,
200
- "session_id": self._session_id,
201
- "seq": self._last_sequence,
202
- },
203
- }
204
- await self._send_json(payload)
205
- logger.info(
206
- "Sent RESUME for session %s at sequence %s.",
207
- self._session_id,
208
- self._last_sequence,
209
- )
210
-
211
- async def update_presence(
212
- self,
213
- status: str,
214
- activity_name: Optional[str] = None,
215
- activity_type: int = 0,
216
- since: int = 0,
217
- afk: bool = False,
218
- ):
219
- """Sends the presence update payload to the Gateway."""
220
- payload = {
221
- "op": GatewayOpcode.PRESENCE_UPDATE,
222
- "d": {
223
- "since": since,
224
- "activities": (
225
- [
226
- {
227
- "name": activity_name,
228
- "type": activity_type,
229
- }
230
- ]
231
- if activity_name
232
- else []
233
- ),
234
- "status": status,
235
- "afk": afk,
236
- },
237
- }
238
- await self._send_json(payload)
239
-
240
- async def _handle_dispatch(self, data: Dict[str, Any]):
241
- """Handles DISPATCH events (actual Discord events)."""
242
- event_name = data.get("t")
243
- sequence_num = data.get("s")
244
- raw_event_d_payload = data.get(
245
- "d"
246
- ) # This is the 'd' field from the gateway event
247
-
248
- if sequence_num is not None:
249
- self._last_sequence = sequence_num
250
-
251
- if event_name == "READY": # Special handling for READY
252
- if not isinstance(raw_event_d_payload, dict):
253
- logger.error(
254
- "READY event 'd' payload is not a dict or is missing: %s",
255
- raw_event_d_payload,
256
- )
257
- # Consider raising an error or attempting a reconnect
258
- return
259
- self._session_id = raw_event_d_payload.get("session_id")
260
- self._resume_gateway_url = raw_event_d_payload.get("resume_gateway_url")
261
-
262
- app_id_str = "N/A"
263
- # Store application_id on the client instance
264
- if (
265
- "application" in raw_event_d_payload
266
- and isinstance(raw_event_d_payload["application"], dict)
267
- and "id" in raw_event_d_payload["application"]
268
- ):
269
- app_id_value = raw_event_d_payload["application"]["id"]
270
- self._client_instance.application_id = (
271
- app_id_value # Snowflake can be str or int
272
- )
273
- app_id_str = str(app_id_value)
274
- else:
275
- logger.warning(
276
- "Could not find application ID in READY payload. App commands may not work."
277
- )
278
-
279
- # Parse and store the bot's own user object
280
- if "user" in raw_event_d_payload and isinstance(
281
- raw_event_d_payload["user"], dict
282
- ):
283
- try:
284
- # Assuming Client has a parse_user method that takes user data dict
285
- # and returns a User object, also caching it.
286
- bot_user_obj = self._client_instance.parse_user(
287
- raw_event_d_payload["user"]
288
- )
289
- self._client_instance.user = bot_user_obj
290
- logger.info(
291
- "Gateway READY. Bot User: %s#%s. Session ID: %s. App ID: %s. Resume URL: %s",
292
- bot_user_obj.username,
293
- bot_user_obj.discriminator,
294
- self._session_id,
295
- app_id_str,
296
- self._resume_gateway_url,
297
- )
298
- except Exception as e:
299
- logger.error("Error parsing bot user from READY payload: %s", e)
300
- logger.info(
301
- "Gateway READY (user parse failed). Session ID: %s. App ID: %s. Resume URL: %s",
302
- self._session_id,
303
- app_id_str,
304
- self._resume_gateway_url,
305
- )
306
- else:
307
- logger.warning("Bot user object not found or invalid in READY payload.")
308
- logger.info(
309
- "Gateway READY (no user). Session ID: %s. App ID: %s. Resume URL: %s",
310
- self._session_id,
311
- app_id_str,
312
- self._resume_gateway_url,
313
- )
314
-
315
- await self._dispatcher.dispatch(event_name, raw_event_d_payload)
316
- elif event_name == "INTERACTION_CREATE":
317
- # print(f"GATEWAY RECV INTERACTION_CREATE: {raw_event_d_payload}")
318
- if isinstance(raw_event_d_payload, dict):
319
- interaction = Interaction(
320
- data=raw_event_d_payload, client_instance=self._client_instance
321
- )
322
- await self._dispatcher.dispatch(
323
- "INTERACTION_CREATE", raw_event_d_payload
324
- )
325
- # Dispatch to a new client method that will then call AppCommandHandler
326
- if hasattr(self._client_instance, "process_interaction"):
327
- asyncio.create_task(
328
- self._client_instance.process_interaction(interaction)
329
- ) # type: ignore
330
- else:
331
- logger.warning(
332
- "Client instance does not have process_interaction method for INTERACTION_CREATE."
333
- )
334
- else:
335
- logger.error(
336
- "INTERACTION_CREATE event 'd' payload is not a dict: %s",
337
- raw_event_d_payload,
338
- )
339
- elif event_name == "RESUMED":
340
- logger.info("Gateway RESUMED successfully.")
341
- # RESUMED 'd' payload is often an empty object or debug info.
342
- # Ensure it's a dict for the dispatcher.
343
- event_data_to_dispatch = (
344
- raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {}
345
- )
346
- await self._dispatcher.dispatch(event_name, event_data_to_dispatch)
347
- elif event_name:
348
- # For other events, ensure 'd' is a dict, or pass {} if 'd' is null/missing.
349
- # Models/parsers in EventDispatcher will need to handle potentially empty dicts.
350
- event_data_to_dispatch = (
351
- raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {}
352
- )
353
- # print(f"GATEWAY RECV EVENT: {event_name} | DATA: {event_data_to_dispatch}")
354
- await self._dispatcher.dispatch(event_name, event_data_to_dispatch)
355
- else:
356
- logger.warning("Received dispatch with no event name: %s", data)
357
-
358
- async def _process_message(self, msg: aiohttp.WSMessage):
359
- """Processes a single message from the WebSocket."""
360
- if msg.type == aiohttp.WSMsgType.TEXT:
361
- try:
362
- data = json.loads(msg.data)
363
- except json.JSONDecodeError:
364
- logger.error("Failed to decode JSON from Gateway: %s", msg.data[:200])
365
- return
366
- elif msg.type == aiohttp.WSMsgType.BINARY:
367
- decompressed_data = await self._decompress_message(msg.data)
368
- if decompressed_data is None:
369
- logger.error(
370
- "Failed to decompress or decode binary message from Gateway."
371
- )
372
- return
373
- data = decompressed_data
374
- elif msg.type == aiohttp.WSMsgType.ERROR:
375
- logger.error(
376
- "WebSocket error: %s",
377
- self._ws.exception() if self._ws else "Unknown WSError",
378
- )
379
- raise GatewayException(
380
- f"WebSocket error: {self._ws.exception() if self._ws else 'Unknown WSError'}"
381
- )
382
- elif msg.type == aiohttp.WSMsgType.CLOSED:
383
- close_code = (
384
- self._ws.close_code
385
- if self._ws and hasattr(self._ws, "close_code")
386
- else "N/A"
387
- )
388
- logger.warning(
389
- "WebSocket connection closed by server. Code: %s", close_code
390
- )
391
- # Raise an exception to signal the closure to the client's main run loop
392
- raise GatewayException(f"WebSocket closed by server. Code: {close_code}")
393
- else:
394
- logger.warning("Received unhandled WebSocket message type: %s", msg.type)
395
- return
396
-
397
- if self.verbose:
398
- logger.debug("GATEWAY RECV: %s", data)
399
- op = data.get("op")
400
- # 'd' payload (event_data) is handled specifically by each opcode handler below
401
-
402
- if op == GatewayOpcode.DISPATCH:
403
- await self._handle_dispatch(data) # _handle_dispatch will extract 'd'
404
- elif op == GatewayOpcode.HEARTBEAT: # Server requests a heartbeat
405
- await self._heartbeat()
406
- elif op == GatewayOpcode.RECONNECT: # Server requests a reconnect
407
- logger.info(
408
- "Gateway requested RECONNECT. Closing and will attempt to reconnect."
409
- )
410
- await self.close(code=4000, reconnect=True)
411
- elif op == GatewayOpcode.INVALID_SESSION:
412
- # The 'd' payload for INVALID_SESSION is a boolean indicating resumability
413
- can_resume = data.get("d") is True
414
- logger.warning(
415
- "Gateway indicated INVALID_SESSION. Resumable: %s", can_resume
416
- )
417
- if not can_resume:
418
- self._session_id = None # Clear session_id to force re-identify
419
- self._last_sequence = None
420
- # Close and reconnect. The connect logic will decide to resume or identify.
421
- await self.close(code=4000 if can_resume else 4009, reconnect=True)
422
- elif op == GatewayOpcode.HELLO:
423
- hello_d_payload = data.get("d")
424
- if (
425
- not isinstance(hello_d_payload, dict)
426
- or "heartbeat_interval" not in hello_d_payload
427
- ):
428
- logger.error(
429
- "HELLO event 'd' payload is invalid or missing heartbeat_interval: %s",
430
- hello_d_payload,
431
- )
432
- await self.close(code=1011) # Internal error, malformed HELLO
433
- return
434
- self._heartbeat_interval = hello_d_payload["heartbeat_interval"]
435
- logger.info(
436
- "Gateway HELLO. Heartbeat interval: %sms.", self._heartbeat_interval
437
- )
438
- # Start heartbeating
439
- if self._keep_alive_task:
440
- self._keep_alive_task.cancel()
441
- self._keep_alive_task = self._loop.create_task(self._keep_alive())
442
-
443
- # Identify or Resume
444
- if self._session_id and self._resume_gateway_url: # Check if we can resume
445
- logger.info("Attempting to RESUME session.")
446
- await self._resume()
447
- else:
448
- logger.info("Performing initial IDENTIFY.")
449
- await self._identify()
450
- elif op == GatewayOpcode.HEARTBEAT_ACK:
451
- self._last_heartbeat_ack = time.monotonic()
452
- # print("Received heartbeat ACK.")
453
- pass # Good, connection is alive
454
- else:
455
- logger.warning(
456
- "Received unhandled Gateway Opcode: %s with data: %s", op, data
457
- )
458
-
459
- async def _receive_loop(self):
460
- """Continuously receives and processes messages from the WebSocket."""
461
- if not self._ws or self._ws.closed:
462
- logger.warning(
463
- "Receive loop cannot start: WebSocket is not connected or closed."
464
- )
465
- return
466
-
467
- try:
468
- async for msg in self._ws:
469
- await self._process_message(msg)
470
- except asyncio.CancelledError:
471
- logger.debug("Receive_loop task cancelled.")
472
- except aiohttp.ClientConnectionError as e:
473
- logger.warning(
474
- "ClientConnectionError in receive_loop: %s. Attempting reconnect.", e
475
- )
476
- await self.close(code=1006, reconnect=True) # Abnormal closure
477
- except Exception as e:
478
- logger.error("Unexpected error in receive_loop: %s", e)
479
- traceback.print_exc()
480
- await self.close(code=1011, reconnect=True)
481
- finally:
482
- logger.info("Receive_loop ended.")
483
- # If the loop ends unexpectedly (not due to explicit close),
484
- # the main client might want to try reconnecting.
485
-
486
- async def connect(self):
487
- """Connects to the Discord Gateway."""
488
- if self._ws and not self._ws.closed:
489
- logger.warning("Gateway already connected or connecting.")
490
- return
491
-
492
- gateway_url = (
493
- self._resume_gateway_url or (await self._http.get_gateway_bot())["url"]
494
- )
495
- if not gateway_url.endswith("?v=10&encoding=json&compress=zlib-stream"):
496
- gateway_url += "?v=10&encoding=json&compress=zlib-stream"
497
-
498
- logger.info("Connecting to Gateway: %s", gateway_url)
499
- try:
500
- await self._http._ensure_session() # Ensure the HTTP client's session is active
501
- assert (
502
- self._http._session is not None
503
- ), "HTTPClient session not initialized after ensure_session"
504
- self._ws = await self._http._session.ws_connect(gateway_url, max_msg_size=0)
505
- logger.info("Gateway WebSocket connection established.")
506
-
507
- if self._receive_task:
508
- self._receive_task.cancel()
509
- self._receive_task = self._loop.create_task(self._receive_loop())
510
-
511
- except aiohttp.ClientConnectorError as e:
512
- raise GatewayException(
513
- f"Failed to connect to Gateway (Connector Error): {e}"
514
- ) from e
515
- except aiohttp.WSServerHandshakeError as e:
516
- if e.status == 401: # Unauthorized during handshake
517
- raise AuthenticationError(
518
- f"Gateway handshake failed (401 Unauthorized): {e.message}. Check your bot token."
519
- ) from e
520
- raise GatewayException(
521
- f"Gateway handshake failed (Status: {e.status}): {e.message}"
522
- ) from e
523
- except Exception as e: # Catch other potential errors during connection
524
- raise GatewayException(
525
- f"An unexpected error occurred during Gateway connection: {e}"
526
- ) from e
527
-
528
- async def close(self, code: int = 1000, *, reconnect: bool = False):
529
- """Closes the Gateway connection."""
530
- logger.info("Closing Gateway connection with code %s...", code)
531
- if self._keep_alive_task and not self._keep_alive_task.done():
532
- self._keep_alive_task.cancel()
533
- try:
534
- await self._keep_alive_task
535
- except asyncio.CancelledError:
536
- pass # Expected
537
-
538
- if self._receive_task and not self._receive_task.done():
539
- current = asyncio.current_task(loop=self._loop)
540
- self._receive_task.cancel()
541
- if self._receive_task is not current:
542
- try:
543
- await self._receive_task
544
- except asyncio.CancelledError:
545
- pass # Expected
546
-
547
- if self._ws and not self._ws.closed:
548
- await self._ws.close(code=code)
549
- logger.info("Gateway WebSocket closed.")
550
-
551
- self._ws = None
552
- # Do not reset session_id, last_sequence, or resume_gateway_url here
553
- # if the close code indicates a resumable disconnect (e.g. 4000-4009, or server-initiated RECONNECT)
554
- # The connect logic will decide whether to resume or re-identify.
555
- # However, if it's a non-resumable close (e.g. Invalid Session non-resumable), clear them.
556
- if code == 4009: # Invalid session, not resumable
557
- logger.info("Clearing session state due to non-resumable invalid session.")
558
- self._session_id = None
559
- self._last_sequence = None
560
- self._resume_gateway_url = None # This might be re-fetched anyway
561
-
562
- @property
563
- def latency(self) -> Optional[float]:
564
- """Returns the latency between heartbeat and ACK in seconds."""
565
- if self._last_heartbeat_sent is None or self._last_heartbeat_ack is None:
566
- return None
567
- return self._last_heartbeat_ack - self._last_heartbeat_sent
568
-
569
- @property
570
- def last_heartbeat_sent(self) -> Optional[float]:
571
- return self._last_heartbeat_sent
572
-
573
- @property
574
- def last_heartbeat_ack(self) -> Optional[float]:
575
- return self._last_heartbeat_ack
1
+ # disagreement/gateway.py
2
+
3
+ """
4
+ Manages the WebSocket connection to the Discord Gateway.
5
+ """
6
+
7
+ import asyncio
8
+ import logging
9
+ import traceback
10
+ import aiohttp
11
+ import json
12
+ import zlib
13
+ import time
14
+ import random
15
+ from typing import Optional, TYPE_CHECKING, Any, Dict
16
+
17
+ from .enums import GatewayOpcode, GatewayIntent
18
+ from .errors import GatewayException, DisagreementException, AuthenticationError
19
+ from .interactions import Interaction
20
+
21
+ if TYPE_CHECKING:
22
+ from .client import Client # For type hinting
23
+ from .event_dispatcher import EventDispatcher
24
+ from .http import HTTPClient
25
+ from .interactions import Interaction # Added for INTERACTION_CREATE
26
+
27
+ # ZLIB Decompression constants
28
+ ZLIB_SUFFIX = b"\x00\x00\xff\xff"
29
+ MAX_DECOMPRESSION_SIZE = 10 * 1024 * 1024 # 10 MiB, adjust as needed
30
+
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ class GatewayClient:
36
+ """
37
+ Handles the Discord Gateway WebSocket connection, heartbeating, and event dispatching.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ http_client: "HTTPClient",
43
+ event_dispatcher: "EventDispatcher",
44
+ token: str,
45
+ intents: int,
46
+ client_instance: "Client", # Pass the main client instance
47
+ verbose: bool = False,
48
+ *,
49
+ shard_id: Optional[int] = None,
50
+ shard_count: Optional[int] = None,
51
+ max_retries: int = 5,
52
+ max_backoff: float = 60.0,
53
+ ):
54
+ self._http: "HTTPClient" = http_client
55
+ self._dispatcher: "EventDispatcher" = event_dispatcher
56
+ self._token: str = token
57
+ self._intents: int = intents
58
+ self._client_instance: "Client" = client_instance # Store client instance
59
+ self.verbose: bool = verbose
60
+ self._shard_id: Optional[int] = shard_id
61
+ self._shard_count: Optional[int] = shard_count
62
+ self._max_retries: int = max_retries
63
+ self._max_backoff: float = max_backoff
64
+
65
+ self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
66
+ self._loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
67
+ self._heartbeat_interval: Optional[float] = None
68
+ self._last_sequence: Optional[int] = None
69
+ self._session_id: Optional[str] = None
70
+ self._resume_gateway_url: Optional[str] = None
71
+
72
+ self._keep_alive_task: Optional[asyncio.Task] = None
73
+ self._receive_task: Optional[asyncio.Task] = None
74
+
75
+ self._last_heartbeat_sent: Optional[float] = None
76
+ self._last_heartbeat_ack: Optional[float] = None
77
+
78
+ # For zlib decompression
79
+ self._buffer = bytearray()
80
+ self._inflator = zlib.decompressobj()
81
+
82
+ self._member_chunk_requests: Dict[str, asyncio.Future] = {}
83
+
84
+ async def _reconnect(self) -> None:
85
+ """Attempts to reconnect using exponential backoff with jitter."""
86
+ delay = 1.0
87
+ for attempt in range(self._max_retries):
88
+ try:
89
+ await self.connect()
90
+ return
91
+ except Exception as e: # noqa: BLE001
92
+ if attempt >= self._max_retries - 1:
93
+ logger.error(
94
+ "Reconnect failed after %s attempts: %s", attempt + 1, e
95
+ )
96
+ raise
97
+ jitter = random.uniform(0, delay)
98
+ wait_time = min(delay + jitter, self._max_backoff)
99
+ logger.warning(
100
+ "Reconnect attempt %s failed: %s. Retrying in %.2f seconds...",
101
+ attempt + 1,
102
+ e,
103
+ wait_time,
104
+ )
105
+ await asyncio.sleep(wait_time)
106
+ delay = min(delay * 2, self._max_backoff)
107
+
108
+ async def _decompress_message(
109
+ self, message_bytes: bytes
110
+ ) -> Optional[Dict[str, Any]]:
111
+ """Decompresses a zlib-compressed message from the Gateway."""
112
+ self._buffer.extend(message_bytes)
113
+
114
+ if len(message_bytes) < 4 or message_bytes[-4:] != ZLIB_SUFFIX:
115
+ # Message is not complete or not zlib compressed in the expected way
116
+ return None
117
+ # Or handle partial messages if Discord ever sends them fragmented like this,
118
+ # but typically each binary message is a complete zlib stream.
119
+
120
+ try:
121
+ decompressed = self._inflator.decompress(self._buffer)
122
+ self._buffer.clear() # Reset buffer after successful decompression
123
+ return json.loads(decompressed.decode("utf-8"))
124
+ except zlib.error as e:
125
+ logger.error("Zlib decompression error: %s", e)
126
+ self._buffer.clear() # Clear buffer on error
127
+ self._inflator = zlib.decompressobj() # Reset inflator
128
+ return None
129
+ except json.JSONDecodeError as e:
130
+ logger.error("JSON decode error after decompression: %s", e)
131
+ return None
132
+
133
+ async def _send_json(self, payload: Dict[str, Any]):
134
+ if self._ws and not self._ws.closed:
135
+ if self.verbose:
136
+ logger.debug("GATEWAY SEND: %s", payload)
137
+ await self._ws.send_json(payload)
138
+ else:
139
+ logger.warning(
140
+ "Gateway send attempted but WebSocket is closed or not available."
141
+ )
142
+ # raise GatewayException("WebSocket is not connected.")
143
+
144
+ async def _heartbeat(self):
145
+ """Sends a heartbeat to the Gateway."""
146
+ self._last_heartbeat_sent = time.monotonic()
147
+ payload = {"op": GatewayOpcode.HEARTBEAT, "d": self._last_sequence}
148
+ await self._send_json(payload)
149
+ # print("Sent heartbeat.")
150
+
151
+ async def _keep_alive(self):
152
+ """Manages the heartbeating loop."""
153
+ if self._heartbeat_interval is None:
154
+ # This should not happen if HELLO was processed correctly
155
+ logger.error("Heartbeat interval not set. Cannot start keep_alive.")
156
+ return
157
+
158
+ try:
159
+ while True:
160
+ await self._heartbeat()
161
+ await asyncio.sleep(
162
+ self._heartbeat_interval / 1000
163
+ ) # Interval is in ms
164
+ except asyncio.CancelledError:
165
+ logger.debug("Keep_alive task cancelled.")
166
+ except Exception as e:
167
+ logger.error("Error in keep_alive loop: %s", e)
168
+ # Potentially trigger a reconnect here or notify client
169
+ await self._client_instance.close_gateway(code=1000) # Generic close
170
+
171
+ async def _identify(self):
172
+ """Sends the IDENTIFY payload to the Gateway."""
173
+ payload = {
174
+ "op": GatewayOpcode.IDENTIFY,
175
+ "d": {
176
+ "token": self._token,
177
+ "intents": self._intents,
178
+ "properties": {
179
+ "$os": "python", # Or platform.system()
180
+ "$browser": "disagreement", # Library name
181
+ "$device": "disagreement", # Library name
182
+ },
183
+ "compress": True, # Request zlib compression
184
+ },
185
+ }
186
+ if self._shard_id is not None and self._shard_count is not None:
187
+ payload["d"]["shard"] = [self._shard_id, self._shard_count]
188
+ await self._send_json(payload)
189
+ logger.info("Sent IDENTIFY.")
190
+
191
+ async def _resume(self):
192
+ """Sends the RESUME payload to the Gateway."""
193
+ if not self._session_id or self._last_sequence is None:
194
+ logger.warning("Cannot RESUME: session_id or last_sequence is missing.")
195
+ await self._identify() # Fallback to identify
196
+ return
197
+
198
+ payload = {
199
+ "op": GatewayOpcode.RESUME,
200
+ "d": {
201
+ "token": self._token,
202
+ "session_id": self._session_id,
203
+ "seq": self._last_sequence,
204
+ },
205
+ }
206
+ await self._send_json(payload)
207
+ logger.info(
208
+ "Sent RESUME for session %s at sequence %s.",
209
+ self._session_id,
210
+ self._last_sequence,
211
+ )
212
+
213
+ async def update_presence(
214
+ self,
215
+ status: str,
216
+ activity_name: Optional[str] = None,
217
+ activity_type: int = 0,
218
+ since: int = 0,
219
+ afk: bool = False,
220
+ ):
221
+ """Sends the presence update payload to the Gateway."""
222
+ payload = {
223
+ "op": GatewayOpcode.PRESENCE_UPDATE,
224
+ "d": {
225
+ "since": since,
226
+ "activities": (
227
+ [
228
+ {
229
+ "name": activity_name,
230
+ "type": activity_type,
231
+ }
232
+ ]
233
+ if activity_name
234
+ else []
235
+ ),
236
+ "status": status,
237
+ "afk": afk,
238
+ },
239
+ }
240
+ await self._send_json(payload)
241
+
242
+ async def request_guild_members(
243
+ self,
244
+ guild_id: str,
245
+ query: str = "",
246
+ limit: int = 0,
247
+ presences: bool = False,
248
+ user_ids: Optional[list[str]] = None,
249
+ nonce: Optional[str] = None,
250
+ ):
251
+ """Sends the request guild members payload to the Gateway."""
252
+ payload = {
253
+ "op": GatewayOpcode.REQUEST_GUILD_MEMBERS,
254
+ "d": {
255
+ "guild_id": guild_id,
256
+ "query": query,
257
+ "limit": limit,
258
+ "presences": presences,
259
+ },
260
+ }
261
+ if user_ids:
262
+ payload["d"]["user_ids"] = user_ids
263
+ if nonce:
264
+ payload["d"]["nonce"] = nonce
265
+
266
+ await self._send_json(payload)
267
+
268
+ async def _handle_dispatch(self, data: Dict[str, Any]):
269
+ """Handles DISPATCH events (actual Discord events)."""
270
+ event_name = data.get("t")
271
+ sequence_num = data.get("s")
272
+ raw_event_d_payload = data.get(
273
+ "d"
274
+ ) # This is the 'd' field from the gateway event
275
+
276
+ if sequence_num is not None:
277
+ self._last_sequence = sequence_num
278
+
279
+ if event_name == "READY": # Special handling for READY
280
+ if not isinstance(raw_event_d_payload, dict):
281
+ logger.error(
282
+ "READY event 'd' payload is not a dict or is missing: %s",
283
+ raw_event_d_payload,
284
+ )
285
+ # Consider raising an error or attempting a reconnect
286
+ return
287
+ self._session_id = raw_event_d_payload.get("session_id")
288
+ self._resume_gateway_url = raw_event_d_payload.get("resume_gateway_url")
289
+
290
+ app_id_str = "N/A"
291
+ # Store application_id on the client instance
292
+ if (
293
+ "application" in raw_event_d_payload
294
+ and isinstance(raw_event_d_payload["application"], dict)
295
+ and "id" in raw_event_d_payload["application"]
296
+ ):
297
+ app_id_value = raw_event_d_payload["application"]["id"]
298
+ self._client_instance.application_id = (
299
+ app_id_value # Snowflake can be str or int
300
+ )
301
+ app_id_str = str(app_id_value)
302
+ else:
303
+ logger.warning(
304
+ "Could not find application ID in READY payload. App commands may not work."
305
+ )
306
+
307
+ # Parse and store the bot's own user object
308
+ if "user" in raw_event_d_payload and isinstance(
309
+ raw_event_d_payload["user"], dict
310
+ ):
311
+ try:
312
+ # Assuming Client has a parse_user method that takes user data dict
313
+ # and returns a User object, also caching it.
314
+ bot_user_obj = self._client_instance.parse_user(
315
+ raw_event_d_payload["user"]
316
+ )
317
+ self._client_instance.user = bot_user_obj
318
+ logger.info(
319
+ "Gateway READY. Bot User: %s#%s. Session ID: %s. App ID: %s. Resume URL: %s",
320
+ bot_user_obj.username,
321
+ bot_user_obj.discriminator,
322
+ self._session_id,
323
+ app_id_str,
324
+ self._resume_gateway_url,
325
+ )
326
+ except Exception as e:
327
+ logger.error("Error parsing bot user from READY payload: %s", e)
328
+ logger.info(
329
+ "Gateway READY (user parse failed). Session ID: %s. App ID: %s. Resume URL: %s",
330
+ self._session_id,
331
+ app_id_str,
332
+ self._resume_gateway_url,
333
+ )
334
+ else:
335
+ logger.warning("Bot user object not found or invalid in READY payload.")
336
+ logger.info(
337
+ "Gateway READY (no user). Session ID: %s. App ID: %s. Resume URL: %s",
338
+ self._session_id,
339
+ app_id_str,
340
+ self._resume_gateway_url,
341
+ )
342
+
343
+ await self._dispatcher.dispatch(event_name, raw_event_d_payload)
344
+ elif event_name == "GUILD_MEMBERS_CHUNK":
345
+ if isinstance(raw_event_d_payload, dict):
346
+ nonce = raw_event_d_payload.get("nonce")
347
+ if nonce and nonce in self._member_chunk_requests:
348
+ future = self._member_chunk_requests[nonce]
349
+ if not future.done():
350
+ # Append members to a temporary list stored on the future object
351
+ if not hasattr(future, "_members"):
352
+ future._members = [] # type: ignore
353
+ future._members.extend(raw_event_d_payload.get("members", [])) # type: ignore
354
+
355
+ # If this is the last chunk, resolve the future
356
+ if raw_event_d_payload.get("chunk_index") == raw_event_d_payload.get("chunk_count", 1) - 1:
357
+ future.set_result(future._members) # type: ignore
358
+ del self._member_chunk_requests[nonce]
359
+
360
+ elif event_name == "INTERACTION_CREATE":
361
+ # print(f"GATEWAY RECV INTERACTION_CREATE: {raw_event_d_payload}")
362
+ if isinstance(raw_event_d_payload, dict):
363
+ interaction = Interaction(
364
+ data=raw_event_d_payload, client_instance=self._client_instance
365
+ )
366
+ await self._dispatcher.dispatch(
367
+ "INTERACTION_CREATE", raw_event_d_payload
368
+ )
369
+ # Dispatch to a new client method that will then call AppCommandHandler
370
+ if hasattr(self._client_instance, "process_interaction"):
371
+ asyncio.create_task(
372
+ self._client_instance.process_interaction(interaction)
373
+ ) # type: ignore
374
+ else:
375
+ logger.warning(
376
+ "Client instance does not have process_interaction method for INTERACTION_CREATE."
377
+ )
378
+ else:
379
+ logger.error(
380
+ "INTERACTION_CREATE event 'd' payload is not a dict: %s",
381
+ raw_event_d_payload,
382
+ )
383
+ elif event_name == "RESUMED":
384
+ logger.info("Gateway RESUMED successfully.")
385
+ # RESUMED 'd' payload is often an empty object or debug info.
386
+ # Ensure it's a dict for the dispatcher.
387
+ event_data_to_dispatch = (
388
+ raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {}
389
+ )
390
+ await self._dispatcher.dispatch(event_name, event_data_to_dispatch)
391
+ await self._dispatcher.dispatch(
392
+ "SHARD_RESUME", {"shard_id": self._shard_id}
393
+ )
394
+ elif event_name:
395
+ # For other events, ensure 'd' is a dict, or pass {} if 'd' is null/missing.
396
+ # Models/parsers in EventDispatcher will need to handle potentially empty dicts.
397
+ event_data_to_dispatch = (
398
+ raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {}
399
+ )
400
+ # print(f"GATEWAY RECV EVENT: {event_name} | DATA: {event_data_to_dispatch}")
401
+ await self._dispatcher.dispatch(event_name, event_data_to_dispatch)
402
+ else:
403
+ logger.warning("Received dispatch with no event name: %s", data)
404
+
405
+ async def _process_message(self, msg: aiohttp.WSMessage):
406
+ """Processes a single message from the WebSocket."""
407
+ if msg.type == aiohttp.WSMsgType.TEXT:
408
+ try:
409
+ data = json.loads(msg.data)
410
+ except json.JSONDecodeError:
411
+ logger.error("Failed to decode JSON from Gateway: %s", msg.data[:200])
412
+ return
413
+ elif msg.type == aiohttp.WSMsgType.BINARY:
414
+ decompressed_data = await self._decompress_message(msg.data)
415
+ if decompressed_data is None:
416
+ logger.error(
417
+ "Failed to decompress or decode binary message from Gateway."
418
+ )
419
+ return
420
+ data = decompressed_data
421
+ elif msg.type == aiohttp.WSMsgType.ERROR:
422
+ logger.error(
423
+ "WebSocket error: %s",
424
+ self._ws.exception() if self._ws else "Unknown WSError",
425
+ )
426
+ raise GatewayException(
427
+ f"WebSocket error: {self._ws.exception() if self._ws else 'Unknown WSError'}"
428
+ )
429
+ elif msg.type == aiohttp.WSMsgType.CLOSED:
430
+ close_code = (
431
+ self._ws.close_code
432
+ if self._ws and hasattr(self._ws, "close_code")
433
+ else "N/A"
434
+ )
435
+ logger.warning(
436
+ "WebSocket connection closed by server. Code: %s", close_code
437
+ )
438
+ # Raise an exception to signal the closure to the client's main run loop
439
+ raise GatewayException(f"WebSocket closed by server. Code: {close_code}")
440
+ else:
441
+ logger.warning("Received unhandled WebSocket message type: %s", msg.type)
442
+ return
443
+
444
+ if self.verbose:
445
+ logger.debug("GATEWAY RECV: %s", data)
446
+ op = data.get("op")
447
+ # 'd' payload (event_data) is handled specifically by each opcode handler below
448
+
449
+ if op == GatewayOpcode.DISPATCH:
450
+ await self._handle_dispatch(data) # _handle_dispatch will extract 'd'
451
+ elif op == GatewayOpcode.HEARTBEAT: # Server requests a heartbeat
452
+ await self._heartbeat()
453
+ elif op == GatewayOpcode.RECONNECT: # Server requests a reconnect
454
+ logger.info(
455
+ "Gateway requested RECONNECT. Closing and will attempt to reconnect."
456
+ )
457
+ await self.close(code=4000, reconnect=True)
458
+ elif op == GatewayOpcode.INVALID_SESSION:
459
+ # The 'd' payload for INVALID_SESSION is a boolean indicating resumability
460
+ can_resume = data.get("d") is True
461
+ logger.warning(
462
+ "Gateway indicated INVALID_SESSION. Resumable: %s", can_resume
463
+ )
464
+ if not can_resume:
465
+ self._session_id = None # Clear session_id to force re-identify
466
+ self._last_sequence = None
467
+ # Close and reconnect. The connect logic will decide to resume or identify.
468
+ await self.close(code=4000 if can_resume else 4009, reconnect=True)
469
+ elif op == GatewayOpcode.HELLO:
470
+ hello_d_payload = data.get("d")
471
+ if (
472
+ not isinstance(hello_d_payload, dict)
473
+ or "heartbeat_interval" not in hello_d_payload
474
+ ):
475
+ logger.error(
476
+ "HELLO event 'd' payload is invalid or missing heartbeat_interval: %s",
477
+ hello_d_payload,
478
+ )
479
+ await self.close(code=1011) # Internal error, malformed HELLO
480
+ return
481
+ self._heartbeat_interval = hello_d_payload["heartbeat_interval"]
482
+ logger.info(
483
+ "Gateway HELLO. Heartbeat interval: %sms.", self._heartbeat_interval
484
+ )
485
+ # Start heartbeating
486
+ if self._keep_alive_task:
487
+ self._keep_alive_task.cancel()
488
+ self._keep_alive_task = self._loop.create_task(self._keep_alive())
489
+
490
+ # Identify or Resume
491
+ if self._session_id and self._resume_gateway_url: # Check if we can resume
492
+ logger.info("Attempting to RESUME session.")
493
+ await self._resume()
494
+ else:
495
+ logger.info("Performing initial IDENTIFY.")
496
+ await self._identify()
497
+ elif op == GatewayOpcode.HEARTBEAT_ACK:
498
+ self._last_heartbeat_ack = time.monotonic()
499
+ # print("Received heartbeat ACK.")
500
+ pass # Good, connection is alive
501
+ else:
502
+ logger.warning(
503
+ "Received unhandled Gateway Opcode: %s with data: %s", op, data
504
+ )
505
+
506
+ async def _receive_loop(self):
507
+ """Continuously receives and processes messages from the WebSocket."""
508
+ if not self._ws or self._ws.closed:
509
+ logger.warning(
510
+ "Receive loop cannot start: WebSocket is not connected or closed."
511
+ )
512
+ return
513
+
514
+ try:
515
+ async for msg in self._ws:
516
+ await self._process_message(msg)
517
+ except asyncio.CancelledError:
518
+ logger.debug("Receive_loop task cancelled.")
519
+ except aiohttp.ClientConnectionError as e:
520
+ logger.warning(
521
+ "ClientConnectionError in receive_loop: %s. Attempting reconnect.", e
522
+ )
523
+ await self.close(code=1006, reconnect=True) # Abnormal closure
524
+ except Exception as e:
525
+ logger.error("Unexpected error in receive_loop: %s", e)
526
+ traceback.print_exc()
527
+ await self.close(code=1011, reconnect=True)
528
+ finally:
529
+ logger.info("Receive_loop ended.")
530
+ # If the loop ends unexpectedly (not due to explicit close),
531
+ # the main client might want to try reconnecting.
532
+
533
+ async def connect(self):
534
+ """Connects to the Discord Gateway."""
535
+ if self._ws and not self._ws.closed:
536
+ logger.warning("Gateway already connected or connecting.")
537
+ return
538
+
539
+ gateway_url = (
540
+ self._resume_gateway_url or (await self._http.get_gateway_bot())["url"]
541
+ )
542
+ if not gateway_url.endswith("?v=10&encoding=json&compress=zlib-stream"):
543
+ gateway_url += "?v=10&encoding=json&compress=zlib-stream"
544
+
545
+ logger.info("Connecting to Gateway: %s", gateway_url)
546
+ try:
547
+ await self._http._ensure_session() # Ensure the HTTP client's session is active
548
+ assert (
549
+ self._http._session is not None
550
+ ), "HTTPClient session not initialized after ensure_session"
551
+ self._ws = await self._http._session.ws_connect(gateway_url, max_msg_size=0)
552
+ logger.info("Gateway WebSocket connection established.")
553
+
554
+ if self._receive_task:
555
+ self._receive_task.cancel()
556
+ self._receive_task = self._loop.create_task(self._receive_loop())
557
+
558
+ await self._dispatcher.dispatch(
559
+ "SHARD_CONNECT", {"shard_id": self._shard_id}
560
+ )
561
+
562
+ except aiohttp.ClientConnectorError as e:
563
+ raise GatewayException(
564
+ f"Failed to connect to Gateway (Connector Error): {e}"
565
+ ) from e
566
+ except aiohttp.WSServerHandshakeError as e:
567
+ if e.status == 401: # Unauthorized during handshake
568
+ raise AuthenticationError(
569
+ f"Gateway handshake failed (401 Unauthorized): {e.message}. Check your bot token."
570
+ ) from e
571
+ raise GatewayException(
572
+ f"Gateway handshake failed (Status: {e.status}): {e.message}"
573
+ ) from e
574
+ except Exception as e: # Catch other potential errors during connection
575
+ raise GatewayException(
576
+ f"An unexpected error occurred during Gateway connection: {e}"
577
+ ) from e
578
+
579
+ async def close(self, code: int = 1000, *, reconnect: bool = False):
580
+ """Closes the Gateway connection."""
581
+ logger.info("Closing Gateway connection with code %s...", code)
582
+ if self._keep_alive_task and not self._keep_alive_task.done():
583
+ self._keep_alive_task.cancel()
584
+ try:
585
+ await self._keep_alive_task
586
+ except asyncio.CancelledError:
587
+ pass # Expected
588
+
589
+ if self._receive_task and not self._receive_task.done():
590
+ current = asyncio.current_task(loop=self._loop)
591
+ self._receive_task.cancel()
592
+ if self._receive_task is not current:
593
+ try:
594
+ await self._receive_task
595
+ except asyncio.CancelledError:
596
+ pass # Expected
597
+
598
+ if self._ws and not self._ws.closed:
599
+ await self._ws.close(code=code)
600
+ logger.info("Gateway WebSocket closed.")
601
+
602
+ self._ws = None
603
+ # Do not reset session_id, last_sequence, or resume_gateway_url here
604
+ # if the close code indicates a resumable disconnect (e.g. 4000-4009, or server-initiated RECONNECT)
605
+ # The connect logic will decide whether to resume or re-identify.
606
+ # However, if it's a non-resumable close (e.g. Invalid Session non-resumable), clear them.
607
+ if code == 4009: # Invalid session, not resumable
608
+ logger.info("Clearing session state due to non-resumable invalid session.")
609
+ self._session_id = None
610
+ self._last_sequence = None
611
+ self._resume_gateway_url = None # This might be re-fetched anyway
612
+
613
+ await self._dispatcher.dispatch(
614
+ "SHARD_DISCONNECT", {"shard_id": self._shard_id}
615
+ )
616
+
617
+ @property
618
+ def latency(self) -> Optional[float]:
619
+ """Returns the latency between heartbeat and ACK in seconds."""
620
+ if self._last_heartbeat_sent is None or self._last_heartbeat_ack is None:
621
+ return None
622
+ return self._last_heartbeat_ack - self._last_heartbeat_sent
623
+
624
+ @property
625
+ def last_heartbeat_sent(self) -> Optional[float]:
626
+ return self._last_heartbeat_sent
627
+
628
+ @property
629
+ def last_heartbeat_ack(self) -> Optional[float]:
630
+ return self._last_heartbeat_ack