disagreement 0.1.0rc2__py3-none-any.whl → 0.2.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,9 @@
1
1
  # disagreement/ext/app_commands/handler.py
2
2
 
3
3
  import inspect
4
+ import json
5
+ import logging
6
+ import os
4
7
  from typing import (
5
8
  TYPE_CHECKING,
6
9
  Dict,
@@ -64,6 +67,11 @@ if not TYPE_CHECKING:
64
67
  Message = Any
65
68
 
66
69
 
70
+ logger = logging.getLogger(__name__)
71
+
72
+ COMMANDS_CACHE_FILE = ".disagreement_commands.json"
73
+
74
+
67
75
  class AppCommandHandler:
68
76
  """
69
77
  Manages application command registration, parsing, and dispatching.
@@ -80,6 +88,33 @@ class AppCommandHandler:
80
88
  self._app_command_groups: Dict[str, AppCommandGroup] = {}
81
89
  self._converter_registry: Dict[type, type] = {}
82
90
 
91
+ def _load_cached_ids(self) -> Dict[str, Dict[str, str]]:
92
+ try:
93
+ with open(COMMANDS_CACHE_FILE, "r", encoding="utf-8") as fp:
94
+ return json.load(fp)
95
+ except FileNotFoundError:
96
+ return {}
97
+ except json.JSONDecodeError:
98
+ logger.warning("Invalid command cache file. Ignoring.")
99
+ return {}
100
+
101
+ def _save_cached_ids(self, data: Dict[str, Dict[str, str]]) -> None:
102
+ try:
103
+ with open(COMMANDS_CACHE_FILE, "w", encoding="utf-8") as fp:
104
+ json.dump(data, fp, indent=2)
105
+ except Exception as e: # pragma: no cover - logging only
106
+ logger.error("Failed to write command cache: %s", e)
107
+
108
+ def clear_stored_registrations(self) -> None:
109
+ """Remove persisted command registration data."""
110
+ if os.path.exists(COMMANDS_CACHE_FILE):
111
+ os.remove(COMMANDS_CACHE_FILE)
112
+
113
+ def migrate_stored_registrations(self, new_path: str) -> None:
114
+ """Move stored registrations to ``new_path``."""
115
+ if os.path.exists(COMMANDS_CACHE_FILE):
116
+ os.replace(COMMANDS_CACHE_FILE, new_path)
117
+
83
118
  def add_command(self, command: Union["AppCommand", "AppCommandGroup"]) -> None:
84
119
  """Adds an application command or a command group to the handler."""
85
120
  if isinstance(command, AppCommandGroup):
@@ -544,7 +579,7 @@ class AppCommandHandler:
544
579
  await command.invoke(ctx, *parsed_args, **parsed_kwargs)
545
580
 
546
581
  except Exception as e:
547
- print(f"Error invoking app command '{command.name}': {e}")
582
+ logger.error("Error invoking app command '%s': %s", command.name, e)
548
583
  await self.dispatch_app_command_error(ctx, e)
549
584
  # else:
550
585
  # # Default error reply if no handler on client
@@ -560,11 +595,13 @@ class AppCommandHandler:
560
595
  Synchronizes (registers/updates) all application commands with Discord.
561
596
  If guild_id is provided, syncs commands for that guild. Otherwise, syncs global commands.
562
597
  """
563
- commands_to_sync: List[Dict[str, Any]] = []
598
+ cache = self._load_cached_ids()
599
+ scope_key = str(guild_id) if guild_id else "global"
600
+ stored = cache.get(scope_key, {})
564
601
 
565
- # Collect commands based on scope (global or specific guild)
566
- # This needs to be more sophisticated to handle guild_ids on commands/groups
602
+ current_payloads: Dict[str, Dict[str, Any]] = {}
567
603
 
604
+ # Collect commands based on scope (global or specific guild)
568
605
  source_commands = (
569
606
  list(self._slash_commands.values())
570
607
  + list(self._user_commands.values())
@@ -573,55 +610,102 @@ class AppCommandHandler:
573
610
  )
574
611
 
575
612
  for cmd_or_group in source_commands:
576
- # Determine if this command/group should be synced for the current scope
577
613
  is_guild_specific_command = (
578
614
  cmd_or_group.guild_ids is not None and len(cmd_or_group.guild_ids) > 0
579
615
  )
580
616
 
581
- if guild_id: # Syncing for a specific guild
582
- # Skip if not a guild-specific command OR if it's for a different guild
617
+ if guild_id:
583
618
  if not is_guild_specific_command or (
584
619
  cmd_or_group.guild_ids is not None
585
620
  and guild_id not in cmd_or_group.guild_ids
586
621
  ):
587
622
  continue
588
- else: # Syncing global commands
623
+ else:
589
624
  if is_guild_specific_command:
590
- continue # Skip guild-specific commands when syncing global
625
+ continue
591
626
 
592
- # Use the to_dict() method from AppCommand or AppCommandGroup
593
627
  try:
594
- payload = cmd_or_group.to_dict()
595
- commands_to_sync.append(payload)
628
+ current_payloads[cmd_or_group.name] = cmd_or_group.to_dict()
596
629
  except AttributeError:
597
- print(
598
- f"Warning: Command or group '{cmd_or_group.name}' does not have a to_dict() method. Skipping."
630
+ logger.warning(
631
+ "Command or group '%s' does not have a to_dict() method. Skipping.",
632
+ cmd_or_group.name,
599
633
  )
600
634
  except Exception as e:
601
- print(
602
- f"Error converting command/group '{cmd_or_group.name}' to dict: {e}. Skipping."
635
+ logger.error(
636
+ "Error converting command/group '%s' to dict: %s. Skipping.",
637
+ cmd_or_group.name,
638
+ e,
603
639
  )
604
640
 
605
- if not commands_to_sync:
606
- print(
607
- f"No commands to sync for {'guild ' + str(guild_id) if guild_id else 'global'} scope."
641
+ if not current_payloads:
642
+ logger.info(
643
+ "No commands to sync for %s scope.",
644
+ f"guild {guild_id}" if guild_id else "global",
645
+ )
646
+ return
647
+
648
+ names_current = set(current_payloads)
649
+ names_stored = set(stored)
650
+
651
+ to_delete = names_stored - names_current
652
+ to_create = names_current - names_stored
653
+ to_update = names_current & names_stored
654
+
655
+ if not to_delete and not to_create and not to_update:
656
+ logger.info(
657
+ "Application commands already up to date for %s scope.", scope_key
608
658
  )
609
659
  return
610
660
 
611
661
  try:
612
- if guild_id:
613
- print(
614
- f"Syncing {len(commands_to_sync)} commands for guild {guild_id}..."
615
- )
616
- await self.client._http.bulk_overwrite_guild_application_commands(
617
- application_id, guild_id, commands_to_sync
618
- )
619
- else:
620
- print(f"Syncing {len(commands_to_sync)} global commands...")
621
- await self.client._http.bulk_overwrite_global_application_commands(
622
- application_id, commands_to_sync
623
- )
624
- print("Command sync successful.")
662
+ for name in to_delete:
663
+ cmd_id = stored[name]
664
+ if guild_id:
665
+ await self.client._http.delete_guild_application_command(
666
+ application_id, guild_id, cmd_id
667
+ )
668
+ else:
669
+ await self.client._http.delete_global_application_command(
670
+ application_id, cmd_id
671
+ )
672
+
673
+ new_ids: Dict[str, str] = {}
674
+ for name in to_create:
675
+ payload = current_payloads[name]
676
+ if guild_id:
677
+ result = await self.client._http.create_guild_application_command(
678
+ application_id, guild_id, payload
679
+ )
680
+ else:
681
+ result = await self.client._http.create_global_application_command(
682
+ application_id, payload
683
+ )
684
+ if result.id:
685
+ new_ids[name] = str(result.id)
686
+
687
+ for name in to_update:
688
+ payload = current_payloads[name]
689
+ cmd_id = stored[name]
690
+ if guild_id:
691
+ await self.client._http.edit_guild_application_command(
692
+ application_id, guild_id, cmd_id, payload
693
+ )
694
+ else:
695
+ await self.client._http.edit_global_application_command(
696
+ application_id, cmd_id, payload
697
+ )
698
+ new_ids[name] = cmd_id
699
+
700
+ final_ids: Dict[str, str] = {}
701
+ for name in names_current:
702
+ if name in new_ids:
703
+ final_ids[name] = new_ids[name]
704
+ else:
705
+ final_ids[name] = stored[name]
706
+
707
+ cache[scope_key] = final_ids
708
+ self._save_cached_ids(cache)
709
+ logger.info("Command sync successful.")
625
710
  except Exception as e:
626
- print(f"Error syncing application commands: {e}")
627
- # Consider re-raising or specific error handling
711
+ logger.error("Error syncing application commands: %s", e)
@@ -58,4 +58,4 @@ class HybridCommand(SlashCommand, PrefixCommand): # Inherit from both
58
58
  # The correct one will be called depending on how the command is dispatched.
59
59
  # The AppCommandHandler will use AppCommand.invoke (via SlashCommand).
60
60
  # The prefix CommandHandler will use PrefixCommand.invoke.
61
- # This seems acceptable.
61
+ # This seems acceptable.
@@ -16,6 +16,7 @@ from .decorators import (
16
16
  check,
17
17
  check_any,
18
18
  cooldown,
19
+ max_concurrency,
19
20
  requires_permissions,
20
21
  )
21
22
  from .errors import (
@@ -28,6 +29,7 @@ from .errors import (
28
29
  CheckAnyFailure,
29
30
  CommandOnCooldown,
30
31
  CommandInvokeError,
32
+ MaxConcurrencyReached,
31
33
  )
32
34
 
33
35
  __all__ = [
@@ -43,6 +45,7 @@ __all__ = [
43
45
  "check",
44
46
  "check_any",
45
47
  "cooldown",
48
+ "max_concurrency",
46
49
  "requires_permissions",
47
50
  # Errors
48
51
  "CommandError",
@@ -54,4 +57,5 @@ __all__ = [
54
57
  "CheckAnyFailure",
55
58
  "CommandOnCooldown",
56
59
  "CommandInvokeError",
60
+ "MaxConcurrencyReached",
57
61
  ]
@@ -1,6 +1,7 @@
1
1
  # disagreement/ext/commands/cog.py
2
2
 
3
3
  import inspect
4
+ import logging
4
5
  from typing import TYPE_CHECKING, List, Tuple, Callable, Awaitable, Any, Dict, Union
5
6
 
6
7
  if TYPE_CHECKING:
@@ -16,6 +17,8 @@ else: # pragma: no cover - runtime imports for isinstance checks
16
17
  # EventDispatcher might be needed if cogs register listeners directly
17
18
  # from disagreement.event_dispatcher import EventDispatcher
18
19
 
20
+ logger = logging.getLogger(__name__)
21
+
19
22
 
20
23
  class Cog:
21
24
  """
@@ -59,8 +62,10 @@ class Cog:
59
62
  cmd.cog = self # Assign the cog instance to the command
60
63
  if cmd.name in self._commands:
61
64
  # This should ideally be caught earlier or handled by CommandHandler
62
- print(
63
- f"Warning: Duplicate command name '{cmd.name}' in cog '{self.cog_name}'. Overwriting."
65
+ logger.warning(
66
+ "Duplicate command name '%s' in cog '%s'. Overwriting.",
67
+ cmd.name,
68
+ self.cog_name,
64
69
  )
65
70
  self._commands[cmd.name.lower()] = cmd
66
71
  # Also register aliases
@@ -79,8 +84,10 @@ class Cog:
79
84
  # For AppCommandGroup, its commands will have cog set individually if they are AppCommands
80
85
  self._app_commands_and_groups.append(app_cmd_obj)
81
86
  else:
82
- print(
83
- f"Warning: Member '{member_name}' in cog '{self.cog_name}' has '__app_command_object__' but it's not an AppCommand or AppCommandGroup."
87
+ logger.warning(
88
+ "Member '%s' in cog '%s' has '__app_command_object__' but it's not an AppCommand or AppCommandGroup.",
89
+ member_name,
90
+ self.cog_name,
84
91
  )
85
92
 
86
93
  elif isinstance(member, (AppCommand, AppCommandGroup)):
@@ -92,8 +99,10 @@ class Cog:
92
99
  # This is a method decorated with @commands.Cog.listener or @commands.listener
93
100
  if not inspect.iscoroutinefunction(member):
94
101
  # Decorator should have caught this, but double check
95
- print(
96
- f"Warning: Listener '{member_name}' in cog '{self.cog_name}' is not a coroutine. Skipping."
102
+ logger.warning(
103
+ "Listener '%s' in cog '%s' is not a coroutine. Skipping.",
104
+ member_name,
105
+ self.cog_name,
97
106
  )
98
107
  continue
99
108
 
@@ -3,6 +3,7 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import asyncio
6
+ import logging
6
7
  import inspect
7
8
  from typing import (
8
9
  TYPE_CHECKING,
@@ -31,6 +32,8 @@ from .errors import (
31
32
  from .converters import run_converters, DEFAULT_CONVERTERS, Converter
32
33
  from disagreement.typing import Typing
33
34
 
35
+ logger = logging.getLogger(__name__)
36
+
34
37
  if TYPE_CHECKING:
35
38
  from .cog import Cog
36
39
  from disagreement.client import Client
@@ -67,6 +70,10 @@ class Command:
67
70
  if hasattr(callback, "__command_checks__"):
68
71
  self.checks.extend(getattr(callback, "__command_checks__"))
69
72
 
73
+ self.max_concurrency: Optional[Tuple[int, str]] = None
74
+ if hasattr(callback, "__max_concurrency__"):
75
+ self.max_concurrency = getattr(callback, "__max_concurrency__")
76
+
70
77
  def add_check(
71
78
  self, predicate: Callable[["CommandContext"], Awaitable[bool] | bool]
72
79
  ) -> None:
@@ -212,6 +219,7 @@ class CommandHandler:
212
219
  ] = prefix
213
220
  self.commands: Dict[str, Command] = {}
214
221
  self.cogs: Dict[str, "Cog"] = {}
222
+ self._concurrency: Dict[str, Dict[str, int]] = {}
215
223
 
216
224
  from .help import HelpCommand
217
225
 
@@ -224,8 +232,10 @@ class CommandHandler:
224
232
  self.commands[command.name.lower()] = command
225
233
  for alias in command.aliases:
226
234
  if alias in self.commands:
227
- print(
228
- f"Warning: Alias '{alias}' for command '{command.name}' conflicts with an existing command or alias."
235
+ logger.warning(
236
+ "Alias '%s' for command '%s' conflicts with an existing command or alias.",
237
+ alias,
238
+ command.name,
229
239
  )
230
240
  self.commands[alias.lower()] = command
231
241
 
@@ -241,6 +251,7 @@ class CommandHandler:
241
251
 
242
252
  def add_cog(self, cog_to_add: "Cog") -> None:
243
253
  from .cog import Cog
254
+
244
255
  if not isinstance(cog_to_add, Cog):
245
256
  raise TypeError("Argument must be a subclass of Cog.")
246
257
 
@@ -258,8 +269,9 @@ class CommandHandler:
258
269
  for event_name, callback in cog_to_add.get_listeners():
259
270
  self.client._event_dispatcher.register(event_name.upper(), callback)
260
271
  else:
261
- print(
262
- f"Warning: Client does not have '_event_dispatcher'. Listeners for cog '{cog_to_add.cog_name}' not registered."
272
+ logger.warning(
273
+ "Client does not have '_event_dispatcher'. Listeners for cog '%s' not registered.",
274
+ cog_to_add.cog_name,
263
275
  )
264
276
 
265
277
  if hasattr(cog_to_add, "cog_load") and inspect.iscoroutinefunction(
@@ -267,7 +279,7 @@ class CommandHandler:
267
279
  ):
268
280
  asyncio.create_task(cog_to_add.cog_load())
269
281
 
270
- print(f"Cog '{cog_to_add.cog_name}' added.")
282
+ logger.info("Cog '%s' added.", cog_to_add.cog_name)
271
283
 
272
284
  def remove_cog(self, cog_name: str) -> Optional["Cog"]:
273
285
  cog_to_remove = self.cogs.pop(cog_name, None)
@@ -277,8 +289,11 @@ class CommandHandler:
277
289
 
278
290
  if hasattr(self.client, "_event_dispatcher"):
279
291
  for event_name, callback in cog_to_remove.get_listeners():
280
- print(
281
- f"Note: Listener '{callback.__name__}' for event '{event_name}' from cog '{cog_name}' needs manual unregistration logic in EventDispatcher."
292
+ logger.debug(
293
+ "Listener '%s' for event '%s' from cog '%s' needs manual unregistration logic in EventDispatcher.",
294
+ callback.__name__,
295
+ event_name,
296
+ cog_name,
282
297
  )
283
298
 
284
299
  if hasattr(cog_to_remove, "cog_unload") and inspect.iscoroutinefunction(
@@ -287,9 +302,50 @@ class CommandHandler:
287
302
  asyncio.create_task(cog_to_remove.cog_unload())
288
303
 
289
304
  cog_to_remove._eject()
290
- print(f"Cog '{cog_name}' removed.")
305
+ logger.info("Cog '%s' removed.", cog_name)
291
306
  return cog_to_remove
292
307
 
308
+ def _acquire_concurrency(self, ctx: CommandContext) -> None:
309
+ mc = getattr(ctx.command, "max_concurrency", None)
310
+ if not mc:
311
+ return
312
+ limit, scope = mc
313
+ if scope == "user":
314
+ key = ctx.author.id
315
+ elif scope == "guild":
316
+ key = ctx.message.guild_id or ctx.author.id
317
+ else:
318
+ key = "global"
319
+ buckets = self._concurrency.setdefault(ctx.command.name, {})
320
+ current = buckets.get(key, 0)
321
+ if current >= limit:
322
+ from .errors import MaxConcurrencyReached
323
+
324
+ raise MaxConcurrencyReached(limit)
325
+ buckets[key] = current + 1
326
+
327
+ def _release_concurrency(self, ctx: CommandContext) -> None:
328
+ mc = getattr(ctx.command, "max_concurrency", None)
329
+ if not mc:
330
+ return
331
+ _, scope = mc
332
+ if scope == "user":
333
+ key = ctx.author.id
334
+ elif scope == "guild":
335
+ key = ctx.message.guild_id or ctx.author.id
336
+ else:
337
+ key = "global"
338
+ buckets = self._concurrency.get(ctx.command.name)
339
+ if not buckets:
340
+ return
341
+ current = buckets.get(key, 0)
342
+ if current <= 1:
343
+ buckets.pop(key, None)
344
+ else:
345
+ buckets[key] = current - 1
346
+ if not buckets:
347
+ self._concurrency.pop(ctx.command.name, None)
348
+
293
349
  async def get_prefix(self, message: "Message") -> Union[str, List[str], None]:
294
350
  if callable(self.prefix):
295
351
  if inspect.iscoroutinefunction(self.prefix):
@@ -491,13 +547,17 @@ class CommandHandler:
491
547
  parsed_args, parsed_kwargs = await self._parse_arguments(command, ctx, view)
492
548
  ctx.args = parsed_args
493
549
  ctx.kwargs = parsed_kwargs
494
- await command.invoke(ctx, *parsed_args, **parsed_kwargs)
550
+ self._acquire_concurrency(ctx)
551
+ try:
552
+ await command.invoke(ctx, *parsed_args, **parsed_kwargs)
553
+ finally:
554
+ self._release_concurrency(ctx)
495
555
  except CommandError as e:
496
- print(f"Command error for '{command.name}': {e}")
556
+ logger.error("Command error for '%s': %s", command.name, e)
497
557
  if hasattr(self.client, "on_command_error"):
498
558
  await self.client.on_command_error(ctx, e)
499
559
  except Exception as e:
500
- print(f"Unexpected error invoking command '{command.name}': {e}")
560
+ logger.error("Unexpected error invoking command '%s': %s", command.name, e)
501
561
  exc = CommandInvokeError(e)
502
562
  if hasattr(self.client, "on_command_error"):
503
563
  await self.client.on_command_error(ctx, exc)
@@ -107,6 +107,33 @@ def check_any(
107
107
  return check(predicate)
108
108
 
109
109
 
110
+ def max_concurrency(
111
+ number: int, per: str = "user"
112
+ ) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
113
+ """Limit how many concurrent invocations of a command are allowed.
114
+
115
+ Parameters
116
+ ----------
117
+ number:
118
+ The maximum number of concurrent invocations.
119
+ per:
120
+ The scope of the limiter. Can be ``"user"``, ``"guild"`` or ``"global"``.
121
+ """
122
+
123
+ if number < 1:
124
+ raise ValueError("Concurrency number must be at least 1.")
125
+ if per not in {"user", "guild", "global"}:
126
+ raise ValueError("per must be 'user', 'guild', or 'global'.")
127
+
128
+ def decorator(
129
+ func: Callable[..., Awaitable[None]],
130
+ ) -> Callable[..., Awaitable[None]]:
131
+ setattr(func, "__max_concurrency__", (number, per))
132
+ return func
133
+
134
+ return decorator
135
+
136
+
110
137
  def cooldown(
111
138
  rate: int, per: float
112
139
  ) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
@@ -72,5 +72,13 @@ class CommandInvokeError(CommandError):
72
72
  super().__init__(f"Error during command invocation: {original}")
73
73
 
74
74
 
75
+ class MaxConcurrencyReached(CommandError):
76
+ """Raised when a command exceeds its concurrency limit."""
77
+
78
+ def __init__(self, limit: int):
79
+ self.limit = limit
80
+ super().__init__(f"Max concurrency of {limit} reached")
81
+
82
+
75
83
  # Add more specific errors as needed, e.g., UserNotFound, ChannelNotFound, etc.
76
84
  # These might inherit from BadArgument.