disagreement 0.0.2__py3-none-any.whl → 0.1.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- disagreement/__init__.py +8 -3
- disagreement/audio.py +116 -0
- disagreement/client.py +176 -6
- disagreement/color.py +50 -0
- disagreement/components.py +2 -2
- disagreement/errors.py +13 -8
- disagreement/event_dispatcher.py +102 -45
- disagreement/ext/commands/__init__.py +9 -1
- disagreement/ext/commands/core.py +7 -0
- disagreement/ext/commands/decorators.py +72 -30
- disagreement/ext/loader.py +12 -1
- disagreement/ext/tasks.py +101 -8
- disagreement/gateway.py +56 -13
- disagreement/http.py +104 -3
- disagreement/models.py +308 -1
- disagreement/shard_manager.py +2 -0
- disagreement/utils.py +10 -0
- disagreement/voice_client.py +42 -0
- {disagreement-0.0.2.dist-info → disagreement-0.1.0rc1.dist-info}/METADATA +9 -2
- {disagreement-0.0.2.dist-info → disagreement-0.1.0rc1.dist-info}/RECORD +23 -20
- {disagreement-0.0.2.dist-info → disagreement-0.1.0rc1.dist-info}/WHEEL +0 -0
- {disagreement-0.0.2.dist-info → disagreement-0.1.0rc1.dist-info}/licenses/LICENSE +0 -0
- {disagreement-0.0.2.dist-info → disagreement-0.1.0rc1.dist-info}/top_level.txt +0 -0
disagreement/event_dispatcher.py
CHANGED
@@ -47,10 +47,20 @@ class EventDispatcher:
|
|
47
47
|
# Pre-defined parsers for specific event types to convert raw data to models
|
48
48
|
self._event_parsers: Dict[str, Callable[[Dict[str, Any]], Any]] = {
|
49
49
|
"MESSAGE_CREATE": self._parse_message_create,
|
50
|
+
"MESSAGE_UPDATE": self._parse_message_update,
|
51
|
+
"MESSAGE_DELETE": self._parse_message_delete,
|
52
|
+
"MESSAGE_REACTION_ADD": self._parse_message_reaction,
|
53
|
+
"MESSAGE_REACTION_REMOVE": self._parse_message_reaction,
|
50
54
|
"INTERACTION_CREATE": self._parse_interaction_create,
|
51
55
|
"GUILD_CREATE": self._parse_guild_create,
|
52
56
|
"CHANNEL_CREATE": self._parse_channel_create,
|
57
|
+
"CHANNEL_UPDATE": self._parse_channel_update,
|
53
58
|
"PRESENCE_UPDATE": self._parse_presence_update,
|
59
|
+
"GUILD_MEMBER_ADD": self._parse_guild_member_add,
|
60
|
+
"GUILD_MEMBER_REMOVE": self._parse_guild_member_remove,
|
61
|
+
"GUILD_BAN_ADD": self._parse_guild_ban_add,
|
62
|
+
"GUILD_BAN_REMOVE": self._parse_guild_ban_remove,
|
63
|
+
"GUILD_ROLE_UPDATE": self._parse_guild_role_update,
|
54
64
|
"TYPING_START": self._parse_typing_start,
|
55
65
|
}
|
56
66
|
|
@@ -58,6 +68,21 @@ class EventDispatcher:
|
|
58
68
|
"""Parses raw MESSAGE_CREATE data into a Message object."""
|
59
69
|
return self._client.parse_message(data)
|
60
70
|
|
71
|
+
def _parse_message_update(self, data: Dict[str, Any]) -> Message:
|
72
|
+
"""Parses raw MESSAGE_UPDATE data into a Message object."""
|
73
|
+
return self._client.parse_message(data)
|
74
|
+
|
75
|
+
def _parse_message_delete(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
76
|
+
"""Parses MESSAGE_DELETE and updates message cache."""
|
77
|
+
message_id = data.get("id")
|
78
|
+
if message_id:
|
79
|
+
self._client._messages.pop(message_id, None)
|
80
|
+
return data
|
81
|
+
|
82
|
+
def _parse_message_reaction_raw(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
83
|
+
"""Returns the raw reaction payload."""
|
84
|
+
return data
|
85
|
+
|
61
86
|
def _parse_interaction_create(self, data: Dict[str, Any]) -> "Interaction":
|
62
87
|
"""Parses raw INTERACTION_CREATE data into an Interaction object."""
|
63
88
|
from .interactions import Interaction
|
@@ -88,6 +113,52 @@ class EventDispatcher:
|
|
88
113
|
|
89
114
|
return TypingStart(data, client_instance=self._client)
|
90
115
|
|
116
|
+
def _parse_message_reaction(self, data: Dict[str, Any]):
|
117
|
+
"""Parses raw reaction data into a Reaction object."""
|
118
|
+
|
119
|
+
from .models import Reaction
|
120
|
+
|
121
|
+
return Reaction(data, client_instance=self._client)
|
122
|
+
|
123
|
+
def _parse_guild_member_add(self, data: Dict[str, Any]):
|
124
|
+
"""Parses GUILD_MEMBER_ADD into a Member object."""
|
125
|
+
|
126
|
+
guild_id = str(data.get("guild_id"))
|
127
|
+
return self._client.parse_member(data, guild_id)
|
128
|
+
|
129
|
+
def _parse_guild_member_remove(self, data: Dict[str, Any]):
|
130
|
+
"""Parses GUILD_MEMBER_REMOVE into a GuildMemberRemove model."""
|
131
|
+
|
132
|
+
from .models import GuildMemberRemove
|
133
|
+
|
134
|
+
return GuildMemberRemove(data, client_instance=self._client)
|
135
|
+
|
136
|
+
def _parse_guild_ban_add(self, data: Dict[str, Any]):
|
137
|
+
"""Parses GUILD_BAN_ADD into a GuildBanAdd model."""
|
138
|
+
|
139
|
+
from .models import GuildBanAdd
|
140
|
+
|
141
|
+
return GuildBanAdd(data, client_instance=self._client)
|
142
|
+
|
143
|
+
def _parse_guild_ban_remove(self, data: Dict[str, Any]):
|
144
|
+
"""Parses GUILD_BAN_REMOVE into a GuildBanRemove model."""
|
145
|
+
|
146
|
+
from .models import GuildBanRemove
|
147
|
+
|
148
|
+
return GuildBanRemove(data, client_instance=self._client)
|
149
|
+
|
150
|
+
def _parse_channel_update(self, data: Dict[str, Any]):
|
151
|
+
"""Parses CHANNEL_UPDATE into a Channel object."""
|
152
|
+
|
153
|
+
return self._client.parse_channel(data)
|
154
|
+
|
155
|
+
def _parse_guild_role_update(self, data: Dict[str, Any]):
|
156
|
+
"""Parses GUILD_ROLE_UPDATE into a GuildRoleUpdate model."""
|
157
|
+
|
158
|
+
from .models import GuildRoleUpdate
|
159
|
+
|
160
|
+
return GuildRoleUpdate(data, client_instance=self._client)
|
161
|
+
|
91
162
|
# Potentially add _parse_user for events that directly provide a full user object
|
92
163
|
# def _parse_user_update(self, data: Dict[str, Any]) -> User:
|
93
164
|
# return User(data=data)
|
@@ -169,75 +240,61 @@ class EventDispatcher:
|
|
169
240
|
if not waiters:
|
170
241
|
self._waiters.pop(event_name, None)
|
171
242
|
|
172
|
-
async def
|
173
|
-
|
174
|
-
Dispatches an event to all registered listeners.
|
175
|
-
|
176
|
-
Args:
|
177
|
-
event_name (str): The name of the event (e.g., 'MESSAGE_CREATE').
|
178
|
-
raw_data (Dict[str, Any]): The raw data payload from the Discord Gateway for this event.
|
179
|
-
"""
|
180
|
-
event_name_upper = event_name.upper()
|
181
|
-
listeners = self._listeners.get(event_name_upper)
|
182
|
-
|
243
|
+
async def _dispatch_to_listeners(self, event_name: str, data: Any) -> None:
|
244
|
+
listeners = self._listeners.get(event_name)
|
183
245
|
if not listeners:
|
184
|
-
# print(f"No listeners for event {event_name_upper}")
|
185
246
|
return
|
186
247
|
|
187
|
-
|
188
|
-
if event_name_upper in self._event_parsers:
|
189
|
-
try:
|
190
|
-
parser = self._event_parsers[event_name_upper]
|
191
|
-
parsed_data = parser(raw_data)
|
192
|
-
except Exception as e:
|
193
|
-
print(f"Error parsing event data for {event_name_upper}: {e}")
|
194
|
-
# Optionally, dispatch with raw_data or raise, or log more formally
|
195
|
-
# For now, we'll proceed to dispatch with raw_data if parsing fails,
|
196
|
-
# or just log and return if parsed_data is critical.
|
197
|
-
# Let's assume if a parser exists, its output is critical.
|
198
|
-
return
|
248
|
+
self._resolve_waiters(event_name, data)
|
199
249
|
|
200
|
-
self._resolve_waiters(event_name_upper, parsed_data)
|
201
|
-
# print(f"Dispatching event {event_name_upper} with data: {parsed_data} to {len(listeners)} listeners.")
|
202
250
|
for listener in listeners:
|
203
251
|
try:
|
204
|
-
# Inspect the listener to see how many arguments it expects
|
205
252
|
sig = inspect.signature(listener)
|
206
253
|
num_params = len(sig.parameters)
|
207
254
|
|
208
|
-
if num_params == 0:
|
255
|
+
if num_params == 0:
|
209
256
|
await listener()
|
210
|
-
elif
|
211
|
-
|
212
|
-
): # Listener takes one argument (the parsed data or model)
|
213
|
-
await listener(parsed_data)
|
214
|
-
# elif num_params == 2 and event_name_upper == "MESSAGE_CREATE": # Special case for (client, message)
|
215
|
-
# await listener(self._client, parsed_data) # This might be too specific here
|
257
|
+
elif num_params == 1:
|
258
|
+
await listener(data)
|
216
259
|
else:
|
217
|
-
# Fallback or error if signature doesn't match expected patterns
|
218
|
-
# For now, assume one arg is the most common for parsed data.
|
219
|
-
# Or, if you want to be strict:
|
220
260
|
print(
|
221
|
-
f"Warning: Listener {listener.__name__} for {
|
261
|
+
f"Warning: Listener {listener.__name__} for {event_name} has an unhandled number of parameters ({num_params}). Skipping or attempting with one arg."
|
222
262
|
)
|
223
|
-
if num_params > 0:
|
224
|
-
await listener(
|
263
|
+
if num_params > 0:
|
264
|
+
await listener(data)
|
225
265
|
|
226
266
|
except Exception as e:
|
227
267
|
callback = self.on_dispatch_error
|
228
268
|
if callback is not None:
|
229
269
|
try:
|
230
|
-
await callback(
|
231
|
-
|
270
|
+
await callback(event_name, e, listener)
|
232
271
|
except Exception as hook_error:
|
233
272
|
print(f"Error in on_dispatch_error hook itself: {hook_error}")
|
234
273
|
else:
|
235
|
-
# Default error handling if no hook is set
|
236
274
|
print(
|
237
|
-
f"Error in event listener {listener.__name__} for {
|
275
|
+
f"Error in event listener {listener.__name__} for {event_name}: {e}"
|
238
276
|
)
|
239
277
|
if hasattr(self._client, "on_error"):
|
240
278
|
try:
|
241
|
-
await self._client.on_error(
|
279
|
+
await self._client.on_error(event_name, e, listener)
|
242
280
|
except Exception as client_err_e:
|
243
281
|
print(f"Error in client.on_error itself: {client_err_e}")
|
282
|
+
|
283
|
+
async def dispatch(self, event_name: str, raw_data: Dict[str, Any]):
|
284
|
+
"""Dispatch an event and its raw counterpart to all listeners."""
|
285
|
+
|
286
|
+
event_name_upper = event_name.upper()
|
287
|
+
raw_event_name = f"RAW_{event_name_upper}"
|
288
|
+
|
289
|
+
await self._dispatch_to_listeners(raw_event_name, raw_data)
|
290
|
+
|
291
|
+
parsed_data: Any = raw_data
|
292
|
+
if event_name_upper in self._event_parsers:
|
293
|
+
try:
|
294
|
+
parser = self._event_parsers[event_name_upper]
|
295
|
+
parsed_data = parser(raw_data)
|
296
|
+
except Exception as e:
|
297
|
+
print(f"Error parsing event data for {event_name_upper}: {e}")
|
298
|
+
return
|
299
|
+
|
300
|
+
await self._dispatch_to_listeners(event_name_upper, parsed_data)
|
@@ -10,7 +10,14 @@ from .core import (
|
|
10
10
|
CommandContext,
|
11
11
|
CommandHandler,
|
12
12
|
) # CommandHandler might be internal
|
13
|
-
from .decorators import
|
13
|
+
from .decorators import (
|
14
|
+
command,
|
15
|
+
listener,
|
16
|
+
check,
|
17
|
+
check_any,
|
18
|
+
cooldown,
|
19
|
+
requires_permissions,
|
20
|
+
)
|
14
21
|
from .errors import (
|
15
22
|
CommandError,
|
16
23
|
CommandNotFound,
|
@@ -36,6 +43,7 @@ __all__ = [
|
|
36
43
|
"check",
|
37
44
|
"check_any",
|
38
45
|
"cooldown",
|
46
|
+
"requires_permissions",
|
39
47
|
# Errors
|
40
48
|
"CommandError",
|
41
49
|
"CommandNotFound",
|
@@ -114,6 +114,13 @@ class CommandContext:
|
|
114
114
|
|
115
115
|
self.author: "User" = message.author
|
116
116
|
|
117
|
+
@property
|
118
|
+
def guild(self):
|
119
|
+
"""The guild this command was invoked in."""
|
120
|
+
if self.message.guild_id and hasattr(self.bot, "get_guild"):
|
121
|
+
return self.bot.get_guild(self.message.guild_id)
|
122
|
+
return None
|
123
|
+
|
117
124
|
async def reply(
|
118
125
|
self,
|
119
126
|
content: str,
|
@@ -1,4 +1,5 @@
|
|
1
1
|
# disagreement/ext/commands/decorators.py
|
2
|
+
from __future__ import annotations
|
2
3
|
|
3
4
|
import asyncio
|
4
5
|
import inspect
|
@@ -6,9 +7,9 @@ import time
|
|
6
7
|
from typing import Callable, Any, Optional, List, TYPE_CHECKING, Awaitable
|
7
8
|
|
8
9
|
if TYPE_CHECKING:
|
9
|
-
from .core import Command, CommandContext
|
10
|
-
|
11
|
-
|
10
|
+
from .core import Command, CommandContext
|
11
|
+
from disagreement.permissions import Permissions
|
12
|
+
from disagreement.models import Member, Guild, Channel
|
12
13
|
|
13
14
|
|
14
15
|
def command(
|
@@ -33,32 +34,16 @@ def command(
|
|
33
34
|
if not asyncio.iscoroutinefunction(func):
|
34
35
|
raise TypeError("Command callback must be a coroutine function.")
|
35
36
|
|
36
|
-
from .core import
|
37
|
-
Command,
|
38
|
-
) # Late import to avoid circular dependencies at module load time
|
39
|
-
|
40
|
-
# The actual registration will happen when a Cog is added or if commands are global.
|
41
|
-
# For now, this decorator creates a Command instance and attaches it to the function,
|
42
|
-
# or returns a Command instance that can be collected.
|
37
|
+
from .core import Command
|
43
38
|
|
44
39
|
cmd_name = name or func.__name__
|
45
40
|
|
46
|
-
# Store command attributes on the function itself for later collection by Cog or Client
|
47
|
-
# This is a common pattern.
|
48
41
|
if hasattr(func, "__command_attrs__"):
|
49
|
-
# This case might occur if decorators are stacked in an unusual way,
|
50
|
-
# or if a function is decorated multiple times (which should be disallowed or handled).
|
51
|
-
# For now, let's assume one @command decorator per function.
|
52
42
|
raise TypeError("Function is already a command or has command attributes.")
|
53
43
|
|
54
|
-
# Create the command object. It will be registered by the Cog or Client.
|
55
44
|
cmd = Command(callback=func, name=cmd_name, aliases=aliases or [], **attrs)
|
56
|
-
|
57
|
-
|
58
|
-
func.__command_object__ = cmd # type: ignore # type: ignore[attr-defined]
|
59
|
-
return func # Return the original function, now marked.
|
60
|
-
# Or return `cmd` if commands are registered globally immediately.
|
61
|
-
# For Cogs, returning `func` and letting Cog collect is cleaner.
|
45
|
+
func.__command_object__ = cmd # type: ignore
|
46
|
+
return func
|
62
47
|
|
63
48
|
return decorator
|
64
49
|
|
@@ -68,11 +53,6 @@ def listener(
|
|
68
53
|
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
69
54
|
"""
|
70
55
|
A decorator that marks a function as an event listener within a Cog.
|
71
|
-
The actual registration happens when the Cog is added to the client.
|
72
|
-
|
73
|
-
Args:
|
74
|
-
name (Optional[str]): The name of the event to listen to.
|
75
|
-
Defaults to the function name (e.g., `on_message`).
|
76
56
|
"""
|
77
57
|
|
78
58
|
def decorator(
|
@@ -81,13 +61,11 @@ def listener(
|
|
81
61
|
if not asyncio.iscoroutinefunction(func):
|
82
62
|
raise TypeError("Listener callback must be a coroutine function.")
|
83
63
|
|
84
|
-
# 'name' here is from the outer 'listener' scope (closure)
|
85
64
|
actual_event_name = name or func.__name__
|
86
|
-
# Store listener info on the function for Cog to collect
|
87
65
|
setattr(func, "__listener_name__", actual_event_name)
|
88
66
|
return func
|
89
67
|
|
90
|
-
return decorator
|
68
|
+
return decorator
|
91
69
|
|
92
70
|
|
93
71
|
def check(
|
@@ -148,3 +126,67 @@ def cooldown(
|
|
148
126
|
return True
|
149
127
|
|
150
128
|
return check(predicate)
|
129
|
+
|
130
|
+
|
131
|
+
def _compute_permissions(
|
132
|
+
member: "Member", channel: "Channel", guild: "Guild"
|
133
|
+
) -> "Permissions":
|
134
|
+
"""Compute the effective permissions for a member in a channel."""
|
135
|
+
return channel.permissions_for(member)
|
136
|
+
|
137
|
+
|
138
|
+
def requires_permissions(
|
139
|
+
*perms: "Permissions",
|
140
|
+
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
141
|
+
"""Check that the invoking member has the given permissions in the channel."""
|
142
|
+
|
143
|
+
async def predicate(ctx: "CommandContext") -> bool:
|
144
|
+
from .errors import CheckFailure
|
145
|
+
from disagreement.permissions import (
|
146
|
+
has_permissions,
|
147
|
+
missing_permissions,
|
148
|
+
)
|
149
|
+
from disagreement.models import Member
|
150
|
+
|
151
|
+
channel = getattr(ctx, "channel", None)
|
152
|
+
if channel is None and hasattr(ctx.bot, "get_channel"):
|
153
|
+
channel = ctx.bot.get_channel(ctx.message.channel_id)
|
154
|
+
if channel is None and hasattr(ctx.bot, "fetch_channel"):
|
155
|
+
channel = await ctx.bot.fetch_channel(ctx.message.channel_id)
|
156
|
+
|
157
|
+
if channel is None:
|
158
|
+
raise CheckFailure("Channel for permission check not found.")
|
159
|
+
|
160
|
+
guild = getattr(channel, "guild", None)
|
161
|
+
if not guild and hasattr(channel, "guild_id") and channel.guild_id:
|
162
|
+
if hasattr(ctx.bot, "get_guild"):
|
163
|
+
guild = ctx.bot.get_guild(channel.guild_id)
|
164
|
+
if not guild and hasattr(ctx.bot, "fetch_guild"):
|
165
|
+
guild = await ctx.bot.fetch_guild(channel.guild_id)
|
166
|
+
|
167
|
+
if not guild:
|
168
|
+
is_dm = not hasattr(channel, "guild_id") or not channel.guild_id
|
169
|
+
if is_dm:
|
170
|
+
if perms:
|
171
|
+
raise CheckFailure("Permission checks are not supported in DMs.")
|
172
|
+
return True
|
173
|
+
raise CheckFailure("Guild for permission check not found.")
|
174
|
+
|
175
|
+
member = ctx.author
|
176
|
+
if not isinstance(member, Member):
|
177
|
+
member = guild.get_member(ctx.author.id)
|
178
|
+
if not member and hasattr(ctx.bot, "fetch_member"):
|
179
|
+
member = await ctx.bot.fetch_member(guild.id, ctx.author.id)
|
180
|
+
|
181
|
+
if not member:
|
182
|
+
raise CheckFailure("Could not resolve author to a guild member.")
|
183
|
+
|
184
|
+
perms_value = _compute_permissions(member, channel, guild)
|
185
|
+
|
186
|
+
if not has_permissions(perms_value, *perms):
|
187
|
+
missing = missing_permissions(perms_value, *perms)
|
188
|
+
missing_names = ", ".join(p.name for p in missing if p.name)
|
189
|
+
raise CheckFailure(f"Missing permissions: {missing_names}")
|
190
|
+
return True
|
191
|
+
|
192
|
+
return check(predicate)
|
disagreement/ext/loader.py
CHANGED
@@ -5,7 +5,7 @@ import sys
|
|
5
5
|
from types import ModuleType
|
6
6
|
from typing import Dict
|
7
7
|
|
8
|
-
__all__ = ["load_extension", "unload_extension"]
|
8
|
+
__all__ = ["load_extension", "unload_extension", "reload_extension"]
|
9
9
|
|
10
10
|
_loaded_extensions: Dict[str, ModuleType] = {}
|
11
11
|
|
@@ -41,3 +41,14 @@ def unload_extension(name: str) -> None:
|
|
41
41
|
module.teardown()
|
42
42
|
|
43
43
|
sys.modules.pop(name, None)
|
44
|
+
|
45
|
+
|
46
|
+
def reload_extension(name: str) -> ModuleType:
|
47
|
+
"""Reload an extension by name.
|
48
|
+
|
49
|
+
This is a convenience wrapper around :func:`unload_extension` followed by
|
50
|
+
:func:`load_extension`.
|
51
|
+
"""
|
52
|
+
|
53
|
+
unload_extension(name)
|
54
|
+
return load_extension(name)
|
disagreement/ext/tasks.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
import asyncio
|
2
|
+
import datetime
|
2
3
|
from typing import Any, Awaitable, Callable, Optional
|
3
4
|
|
4
5
|
__all__ = ["loop", "Task"]
|
@@ -7,16 +8,61 @@ __all__ = ["loop", "Task"]
|
|
7
8
|
class Task:
|
8
9
|
"""Simple repeating task."""
|
9
10
|
|
10
|
-
def __init__(
|
11
|
+
def __init__(
|
12
|
+
self,
|
13
|
+
coro: Callable[..., Awaitable[Any]],
|
14
|
+
*,
|
15
|
+
seconds: float = 0.0,
|
16
|
+
minutes: float = 0.0,
|
17
|
+
hours: float = 0.0,
|
18
|
+
delta: Optional[datetime.timedelta] = None,
|
19
|
+
time_of_day: Optional[datetime.time] = None,
|
20
|
+
on_error: Optional[Callable[[Exception], Awaitable[None]]] = None,
|
21
|
+
) -> None:
|
11
22
|
self._coro = coro
|
12
|
-
self._seconds = float(seconds)
|
13
23
|
self._task: Optional[asyncio.Task[None]] = None
|
24
|
+
if time_of_day is not None and (
|
25
|
+
seconds or minutes or hours or delta is not None
|
26
|
+
):
|
27
|
+
raise ValueError("time_of_day cannot be used with an interval")
|
28
|
+
|
29
|
+
if delta is not None:
|
30
|
+
if not isinstance(delta, datetime.timedelta):
|
31
|
+
raise TypeError("delta must be a datetime.timedelta")
|
32
|
+
interval_seconds = delta.total_seconds()
|
33
|
+
else:
|
34
|
+
interval_seconds = seconds + minutes * 60.0 + hours * 3600.0
|
35
|
+
|
36
|
+
self._seconds = float(interval_seconds)
|
37
|
+
self._time_of_day = time_of_day
|
38
|
+
self._on_error = on_error
|
39
|
+
|
40
|
+
def _seconds_until_time(self) -> float:
|
41
|
+
assert self._time_of_day is not None
|
42
|
+
now = datetime.datetime.now()
|
43
|
+
target = datetime.datetime.combine(now.date(), self._time_of_day)
|
44
|
+
if target <= now:
|
45
|
+
target += datetime.timedelta(days=1)
|
46
|
+
return (target - now).total_seconds()
|
14
47
|
|
15
48
|
async def _run(self, *args: Any, **kwargs: Any) -> None:
|
16
49
|
try:
|
50
|
+
first = True
|
17
51
|
while True:
|
18
|
-
|
19
|
-
|
52
|
+
if self._time_of_day is not None:
|
53
|
+
await asyncio.sleep(self._seconds_until_time())
|
54
|
+
elif not first:
|
55
|
+
await asyncio.sleep(self._seconds)
|
56
|
+
|
57
|
+
try:
|
58
|
+
await self._coro(*args, **kwargs)
|
59
|
+
except Exception as exc: # noqa: BLE001
|
60
|
+
if self._on_error is not None:
|
61
|
+
await _maybe_call(self._on_error, exc)
|
62
|
+
else:
|
63
|
+
raise
|
64
|
+
|
65
|
+
first = False
|
20
66
|
except asyncio.CancelledError:
|
21
67
|
pass
|
22
68
|
|
@@ -35,10 +81,33 @@ class Task:
|
|
35
81
|
return self._task is not None and not self._task.done()
|
36
82
|
|
37
83
|
|
84
|
+
async def _maybe_call(
|
85
|
+
func: Callable[[Exception], Awaitable[None] | None], exc: Exception
|
86
|
+
) -> None:
|
87
|
+
result = func(exc)
|
88
|
+
if asyncio.iscoroutine(result):
|
89
|
+
await result
|
90
|
+
|
91
|
+
|
38
92
|
class _Loop:
|
39
|
-
def __init__(
|
93
|
+
def __init__(
|
94
|
+
self,
|
95
|
+
func: Callable[..., Awaitable[Any]],
|
96
|
+
*,
|
97
|
+
seconds: float = 0.0,
|
98
|
+
minutes: float = 0.0,
|
99
|
+
hours: float = 0.0,
|
100
|
+
delta: Optional[datetime.timedelta] = None,
|
101
|
+
time_of_day: Optional[datetime.time] = None,
|
102
|
+
on_error: Optional[Callable[[Exception], Awaitable[None]]] = None,
|
103
|
+
) -> None:
|
40
104
|
self.func = func
|
41
105
|
self.seconds = seconds
|
106
|
+
self.minutes = minutes
|
107
|
+
self.hours = hours
|
108
|
+
self.delta = delta
|
109
|
+
self.time_of_day = time_of_day
|
110
|
+
self.on_error = on_error
|
42
111
|
self._task: Optional[Task] = None
|
43
112
|
self._owner: Any = None
|
44
113
|
|
@@ -51,7 +120,15 @@ class _Loop:
|
|
51
120
|
return self.func(self._owner, *args, **kwargs)
|
52
121
|
|
53
122
|
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
|
54
|
-
self._task = Task(
|
123
|
+
self._task = Task(
|
124
|
+
self._coro,
|
125
|
+
seconds=self.seconds,
|
126
|
+
minutes=self.minutes,
|
127
|
+
hours=self.hours,
|
128
|
+
delta=self.delta,
|
129
|
+
time_of_day=self.time_of_day,
|
130
|
+
on_error=self.on_error,
|
131
|
+
)
|
55
132
|
return self._task.start(*args, **kwargs)
|
56
133
|
|
57
134
|
def stop(self) -> None:
|
@@ -80,10 +157,26 @@ class _BoundLoop:
|
|
80
157
|
return self._parent.running
|
81
158
|
|
82
159
|
|
83
|
-
def loop(
|
160
|
+
def loop(
|
161
|
+
*,
|
162
|
+
seconds: float = 0.0,
|
163
|
+
minutes: float = 0.0,
|
164
|
+
hours: float = 0.0,
|
165
|
+
delta: Optional[datetime.timedelta] = None,
|
166
|
+
time_of_day: Optional[datetime.time] = None,
|
167
|
+
on_error: Optional[Callable[[Exception], Awaitable[None]]] = None,
|
168
|
+
) -> Callable[[Callable[..., Awaitable[Any]]], _Loop]:
|
84
169
|
"""Decorator to create a looping task."""
|
85
170
|
|
86
171
|
def decorator(func: Callable[..., Awaitable[Any]]) -> _Loop:
|
87
|
-
return _Loop(
|
172
|
+
return _Loop(
|
173
|
+
func,
|
174
|
+
seconds=seconds,
|
175
|
+
minutes=minutes,
|
176
|
+
hours=hours,
|
177
|
+
delta=delta,
|
178
|
+
time_of_day=time_of_day,
|
179
|
+
on_error=on_error,
|
180
|
+
)
|
88
181
|
|
89
182
|
return decorator
|