disagreement 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
disagreement/http.py ADDED
@@ -0,0 +1,657 @@
1
+ # disagreement/http.py
2
+
3
+ """
4
+ HTTP client for interacting with the Discord REST API.
5
+ """
6
+
7
+ import asyncio
8
+ import aiohttp # pylint: disable=import-error
9
+ import json
10
+ from urllib.parse import quote
11
+ from typing import Optional, Dict, Any, Union, TYPE_CHECKING, List
12
+
13
+ from .errors import (
14
+ HTTPException,
15
+ RateLimitError,
16
+ AuthenticationError,
17
+ DisagreementException,
18
+ )
19
+ from . import __version__ # For User-Agent
20
+
21
+ if TYPE_CHECKING:
22
+ from .client import Client
23
+ from .models import Message
24
+ from .interactions import ApplicationCommand, InteractionResponsePayload, Snowflake
25
+
26
+ # Discord API constants
27
+ API_BASE_URL = "https://discord.com/api/v10" # Using API v10
28
+
29
+
30
+ class HTTPClient:
31
+ """Handles HTTP requests to the Discord API."""
32
+
33
+ def __init__(
34
+ self,
35
+ token: str,
36
+ client_session: Optional[aiohttp.ClientSession] = None,
37
+ verbose: bool = False,
38
+ ):
39
+ self.token = token
40
+ self._session: Optional[aiohttp.ClientSession] = (
41
+ client_session # Can be externally managed
42
+ )
43
+ self.user_agent = f"DiscordBot (https://github.com/yourusername/disagreement, {__version__})" # Customize URL
44
+
45
+ self.verbose = verbose
46
+
47
+ self._global_rate_limit_lock = asyncio.Event()
48
+ self._global_rate_limit_lock.set() # Initially unlocked
49
+
50
+ async def _ensure_session(self):
51
+ if self._session is None or self._session.closed:
52
+ self._session = aiohttp.ClientSession()
53
+
54
+ async def close(self):
55
+ """Closes the underlying aiohttp.ClientSession."""
56
+ if self._session and not self._session.closed:
57
+ await self._session.close()
58
+
59
+ async def request(
60
+ self,
61
+ method: str,
62
+ endpoint: str,
63
+ payload: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
64
+ params: Optional[Dict[str, Any]] = None,
65
+ is_json: bool = True,
66
+ use_auth_header: bool = True,
67
+ custom_headers: Optional[Dict[str, str]] = None,
68
+ ) -> Any:
69
+ """Makes an HTTP request to the Discord API."""
70
+ await self._ensure_session()
71
+
72
+ url = f"{API_BASE_URL}{endpoint}"
73
+ final_headers: Dict[str, str] = { # Renamed to final_headers
74
+ "User-Agent": self.user_agent,
75
+ }
76
+ if use_auth_header:
77
+ final_headers["Authorization"] = f"Bot {self.token}"
78
+
79
+ if is_json and payload:
80
+ final_headers["Content-Type"] = "application/json"
81
+
82
+ if custom_headers: # Merge custom headers
83
+ final_headers.update(custom_headers)
84
+
85
+ if self.verbose:
86
+ print(f"HTTP REQUEST: {method} {url} | payload={payload} params={params}")
87
+
88
+ # Global rate limit handling
89
+ await self._global_rate_limit_lock.wait()
90
+
91
+ for attempt in range(5): # Max 5 retries for rate limits
92
+ assert self._session is not None, "ClientSession not initialized"
93
+ async with self._session.request(
94
+ method,
95
+ url,
96
+ json=payload if is_json else None,
97
+ data=payload if not is_json else None,
98
+ headers=final_headers,
99
+ params=params,
100
+ ) as response:
101
+
102
+ data = None
103
+ try:
104
+ if response.headers.get("Content-Type", "").startswith(
105
+ "application/json"
106
+ ):
107
+ data = await response.json()
108
+ else:
109
+ # For non-JSON responses, like fetching images or other files
110
+ # We might return the raw response or handle it differently
111
+ # For now, let's assume most API calls expect JSON
112
+ data = await response.text()
113
+ except (aiohttp.ContentTypeError, json.JSONDecodeError):
114
+ data = (
115
+ await response.text()
116
+ ) # Fallback to text if JSON parsing fails
117
+
118
+ if self.verbose:
119
+ print(f"HTTP RESPONSE: {response.status} {url} | {data}")
120
+
121
+ if 200 <= response.status < 300:
122
+ if response.status == 204:
123
+ return None
124
+ return data
125
+
126
+ # Rate limit handling
127
+ if response.status == 429: # Rate limited
128
+ retry_after_str = response.headers.get("Retry-After", "1")
129
+ try:
130
+ retry_after = float(retry_after_str)
131
+ except ValueError:
132
+ retry_after = 1.0 # Default retry if header is malformed
133
+
134
+ is_global = (
135
+ response.headers.get("X-RateLimit-Global", "false").lower()
136
+ == "true"
137
+ )
138
+
139
+ error_message = f"Rate limited on {method} {endpoint}."
140
+ if data and isinstance(data, dict) and "message" in data:
141
+ error_message += f" Discord says: {data['message']}"
142
+
143
+ if is_global:
144
+ self._global_rate_limit_lock.clear()
145
+ await asyncio.sleep(retry_after)
146
+ self._global_rate_limit_lock.set()
147
+ else:
148
+ await asyncio.sleep(retry_after)
149
+
150
+ if attempt < 4: # Don't log on the last attempt before raising
151
+ print(
152
+ f"{error_message} Retrying after {retry_after}s (Attempt {attempt + 1}/5). Global: {is_global}"
153
+ )
154
+ continue # Retry the request
155
+ else: # Last attempt failed
156
+ raise RateLimitError(
157
+ response,
158
+ message=error_message,
159
+ retry_after=retry_after,
160
+ is_global=is_global,
161
+ )
162
+
163
+ # Other error handling
164
+ if response.status == 401: # Unauthorized
165
+ raise AuthenticationError(response, "Invalid token provided.")
166
+ if response.status == 403: # Forbidden
167
+ raise HTTPException(
168
+ response,
169
+ "Missing permissions or access denied.",
170
+ status=response.status,
171
+ text=str(data),
172
+ )
173
+
174
+ # General HTTP error
175
+ error_text = str(data) if data else "Unknown error"
176
+ discord_error_code = (
177
+ data.get("code") if isinstance(data, dict) else None
178
+ )
179
+ raise HTTPException(
180
+ response,
181
+ f"API Error on {method} {endpoint}: {error_text}",
182
+ status=response.status,
183
+ text=error_text,
184
+ error_code=discord_error_code,
185
+ )
186
+
187
+ # Should not be reached if retries are exhausted by RateLimitError
188
+ raise DisagreementException(
189
+ f"Failed request to {method} {endpoint} after multiple retries."
190
+ )
191
+
192
+ # --- Specific API call methods ---
193
+
194
+ async def get_gateway_bot(self) -> Dict[str, Any]:
195
+ """Gets the WSS URL and sharding information for the Gateway."""
196
+ return await self.request("GET", "/gateway/bot")
197
+
198
+ async def send_message(
199
+ self,
200
+ channel_id: str,
201
+ content: Optional[str] = None,
202
+ tts: bool = False,
203
+ embeds: Optional[List[Dict[str, Any]]] = None,
204
+ components: Optional[List[Dict[str, Any]]] = None,
205
+ allowed_mentions: Optional[dict] = None,
206
+ message_reference: Optional[Dict[str, Any]] = None,
207
+ flags: Optional[int] = None,
208
+ ) -> Dict[str, Any]:
209
+ """Sends a message to a channel.
210
+
211
+ Returns the created message data as a dict.
212
+ """
213
+ payload: Dict[str, Any] = {}
214
+ if content is not None: # Content is optional if embeds/components are present
215
+ payload["content"] = content
216
+ if tts:
217
+ payload["tts"] = True
218
+ if embeds:
219
+ payload["embeds"] = embeds
220
+ if components:
221
+ payload["components"] = components
222
+ if allowed_mentions:
223
+ payload["allowed_mentions"] = allowed_mentions
224
+ if flags:
225
+ payload["flags"] = flags
226
+ if message_reference:
227
+ payload["message_reference"] = message_reference
228
+
229
+ if not payload:
230
+ raise ValueError("Message must have content, embeds, or components.")
231
+
232
+ return await self.request(
233
+ "POST", f"/channels/{channel_id}/messages", payload=payload
234
+ )
235
+
236
+ async def edit_message(
237
+ self,
238
+ channel_id: str,
239
+ message_id: str,
240
+ payload: Dict[str, Any],
241
+ ) -> Dict[str, Any]:
242
+ """Edits a message in a channel."""
243
+
244
+ return await self.request(
245
+ "PATCH",
246
+ f"/channels/{channel_id}/messages/{message_id}",
247
+ payload=payload,
248
+ )
249
+
250
+ async def get_message(
251
+ self, channel_id: "Snowflake", message_id: "Snowflake"
252
+ ) -> Dict[str, Any]:
253
+ """Fetches a message from a channel."""
254
+
255
+ return await self.request(
256
+ "GET", f"/channels/{channel_id}/messages/{message_id}"
257
+ )
258
+
259
+ async def create_reaction(
260
+ self, channel_id: "Snowflake", message_id: "Snowflake", emoji: str
261
+ ) -> None:
262
+ """Adds a reaction to a message as the current user."""
263
+ encoded = quote(emoji)
264
+ await self.request(
265
+ "PUT",
266
+ f"/channels/{channel_id}/messages/{message_id}/reactions/{encoded}/@me",
267
+ )
268
+
269
+ async def delete_reaction(
270
+ self, channel_id: "Snowflake", message_id: "Snowflake", emoji: str
271
+ ) -> None:
272
+ """Removes the current user's reaction from a message."""
273
+ encoded = quote(emoji)
274
+ await self.request(
275
+ "DELETE",
276
+ f"/channels/{channel_id}/messages/{message_id}/reactions/{encoded}/@me",
277
+ )
278
+
279
+ async def get_reactions(
280
+ self, channel_id: "Snowflake", message_id: "Snowflake", emoji: str
281
+ ) -> List[Dict[str, Any]]:
282
+ """Fetches the users that reacted with a specific emoji."""
283
+ encoded = quote(emoji)
284
+ return await self.request(
285
+ "GET",
286
+ f"/channels/{channel_id}/messages/{message_id}/reactions/{encoded}",
287
+ )
288
+
289
+ async def delete_channel(
290
+ self, channel_id: str, reason: Optional[str] = None
291
+ ) -> None:
292
+ """Deletes a channel.
293
+
294
+ If the channel is a guild channel, requires the MANAGE_CHANNELS permission.
295
+ If the channel is a thread, requires the MANAGE_THREADS permission (if locked) or
296
+ be the thread creator (if not locked).
297
+ Deleting a category does not delete its child channels.
298
+ """
299
+ custom_headers = {}
300
+ if reason:
301
+ custom_headers["X-Audit-Log-Reason"] = reason
302
+
303
+ await self.request(
304
+ "DELETE",
305
+ f"/channels/{channel_id}",
306
+ custom_headers=custom_headers if custom_headers else None,
307
+ )
308
+
309
+ async def get_channel(self, channel_id: str) -> Dict[str, Any]:
310
+ """Fetches a channel by ID."""
311
+ return await self.request("GET", f"/channels/{channel_id}")
312
+
313
+ async def get_user(self, user_id: "Snowflake") -> Dict[str, Any]:
314
+ """Fetches a user object for a given user ID."""
315
+ return await self.request("GET", f"/users/{user_id}")
316
+
317
+ async def get_guild_member(
318
+ self, guild_id: "Snowflake", user_id: "Snowflake"
319
+ ) -> Dict[str, Any]:
320
+ """Returns a guild member object for the specified user."""
321
+ return await self.request("GET", f"/guilds/{guild_id}/members/{user_id}")
322
+
323
+ async def kick_member(
324
+ self, guild_id: "Snowflake", user_id: "Snowflake", reason: Optional[str] = None
325
+ ) -> None:
326
+ """Kicks a member from the guild."""
327
+ headers = {"X-Audit-Log-Reason": reason} if reason else None
328
+ await self.request(
329
+ "DELETE",
330
+ f"/guilds/{guild_id}/members/{user_id}",
331
+ custom_headers=headers,
332
+ )
333
+
334
+ async def ban_member(
335
+ self,
336
+ guild_id: "Snowflake",
337
+ user_id: "Snowflake",
338
+ *,
339
+ delete_message_seconds: int = 0,
340
+ reason: Optional[str] = None,
341
+ ) -> None:
342
+ """Bans a member from the guild."""
343
+ payload = {}
344
+ if delete_message_seconds:
345
+ payload["delete_message_seconds"] = delete_message_seconds
346
+ headers = {"X-Audit-Log-Reason": reason} if reason else None
347
+ await self.request(
348
+ "PUT",
349
+ f"/guilds/{guild_id}/bans/{user_id}",
350
+ payload=payload if payload else None,
351
+ custom_headers=headers,
352
+ )
353
+
354
+ async def timeout_member(
355
+ self,
356
+ guild_id: "Snowflake",
357
+ user_id: "Snowflake",
358
+ *,
359
+ until: Optional[str],
360
+ reason: Optional[str] = None,
361
+ ) -> Dict[str, Any]:
362
+ """Times out a member until the given ISO8601 timestamp."""
363
+ payload = {"communication_disabled_until": until}
364
+ headers = {"X-Audit-Log-Reason": reason} if reason else None
365
+ return await self.request(
366
+ "PATCH",
367
+ f"/guilds/{guild_id}/members/{user_id}",
368
+ payload=payload,
369
+ custom_headers=headers,
370
+ )
371
+
372
+ async def get_guild_roles(self, guild_id: "Snowflake") -> List[Dict[str, Any]]:
373
+ """Returns a list of role objects for the guild."""
374
+ return await self.request("GET", f"/guilds/{guild_id}/roles")
375
+
376
+ async def get_guild(self, guild_id: "Snowflake") -> Dict[str, Any]:
377
+ """Fetches a guild object for a given guild ID."""
378
+ return await self.request("GET", f"/guilds/{guild_id}")
379
+
380
+ # Add other methods like:
381
+ # async def get_guild(self, guild_id: str) -> Dict[str, Any]: ...
382
+ # async def create_reaction(self, channel_id: str, message_id: str, emoji: str) -> None: ...
383
+ # etc.
384
+ # --- Application Command Endpoints ---
385
+
386
+ # Global Application Commands
387
+ async def get_global_application_commands(
388
+ self, application_id: "Snowflake", with_localizations: bool = False
389
+ ) -> List["ApplicationCommand"]:
390
+ """Fetches all global commands for your application."""
391
+ params = {"with_localizations": str(with_localizations).lower()}
392
+ data = await self.request(
393
+ "GET", f"/applications/{application_id}/commands", params=params
394
+ )
395
+ from .interactions import ApplicationCommand # Ensure constructor is available
396
+
397
+ return [ApplicationCommand(cmd_data) for cmd_data in data]
398
+
399
+ async def create_global_application_command(
400
+ self, application_id: "Snowflake", payload: Dict[str, Any]
401
+ ) -> "ApplicationCommand":
402
+ """Creates a new global command."""
403
+ data = await self.request(
404
+ "POST", f"/applications/{application_id}/commands", payload=payload
405
+ )
406
+ from .interactions import ApplicationCommand
407
+
408
+ return ApplicationCommand(data)
409
+
410
+ async def get_global_application_command(
411
+ self, application_id: "Snowflake", command_id: "Snowflake"
412
+ ) -> "ApplicationCommand":
413
+ """Fetches a specific global command."""
414
+ data = await self.request(
415
+ "GET", f"/applications/{application_id}/commands/{command_id}"
416
+ )
417
+ from .interactions import ApplicationCommand
418
+
419
+ return ApplicationCommand(data)
420
+
421
+ async def edit_global_application_command(
422
+ self,
423
+ application_id: "Snowflake",
424
+ command_id: "Snowflake",
425
+ payload: Dict[str, Any],
426
+ ) -> "ApplicationCommand":
427
+ """Edits a specific global command."""
428
+ data = await self.request(
429
+ "PATCH",
430
+ f"/applications/{application_id}/commands/{command_id}",
431
+ payload=payload,
432
+ )
433
+ from .interactions import ApplicationCommand
434
+
435
+ return ApplicationCommand(data)
436
+
437
+ async def delete_global_application_command(
438
+ self, application_id: "Snowflake", command_id: "Snowflake"
439
+ ) -> None:
440
+ """Deletes a specific global command."""
441
+ await self.request(
442
+ "DELETE", f"/applications/{application_id}/commands/{command_id}"
443
+ )
444
+
445
+ async def bulk_overwrite_global_application_commands(
446
+ self, application_id: "Snowflake", payload: List[Dict[str, Any]]
447
+ ) -> List["ApplicationCommand"]:
448
+ """Bulk overwrites all global commands for your application."""
449
+ data = await self.request(
450
+ "PUT", f"/applications/{application_id}/commands", payload=payload
451
+ )
452
+ from .interactions import ApplicationCommand
453
+
454
+ return [ApplicationCommand(cmd_data) for cmd_data in data]
455
+
456
+ # Guild Application Commands
457
+ async def get_guild_application_commands(
458
+ self,
459
+ application_id: "Snowflake",
460
+ guild_id: "Snowflake",
461
+ with_localizations: bool = False,
462
+ ) -> List["ApplicationCommand"]:
463
+ """Fetches all commands for your application for a specific guild."""
464
+ params = {"with_localizations": str(with_localizations).lower()}
465
+ data = await self.request(
466
+ "GET",
467
+ f"/applications/{application_id}/guilds/{guild_id}/commands",
468
+ params=params,
469
+ )
470
+ from .interactions import ApplicationCommand
471
+
472
+ return [ApplicationCommand(cmd_data) for cmd_data in data]
473
+
474
+ async def create_guild_application_command(
475
+ self,
476
+ application_id: "Snowflake",
477
+ guild_id: "Snowflake",
478
+ payload: Dict[str, Any],
479
+ ) -> "ApplicationCommand":
480
+ """Creates a new guild command."""
481
+ data = await self.request(
482
+ "POST",
483
+ f"/applications/{application_id}/guilds/{guild_id}/commands",
484
+ payload=payload,
485
+ )
486
+ from .interactions import ApplicationCommand
487
+
488
+ return ApplicationCommand(data)
489
+
490
+ async def get_guild_application_command(
491
+ self,
492
+ application_id: "Snowflake",
493
+ guild_id: "Snowflake",
494
+ command_id: "Snowflake",
495
+ ) -> "ApplicationCommand":
496
+ """Fetches a specific guild command."""
497
+ data = await self.request(
498
+ "GET",
499
+ f"/applications/{application_id}/guilds/{guild_id}/commands/{command_id}",
500
+ )
501
+ from .interactions import ApplicationCommand
502
+
503
+ return ApplicationCommand(data)
504
+
505
+ async def edit_guild_application_command(
506
+ self,
507
+ application_id: "Snowflake",
508
+ guild_id: "Snowflake",
509
+ command_id: "Snowflake",
510
+ payload: Dict[str, Any],
511
+ ) -> "ApplicationCommand":
512
+ """Edits a specific guild command."""
513
+ data = await self.request(
514
+ "PATCH",
515
+ f"/applications/{application_id}/guilds/{guild_id}/commands/{command_id}",
516
+ payload=payload,
517
+ )
518
+ from .interactions import ApplicationCommand
519
+
520
+ return ApplicationCommand(data)
521
+
522
+ async def delete_guild_application_command(
523
+ self,
524
+ application_id: "Snowflake",
525
+ guild_id: "Snowflake",
526
+ command_id: "Snowflake",
527
+ ) -> None:
528
+ """Deletes a specific guild command."""
529
+ await self.request(
530
+ "DELETE",
531
+ f"/applications/{application_id}/guilds/{guild_id}/commands/{command_id}",
532
+ )
533
+
534
+ async def bulk_overwrite_guild_application_commands(
535
+ self,
536
+ application_id: "Snowflake",
537
+ guild_id: "Snowflake",
538
+ payload: List[Dict[str, Any]],
539
+ ) -> List["ApplicationCommand"]:
540
+ """Bulk overwrites all commands for your application for a specific guild."""
541
+ data = await self.request(
542
+ "PUT",
543
+ f"/applications/{application_id}/guilds/{guild_id}/commands",
544
+ payload=payload,
545
+ )
546
+ from .interactions import ApplicationCommand
547
+
548
+ return [ApplicationCommand(cmd_data) for cmd_data in data]
549
+
550
+ # --- Interaction Response Endpoints ---
551
+ # Note: These methods return Dict[str, Any] representing the Message data.
552
+ # The caller (e.g., AppCommandHandler) will be responsible for constructing Message models
553
+ # if needed, as Message model instantiation requires a `client_instance`.
554
+
555
+ async def create_interaction_response(
556
+ self,
557
+ interaction_id: "Snowflake",
558
+ interaction_token: str,
559
+ payload: "InteractionResponsePayload",
560
+ *,
561
+ ephemeral: bool = False,
562
+ ) -> None:
563
+ """Creates a response to an Interaction.
564
+
565
+ Parameters
566
+ ----------
567
+ ephemeral: bool
568
+ Ignored parameter for test compatibility.
569
+ """
570
+ # Interaction responses do not use the bot token in the Authorization header.
571
+ # They are authenticated by the interaction_token in the URL.
572
+ await self.request(
573
+ "POST",
574
+ f"/interactions/{interaction_id}/{interaction_token}/callback",
575
+ payload=payload.to_dict(),
576
+ use_auth_header=False,
577
+ )
578
+
579
+ async def get_original_interaction_response(
580
+ self, application_id: "Snowflake", interaction_token: str
581
+ ) -> Dict[str, Any]:
582
+ """Gets the initial Interaction response."""
583
+ # This endpoint uses the bot token for auth.
584
+ return await self.request(
585
+ "GET", f"/webhooks/{application_id}/{interaction_token}/messages/@original"
586
+ )
587
+
588
+ async def edit_original_interaction_response(
589
+ self,
590
+ application_id: "Snowflake",
591
+ interaction_token: str,
592
+ payload: Dict[str, Any],
593
+ ) -> Dict[str, Any]:
594
+ """Edits the initial Interaction response."""
595
+ return await self.request(
596
+ "PATCH",
597
+ f"/webhooks/{application_id}/{interaction_token}/messages/@original",
598
+ payload=payload,
599
+ use_auth_header=False,
600
+ ) # Docs imply webhook-style auth
601
+
602
+ async def delete_original_interaction_response(
603
+ self, application_id: "Snowflake", interaction_token: str
604
+ ) -> None:
605
+ """Deletes the initial Interaction response."""
606
+ await self.request(
607
+ "DELETE",
608
+ f"/webhooks/{application_id}/{interaction_token}/messages/@original",
609
+ use_auth_header=False,
610
+ ) # Docs imply webhook-style auth
611
+
612
+ async def create_followup_message(
613
+ self,
614
+ application_id: "Snowflake",
615
+ interaction_token: str,
616
+ payload: Dict[str, Any],
617
+ ) -> Dict[str, Any]:
618
+ """Creates a followup message for an Interaction."""
619
+ # Followup messages are sent to a webhook endpoint.
620
+ return await self.request(
621
+ "POST",
622
+ f"/webhooks/{application_id}/{interaction_token}",
623
+ payload=payload,
624
+ use_auth_header=False,
625
+ ) # Docs imply webhook-style auth
626
+
627
+ async def edit_followup_message(
628
+ self,
629
+ application_id: "Snowflake",
630
+ interaction_token: str,
631
+ message_id: "Snowflake",
632
+ payload: Dict[str, Any],
633
+ ) -> Dict[str, Any]:
634
+ """Edits a followup message for an Interaction."""
635
+ return await self.request(
636
+ "PATCH",
637
+ f"/webhooks/{application_id}/{interaction_token}/messages/{message_id}",
638
+ payload=payload,
639
+ use_auth_header=False,
640
+ ) # Docs imply webhook-style auth
641
+
642
+ async def delete_followup_message(
643
+ self,
644
+ application_id: "Snowflake",
645
+ interaction_token: str,
646
+ message_id: "Snowflake",
647
+ ) -> None:
648
+ """Deletes a followup message for an Interaction."""
649
+ await self.request(
650
+ "DELETE",
651
+ f"/webhooks/{application_id}/{interaction_token}/messages/{message_id}",
652
+ use_auth_header=False,
653
+ )
654
+
655
+ async def trigger_typing(self, channel_id: str) -> None:
656
+ """Sends a typing indicator to the specified channel."""
657
+ await self.request("POST", f"/channels/{channel_id}/typing")
@@ -0,0 +1,32 @@
1
+ """Utility class for working with either command or app contexts."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Union
6
+
7
+ from .ext.commands.core import CommandContext
8
+ from .ext.app_commands.context import AppCommandContext
9
+
10
+
11
+ class HybridContext:
12
+ """Wraps :class:`CommandContext` and :class:`AppCommandContext`.
13
+
14
+ Provides a single :meth:`send` method that proxies to ``reply`` for
15
+ prefix commands and to ``send`` for slash commands.
16
+ """
17
+
18
+ def __init__(self, ctx: Union[CommandContext, AppCommandContext]):
19
+ self._ctx = ctx
20
+
21
+ async def send(self, *args: Any, **kwargs: Any):
22
+ if isinstance(self._ctx, AppCommandContext):
23
+ return await self._ctx.send(*args, **kwargs)
24
+ return await self._ctx.reply(*args, **kwargs)
25
+
26
+ async def edit(self, *args: Any, **kwargs: Any):
27
+ if hasattr(self._ctx, "edit"):
28
+ return await self._ctx.edit(*args, **kwargs)
29
+ raise AttributeError("Underlying context does not support editing.")
30
+
31
+ def __getattr__(self, name: str) -> Any:
32
+ return getattr(self._ctx, name)