disagreement 0.2.0rc1__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. disagreement/__init__.py +2 -4
  2. disagreement/audio.py +42 -5
  3. disagreement/cache.py +43 -4
  4. disagreement/caching.py +121 -0
  5. disagreement/client.py +1682 -1535
  6. disagreement/enums.py +10 -3
  7. disagreement/error_handler.py +5 -1
  8. disagreement/errors.py +1341 -3
  9. disagreement/event_dispatcher.py +3 -5
  10. disagreement/ext/__init__.py +1 -0
  11. disagreement/ext/app_commands/__init__.py +0 -2
  12. disagreement/ext/app_commands/commands.py +0 -2
  13. disagreement/ext/app_commands/context.py +0 -2
  14. disagreement/ext/app_commands/converters.py +2 -4
  15. disagreement/ext/app_commands/decorators.py +5 -7
  16. disagreement/ext/app_commands/handler.py +1 -3
  17. disagreement/ext/app_commands/hybrid.py +0 -2
  18. disagreement/ext/commands/__init__.py +63 -61
  19. disagreement/ext/commands/cog.py +0 -2
  20. disagreement/ext/commands/converters.py +16 -5
  21. disagreement/ext/commands/core.py +728 -563
  22. disagreement/ext/commands/decorators.py +294 -219
  23. disagreement/ext/commands/errors.py +0 -2
  24. disagreement/ext/commands/help.py +0 -2
  25. disagreement/ext/commands/view.py +1 -3
  26. disagreement/gateway.py +632 -586
  27. disagreement/http.py +1362 -1041
  28. disagreement/interactions.py +0 -2
  29. disagreement/models.py +2682 -2263
  30. disagreement/shard_manager.py +0 -2
  31. disagreement/ui/view.py +167 -165
  32. disagreement/voice_client.py +263 -162
  33. {disagreement-0.2.0rc1.dist-info → disagreement-0.4.0.dist-info}/METADATA +33 -6
  34. disagreement-0.4.0.dist-info/RECORD +55 -0
  35. disagreement-0.2.0rc1.dist-info/RECORD +0 -54
  36. {disagreement-0.2.0rc1.dist-info → disagreement-0.4.0.dist-info}/WHEEL +0 -0
  37. {disagreement-0.2.0rc1.dist-info → disagreement-0.4.0.dist-info}/licenses/LICENSE +0 -0
  38. {disagreement-0.2.0rc1.dist-info → disagreement-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,563 +1,728 @@
1
- # disagreement/ext/commands/core.py
2
-
3
- from __future__ import annotations
4
-
5
- import asyncio
6
- import logging
7
- import inspect
8
- from typing import (
9
- TYPE_CHECKING,
10
- Optional,
11
- List,
12
- Dict,
13
- Any,
14
- Union,
15
- Callable,
16
- Awaitable,
17
- Tuple,
18
- get_origin,
19
- get_args,
20
- )
21
-
22
- from .view import StringView
23
- from .errors import (
24
- CommandError,
25
- CommandNotFound,
26
- BadArgument,
27
- MissingRequiredArgument,
28
- ArgumentParsingError,
29
- CheckFailure,
30
- CommandInvokeError,
31
- )
32
- from .converters import run_converters, DEFAULT_CONVERTERS, Converter
33
- from disagreement.typing import Typing
34
-
35
- logger = logging.getLogger(__name__)
36
-
37
- if TYPE_CHECKING:
38
- from .cog import Cog
39
- from disagreement.client import Client
40
- from disagreement.models import Message, User
41
-
42
-
43
- class Command:
44
- """
45
- Represents a bot command.
46
-
47
- Attributes:
48
- name (str): The primary name of the command.
49
- callback (Callable[..., Awaitable[None]]): The coroutine function to execute.
50
- aliases (List[str]): Alternative names for the command.
51
- brief (Optional[str]): A short description for help commands.
52
- description (Optional[str]): A longer description for help commands.
53
- cog (Optional['Cog']): Reference to the Cog this command belongs to.
54
- params (Dict[str, inspect.Parameter]): Cached parameters of the callback.
55
- """
56
-
57
- def __init__(self, callback: Callable[..., Awaitable[None]], **attrs: Any):
58
- if not asyncio.iscoroutinefunction(callback):
59
- raise TypeError("Command callback must be a coroutine function.")
60
-
61
- self.callback: Callable[..., Awaitable[None]] = callback
62
- self.name: str = attrs.get("name", callback.__name__)
63
- self.aliases: List[str] = attrs.get("aliases", [])
64
- self.brief: Optional[str] = attrs.get("brief")
65
- self.description: Optional[str] = attrs.get("description") or callback.__doc__
66
- self.cog: Optional["Cog"] = attrs.get("cog")
67
-
68
- self.params = inspect.signature(callback).parameters
69
- self.checks: List[Callable[["CommandContext"], Awaitable[bool] | bool]] = []
70
- if hasattr(callback, "__command_checks__"):
71
- self.checks.extend(getattr(callback, "__command_checks__"))
72
-
73
- self.max_concurrency: Optional[Tuple[int, str]] = None
74
- if hasattr(callback, "__max_concurrency__"):
75
- self.max_concurrency = getattr(callback, "__max_concurrency__")
76
-
77
- def add_check(
78
- self, predicate: Callable[["CommandContext"], Awaitable[bool] | bool]
79
- ) -> None:
80
- self.checks.append(predicate)
81
-
82
- async def invoke(self, ctx: "CommandContext", *args: Any, **kwargs: Any) -> None:
83
- from .errors import CheckFailure
84
-
85
- for predicate in self.checks:
86
- result = predicate(ctx)
87
- if inspect.isawaitable(result):
88
- result = await result
89
- if not result:
90
- raise CheckFailure("Check predicate failed.")
91
-
92
- if self.cog:
93
- await self.callback(self.cog, ctx, *args, **kwargs)
94
- else:
95
- await self.callback(ctx, *args, **kwargs)
96
-
97
-
98
- PrefixCommand = Command # Alias for clarity in hybrid commands
99
-
100
-
101
- class CommandContext:
102
- """
103
- Represents the context in which a command is being invoked.
104
- """
105
-
106
- def __init__(
107
- self,
108
- *,
109
- message: "Message",
110
- bot: "Client",
111
- prefix: str,
112
- command: "Command",
113
- invoked_with: str,
114
- args: Optional[List[Any]] = None,
115
- kwargs: Optional[Dict[str, Any]] = None,
116
- cog: Optional["Cog"] = None,
117
- ):
118
- self.message: "Message" = message
119
- self.bot: "Client" = bot
120
- self.prefix: str = prefix
121
- self.command: "Command" = command
122
- self.invoked_with: str = invoked_with
123
- self.args: List[Any] = args or []
124
- self.kwargs: Dict[str, Any] = kwargs or {}
125
- self.cog: Optional["Cog"] = cog
126
-
127
- self.author: "User" = message.author
128
-
129
- @property
130
- def guild(self):
131
- """The guild this command was invoked in."""
132
- if self.message.guild_id and hasattr(self.bot, "get_guild"):
133
- return self.bot.get_guild(self.message.guild_id)
134
- return None
135
-
136
- async def reply(
137
- self,
138
- content: Optional[str] = None,
139
- *,
140
- mention_author: Optional[bool] = None,
141
- **kwargs: Any,
142
- ) -> "Message":
143
- """Replies to the invoking message.
144
-
145
- Parameters
146
- ----------
147
- content: str
148
- The content to send.
149
- mention_author: Optional[bool]
150
- Whether to mention the author in the reply. If ``None`` the
151
- client's :attr:`mention_replies` value is used.
152
- """
153
-
154
- allowed_mentions = kwargs.pop("allowed_mentions", None)
155
- if mention_author is None:
156
- mention_author = getattr(self.bot, "mention_replies", False)
157
-
158
- if allowed_mentions is None:
159
- allowed_mentions = {"replied_user": mention_author}
160
- else:
161
- allowed_mentions = dict(allowed_mentions)
162
- allowed_mentions.setdefault("replied_user", mention_author)
163
-
164
- return await self.bot.send_message(
165
- channel_id=self.message.channel_id,
166
- content=content,
167
- message_reference={
168
- "message_id": self.message.id,
169
- "channel_id": self.message.channel_id,
170
- "guild_id": self.message.guild_id,
171
- },
172
- allowed_mentions=allowed_mentions,
173
- **kwargs,
174
- )
175
-
176
- async def send(self, content: str, **kwargs: Any) -> "Message":
177
- return await self.bot.send_message(
178
- channel_id=self.message.channel_id, content=content, **kwargs
179
- )
180
-
181
- async def edit(
182
- self,
183
- message: Union[str, "Message"],
184
- *,
185
- content: Optional[str] = None,
186
- **kwargs: Any,
187
- ) -> "Message":
188
- """Edits a message previously sent by the bot."""
189
-
190
- message_id = message if isinstance(message, str) else message.id
191
- return await self.bot.edit_message(
192
- channel_id=self.message.channel_id,
193
- message_id=message_id,
194
- content=content,
195
- **kwargs,
196
- )
197
-
198
- def typing(self) -> "Typing":
199
- """Return a typing context manager for this context's channel."""
200
-
201
- return self.bot.typing(self.message.channel_id)
202
-
203
-
204
- class CommandHandler:
205
- """
206
- Manages command registration, parsing, and dispatching.
207
- """
208
-
209
- def __init__(
210
- self,
211
- client: "Client",
212
- prefix: Union[
213
- str, List[str], Callable[["Client", "Message"], Union[str, List[str]]]
214
- ],
215
- ):
216
- self.client: "Client" = client
217
- self.prefix: Union[
218
- str, List[str], Callable[["Client", "Message"], Union[str, List[str]]]
219
- ] = prefix
220
- self.commands: Dict[str, Command] = {}
221
- self.cogs: Dict[str, "Cog"] = {}
222
- self._concurrency: Dict[str, Dict[str, int]] = {}
223
-
224
- from .help import HelpCommand
225
-
226
- self.add_command(HelpCommand(self))
227
-
228
- def add_command(self, command: Command) -> None:
229
- if command.name in self.commands:
230
- raise ValueError(f"Command '{command.name}' is already registered.")
231
-
232
- self.commands[command.name.lower()] = command
233
- for alias in command.aliases:
234
- if alias in self.commands:
235
- logger.warning(
236
- "Alias '%s' for command '%s' conflicts with an existing command or alias.",
237
- alias,
238
- command.name,
239
- )
240
- self.commands[alias.lower()] = command
241
-
242
- def remove_command(self, name: str) -> Optional[Command]:
243
- command = self.commands.pop(name.lower(), None)
244
- if command:
245
- for alias in command.aliases:
246
- self.commands.pop(alias.lower(), None)
247
- return command
248
-
249
- def get_command(self, name: str) -> Optional[Command]:
250
- return self.commands.get(name.lower())
251
-
252
- def add_cog(self, cog_to_add: "Cog") -> None:
253
- from .cog import Cog
254
-
255
- if not isinstance(cog_to_add, Cog):
256
- raise TypeError("Argument must be a subclass of Cog.")
257
-
258
- if cog_to_add.cog_name in self.cogs:
259
- raise ValueError(
260
- f"Cog with name '{cog_to_add.cog_name}' is already registered."
261
- )
262
-
263
- self.cogs[cog_to_add.cog_name] = cog_to_add
264
-
265
- for cmd in cog_to_add.get_commands():
266
- self.add_command(cmd)
267
-
268
- if hasattr(self.client, "_event_dispatcher"):
269
- for event_name, callback in cog_to_add.get_listeners():
270
- self.client._event_dispatcher.register(event_name.upper(), callback)
271
- else:
272
- logger.warning(
273
- "Client does not have '_event_dispatcher'. Listeners for cog '%s' not registered.",
274
- cog_to_add.cog_name,
275
- )
276
-
277
- if hasattr(cog_to_add, "cog_load") and inspect.iscoroutinefunction(
278
- cog_to_add.cog_load
279
- ):
280
- asyncio.create_task(cog_to_add.cog_load())
281
-
282
- logger.info("Cog '%s' added.", cog_to_add.cog_name)
283
-
284
- def remove_cog(self, cog_name: str) -> Optional["Cog"]:
285
- cog_to_remove = self.cogs.pop(cog_name, None)
286
- if cog_to_remove:
287
- for cmd in cog_to_remove.get_commands():
288
- self.remove_command(cmd.name)
289
-
290
- if hasattr(self.client, "_event_dispatcher"):
291
- for event_name, callback in cog_to_remove.get_listeners():
292
- logger.debug(
293
- "Listener '%s' for event '%s' from cog '%s' needs manual unregistration logic in EventDispatcher.",
294
- callback.__name__,
295
- event_name,
296
- cog_name,
297
- )
298
-
299
- if hasattr(cog_to_remove, "cog_unload") and inspect.iscoroutinefunction(
300
- cog_to_remove.cog_unload
301
- ):
302
- asyncio.create_task(cog_to_remove.cog_unload())
303
-
304
- cog_to_remove._eject()
305
- logger.info("Cog '%s' removed.", cog_name)
306
- return cog_to_remove
307
-
308
- def _acquire_concurrency(self, ctx: CommandContext) -> None:
309
- mc = getattr(ctx.command, "max_concurrency", None)
310
- if not mc:
311
- return
312
- limit, scope = mc
313
- if scope == "user":
314
- key = ctx.author.id
315
- elif scope == "guild":
316
- key = ctx.message.guild_id or ctx.author.id
317
- else:
318
- key = "global"
319
- buckets = self._concurrency.setdefault(ctx.command.name, {})
320
- current = buckets.get(key, 0)
321
- if current >= limit:
322
- from .errors import MaxConcurrencyReached
323
-
324
- raise MaxConcurrencyReached(limit)
325
- buckets[key] = current + 1
326
-
327
- def _release_concurrency(self, ctx: CommandContext) -> None:
328
- mc = getattr(ctx.command, "max_concurrency", None)
329
- if not mc:
330
- return
331
- _, scope = mc
332
- if scope == "user":
333
- key = ctx.author.id
334
- elif scope == "guild":
335
- key = ctx.message.guild_id or ctx.author.id
336
- else:
337
- key = "global"
338
- buckets = self._concurrency.get(ctx.command.name)
339
- if not buckets:
340
- return
341
- current = buckets.get(key, 0)
342
- if current <= 1:
343
- buckets.pop(key, None)
344
- else:
345
- buckets[key] = current - 1
346
- if not buckets:
347
- self._concurrency.pop(ctx.command.name, None)
348
-
349
- async def get_prefix(self, message: "Message") -> Union[str, List[str], None]:
350
- if callable(self.prefix):
351
- if inspect.iscoroutinefunction(self.prefix):
352
- return await self.prefix(self.client, message)
353
- else:
354
- return self.prefix(self.client, message) # type: ignore
355
- return self.prefix
356
-
357
- async def _parse_arguments(
358
- self, command: Command, ctx: CommandContext, view: StringView
359
- ) -> Tuple[List[Any], Dict[str, Any]]:
360
- args_list = []
361
- kwargs_dict = {}
362
- params_to_parse = list(command.params.values())
363
-
364
- if params_to_parse and params_to_parse[0].name == "self" and command.cog:
365
- params_to_parse.pop(0)
366
- if params_to_parse and params_to_parse[0].name == "ctx":
367
- params_to_parse.pop(0)
368
-
369
- for param in params_to_parse:
370
- view.skip_whitespace()
371
- final_value_for_param: Any = inspect.Parameter.empty
372
-
373
- if param.kind == inspect.Parameter.VAR_POSITIONAL:
374
- while not view.eof:
375
- view.skip_whitespace()
376
- if view.eof:
377
- break
378
- word = view.get_word()
379
- if word or not view.eof:
380
- args_list.append(word)
381
- elif view.eof:
382
- break
383
- break
384
-
385
- arg_str_value: Optional[str] = (
386
- None # Holds the raw string for current param
387
- )
388
-
389
- if view.eof: # No more input string
390
- if param.default is not inspect.Parameter.empty:
391
- final_value_for_param = param.default
392
- elif param.kind != inspect.Parameter.VAR_KEYWORD:
393
- raise MissingRequiredArgument(param.name)
394
- else: # VAR_KEYWORD at EOF is fine
395
- break
396
- else: # Input available
397
- is_last_pos_str_greedy = (
398
- param == params_to_parse[-1]
399
- and param.annotation is str
400
- and param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
401
- )
402
-
403
- if is_last_pos_str_greedy:
404
- arg_str_value = view.read_rest().strip()
405
- if (
406
- not arg_str_value
407
- and param.default is not inspect.Parameter.empty
408
- ):
409
- final_value_for_param = param.default
410
- else: # Includes empty string if that's what's left
411
- final_value_for_param = arg_str_value
412
- else: # Not greedy, or not string, or not last positional
413
- if view.buffer[view.index] == '"':
414
- arg_str_value = view.get_quoted_string()
415
- if arg_str_value == "" and view.buffer[view.index] == '"':
416
- raise BadArgument(
417
- f"Unterminated quoted string for argument '{param.name}'."
418
- )
419
- else:
420
- arg_str_value = view.get_word()
421
-
422
- # If final_value_for_param was not set by greedy logic, try conversion
423
- if final_value_for_param is inspect.Parameter.empty:
424
- if (
425
- arg_str_value is None
426
- ): # Should not happen if view.get_word/get_quoted_string is robust
427
- if param.default is not inspect.Parameter.empty:
428
- final_value_for_param = param.default
429
- else:
430
- raise MissingRequiredArgument(param.name)
431
- else: # We have an arg_str_value (could be empty string "" from quotes)
432
- annotation = param.annotation
433
- origin = get_origin(annotation)
434
-
435
- if origin is Union: # Handles Optional[T] and Union[T1, T2]
436
- union_args = get_args(annotation)
437
- is_optional = (
438
- len(union_args) == 2 and type(None) in union_args
439
- )
440
-
441
- converted_for_union = False
442
- last_err_union: Optional[BadArgument] = None
443
- for t_arg in union_args:
444
- if t_arg is type(None):
445
- continue
446
- try:
447
- final_value_for_param = await run_converters(
448
- ctx, t_arg, arg_str_value
449
- )
450
- converted_for_union = True
451
- break
452
- except BadArgument as e:
453
- last_err_union = e
454
-
455
- if not converted_for_union:
456
- if (
457
- is_optional and param.default is None
458
- ): # Special handling for Optional[T] if conversion failed
459
- # If arg_str_value was "" and type was Optional[str], StringConverter would return ""
460
- # If arg_str_value was "" and type was Optional[int], BadArgument would be raised.
461
- # This path is for when all actual types in Optional[T] fail conversion.
462
- # If default is None, we can assign None.
463
- final_value_for_param = None
464
- elif last_err_union:
465
- raise last_err_union
466
- else: # Should not be reached if logic is correct
467
- raise BadArgument(
468
- f"Could not convert '{arg_str_value}' to any of {union_args} for param '{param.name}'."
469
- )
470
- elif annotation is inspect.Parameter.empty or annotation is str:
471
- final_value_for_param = arg_str_value
472
- else: # Standard type hint
473
- final_value_for_param = await run_converters(
474
- ctx, annotation, arg_str_value
475
- )
476
-
477
- # Final check if value was resolved
478
- if final_value_for_param is inspect.Parameter.empty:
479
- if param.default is not inspect.Parameter.empty:
480
- final_value_for_param = param.default
481
- elif param.kind != inspect.Parameter.VAR_KEYWORD:
482
- # This state implies an issue if required and no default, and no input was parsed.
483
- raise MissingRequiredArgument(
484
- f"Parameter '{param.name}' could not be resolved."
485
- )
486
-
487
- # Assign to args_list or kwargs_dict if a value was determined
488
- if final_value_for_param is not inspect.Parameter.empty:
489
- if (
490
- param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
491
- or param.kind == inspect.Parameter.POSITIONAL_ONLY
492
- ):
493
- args_list.append(final_value_for_param)
494
- elif param.kind == inspect.Parameter.KEYWORD_ONLY:
495
- kwargs_dict[param.name] = final_value_for_param
496
-
497
- return args_list, kwargs_dict
498
-
499
- async def process_commands(self, message: "Message") -> None:
500
- if not message.content:
501
- return
502
-
503
- prefix_to_use = await self.get_prefix(message)
504
- if not prefix_to_use:
505
- return
506
-
507
- actual_prefix: Optional[str] = None
508
- if isinstance(prefix_to_use, list):
509
- for p in prefix_to_use:
510
- if message.content.startswith(p):
511
- actual_prefix = p
512
- break
513
- if not actual_prefix:
514
- return
515
- elif isinstance(prefix_to_use, str):
516
- if message.content.startswith(prefix_to_use):
517
- actual_prefix = prefix_to_use
518
- else:
519
- return
520
- else:
521
- return
522
-
523
- if actual_prefix is None:
524
- return
525
-
526
- content_without_prefix = message.content[len(actual_prefix) :]
527
- view = StringView(content_without_prefix)
528
-
529
- command_name = view.get_word()
530
- if not command_name:
531
- return
532
-
533
- command = self.get_command(command_name)
534
- if not command:
535
- return
536
-
537
- ctx = CommandContext(
538
- message=message,
539
- bot=self.client,
540
- prefix=actual_prefix,
541
- command=command,
542
- invoked_with=command_name,
543
- cog=command.cog,
544
- )
545
-
546
- try:
547
- parsed_args, parsed_kwargs = await self._parse_arguments(command, ctx, view)
548
- ctx.args = parsed_args
549
- ctx.kwargs = parsed_kwargs
550
- self._acquire_concurrency(ctx)
551
- try:
552
- await command.invoke(ctx, *parsed_args, **parsed_kwargs)
553
- finally:
554
- self._release_concurrency(ctx)
555
- except CommandError as e:
556
- logger.error("Command error for '%s': %s", command.name, e)
557
- if hasattr(self.client, "on_command_error"):
558
- await self.client.on_command_error(ctx, e)
559
- except Exception as e:
560
- logger.error("Unexpected error invoking command '%s': %s", command.name, e)
561
- exc = CommandInvokeError(e)
562
- if hasattr(self.client, "on_command_error"):
563
- await self.client.on_command_error(ctx, exc)
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import logging
5
+ import inspect
6
+ from typing import (
7
+ TYPE_CHECKING,
8
+ Optional,
9
+ List,
10
+ Dict,
11
+ Any,
12
+ Union,
13
+ Callable,
14
+ Awaitable,
15
+ Tuple,
16
+ get_origin,
17
+ get_args,
18
+ )
19
+
20
+ from .view import StringView
21
+ from .errors import (
22
+ CommandError,
23
+ CommandNotFound,
24
+ BadArgument,
25
+ MissingRequiredArgument,
26
+ ArgumentParsingError,
27
+ CheckFailure,
28
+ CommandInvokeError,
29
+ )
30
+ from .converters import Greedy, run_converters, DEFAULT_CONVERTERS, Converter
31
+ from disagreement.typing import Typing
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+ if TYPE_CHECKING:
36
+ from .cog import Cog
37
+ from disagreement.client import Client
38
+ from disagreement.models import Message, User
39
+
40
+
41
+ class GroupMixin:
42
+ def __init__(self, *args, **kwargs):
43
+ super().__init__()
44
+ self.commands: Dict[str, "Command"] = {}
45
+ self.name: str = ""
46
+
47
+ def command(
48
+ self, **attrs: Any
49
+ ) -> Callable[[Callable[..., Awaitable[None]]], "Command"]:
50
+ def decorator(func: Callable[..., Awaitable[None]]) -> "Command":
51
+ cmd = Command(func, **attrs)
52
+ cmd.cog = getattr(self, "cog", None)
53
+ self.add_command(cmd)
54
+ return cmd
55
+
56
+ return decorator
57
+
58
+ def group(
59
+ self, **attrs: Any
60
+ ) -> Callable[[Callable[..., Awaitable[None]]], "Group"]:
61
+ def decorator(func: Callable[..., Awaitable[None]]) -> "Group":
62
+ cmd = Group(func, **attrs)
63
+ cmd.cog = getattr(self, "cog", None)
64
+ self.add_command(cmd)
65
+ return cmd
66
+
67
+ return decorator
68
+
69
+ def add_command(self, command: "Command") -> None:
70
+ if command.name in self.commands:
71
+ raise ValueError(
72
+ f"Command '{command.name}' is already registered in group '{self.name}'."
73
+ )
74
+ self.commands[command.name.lower()] = command
75
+ for alias in command.aliases:
76
+ if alias in self.commands:
77
+ logger.warning(
78
+ f"Alias '{alias}' for command '{command.name}' in group '{self.name}' conflicts with an existing command or alias."
79
+ )
80
+ self.commands[alias.lower()] = command
81
+
82
+ def get_command(self, name: str) -> Optional["Command"]:
83
+ return self.commands.get(name.lower())
84
+
85
+
86
+ class Command(GroupMixin):
87
+ """
88
+ Represents a bot command.
89
+
90
+ Attributes:
91
+ name (str): The primary name of the command.
92
+ callback (Callable[..., Awaitable[None]]): The coroutine function to execute.
93
+ aliases (List[str]): Alternative names for the command.
94
+ brief (Optional[str]): A short description for help commands.
95
+ description (Optional[str]): A longer description for help commands.
96
+ cog (Optional['Cog']): Reference to the Cog this command belongs to.
97
+ params (Dict[str, inspect.Parameter]): Cached parameters of the callback.
98
+ """
99
+
100
+ def __init__(self, callback: Callable[..., Awaitable[None]], **attrs: Any):
101
+ if not asyncio.iscoroutinefunction(callback):
102
+ raise TypeError("Command callback must be a coroutine function.")
103
+
104
+ super().__init__(**attrs)
105
+ self.callback: Callable[..., Awaitable[None]] = callback
106
+ self.name: str = attrs.get("name", callback.__name__)
107
+ self.aliases: List[str] = attrs.get("aliases", [])
108
+ self.brief: Optional[str] = attrs.get("brief")
109
+ self.description: Optional[str] = attrs.get("description") or callback.__doc__
110
+ self.cog: Optional["Cog"] = attrs.get("cog")
111
+ self.invoke_without_command: bool = attrs.get("invoke_without_command", False)
112
+
113
+ self.params = inspect.signature(callback).parameters
114
+ self.checks: List[Callable[["CommandContext"], Awaitable[bool] | bool]] = []
115
+ if hasattr(callback, "__command_checks__"):
116
+ self.checks.extend(getattr(callback, "__command_checks__"))
117
+
118
+ self.max_concurrency: Optional[Tuple[int, str]] = None
119
+ if hasattr(callback, "__max_concurrency__"):
120
+ self.max_concurrency = getattr(callback, "__max_concurrency__")
121
+
122
+ def add_check(
123
+ self, predicate: Callable[["CommandContext"], Awaitable[bool] | bool]
124
+ ) -> None:
125
+ self.checks.append(predicate)
126
+
127
+ async def _run_checks(self, ctx: "CommandContext") -> None:
128
+ """Runs all cog, local and global checks for the command."""
129
+ from .errors import CheckFailure
130
+
131
+ # Run cog-level check first
132
+ if self.cog:
133
+ cog_check = getattr(self.cog, "cog_check", None)
134
+ if cog_check:
135
+ try:
136
+ result = cog_check(ctx)
137
+ if inspect.isawaitable(result):
138
+ result = await result
139
+ if not result:
140
+ raise CheckFailure(
141
+ f"The cog-level check for command '{self.name}' failed."
142
+ )
143
+ except CheckFailure:
144
+ raise
145
+ except Exception as e:
146
+ raise CommandInvokeError(e) from e
147
+
148
+ # Run local checks
149
+ for predicate in self.checks:
150
+ result = predicate(ctx)
151
+ if inspect.isawaitable(result):
152
+ result = await result
153
+ if not result:
154
+ raise CheckFailure(f"A local check for command '{self.name}' failed.")
155
+
156
+ # Then run global checks from the handler
157
+ if hasattr(ctx.bot, "command_handler"):
158
+ for predicate in ctx.bot.command_handler._global_checks:
159
+ result = predicate(ctx)
160
+ if inspect.isawaitable(result):
161
+ result = await result
162
+ if not result:
163
+ raise CheckFailure(
164
+ f"A global check failed for command '{self.name}'."
165
+ )
166
+
167
+ async def invoke(self, ctx: "CommandContext", *args: Any, **kwargs: Any) -> None:
168
+ await self._run_checks(ctx)
169
+
170
+ before_invoke = None
171
+ after_invoke = None
172
+
173
+ if self.cog:
174
+ before_invoke = getattr(self.cog, "cog_before_invoke", None)
175
+ after_invoke = getattr(self.cog, "cog_after_invoke", None)
176
+
177
+ if before_invoke:
178
+ await before_invoke(ctx)
179
+
180
+ try:
181
+ if self.cog:
182
+ await self.callback(self.cog, ctx, *args, **kwargs)
183
+ else:
184
+ await self.callback(ctx, *args, **kwargs)
185
+ finally:
186
+ if after_invoke:
187
+ await after_invoke(ctx)
188
+
189
+
190
+ class Group(Command):
191
+ """A command that can have subcommands."""
192
+
193
+ def __init__(self, callback: Callable[..., Awaitable[None]], **attrs: Any):
194
+ super().__init__(callback, **attrs)
195
+
196
+
197
+ PrefixCommand = Command # Alias for clarity in hybrid commands
198
+
199
+
200
+ class CommandContext:
201
+ """
202
+ Represents the context in which a command is being invoked.
203
+ """
204
+
205
+ def __init__(
206
+ self,
207
+ *,
208
+ message: "Message",
209
+ bot: "Client",
210
+ prefix: str,
211
+ command: "Command",
212
+ invoked_with: str,
213
+ args: Optional[List[Any]] = None,
214
+ kwargs: Optional[Dict[str, Any]] = None,
215
+ cog: Optional["Cog"] = None,
216
+ ):
217
+ self.message: "Message" = message
218
+ self.bot: "Client" = bot
219
+ self.prefix: str = prefix
220
+ self.command: "Command" = command
221
+ self.invoked_with: str = invoked_with
222
+ self.args: List[Any] = args or []
223
+ self.kwargs: Dict[str, Any] = kwargs or {}
224
+ self.cog: Optional["Cog"] = cog
225
+
226
+ self.author: "User" = message.author
227
+
228
+ @property
229
+ def guild(self):
230
+ """The guild this command was invoked in."""
231
+ if self.message.guild_id and hasattr(self.bot, "get_guild"):
232
+ return self.bot.get_guild(self.message.guild_id)
233
+ return None
234
+
235
+ async def reply(
236
+ self,
237
+ content: Optional[str] = None,
238
+ *,
239
+ mention_author: Optional[bool] = None,
240
+ **kwargs: Any,
241
+ ) -> "Message":
242
+ """Replies to the invoking message.
243
+
244
+ Parameters
245
+ ----------
246
+ content: str
247
+ The content to send.
248
+ mention_author: Optional[bool]
249
+ Whether to mention the author in the reply. If ``None`` the
250
+ client's :attr:`mention_replies` value is used.
251
+ """
252
+
253
+ allowed_mentions = kwargs.pop("allowed_mentions", None)
254
+ if mention_author is None:
255
+ mention_author = getattr(self.bot, "mention_replies", False)
256
+
257
+ if allowed_mentions is None:
258
+ allowed_mentions = {"replied_user": mention_author}
259
+ else:
260
+ allowed_mentions = dict(allowed_mentions)
261
+ allowed_mentions.setdefault("replied_user", mention_author)
262
+
263
+ return await self.bot.send_message(
264
+ channel_id=self.message.channel_id,
265
+ content=content,
266
+ message_reference={
267
+ "message_id": self.message.id,
268
+ "channel_id": self.message.channel_id,
269
+ "guild_id": self.message.guild_id,
270
+ },
271
+ allowed_mentions=allowed_mentions,
272
+ **kwargs,
273
+ )
274
+
275
+ async def send(self, content: str, **kwargs: Any) -> "Message":
276
+ return await self.bot.send_message(
277
+ channel_id=self.message.channel_id, content=content, **kwargs
278
+ )
279
+
280
+ async def edit(
281
+ self,
282
+ message: Union[str, "Message"],
283
+ *,
284
+ content: Optional[str] = None,
285
+ **kwargs: Any,
286
+ ) -> "Message":
287
+ """Edits a message previously sent by the bot."""
288
+
289
+ message_id = message if isinstance(message, str) else message.id
290
+ return await self.bot.edit_message(
291
+ channel_id=self.message.channel_id,
292
+ message_id=message_id,
293
+ content=content,
294
+ **kwargs,
295
+ )
296
+
297
+ def typing(self) -> "Typing":
298
+ """Return a typing context manager for this context's channel."""
299
+
300
+ return self.bot.typing(self.message.channel_id)
301
+
302
+
303
+ class CommandHandler:
304
+ """
305
+ Manages command registration, parsing, and dispatching.
306
+ """
307
+
308
+ def __init__(
309
+ self,
310
+ client: "Client",
311
+ prefix: Union[
312
+ str, List[str], Callable[["Client", "Message"], Union[str, List[str]]]
313
+ ],
314
+ ):
315
+ self.client: "Client" = client
316
+ self.prefix: Union[
317
+ str, List[str], Callable[["Client", "Message"], Union[str, List[str]]]
318
+ ] = prefix
319
+ self.commands: Dict[str, Command] = {}
320
+ self.cogs: Dict[str, "Cog"] = {}
321
+ self._concurrency: Dict[str, Dict[str, int]] = {}
322
+ self._global_checks: List[
323
+ Callable[["CommandContext"], Awaitable[bool] | bool]
324
+ ] = []
325
+
326
+ from .help import HelpCommand
327
+
328
+ self.add_command(HelpCommand(self))
329
+
330
+ def add_check(
331
+ self, predicate: Callable[["CommandContext"], Awaitable[bool] | bool]
332
+ ) -> None:
333
+ """Adds a global check to the command handler."""
334
+ self._global_checks.append(predicate)
335
+
336
+ def add_command(self, command: Command) -> None:
337
+ if command.name in self.commands:
338
+ raise ValueError(f"Command '{command.name}' is already registered.")
339
+
340
+ self.commands[command.name.lower()] = command
341
+ for alias in command.aliases:
342
+ if alias in self.commands:
343
+ logger.warning(
344
+ "Alias '%s' for command '%s' conflicts with an existing command or alias.",
345
+ alias,
346
+ command.name,
347
+ )
348
+ self.commands[alias.lower()] = command
349
+
350
+ if isinstance(command, Group):
351
+ for sub_cmd in command.commands.values():
352
+ if sub_cmd.name in self.commands:
353
+ logger.warning(
354
+ "Subcommand '%s' of group '%s' conflicts with a top-level command.",
355
+ sub_cmd.name,
356
+ command.name,
357
+ )
358
+
359
+ def remove_command(self, name: str) -> Optional[Command]:
360
+ command = self.commands.pop(name.lower(), None)
361
+ if command:
362
+ for alias in command.aliases:
363
+ self.commands.pop(alias.lower(), None)
364
+ return command
365
+
366
+ def get_command(self, name: str) -> Optional[Command]:
367
+ return self.commands.get(name.lower())
368
+
369
+ def add_cog(self, cog_to_add: "Cog") -> None:
370
+ from .cog import Cog
371
+
372
+ if not isinstance(cog_to_add, Cog):
373
+ raise TypeError("Argument must be a subclass of Cog.")
374
+
375
+ if cog_to_add.cog_name in self.cogs:
376
+ raise ValueError(
377
+ f"Cog with name '{cog_to_add.cog_name}' is already registered."
378
+ )
379
+
380
+ self.cogs[cog_to_add.cog_name] = cog_to_add
381
+
382
+ for cmd in cog_to_add.get_commands():
383
+ self.add_command(cmd)
384
+
385
+ if hasattr(self.client, "_event_dispatcher"):
386
+ for event_name, callback in cog_to_add.get_listeners():
387
+ self.client._event_dispatcher.register(event_name.upper(), callback)
388
+ else:
389
+ logger.warning(
390
+ "Client does not have '_event_dispatcher'. Listeners for cog '%s' not registered.",
391
+ cog_to_add.cog_name,
392
+ )
393
+
394
+ if hasattr(cog_to_add, "cog_load") and inspect.iscoroutinefunction(
395
+ cog_to_add.cog_load
396
+ ):
397
+ asyncio.create_task(cog_to_add.cog_load())
398
+
399
+ logger.info("Cog '%s' added.", cog_to_add.cog_name)
400
+
401
+ def remove_cog(self, cog_name: str) -> Optional["Cog"]:
402
+ cog_to_remove = self.cogs.pop(cog_name, None)
403
+ if cog_to_remove:
404
+ for cmd in cog_to_remove.get_commands():
405
+ self.remove_command(cmd.name)
406
+
407
+ if hasattr(self.client, "_event_dispatcher"):
408
+ for event_name, callback in cog_to_remove.get_listeners():
409
+ logger.debug(
410
+ "Listener '%s' for event '%s' from cog '%s' needs manual unregistration logic in EventDispatcher.",
411
+ callback.__name__,
412
+ event_name,
413
+ cog_name,
414
+ )
415
+
416
+ if hasattr(cog_to_remove, "cog_unload") and inspect.iscoroutinefunction(
417
+ cog_to_remove.cog_unload
418
+ ):
419
+ asyncio.create_task(cog_to_remove.cog_unload())
420
+
421
+ cog_to_remove._eject()
422
+ logger.info("Cog '%s' removed.", cog_name)
423
+ return cog_to_remove
424
+
425
+ def _acquire_concurrency(self, ctx: CommandContext) -> None:
426
+ mc = getattr(ctx.command, "max_concurrency", None)
427
+ if not mc:
428
+ return
429
+ limit, scope = mc
430
+ if scope == "user":
431
+ key = ctx.author.id
432
+ elif scope == "guild":
433
+ key = ctx.message.guild_id or ctx.author.id
434
+ else:
435
+ key = "global"
436
+ buckets = self._concurrency.setdefault(ctx.command.name, {})
437
+ current = buckets.get(key, 0)
438
+ if current >= limit:
439
+ from .errors import MaxConcurrencyReached
440
+
441
+ raise MaxConcurrencyReached(limit)
442
+ buckets[key] = current + 1
443
+
444
+ def _release_concurrency(self, ctx: CommandContext) -> None:
445
+ mc = getattr(ctx.command, "max_concurrency", None)
446
+ if not mc:
447
+ return
448
+ _, scope = mc
449
+ if scope == "user":
450
+ key = ctx.author.id
451
+ elif scope == "guild":
452
+ key = ctx.message.guild_id or ctx.author.id
453
+ else:
454
+ key = "global"
455
+ buckets = self._concurrency.get(ctx.command.name)
456
+ if not buckets:
457
+ return
458
+ current = buckets.get(key, 0)
459
+ if current <= 1:
460
+ buckets.pop(key, None)
461
+ else:
462
+ buckets[key] = current - 1
463
+ if not buckets:
464
+ self._concurrency.pop(ctx.command.name, None)
465
+
466
+ async def get_prefix(self, message: "Message") -> Union[str, List[str], None]:
467
+ if callable(self.prefix):
468
+ if inspect.iscoroutinefunction(self.prefix):
469
+ return await self.prefix(self.client, message)
470
+ else:
471
+ return self.prefix(self.client, message) # type: ignore
472
+ return self.prefix
473
+
474
+ async def _parse_arguments(
475
+ self, command: Command, ctx: CommandContext, view: StringView
476
+ ) -> Tuple[List[Any], Dict[str, Any]]:
477
+ args_list = []
478
+ kwargs_dict = {}
479
+ params_to_parse = list(command.params.values())
480
+
481
+ if params_to_parse and params_to_parse[0].name == "self" and command.cog:
482
+ params_to_parse.pop(0)
483
+ if params_to_parse and params_to_parse[0].name == "ctx":
484
+ params_to_parse.pop(0)
485
+
486
+ for param in params_to_parse:
487
+ view.skip_whitespace()
488
+ final_value_for_param: Any = inspect.Parameter.empty
489
+
490
+ if param.kind == inspect.Parameter.VAR_POSITIONAL:
491
+ while not view.eof:
492
+ view.skip_whitespace()
493
+ if view.eof:
494
+ break
495
+ word = view.get_word()
496
+ if word or not view.eof:
497
+ args_list.append(word)
498
+ elif view.eof:
499
+ break
500
+ break
501
+
502
+ arg_str_value: Optional[str] = (
503
+ None # Holds the raw string for current param
504
+ )
505
+
506
+ annotation = param.annotation
507
+ if inspect.isclass(annotation) and issubclass(annotation, Greedy):
508
+ greedy_values = []
509
+ converter_type = annotation.converter
510
+ while not view.eof:
511
+ view.skip_whitespace()
512
+ if view.eof:
513
+ break
514
+ start = view.index
515
+ if view.buffer[view.index] == '"':
516
+ arg_str_value = view.get_quoted_string()
517
+ if arg_str_value == "" and view.buffer[view.index] == '"':
518
+ raise BadArgument(
519
+ f"Unterminated quoted string for argument '{param.name}'."
520
+ )
521
+ else:
522
+ arg_str_value = view.get_word()
523
+ try:
524
+ converted = await run_converters(
525
+ ctx, converter_type, arg_str_value
526
+ )
527
+ except BadArgument:
528
+ view.index = start
529
+ break
530
+ greedy_values.append(converted)
531
+ final_value_for_param = greedy_values
532
+ arg_str_value = None
533
+ elif view.eof: # No more input string
534
+ if param.default is not inspect.Parameter.empty:
535
+ final_value_for_param = param.default
536
+ elif param.kind != inspect.Parameter.VAR_KEYWORD:
537
+ raise MissingRequiredArgument(param.name)
538
+ else: # VAR_KEYWORD at EOF is fine
539
+ break
540
+ else: # Input available
541
+ is_last_pos_str_greedy = (
542
+ param == params_to_parse[-1]
543
+ and param.annotation is str
544
+ and param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
545
+ )
546
+
547
+ if is_last_pos_str_greedy:
548
+ arg_str_value = view.read_rest().strip()
549
+ if (
550
+ not arg_str_value
551
+ and param.default is not inspect.Parameter.empty
552
+ ):
553
+ final_value_for_param = param.default
554
+ else: # Includes empty string if that's what's left
555
+ final_value_for_param = arg_str_value
556
+ else: # Not greedy, or not string, or not last positional
557
+ if view.buffer[view.index] == '"':
558
+ arg_str_value = view.get_quoted_string()
559
+ if arg_str_value == "" and view.buffer[view.index] == '"':
560
+ raise BadArgument(
561
+ f"Unterminated quoted string for argument '{param.name}'."
562
+ )
563
+ else:
564
+ arg_str_value = view.get_word()
565
+
566
+ # If final_value_for_param was not set by greedy logic, try conversion
567
+ if final_value_for_param is inspect.Parameter.empty:
568
+ if arg_str_value is None:
569
+ if param.default is not inspect.Parameter.empty:
570
+ final_value_for_param = param.default
571
+ else:
572
+ raise MissingRequiredArgument(param.name)
573
+ else: # We have an arg_str_value (could be empty string "" from quotes)
574
+ annotation = param.annotation
575
+ origin = get_origin(annotation)
576
+
577
+ if origin is Union: # Handles Optional[T] and Union[T1, T2]
578
+ union_args = get_args(annotation)
579
+ is_optional = (
580
+ len(union_args) == 2 and type(None) in union_args
581
+ )
582
+
583
+ converted_for_union = False
584
+ last_err_union: Optional[BadArgument] = None
585
+ for t_arg in union_args:
586
+ if t_arg is type(None):
587
+ continue
588
+ try:
589
+ final_value_for_param = await run_converters(
590
+ ctx, t_arg, arg_str_value
591
+ )
592
+ converted_for_union = True
593
+ break
594
+ except BadArgument as e:
595
+ last_err_union = e
596
+
597
+ if not converted_for_union:
598
+ if (
599
+ is_optional and param.default is None
600
+ ): # Special handling for Optional[T] if conversion failed
601
+ # If arg_str_value was "" and type was Optional[str], StringConverter would return ""
602
+ # If arg_str_value was "" and type was Optional[int], BadArgument would be raised.
603
+ # This path is for when all actual types in Optional[T] fail conversion.
604
+ # If default is None, we can assign None.
605
+ final_value_for_param = None
606
+ elif last_err_union:
607
+ raise last_err_union
608
+ else:
609
+ raise BadArgument(
610
+ f"Could not convert '{arg_str_value}' to any of {union_args} for param '{param.name}'."
611
+ )
612
+ elif annotation is inspect.Parameter.empty or annotation is str:
613
+ final_value_for_param = arg_str_value
614
+ else: # Standard type hint
615
+ final_value_for_param = await run_converters(
616
+ ctx, annotation, arg_str_value
617
+ )
618
+
619
+ # Final check if value was resolved
620
+ if final_value_for_param is inspect.Parameter.empty:
621
+ if param.default is not inspect.Parameter.empty:
622
+ final_value_for_param = param.default
623
+ elif param.kind != inspect.Parameter.VAR_KEYWORD:
624
+ # This state implies an issue if required and no default, and no input was parsed.
625
+ raise MissingRequiredArgument(
626
+ f"Parameter '{param.name}' could not be resolved."
627
+ )
628
+
629
+ # Assign to args_list or kwargs_dict if a value was determined
630
+ if final_value_for_param is not inspect.Parameter.empty:
631
+ if (
632
+ param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
633
+ or param.kind == inspect.Parameter.POSITIONAL_ONLY
634
+ ):
635
+ args_list.append(final_value_for_param)
636
+ elif param.kind == inspect.Parameter.KEYWORD_ONLY:
637
+ kwargs_dict[param.name] = final_value_for_param
638
+
639
+ return args_list, kwargs_dict
640
+
641
+ async def process_commands(self, message: "Message") -> None:
642
+ if not message.content:
643
+ return
644
+
645
+ prefix_to_use = await self.get_prefix(message)
646
+ if not prefix_to_use:
647
+ return
648
+
649
+ actual_prefix: Optional[str] = None
650
+ if isinstance(prefix_to_use, list):
651
+ for p in prefix_to_use:
652
+ if message.content.startswith(p):
653
+ actual_prefix = p
654
+ break
655
+ if not actual_prefix:
656
+ return
657
+ elif isinstance(prefix_to_use, str):
658
+ if message.content.startswith(prefix_to_use):
659
+ actual_prefix = prefix_to_use
660
+ else:
661
+ return
662
+ else:
663
+ return
664
+
665
+ if actual_prefix is None:
666
+ return
667
+
668
+ content_without_prefix = message.content[len(actual_prefix) :]
669
+ view = StringView(content_without_prefix)
670
+
671
+ command_name = view.get_word()
672
+ if not command_name:
673
+ return
674
+
675
+ command = self.get_command(command_name)
676
+ if not command:
677
+ return
678
+
679
+ invoked_with = command_name
680
+ original_command = command
681
+
682
+ if isinstance(command, Group):
683
+ view.skip_whitespace()
684
+ potential_subcommand = view.get_word()
685
+ if potential_subcommand:
686
+ subcommand = command.get_command(potential_subcommand)
687
+ if subcommand:
688
+ command = subcommand
689
+ invoked_with += f" {potential_subcommand}"
690
+ elif command.invoke_without_command:
691
+ view.index -= len(potential_subcommand) + view.previous
692
+ else:
693
+ raise CommandNotFound(
694
+ f"Subcommand '{potential_subcommand}' not found."
695
+ )
696
+
697
+ ctx = CommandContext(
698
+ message=message,
699
+ bot=self.client,
700
+ prefix=actual_prefix,
701
+ command=command,
702
+ invoked_with=invoked_with,
703
+ cog=command.cog,
704
+ )
705
+
706
+ try:
707
+ parsed_args, parsed_kwargs = await self._parse_arguments(command, ctx, view)
708
+ ctx.args = parsed_args
709
+ ctx.kwargs = parsed_kwargs
710
+ self._acquire_concurrency(ctx)
711
+ try:
712
+ await command.invoke(ctx, *parsed_args, **parsed_kwargs)
713
+ finally:
714
+ self._release_concurrency(ctx)
715
+ except CommandError as e:
716
+ logger.error("Command error for '%s': %s", original_command.name, e)
717
+ if hasattr(self.client, "on_command_error"):
718
+ await self.client.on_command_error(ctx, e)
719
+ except Exception as e:
720
+ logger.error(
721
+ "Unexpected error invoking command '%s': %s", original_command.name, e
722
+ )
723
+ exc = CommandInvokeError(e)
724
+ if hasattr(self.client, "on_command_error"):
725
+ await self.client.on_command_error(ctx, exc)
726
+ else:
727
+ if hasattr(self.client, "on_command_completion"):
728
+ await self.client.on_command_completion(ctx)