disagreement 0.0.1__py3-none-any.whl → 0.1.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- disagreement/__init__.py +8 -3
- disagreement/audio.py +116 -0
- disagreement/client.py +176 -6
- disagreement/color.py +50 -0
- disagreement/components.py +2 -2
- disagreement/errors.py +13 -8
- disagreement/event_dispatcher.py +102 -45
- disagreement/ext/__init__.py +0 -0
- disagreement/ext/app_commands/__init__.py +46 -0
- disagreement/ext/app_commands/commands.py +513 -0
- disagreement/ext/app_commands/context.py +556 -0
- disagreement/ext/app_commands/converters.py +478 -0
- disagreement/ext/app_commands/decorators.py +569 -0
- disagreement/ext/app_commands/handler.py +627 -0
- disagreement/ext/commands/__init__.py +57 -0
- disagreement/ext/commands/cog.py +155 -0
- disagreement/ext/commands/converters.py +175 -0
- disagreement/ext/commands/core.py +497 -0
- disagreement/ext/commands/decorators.py +192 -0
- disagreement/ext/commands/errors.py +76 -0
- disagreement/ext/commands/help.py +37 -0
- disagreement/ext/commands/view.py +103 -0
- disagreement/ext/loader.py +54 -0
- disagreement/ext/tasks.py +182 -0
- disagreement/gateway.py +67 -21
- disagreement/http.py +104 -3
- disagreement/models.py +308 -1
- disagreement/shard_manager.py +2 -0
- disagreement/utils.py +10 -0
- disagreement/voice_client.py +42 -0
- {disagreement-0.0.1.dist-info → disagreement-0.1.0rc1.dist-info}/METADATA +47 -33
- disagreement-0.1.0rc1.dist-info/RECORD +52 -0
- disagreement-0.0.1.dist-info/RECORD +0 -32
- {disagreement-0.0.1.dist-info → disagreement-0.1.0rc1.dist-info}/WHEEL +0 -0
- {disagreement-0.0.1.dist-info → disagreement-0.1.0rc1.dist-info}/licenses/LICENSE +0 -0
- {disagreement-0.0.1.dist-info → disagreement-0.1.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,76 @@
|
|
1
|
+
# disagreement/ext/commands/errors.py
|
2
|
+
|
3
|
+
"""
|
4
|
+
Custom exceptions for the command extension.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from disagreement.errors import DisagreementException
|
8
|
+
|
9
|
+
|
10
|
+
class CommandError(DisagreementException):
|
11
|
+
"""Base exception for errors raised by the commands extension."""
|
12
|
+
|
13
|
+
pass
|
14
|
+
|
15
|
+
|
16
|
+
class CommandNotFound(CommandError):
|
17
|
+
"""Exception raised when a command is not found."""
|
18
|
+
|
19
|
+
def __init__(self, command_name: str):
|
20
|
+
self.command_name = command_name
|
21
|
+
super().__init__(f"Command '{command_name}' not found.")
|
22
|
+
|
23
|
+
|
24
|
+
class BadArgument(CommandError):
|
25
|
+
"""Exception raised when a command argument fails to parse or validate."""
|
26
|
+
|
27
|
+
pass
|
28
|
+
|
29
|
+
|
30
|
+
class MissingRequiredArgument(BadArgument):
|
31
|
+
"""Exception raised when a required command argument is missing."""
|
32
|
+
|
33
|
+
def __init__(self, param_name: str):
|
34
|
+
self.param_name = param_name
|
35
|
+
super().__init__(f"Missing required argument: {param_name}")
|
36
|
+
|
37
|
+
|
38
|
+
class ArgumentParsingError(BadArgument):
|
39
|
+
"""Exception raised during the argument parsing process."""
|
40
|
+
|
41
|
+
pass
|
42
|
+
|
43
|
+
|
44
|
+
class CheckFailure(CommandError):
|
45
|
+
"""Exception raised when a command check fails."""
|
46
|
+
|
47
|
+
pass
|
48
|
+
|
49
|
+
|
50
|
+
class CheckAnyFailure(CheckFailure):
|
51
|
+
"""Raised when :func:`check_any` fails all checks."""
|
52
|
+
|
53
|
+
def __init__(self, errors: list[CheckFailure]):
|
54
|
+
self.errors = errors
|
55
|
+
msg = "; ".join(str(e) for e in errors)
|
56
|
+
super().__init__(f"All checks failed: {msg}")
|
57
|
+
|
58
|
+
|
59
|
+
class CommandOnCooldown(CheckFailure):
|
60
|
+
"""Raised when a command is invoked while on cooldown."""
|
61
|
+
|
62
|
+
def __init__(self, retry_after: float):
|
63
|
+
self.retry_after = retry_after
|
64
|
+
super().__init__(f"Command is on cooldown. Retry in {retry_after:.2f}s")
|
65
|
+
|
66
|
+
|
67
|
+
class CommandInvokeError(CommandError):
|
68
|
+
"""Exception raised when an error occurs during command invocation."""
|
69
|
+
|
70
|
+
def __init__(self, original: Exception):
|
71
|
+
self.original = original
|
72
|
+
super().__init__(f"Error during command invocation: {original}")
|
73
|
+
|
74
|
+
|
75
|
+
# Add more specific errors as needed, e.g., UserNotFound, ChannelNotFound, etc.
|
76
|
+
# These might inherit from BadArgument.
|
@@ -0,0 +1,37 @@
|
|
1
|
+
# disagreement/ext/commands/help.py
|
2
|
+
|
3
|
+
from typing import List, Optional
|
4
|
+
|
5
|
+
from .core import Command, CommandContext, CommandHandler
|
6
|
+
|
7
|
+
|
8
|
+
class HelpCommand(Command):
|
9
|
+
"""Built-in command that displays help information for other commands."""
|
10
|
+
|
11
|
+
def __init__(self, handler: CommandHandler) -> None:
|
12
|
+
self.handler = handler
|
13
|
+
|
14
|
+
async def callback(ctx: CommandContext, command: Optional[str] = None) -> None:
|
15
|
+
if command:
|
16
|
+
cmd = handler.get_command(command)
|
17
|
+
if not cmd or cmd.name.lower() != command.lower():
|
18
|
+
await ctx.send(f"Command '{command}' not found.")
|
19
|
+
return
|
20
|
+
description = cmd.description or cmd.brief or "No description provided."
|
21
|
+
await ctx.send(f"**{ctx.prefix}{cmd.name}**\n{description}")
|
22
|
+
else:
|
23
|
+
lines: List[str] = []
|
24
|
+
for registered in dict.fromkeys(handler.commands.values()):
|
25
|
+
brief = registered.brief or registered.description or ""
|
26
|
+
lines.append(f"{ctx.prefix}{registered.name} - {brief}".strip())
|
27
|
+
if lines:
|
28
|
+
await ctx.send("\n".join(lines))
|
29
|
+
else:
|
30
|
+
await ctx.send("No commands available.")
|
31
|
+
|
32
|
+
super().__init__(
|
33
|
+
callback,
|
34
|
+
name="help",
|
35
|
+
brief="Show command help.",
|
36
|
+
description="Displays help for commands.",
|
37
|
+
)
|
@@ -0,0 +1,103 @@
|
|
1
|
+
# disagreement/ext/commands/view.py
|
2
|
+
|
3
|
+
import re
|
4
|
+
|
5
|
+
|
6
|
+
class StringView:
|
7
|
+
"""
|
8
|
+
A utility class to help with parsing strings, particularly for command arguments.
|
9
|
+
It keeps track of the current position in the string and provides methods
|
10
|
+
to read parts of it.
|
11
|
+
"""
|
12
|
+
|
13
|
+
def __init__(self, buffer: str):
|
14
|
+
self.buffer: str = buffer
|
15
|
+
self.original: str = buffer # Keep original for error reporting if needed
|
16
|
+
self.index: int = 0
|
17
|
+
self.end: int = len(buffer)
|
18
|
+
self.previous: int = 0 # Index before the last successful read
|
19
|
+
|
20
|
+
@property
|
21
|
+
def remaining(self) -> str:
|
22
|
+
"""Returns the rest of the string that hasn't been consumed."""
|
23
|
+
return self.buffer[self.index :]
|
24
|
+
|
25
|
+
@property
|
26
|
+
def eof(self) -> bool:
|
27
|
+
"""Checks if the end of the string has been reached."""
|
28
|
+
return self.index >= self.end
|
29
|
+
|
30
|
+
def skip_whitespace(self) -> None:
|
31
|
+
"""Skips any leading whitespace from the current position."""
|
32
|
+
while not self.eof and self.buffer[self.index].isspace():
|
33
|
+
self.index += 1
|
34
|
+
|
35
|
+
def get_word(self) -> str:
|
36
|
+
"""
|
37
|
+
Reads a "word" from the current position.
|
38
|
+
A word is a sequence of non-whitespace characters.
|
39
|
+
"""
|
40
|
+
self.skip_whitespace()
|
41
|
+
if self.eof:
|
42
|
+
return ""
|
43
|
+
|
44
|
+
self.previous = self.index
|
45
|
+
match = re.match(r"\S+", self.buffer[self.index :])
|
46
|
+
if match:
|
47
|
+
word = match.group(0)
|
48
|
+
self.index += len(word)
|
49
|
+
return word
|
50
|
+
return "" # Should not happen if not eof and skip_whitespace was called
|
51
|
+
|
52
|
+
def get_quoted_string(self) -> str:
|
53
|
+
"""
|
54
|
+
Reads a string enclosed in double quotes.
|
55
|
+
Handles escaped quotes inside the string.
|
56
|
+
"""
|
57
|
+
self.skip_whitespace()
|
58
|
+
if self.eof or self.buffer[self.index] != '"':
|
59
|
+
return "" # Or raise an error, or return None
|
60
|
+
|
61
|
+
self.previous = self.index
|
62
|
+
self.index += 1 # Skip the opening quote
|
63
|
+
result = []
|
64
|
+
escaped = False
|
65
|
+
|
66
|
+
while not self.eof:
|
67
|
+
char = self.buffer[self.index]
|
68
|
+
self.index += 1
|
69
|
+
|
70
|
+
if escaped:
|
71
|
+
result.append(char)
|
72
|
+
escaped = False
|
73
|
+
elif char == "\\":
|
74
|
+
escaped = True
|
75
|
+
elif char == '"':
|
76
|
+
return "".join(result) # Closing quote found
|
77
|
+
else:
|
78
|
+
result.append(char)
|
79
|
+
|
80
|
+
# If loop finishes, means EOF was reached before closing quote
|
81
|
+
# This is an error condition. Restore index and indicate failure.
|
82
|
+
self.index = self.previous
|
83
|
+
# Consider raising an error like UnterminatedQuotedStringError
|
84
|
+
return "" # Or raise
|
85
|
+
|
86
|
+
def read_rest(self) -> str:
|
87
|
+
"""Reads all remaining characters from the current position."""
|
88
|
+
self.skip_whitespace()
|
89
|
+
if self.eof:
|
90
|
+
return ""
|
91
|
+
|
92
|
+
self.previous = self.index
|
93
|
+
result = self.buffer[self.index :]
|
94
|
+
self.index = self.end
|
95
|
+
return result
|
96
|
+
|
97
|
+
def undo(self) -> None:
|
98
|
+
"""Resets the current position to before the last successful read."""
|
99
|
+
self.index = self.previous
|
100
|
+
|
101
|
+
# Could add more methods like:
|
102
|
+
# peek() - look at next char without consuming
|
103
|
+
# match_regex(pattern) - consume if regex matches
|
@@ -0,0 +1,54 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from importlib import import_module
|
4
|
+
import sys
|
5
|
+
from types import ModuleType
|
6
|
+
from typing import Dict
|
7
|
+
|
8
|
+
__all__ = ["load_extension", "unload_extension", "reload_extension"]
|
9
|
+
|
10
|
+
_loaded_extensions: Dict[str, ModuleType] = {}
|
11
|
+
|
12
|
+
|
13
|
+
def load_extension(name: str) -> ModuleType:
|
14
|
+
"""Load an extension by name.
|
15
|
+
|
16
|
+
The extension module must define a ``setup`` coroutine or function that
|
17
|
+
will be called after loading. Any value returned by ``setup`` is ignored.
|
18
|
+
"""
|
19
|
+
|
20
|
+
if name in _loaded_extensions:
|
21
|
+
raise ValueError(f"Extension '{name}' already loaded")
|
22
|
+
|
23
|
+
module = import_module(name)
|
24
|
+
|
25
|
+
if not hasattr(module, "setup"):
|
26
|
+
raise ImportError(f"Extension '{name}' does not define a setup function")
|
27
|
+
|
28
|
+
module.setup()
|
29
|
+
_loaded_extensions[name] = module
|
30
|
+
return module
|
31
|
+
|
32
|
+
|
33
|
+
def unload_extension(name: str) -> None:
|
34
|
+
"""Unload a previously loaded extension."""
|
35
|
+
|
36
|
+
module = _loaded_extensions.pop(name, None)
|
37
|
+
if module is None:
|
38
|
+
raise ValueError(f"Extension '{name}' is not loaded")
|
39
|
+
|
40
|
+
if hasattr(module, "teardown"):
|
41
|
+
module.teardown()
|
42
|
+
|
43
|
+
sys.modules.pop(name, None)
|
44
|
+
|
45
|
+
|
46
|
+
def reload_extension(name: str) -> ModuleType:
|
47
|
+
"""Reload an extension by name.
|
48
|
+
|
49
|
+
This is a convenience wrapper around :func:`unload_extension` followed by
|
50
|
+
:func:`load_extension`.
|
51
|
+
"""
|
52
|
+
|
53
|
+
unload_extension(name)
|
54
|
+
return load_extension(name)
|
@@ -0,0 +1,182 @@
|
|
1
|
+
import asyncio
|
2
|
+
import datetime
|
3
|
+
from typing import Any, Awaitable, Callable, Optional
|
4
|
+
|
5
|
+
__all__ = ["loop", "Task"]
|
6
|
+
|
7
|
+
|
8
|
+
class Task:
|
9
|
+
"""Simple repeating task."""
|
10
|
+
|
11
|
+
def __init__(
|
12
|
+
self,
|
13
|
+
coro: Callable[..., Awaitable[Any]],
|
14
|
+
*,
|
15
|
+
seconds: float = 0.0,
|
16
|
+
minutes: float = 0.0,
|
17
|
+
hours: float = 0.0,
|
18
|
+
delta: Optional[datetime.timedelta] = None,
|
19
|
+
time_of_day: Optional[datetime.time] = None,
|
20
|
+
on_error: Optional[Callable[[Exception], Awaitable[None]]] = None,
|
21
|
+
) -> None:
|
22
|
+
self._coro = coro
|
23
|
+
self._task: Optional[asyncio.Task[None]] = None
|
24
|
+
if time_of_day is not None and (
|
25
|
+
seconds or minutes or hours or delta is not None
|
26
|
+
):
|
27
|
+
raise ValueError("time_of_day cannot be used with an interval")
|
28
|
+
|
29
|
+
if delta is not None:
|
30
|
+
if not isinstance(delta, datetime.timedelta):
|
31
|
+
raise TypeError("delta must be a datetime.timedelta")
|
32
|
+
interval_seconds = delta.total_seconds()
|
33
|
+
else:
|
34
|
+
interval_seconds = seconds + minutes * 60.0 + hours * 3600.0
|
35
|
+
|
36
|
+
self._seconds = float(interval_seconds)
|
37
|
+
self._time_of_day = time_of_day
|
38
|
+
self._on_error = on_error
|
39
|
+
|
40
|
+
def _seconds_until_time(self) -> float:
|
41
|
+
assert self._time_of_day is not None
|
42
|
+
now = datetime.datetime.now()
|
43
|
+
target = datetime.datetime.combine(now.date(), self._time_of_day)
|
44
|
+
if target <= now:
|
45
|
+
target += datetime.timedelta(days=1)
|
46
|
+
return (target - now).total_seconds()
|
47
|
+
|
48
|
+
async def _run(self, *args: Any, **kwargs: Any) -> None:
|
49
|
+
try:
|
50
|
+
first = True
|
51
|
+
while True:
|
52
|
+
if self._time_of_day is not None:
|
53
|
+
await asyncio.sleep(self._seconds_until_time())
|
54
|
+
elif not first:
|
55
|
+
await asyncio.sleep(self._seconds)
|
56
|
+
|
57
|
+
try:
|
58
|
+
await self._coro(*args, **kwargs)
|
59
|
+
except Exception as exc: # noqa: BLE001
|
60
|
+
if self._on_error is not None:
|
61
|
+
await _maybe_call(self._on_error, exc)
|
62
|
+
else:
|
63
|
+
raise
|
64
|
+
|
65
|
+
first = False
|
66
|
+
except asyncio.CancelledError:
|
67
|
+
pass
|
68
|
+
|
69
|
+
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
|
70
|
+
if self._task is None or self._task.done():
|
71
|
+
self._task = asyncio.create_task(self._run(*args, **kwargs))
|
72
|
+
return self._task
|
73
|
+
|
74
|
+
def stop(self) -> None:
|
75
|
+
if self._task is not None:
|
76
|
+
self._task.cancel()
|
77
|
+
self._task = None
|
78
|
+
|
79
|
+
@property
|
80
|
+
def running(self) -> bool:
|
81
|
+
return self._task is not None and not self._task.done()
|
82
|
+
|
83
|
+
|
84
|
+
async def _maybe_call(
|
85
|
+
func: Callable[[Exception], Awaitable[None] | None], exc: Exception
|
86
|
+
) -> None:
|
87
|
+
result = func(exc)
|
88
|
+
if asyncio.iscoroutine(result):
|
89
|
+
await result
|
90
|
+
|
91
|
+
|
92
|
+
class _Loop:
|
93
|
+
def __init__(
|
94
|
+
self,
|
95
|
+
func: Callable[..., Awaitable[Any]],
|
96
|
+
*,
|
97
|
+
seconds: float = 0.0,
|
98
|
+
minutes: float = 0.0,
|
99
|
+
hours: float = 0.0,
|
100
|
+
delta: Optional[datetime.timedelta] = None,
|
101
|
+
time_of_day: Optional[datetime.time] = None,
|
102
|
+
on_error: Optional[Callable[[Exception], Awaitable[None]]] = None,
|
103
|
+
) -> None:
|
104
|
+
self.func = func
|
105
|
+
self.seconds = seconds
|
106
|
+
self.minutes = minutes
|
107
|
+
self.hours = hours
|
108
|
+
self.delta = delta
|
109
|
+
self.time_of_day = time_of_day
|
110
|
+
self.on_error = on_error
|
111
|
+
self._task: Optional[Task] = None
|
112
|
+
self._owner: Any = None
|
113
|
+
|
114
|
+
def __get__(self, obj: Any, objtype: Any) -> "_BoundLoop":
|
115
|
+
return _BoundLoop(self, obj)
|
116
|
+
|
117
|
+
def _coro(self, *args: Any, **kwargs: Any) -> Awaitable[Any]:
|
118
|
+
if self._owner is None:
|
119
|
+
return self.func(*args, **kwargs)
|
120
|
+
return self.func(self._owner, *args, **kwargs)
|
121
|
+
|
122
|
+
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
|
123
|
+
self._task = Task(
|
124
|
+
self._coro,
|
125
|
+
seconds=self.seconds,
|
126
|
+
minutes=self.minutes,
|
127
|
+
hours=self.hours,
|
128
|
+
delta=self.delta,
|
129
|
+
time_of_day=self.time_of_day,
|
130
|
+
on_error=self.on_error,
|
131
|
+
)
|
132
|
+
return self._task.start(*args, **kwargs)
|
133
|
+
|
134
|
+
def stop(self) -> None:
|
135
|
+
if self._task is not None:
|
136
|
+
self._task.stop()
|
137
|
+
|
138
|
+
@property
|
139
|
+
def running(self) -> bool:
|
140
|
+
return self._task.running if self._task else False
|
141
|
+
|
142
|
+
|
143
|
+
class _BoundLoop:
|
144
|
+
def __init__(self, parent: _Loop, owner: Any) -> None:
|
145
|
+
self._parent = parent
|
146
|
+
self._owner = owner
|
147
|
+
|
148
|
+
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
|
149
|
+
self._parent._owner = self._owner
|
150
|
+
return self._parent.start(*args, **kwargs)
|
151
|
+
|
152
|
+
def stop(self) -> None:
|
153
|
+
self._parent.stop()
|
154
|
+
|
155
|
+
@property
|
156
|
+
def running(self) -> bool:
|
157
|
+
return self._parent.running
|
158
|
+
|
159
|
+
|
160
|
+
def loop(
|
161
|
+
*,
|
162
|
+
seconds: float = 0.0,
|
163
|
+
minutes: float = 0.0,
|
164
|
+
hours: float = 0.0,
|
165
|
+
delta: Optional[datetime.timedelta] = None,
|
166
|
+
time_of_day: Optional[datetime.time] = None,
|
167
|
+
on_error: Optional[Callable[[Exception], Awaitable[None]]] = None,
|
168
|
+
) -> Callable[[Callable[..., Awaitable[Any]]], _Loop]:
|
169
|
+
"""Decorator to create a looping task."""
|
170
|
+
|
171
|
+
def decorator(func: Callable[..., Awaitable[Any]]) -> _Loop:
|
172
|
+
return _Loop(
|
173
|
+
func,
|
174
|
+
seconds=seconds,
|
175
|
+
minutes=minutes,
|
176
|
+
hours=hours,
|
177
|
+
delta=delta,
|
178
|
+
time_of_day=time_of_day,
|
179
|
+
on_error=on_error,
|
180
|
+
)
|
181
|
+
|
182
|
+
return decorator
|
disagreement/gateway.py
CHANGED
@@ -10,6 +10,7 @@ import aiohttp
|
|
10
10
|
import json
|
11
11
|
import zlib
|
12
12
|
import time
|
13
|
+
import random
|
13
14
|
from typing import Optional, TYPE_CHECKING, Any, Dict
|
14
15
|
|
15
16
|
from .enums import GatewayOpcode, GatewayIntent
|
@@ -43,6 +44,8 @@ class GatewayClient:
|
|
43
44
|
*,
|
44
45
|
shard_id: Optional[int] = None,
|
45
46
|
shard_count: Optional[int] = None,
|
47
|
+
max_retries: int = 5,
|
48
|
+
max_backoff: float = 60.0,
|
46
49
|
):
|
47
50
|
self._http: "HTTPClient" = http_client
|
48
51
|
self._dispatcher: "EventDispatcher" = event_dispatcher
|
@@ -52,6 +55,8 @@ class GatewayClient:
|
|
52
55
|
self.verbose: bool = verbose
|
53
56
|
self._shard_id: Optional[int] = shard_id
|
54
57
|
self._shard_count: Optional[int] = shard_count
|
58
|
+
self._max_retries: int = max_retries
|
59
|
+
self._max_backoff: float = max_backoff
|
55
60
|
|
56
61
|
self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
|
57
62
|
self._loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
|
@@ -63,10 +68,33 @@ class GatewayClient:
|
|
63
68
|
self._keep_alive_task: Optional[asyncio.Task] = None
|
64
69
|
self._receive_task: Optional[asyncio.Task] = None
|
65
70
|
|
71
|
+
self._last_heartbeat_sent: Optional[float] = None
|
72
|
+
self._last_heartbeat_ack: Optional[float] = None
|
73
|
+
|
66
74
|
# For zlib decompression
|
67
75
|
self._buffer = bytearray()
|
68
76
|
self._inflator = zlib.decompressobj()
|
69
77
|
|
78
|
+
async def _reconnect(self) -> None:
|
79
|
+
"""Attempts to reconnect using exponential backoff with jitter."""
|
80
|
+
delay = 1.0
|
81
|
+
for attempt in range(self._max_retries):
|
82
|
+
try:
|
83
|
+
await self.connect()
|
84
|
+
return
|
85
|
+
except Exception as e: # noqa: BLE001
|
86
|
+
if attempt >= self._max_retries - 1:
|
87
|
+
print(f"Reconnect failed after {attempt + 1} attempts: {e}")
|
88
|
+
raise
|
89
|
+
jitter = random.uniform(0, delay)
|
90
|
+
wait_time = min(delay + jitter, self._max_backoff)
|
91
|
+
print(
|
92
|
+
f"Reconnect attempt {attempt + 1} failed: {e}. "
|
93
|
+
f"Retrying in {wait_time:.2f} seconds..."
|
94
|
+
)
|
95
|
+
await asyncio.sleep(wait_time)
|
96
|
+
delay = min(delay * 2, self._max_backoff)
|
97
|
+
|
70
98
|
async def _decompress_message(
|
71
99
|
self, message_bytes: bytes
|
72
100
|
) -> Optional[Dict[str, Any]]:
|
@@ -103,6 +131,7 @@ class GatewayClient:
|
|
103
131
|
|
104
132
|
async def _heartbeat(self):
|
105
133
|
"""Sends a heartbeat to the Gateway."""
|
134
|
+
self._last_heartbeat_sent = time.monotonic()
|
106
135
|
payload = {"op": GatewayOpcode.HEARTBEAT, "d": self._last_sequence}
|
107
136
|
await self._send_json(payload)
|
108
137
|
# print("Sent heartbeat.")
|
@@ -166,6 +195,7 @@ class GatewayClient:
|
|
166
195
|
print(
|
167
196
|
f"Sent RESUME for session {self._session_id} at sequence {self._last_sequence}."
|
168
197
|
)
|
198
|
+
|
169
199
|
async def update_presence(
|
170
200
|
self,
|
171
201
|
status: str,
|
@@ -179,14 +209,16 @@ class GatewayClient:
|
|
179
209
|
"op": GatewayOpcode.PRESENCE_UPDATE,
|
180
210
|
"d": {
|
181
211
|
"since": since,
|
182
|
-
"activities":
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
212
|
+
"activities": (
|
213
|
+
[
|
214
|
+
{
|
215
|
+
"name": activity_name,
|
216
|
+
"type": activity_type,
|
217
|
+
}
|
218
|
+
]
|
219
|
+
if activity_name
|
220
|
+
else []
|
221
|
+
),
|
190
222
|
"status": status,
|
191
223
|
"afk": afk,
|
192
224
|
},
|
@@ -347,7 +379,7 @@ class GatewayClient:
|
|
347
379
|
await self._heartbeat()
|
348
380
|
elif op == GatewayOpcode.RECONNECT: # Server requests a reconnect
|
349
381
|
print("Gateway requested RECONNECT. Closing and will attempt to reconnect.")
|
350
|
-
await self.close(code=4000
|
382
|
+
await self.close(code=4000, reconnect=True)
|
351
383
|
elif op == GatewayOpcode.INVALID_SESSION:
|
352
384
|
# The 'd' payload for INVALID_SESSION is a boolean indicating resumability
|
353
385
|
can_resume = data.get("d") is True
|
@@ -356,9 +388,7 @@ class GatewayClient:
|
|
356
388
|
self._session_id = None # Clear session_id to force re-identify
|
357
389
|
self._last_sequence = None
|
358
390
|
# Close and reconnect. The connect logic will decide to resume or identify.
|
359
|
-
await self.close(
|
360
|
-
code=4000 if can_resume else 4009
|
361
|
-
) # 4009 for non-resumable
|
391
|
+
await self.close(code=4000 if can_resume else 4009, reconnect=True)
|
362
392
|
elif op == GatewayOpcode.HELLO:
|
363
393
|
hello_d_payload = data.get("d")
|
364
394
|
if (
|
@@ -385,6 +415,7 @@ class GatewayClient:
|
|
385
415
|
print("Performing initial IDENTIFY.")
|
386
416
|
await self._identify()
|
387
417
|
elif op == GatewayOpcode.HEARTBEAT_ACK:
|
418
|
+
self._last_heartbeat_ack = time.monotonic()
|
388
419
|
# print("Received heartbeat ACK.")
|
389
420
|
pass # Good, connection is alive
|
390
421
|
else:
|
@@ -403,13 +434,11 @@ class GatewayClient:
|
|
403
434
|
print("Receive_loop task cancelled.")
|
404
435
|
except aiohttp.ClientConnectionError as e:
|
405
436
|
print(f"ClientConnectionError in receive_loop: {e}. Attempting reconnect.")
|
406
|
-
|
407
|
-
await self.close(code=1006) # Abnormal closure
|
437
|
+
await self.close(code=1006, reconnect=True) # Abnormal closure
|
408
438
|
except Exception as e:
|
409
439
|
print(f"Unexpected error in receive_loop: {e}")
|
410
440
|
traceback.print_exc()
|
411
|
-
|
412
|
-
await self.close(code=1011) # Internal error
|
441
|
+
await self.close(code=1011, reconnect=True)
|
413
442
|
finally:
|
414
443
|
print("Receive_loop ended.")
|
415
444
|
# If the loop ends unexpectedly (not due to explicit close),
|
@@ -457,7 +486,7 @@ class GatewayClient:
|
|
457
486
|
f"An unexpected error occurred during Gateway connection: {e}"
|
458
487
|
) from e
|
459
488
|
|
460
|
-
async def close(self, code: int = 1000):
|
489
|
+
async def close(self, code: int = 1000, *, reconnect: bool = False):
|
461
490
|
"""Closes the Gateway connection."""
|
462
491
|
print(f"Closing Gateway connection with code {code}...")
|
463
492
|
if self._keep_alive_task and not self._keep_alive_task.done():
|
@@ -468,11 +497,13 @@ class GatewayClient:
|
|
468
497
|
pass # Expected
|
469
498
|
|
470
499
|
if self._receive_task and not self._receive_task.done():
|
500
|
+
current = asyncio.current_task(loop=self._loop)
|
471
501
|
self._receive_task.cancel()
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
502
|
+
if self._receive_task is not current:
|
503
|
+
try:
|
504
|
+
await self._receive_task
|
505
|
+
except asyncio.CancelledError:
|
506
|
+
pass # Expected
|
476
507
|
|
477
508
|
if self._ws and not self._ws.closed:
|
478
509
|
await self._ws.close(code=code)
|
@@ -488,3 +519,18 @@ class GatewayClient:
|
|
488
519
|
self._session_id = None
|
489
520
|
self._last_sequence = None
|
490
521
|
self._resume_gateway_url = None # This might be re-fetched anyway
|
522
|
+
|
523
|
+
@property
|
524
|
+
def latency(self) -> Optional[float]:
|
525
|
+
"""Returns the latency between heartbeat and ACK in seconds."""
|
526
|
+
if self._last_heartbeat_sent is None or self._last_heartbeat_ack is None:
|
527
|
+
return None
|
528
|
+
return self._last_heartbeat_ack - self._last_heartbeat_sent
|
529
|
+
|
530
|
+
@property
|
531
|
+
def last_heartbeat_sent(self) -> Optional[float]:
|
532
|
+
return self._last_heartbeat_sent
|
533
|
+
|
534
|
+
@property
|
535
|
+
def last_heartbeat_ack(self) -> Optional[float]:
|
536
|
+
return self._last_heartbeat_ack
|