disagreement 0.2.0rc1__py3-none-any.whl → 0.3.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,219 +1,298 @@
1
- # disagreement/ext/commands/decorators.py
2
- from __future__ import annotations
3
-
4
- import asyncio
5
- import inspect
6
- import time
7
- from typing import Callable, Any, Optional, List, TYPE_CHECKING, Awaitable
8
-
9
- if TYPE_CHECKING:
10
- from .core import Command, CommandContext
11
- from disagreement.permissions import Permissions
12
- from disagreement.models import Member, Guild, Channel
13
-
14
-
15
- def command(
16
- name: Optional[str] = None, aliases: Optional[List[str]] = None, **attrs: Any
17
- ) -> Callable:
18
- """
19
- A decorator that transforms a function into a Command.
20
-
21
- Args:
22
- name (Optional[str]): The name of the command. Defaults to the function name.
23
- aliases (Optional[List[str]]): Alternative names for the command.
24
- **attrs: Additional attributes to pass to the Command constructor
25
- (e.g., brief, description, hidden).
26
-
27
- Returns:
28
- Callable: A decorator that registers the command.
29
- """
30
-
31
- def decorator(
32
- func: Callable[..., Awaitable[None]],
33
- ) -> Callable[..., Awaitable[None]]:
34
- if not asyncio.iscoroutinefunction(func):
35
- raise TypeError("Command callback must be a coroutine function.")
36
-
37
- from .core import Command
38
-
39
- cmd_name = name or func.__name__
40
-
41
- if hasattr(func, "__command_attrs__"):
42
- raise TypeError("Function is already a command or has command attributes.")
43
-
44
- cmd = Command(callback=func, name=cmd_name, aliases=aliases or [], **attrs)
45
- func.__command_object__ = cmd # type: ignore
46
- return func
47
-
48
- return decorator
49
-
50
-
51
- def listener(
52
- name: Optional[str] = None,
53
- ) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
54
- """
55
- A decorator that marks a function as an event listener within a Cog.
56
- """
57
-
58
- def decorator(
59
- func: Callable[..., Awaitable[None]],
60
- ) -> Callable[..., Awaitable[None]]:
61
- if not asyncio.iscoroutinefunction(func):
62
- raise TypeError("Listener callback must be a coroutine function.")
63
-
64
- actual_event_name = name or func.__name__
65
- setattr(func, "__listener_name__", actual_event_name)
66
- return func
67
-
68
- return decorator
69
-
70
-
71
- def check(
72
- predicate: Callable[["CommandContext"], Awaitable[bool] | bool],
73
- ) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
74
- """Decorator to add a check to a command."""
75
-
76
- def decorator(
77
- func: Callable[..., Awaitable[None]],
78
- ) -> Callable[..., Awaitable[None]]:
79
- checks = getattr(func, "__command_checks__", [])
80
- checks.append(predicate)
81
- setattr(func, "__command_checks__", checks)
82
- return func
83
-
84
- return decorator
85
-
86
-
87
- def check_any(
88
- *predicates: Callable[["CommandContext"], Awaitable[bool] | bool]
89
- ) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
90
- """Decorator that passes if any predicate returns ``True``."""
91
-
92
- async def predicate(ctx: "CommandContext") -> bool:
93
- from .errors import CheckAnyFailure, CheckFailure
94
-
95
- errors = []
96
- for p in predicates:
97
- try:
98
- result = p(ctx)
99
- if inspect.isawaitable(result):
100
- result = await result
101
- if result:
102
- return True
103
- except CheckFailure as e:
104
- errors.append(e)
105
- raise CheckAnyFailure(errors)
106
-
107
- return check(predicate)
108
-
109
-
110
- def max_concurrency(
111
- number: int, per: str = "user"
112
- ) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
113
- """Limit how many concurrent invocations of a command are allowed.
114
-
115
- Parameters
116
- ----------
117
- number:
118
- The maximum number of concurrent invocations.
119
- per:
120
- The scope of the limiter. Can be ``"user"``, ``"guild"`` or ``"global"``.
121
- """
122
-
123
- if number < 1:
124
- raise ValueError("Concurrency number must be at least 1.")
125
- if per not in {"user", "guild", "global"}:
126
- raise ValueError("per must be 'user', 'guild', or 'global'.")
127
-
128
- def decorator(
129
- func: Callable[..., Awaitable[None]],
130
- ) -> Callable[..., Awaitable[None]]:
131
- setattr(func, "__max_concurrency__", (number, per))
132
- return func
133
-
134
- return decorator
135
-
136
-
137
- def cooldown(
138
- rate: int, per: float
139
- ) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
140
- """Simple per-user cooldown decorator."""
141
-
142
- buckets: dict[str, dict[str, float]] = {}
143
-
144
- async def predicate(ctx: "CommandContext") -> bool:
145
- from .errors import CommandOnCooldown
146
-
147
- now = time.monotonic()
148
- user_buckets = buckets.setdefault(ctx.command.name, {})
149
- reset = user_buckets.get(ctx.author.id, 0)
150
- if now < reset:
151
- raise CommandOnCooldown(reset - now)
152
- user_buckets[ctx.author.id] = now + per
153
- return True
154
-
155
- return check(predicate)
156
-
157
-
158
- def _compute_permissions(
159
- member: "Member", channel: "Channel", guild: "Guild"
160
- ) -> "Permissions":
161
- """Compute the effective permissions for a member in a channel."""
162
- return channel.permissions_for(member)
163
-
164
-
165
- def requires_permissions(
166
- *perms: "Permissions",
167
- ) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
168
- """Check that the invoking member has the given permissions in the channel."""
169
-
170
- async def predicate(ctx: "CommandContext") -> bool:
171
- from .errors import CheckFailure
172
- from disagreement.permissions import (
173
- has_permissions,
174
- missing_permissions,
175
- )
176
- from disagreement.models import Member
177
-
178
- channel = getattr(ctx, "channel", None)
179
- if channel is None and hasattr(ctx.bot, "get_channel"):
180
- channel = ctx.bot.get_channel(ctx.message.channel_id)
181
- if channel is None and hasattr(ctx.bot, "fetch_channel"):
182
- channel = await ctx.bot.fetch_channel(ctx.message.channel_id)
183
-
184
- if channel is None:
185
- raise CheckFailure("Channel for permission check not found.")
186
-
187
- guild = getattr(channel, "guild", None)
188
- if not guild and hasattr(channel, "guild_id") and channel.guild_id:
189
- if hasattr(ctx.bot, "get_guild"):
190
- guild = ctx.bot.get_guild(channel.guild_id)
191
- if not guild and hasattr(ctx.bot, "fetch_guild"):
192
- guild = await ctx.bot.fetch_guild(channel.guild_id)
193
-
194
- if not guild:
195
- is_dm = not hasattr(channel, "guild_id") or not channel.guild_id
196
- if is_dm:
197
- if perms:
198
- raise CheckFailure("Permission checks are not supported in DMs.")
199
- return True
200
- raise CheckFailure("Guild for permission check not found.")
201
-
202
- member = ctx.author
203
- if not isinstance(member, Member):
204
- member = guild.get_member(ctx.author.id)
205
- if not member and hasattr(ctx.bot, "fetch_member"):
206
- member = await ctx.bot.fetch_member(guild.id, ctx.author.id)
207
-
208
- if not member:
209
- raise CheckFailure("Could not resolve author to a guild member.")
210
-
211
- perms_value = _compute_permissions(member, channel, guild)
212
-
213
- if not has_permissions(perms_value, *perms):
214
- missing = missing_permissions(perms_value, *perms)
215
- missing_names = ", ".join(p.name for p in missing if p.name)
216
- raise CheckFailure(f"Missing permissions: {missing_names}")
217
- return True
218
-
219
- return check(predicate)
1
+ # disagreement/ext/commands/decorators.py
2
+ from __future__ import annotations
3
+
4
+ import asyncio
5
+ import inspect
6
+ import time
7
+ from typing import Callable, Any, Optional, List, TYPE_CHECKING, Awaitable
8
+
9
+ if TYPE_CHECKING:
10
+ from .core import Command, CommandContext
11
+ from disagreement.permissions import Permissions
12
+ from disagreement.models import Member, Guild, Channel
13
+
14
+
15
+ def command(
16
+ name: Optional[str] = None, aliases: Optional[List[str]] = None, **attrs: Any
17
+ ) -> Callable:
18
+ """
19
+ A decorator that transforms a function into a Command.
20
+
21
+ Args:
22
+ name (Optional[str]): The name of the command. Defaults to the function name.
23
+ aliases (Optional[List[str]]): Alternative names for the command.
24
+ **attrs: Additional attributes to pass to the Command constructor
25
+ (e.g., brief, description, hidden).
26
+
27
+ Returns:
28
+ Callable: A decorator that registers the command.
29
+ """
30
+
31
+ def decorator(
32
+ func: Callable[..., Awaitable[None]],
33
+ ) -> Callable[..., Awaitable[None]]:
34
+ if not asyncio.iscoroutinefunction(func):
35
+ raise TypeError("Command callback must be a coroutine function.")
36
+
37
+ from .core import Command
38
+
39
+ cmd_name = name or func.__name__
40
+
41
+ if hasattr(func, "__command_attrs__"):
42
+ raise TypeError("Function is already a command or has command attributes.")
43
+
44
+ cmd = Command(callback=func, name=cmd_name, aliases=aliases or [], **attrs)
45
+ func.__command_object__ = cmd # type: ignore
46
+ return func
47
+
48
+ return decorator
49
+
50
+
51
+ def listener(
52
+ name: Optional[str] = None,
53
+ ) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
54
+ """
55
+ A decorator that marks a function as an event listener within a Cog.
56
+ """
57
+
58
+ def decorator(
59
+ func: Callable[..., Awaitable[None]],
60
+ ) -> Callable[..., Awaitable[None]]:
61
+ if not asyncio.iscoroutinefunction(func):
62
+ raise TypeError("Listener callback must be a coroutine function.")
63
+
64
+ actual_event_name = name or func.__name__
65
+ setattr(func, "__listener_name__", actual_event_name)
66
+ return func
67
+
68
+ return decorator
69
+
70
+
71
+ def check(
72
+ predicate: Callable[["CommandContext"], Awaitable[bool] | bool],
73
+ ) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
74
+ """Decorator to add a check to a command."""
75
+
76
+ def decorator(
77
+ func: Callable[..., Awaitable[None]],
78
+ ) -> Callable[..., Awaitable[None]]:
79
+ checks = getattr(func, "__command_checks__", [])
80
+ checks.append(predicate)
81
+ setattr(func, "__command_checks__", checks)
82
+ return func
83
+
84
+ return decorator
85
+
86
+
87
+ def check_any(
88
+ *predicates: Callable[["CommandContext"], Awaitable[bool] | bool]
89
+ ) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
90
+ """Decorator that passes if any predicate returns ``True``."""
91
+
92
+ async def predicate(ctx: "CommandContext") -> bool:
93
+ from .errors import CheckAnyFailure, CheckFailure
94
+
95
+ errors = []
96
+ for p in predicates:
97
+ try:
98
+ result = p(ctx)
99
+ if inspect.isawaitable(result):
100
+ result = await result
101
+ if result:
102
+ return True
103
+ except CheckFailure as e:
104
+ errors.append(e)
105
+ raise CheckAnyFailure(errors)
106
+
107
+ return check(predicate)
108
+
109
+
110
+ def max_concurrency(
111
+ number: int, per: str = "user"
112
+ ) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
113
+ """Limit how many concurrent invocations of a command are allowed.
114
+
115
+ Parameters
116
+ ----------
117
+ number:
118
+ The maximum number of concurrent invocations.
119
+ per:
120
+ The scope of the limiter. Can be ``"user"``, ``"guild"`` or ``"global"``.
121
+ """
122
+
123
+ if number < 1:
124
+ raise ValueError("Concurrency number must be at least 1.")
125
+ if per not in {"user", "guild", "global"}:
126
+ raise ValueError("per must be 'user', 'guild', or 'global'.")
127
+
128
+ def decorator(
129
+ func: Callable[..., Awaitable[None]],
130
+ ) -> Callable[..., Awaitable[None]]:
131
+ setattr(func, "__max_concurrency__", (number, per))
132
+ return func
133
+
134
+ return decorator
135
+
136
+
137
+ def cooldown(
138
+ rate: int, per: float
139
+ ) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
140
+ """Simple per-user cooldown decorator."""
141
+
142
+ buckets: dict[str, dict[str, float]] = {}
143
+
144
+ async def predicate(ctx: "CommandContext") -> bool:
145
+ from .errors import CommandOnCooldown
146
+
147
+ now = time.monotonic()
148
+ user_buckets = buckets.setdefault(ctx.command.name, {})
149
+ reset = user_buckets.get(ctx.author.id, 0)
150
+ if now < reset:
151
+ raise CommandOnCooldown(reset - now)
152
+ user_buckets[ctx.author.id] = now + per
153
+ return True
154
+
155
+ return check(predicate)
156
+
157
+
158
+ def _compute_permissions(
159
+ member: "Member", channel: "Channel", guild: "Guild"
160
+ ) -> "Permissions":
161
+ """Compute the effective permissions for a member in a channel."""
162
+ return channel.permissions_for(member)
163
+
164
+
165
+ def requires_permissions(
166
+ *perms: "Permissions",
167
+ ) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
168
+ """Check that the invoking member has the given permissions in the channel."""
169
+
170
+ async def predicate(ctx: "CommandContext") -> bool:
171
+ from .errors import CheckFailure
172
+ from disagreement.permissions import (
173
+ has_permissions,
174
+ missing_permissions,
175
+ )
176
+ from disagreement.models import Member
177
+
178
+ channel = getattr(ctx, "channel", None)
179
+ if channel is None and hasattr(ctx.bot, "get_channel"):
180
+ channel = ctx.bot.get_channel(ctx.message.channel_id)
181
+ if channel is None and hasattr(ctx.bot, "fetch_channel"):
182
+ channel = await ctx.bot.fetch_channel(ctx.message.channel_id)
183
+
184
+ if channel is None:
185
+ raise CheckFailure("Channel for permission check not found.")
186
+
187
+ guild = getattr(channel, "guild", None)
188
+ if not guild and hasattr(channel, "guild_id") and channel.guild_id:
189
+ if hasattr(ctx.bot, "get_guild"):
190
+ guild = ctx.bot.get_guild(channel.guild_id)
191
+ if not guild and hasattr(ctx.bot, "fetch_guild"):
192
+ guild = await ctx.bot.fetch_guild(channel.guild_id)
193
+
194
+ if not guild:
195
+ is_dm = not hasattr(channel, "guild_id") or not channel.guild_id
196
+ if is_dm:
197
+ if perms:
198
+ raise CheckFailure("Permission checks are not supported in DMs.")
199
+ return True
200
+ raise CheckFailure("Guild for permission check not found.")
201
+
202
+ member = ctx.author
203
+ if not isinstance(member, Member):
204
+ member = guild.get_member(ctx.author.id)
205
+ if not member and hasattr(ctx.bot, "fetch_member"):
206
+ member = await ctx.bot.fetch_member(guild.id, ctx.author.id)
207
+
208
+ if not member:
209
+ raise CheckFailure("Could not resolve author to a guild member.")
210
+
211
+ perms_value = _compute_permissions(member, channel, guild)
212
+
213
+ if not has_permissions(perms_value, *perms):
214
+ missing = missing_permissions(perms_value, *perms)
215
+ missing_names = ", ".join(p.name for p in missing if p.name)
216
+ raise CheckFailure(f"Missing permissions: {missing_names}")
217
+ return True
218
+
219
+ return check(predicate)
220
+
221
+ def has_role(
222
+ name_or_id: str | int,
223
+ ) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
224
+ """Check that the invoking member has a role with the given name or ID."""
225
+
226
+ async def predicate(ctx: "CommandContext") -> bool:
227
+ from .errors import CheckFailure
228
+ from disagreement.models import Member
229
+
230
+ if not ctx.guild:
231
+ raise CheckFailure("This command cannot be used in DMs.")
232
+
233
+ author = ctx.author
234
+ if not isinstance(author, Member):
235
+ try:
236
+ author = await ctx.bot.fetch_member(ctx.guild.id, author.id)
237
+ except Exception:
238
+ raise CheckFailure("Could not resolve author to a guild member.")
239
+
240
+ if not author:
241
+ raise CheckFailure("Could not resolve author to a guild member.")
242
+
243
+ # Create a list of the member's role objects by looking them up in the guild's roles list
244
+ member_roles = [
245
+ role for role in ctx.guild.roles if role.id in author.roles
246
+ ]
247
+
248
+ if any(
249
+ role.id == str(name_or_id) or role.name == name_or_id
250
+ for role in member_roles
251
+ ):
252
+ return True
253
+
254
+ raise CheckFailure(f"You need the '{name_or_id}' role to use this command.")
255
+
256
+ return check(predicate)
257
+
258
+
259
+ def has_any_role(
260
+ *names_or_ids: str | int,
261
+ ) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
262
+ """Check that the invoking member has any of the roles with the given names or IDs."""
263
+
264
+ async def predicate(ctx: "CommandContext") -> bool:
265
+ from .errors import CheckFailure
266
+ from disagreement.models import Member
267
+
268
+ if not ctx.guild:
269
+ raise CheckFailure("This command cannot be used in DMs.")
270
+
271
+ author = ctx.author
272
+ if not isinstance(author, Member):
273
+ try:
274
+ author = await ctx.bot.fetch_member(ctx.guild.id, author.id)
275
+ except Exception:
276
+ raise CheckFailure("Could not resolve author to a guild member.")
277
+
278
+ if not author:
279
+ raise CheckFailure("Could not resolve author to a guild member.")
280
+
281
+ member_roles = [
282
+ role for role in ctx.guild.roles if role.id in author.roles
283
+ ]
284
+ # Convert names_or_ids to a set for efficient lookup
285
+ names_or_ids_set = set(map(str, names_or_ids))
286
+
287
+ if any(
288
+ role.id in names_or_ids_set or role.name in names_or_ids_set
289
+ for role in member_roles
290
+ ):
291
+ return True
292
+
293
+ role_list = ", ".join(f"'{r}'" for r in names_or_ids)
294
+ raise CheckFailure(
295
+ f"You need one of the following roles to use this command: {role_list}"
296
+ )
297
+
298
+ return check(predicate)