disagreement 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- disagreement/__init__.py +36 -0
- disagreement/cache.py +55 -0
- disagreement/client.py +1144 -0
- disagreement/components.py +166 -0
- disagreement/enums.py +357 -0
- disagreement/error_handler.py +33 -0
- disagreement/errors.py +112 -0
- disagreement/event_dispatcher.py +243 -0
- disagreement/gateway.py +490 -0
- disagreement/http.py +657 -0
- disagreement/hybrid_context.py +32 -0
- disagreement/i18n.py +22 -0
- disagreement/interactions.py +572 -0
- disagreement/logging_config.py +26 -0
- disagreement/models.py +1642 -0
- disagreement/oauth.py +109 -0
- disagreement/permissions.py +99 -0
- disagreement/rate_limiter.py +75 -0
- disagreement/shard_manager.py +65 -0
- disagreement/typing.py +42 -0
- disagreement/ui/__init__.py +17 -0
- disagreement/ui/button.py +99 -0
- disagreement/ui/item.py +38 -0
- disagreement/ui/modal.py +132 -0
- disagreement/ui/select.py +92 -0
- disagreement/ui/view.py +165 -0
- disagreement/voice_client.py +120 -0
- disagreement-0.0.1.dist-info/METADATA +163 -0
- disagreement-0.0.1.dist-info/RECORD +32 -0
- disagreement-0.0.1.dist-info/WHEEL +5 -0
- disagreement-0.0.1.dist-info/licenses/LICENSE +26 -0
- disagreement-0.0.1.dist-info/top_level.txt +1 -0
disagreement/oauth.py
ADDED
@@ -0,0 +1,109 @@
|
|
1
|
+
"""OAuth2 utilities."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import aiohttp
|
6
|
+
from typing import List, Optional, Dict, Any, Union
|
7
|
+
from urllib.parse import urlencode
|
8
|
+
|
9
|
+
from .errors import HTTPException
|
10
|
+
|
11
|
+
|
12
|
+
def build_authorization_url(
|
13
|
+
client_id: str,
|
14
|
+
redirect_uri: str,
|
15
|
+
scope: Union[str, List[str]],
|
16
|
+
*,
|
17
|
+
state: Optional[str] = None,
|
18
|
+
response_type: str = "code",
|
19
|
+
prompt: Optional[str] = None,
|
20
|
+
) -> str:
|
21
|
+
"""Return the Discord OAuth2 authorization URL."""
|
22
|
+
if isinstance(scope, list):
|
23
|
+
scope = " ".join(scope)
|
24
|
+
|
25
|
+
params = {
|
26
|
+
"client_id": client_id,
|
27
|
+
"redirect_uri": redirect_uri,
|
28
|
+
"response_type": response_type,
|
29
|
+
"scope": scope,
|
30
|
+
}
|
31
|
+
if state is not None:
|
32
|
+
params["state"] = state
|
33
|
+
if prompt is not None:
|
34
|
+
params["prompt"] = prompt
|
35
|
+
|
36
|
+
return "https://discord.com/oauth2/authorize?" + urlencode(params)
|
37
|
+
|
38
|
+
|
39
|
+
async def exchange_code_for_token(
|
40
|
+
client_id: str,
|
41
|
+
client_secret: str,
|
42
|
+
code: str,
|
43
|
+
redirect_uri: str,
|
44
|
+
*,
|
45
|
+
session: Optional[aiohttp.ClientSession] = None,
|
46
|
+
) -> Dict[str, Any]:
|
47
|
+
"""Exchange an authorization code for an access token."""
|
48
|
+
close = False
|
49
|
+
if session is None:
|
50
|
+
session = aiohttp.ClientSession()
|
51
|
+
close = True
|
52
|
+
|
53
|
+
data = {
|
54
|
+
"client_id": client_id,
|
55
|
+
"client_secret": client_secret,
|
56
|
+
"grant_type": "authorization_code",
|
57
|
+
"code": code,
|
58
|
+
"redirect_uri": redirect_uri,
|
59
|
+
}
|
60
|
+
|
61
|
+
resp = await session.post(
|
62
|
+
"https://discord.com/api/v10/oauth2/token",
|
63
|
+
data=data,
|
64
|
+
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
65
|
+
)
|
66
|
+
try:
|
67
|
+
json_data = await resp.json()
|
68
|
+
if resp.status != 200:
|
69
|
+
raise HTTPException(resp, message="OAuth token exchange failed")
|
70
|
+
finally:
|
71
|
+
if close:
|
72
|
+
await session.close()
|
73
|
+
return json_data
|
74
|
+
|
75
|
+
|
76
|
+
async def refresh_access_token(
|
77
|
+
refresh_token: str,
|
78
|
+
client_id: str,
|
79
|
+
client_secret: str,
|
80
|
+
*,
|
81
|
+
session: Optional[aiohttp.ClientSession] = None,
|
82
|
+
) -> Dict[str, Any]:
|
83
|
+
"""Refresh an access token using a refresh token."""
|
84
|
+
|
85
|
+
close = False
|
86
|
+
if session is None:
|
87
|
+
session = aiohttp.ClientSession()
|
88
|
+
close = True
|
89
|
+
|
90
|
+
data = {
|
91
|
+
"client_id": client_id,
|
92
|
+
"client_secret": client_secret,
|
93
|
+
"grant_type": "refresh_token",
|
94
|
+
"refresh_token": refresh_token,
|
95
|
+
}
|
96
|
+
|
97
|
+
resp = await session.post(
|
98
|
+
"https://discord.com/api/v10/oauth2/token",
|
99
|
+
data=data,
|
100
|
+
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
101
|
+
)
|
102
|
+
try:
|
103
|
+
json_data = await resp.json()
|
104
|
+
if resp.status != 200:
|
105
|
+
raise HTTPException(resp, message="OAuth token refresh failed")
|
106
|
+
finally:
|
107
|
+
if close:
|
108
|
+
await session.close()
|
109
|
+
return json_data
|
@@ -0,0 +1,99 @@
|
|
1
|
+
"""Utility helpers for working with Discord permission bitmasks."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
from enum import IntFlag
|
6
|
+
from typing import Iterable, List
|
7
|
+
|
8
|
+
|
9
|
+
class Permissions(IntFlag):
|
10
|
+
"""Discord guild and channel permissions."""
|
11
|
+
|
12
|
+
CREATE_INSTANT_INVITE = 1 << 0
|
13
|
+
KICK_MEMBERS = 1 << 1
|
14
|
+
BAN_MEMBERS = 1 << 2
|
15
|
+
ADMINISTRATOR = 1 << 3
|
16
|
+
MANAGE_CHANNELS = 1 << 4
|
17
|
+
MANAGE_GUILD = 1 << 5
|
18
|
+
ADD_REACTIONS = 1 << 6
|
19
|
+
VIEW_AUDIT_LOG = 1 << 7
|
20
|
+
PRIORITY_SPEAKER = 1 << 8
|
21
|
+
STREAM = 1 << 9
|
22
|
+
VIEW_CHANNEL = 1 << 10
|
23
|
+
SEND_MESSAGES = 1 << 11
|
24
|
+
SEND_TTS_MESSAGES = 1 << 12
|
25
|
+
MANAGE_MESSAGES = 1 << 13
|
26
|
+
EMBED_LINKS = 1 << 14
|
27
|
+
ATTACH_FILES = 1 << 15
|
28
|
+
READ_MESSAGE_HISTORY = 1 << 16
|
29
|
+
MENTION_EVERYONE = 1 << 17
|
30
|
+
USE_EXTERNAL_EMOJIS = 1 << 18
|
31
|
+
VIEW_GUILD_INSIGHTS = 1 << 19
|
32
|
+
CONNECT = 1 << 20
|
33
|
+
SPEAK = 1 << 21
|
34
|
+
MUTE_MEMBERS = 1 << 22
|
35
|
+
DEAFEN_MEMBERS = 1 << 23
|
36
|
+
MOVE_MEMBERS = 1 << 24
|
37
|
+
USE_VAD = 1 << 25
|
38
|
+
CHANGE_NICKNAME = 1 << 26
|
39
|
+
MANAGE_NICKNAMES = 1 << 27
|
40
|
+
MANAGE_ROLES = 1 << 28
|
41
|
+
MANAGE_WEBHOOKS = 1 << 29
|
42
|
+
MANAGE_GUILD_EXPRESSIONS = 1 << 30
|
43
|
+
USE_APPLICATION_COMMANDS = 1 << 31
|
44
|
+
REQUEST_TO_SPEAK = 1 << 32
|
45
|
+
MANAGE_EVENTS = 1 << 33
|
46
|
+
MANAGE_THREADS = 1 << 34
|
47
|
+
CREATE_PUBLIC_THREADS = 1 << 35
|
48
|
+
CREATE_PRIVATE_THREADS = 1 << 36
|
49
|
+
USE_EXTERNAL_STICKERS = 1 << 37
|
50
|
+
SEND_MESSAGES_IN_THREADS = 1 << 38
|
51
|
+
USE_EMBEDDED_ACTIVITIES = 1 << 39
|
52
|
+
MODERATE_MEMBERS = 1 << 40
|
53
|
+
VIEW_CREATOR_MONETIZATION_ANALYTICS = 1 << 41
|
54
|
+
USE_SOUNDBOARD = 1 << 42
|
55
|
+
CREATE_GUILD_EXPRESSIONS = 1 << 43
|
56
|
+
CREATE_EVENTS = 1 << 44
|
57
|
+
USE_EXTERNAL_SOUNDS = 1 << 45
|
58
|
+
SEND_VOICE_MESSAGES = 1 << 46
|
59
|
+
|
60
|
+
|
61
|
+
def permissions_value(*perms: Permissions | int | Iterable[Permissions | int]) -> int:
|
62
|
+
"""Return a combined integer value for multiple permissions."""
|
63
|
+
|
64
|
+
value = 0
|
65
|
+
for perm in perms:
|
66
|
+
if isinstance(perm, Iterable) and not isinstance(perm, (Permissions, int)):
|
67
|
+
value |= permissions_value(*perm)
|
68
|
+
else:
|
69
|
+
value |= int(perm)
|
70
|
+
return value
|
71
|
+
|
72
|
+
|
73
|
+
def has_permissions(
|
74
|
+
current: int | str | Permissions,
|
75
|
+
*perms: Permissions | int | Iterable[Permissions | int],
|
76
|
+
) -> bool:
|
77
|
+
"""Return ``True`` if ``current`` includes all ``perms``."""
|
78
|
+
|
79
|
+
current_val = int(current)
|
80
|
+
needed = permissions_value(*perms)
|
81
|
+
return (current_val & needed) == needed
|
82
|
+
|
83
|
+
|
84
|
+
def missing_permissions(
|
85
|
+
current: int | str | Permissions,
|
86
|
+
*perms: Permissions | int | Iterable[Permissions | int],
|
87
|
+
) -> List[Permissions]:
|
88
|
+
"""Return the subset of ``perms`` not present in ``current``."""
|
89
|
+
|
90
|
+
current_val = int(current)
|
91
|
+
missing: List[Permissions] = []
|
92
|
+
for perm in perms:
|
93
|
+
if isinstance(perm, Iterable) and not isinstance(perm, (Permissions, int)):
|
94
|
+
missing.extend(missing_permissions(current_val, *perm))
|
95
|
+
else:
|
96
|
+
perm_val = int(perm)
|
97
|
+
if not current_val & perm_val:
|
98
|
+
missing.append(Permissions(perm_val))
|
99
|
+
return missing
|
@@ -0,0 +1,75 @@
|
|
1
|
+
"""Asynchronous rate limiter for Discord HTTP requests."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import asyncio
|
6
|
+
import time
|
7
|
+
from typing import Dict, Mapping
|
8
|
+
|
9
|
+
|
10
|
+
class _Bucket:
|
11
|
+
def __init__(self) -> None:
|
12
|
+
self.remaining: int = 1
|
13
|
+
self.reset_at: float = 0.0
|
14
|
+
self.lock = asyncio.Lock()
|
15
|
+
|
16
|
+
|
17
|
+
class RateLimiter:
|
18
|
+
"""Rate limiter implementing per-route buckets and a global queue."""
|
19
|
+
|
20
|
+
def __init__(self) -> None:
|
21
|
+
self._buckets: Dict[str, _Bucket] = {}
|
22
|
+
self._global_event = asyncio.Event()
|
23
|
+
self._global_event.set()
|
24
|
+
|
25
|
+
def _get_bucket(self, route: str) -> _Bucket:
|
26
|
+
bucket = self._buckets.get(route)
|
27
|
+
if bucket is None:
|
28
|
+
bucket = _Bucket()
|
29
|
+
self._buckets[route] = bucket
|
30
|
+
return bucket
|
31
|
+
|
32
|
+
async def acquire(self, route: str) -> _Bucket:
|
33
|
+
bucket = self._get_bucket(route)
|
34
|
+
while True:
|
35
|
+
await self._global_event.wait()
|
36
|
+
async with bucket.lock:
|
37
|
+
now = time.monotonic()
|
38
|
+
if bucket.remaining <= 0 and now < bucket.reset_at:
|
39
|
+
await asyncio.sleep(bucket.reset_at - now)
|
40
|
+
continue
|
41
|
+
if bucket.remaining > 0:
|
42
|
+
bucket.remaining -= 1
|
43
|
+
return bucket
|
44
|
+
|
45
|
+
def release(self, route: str, headers: Mapping[str, str]) -> None:
|
46
|
+
bucket = self._get_bucket(route)
|
47
|
+
try:
|
48
|
+
remaining = int(headers.get("X-RateLimit-Remaining", bucket.remaining))
|
49
|
+
reset_after = float(headers.get("X-RateLimit-Reset-After", "0"))
|
50
|
+
bucket.remaining = remaining
|
51
|
+
bucket.reset_at = time.monotonic() + reset_after
|
52
|
+
except ValueError:
|
53
|
+
pass
|
54
|
+
|
55
|
+
if headers.get("X-RateLimit-Global", "false").lower() == "true":
|
56
|
+
retry_after = float(headers.get("Retry-After", "0"))
|
57
|
+
self._global_event.clear()
|
58
|
+
asyncio.create_task(self._lift_global(retry_after))
|
59
|
+
|
60
|
+
async def handle_rate_limit(
|
61
|
+
self, route: str, retry_after: float, is_global: bool
|
62
|
+
) -> None:
|
63
|
+
bucket = self._get_bucket(route)
|
64
|
+
bucket.remaining = 0
|
65
|
+
bucket.reset_at = time.monotonic() + retry_after
|
66
|
+
if is_global:
|
67
|
+
self._global_event.clear()
|
68
|
+
await asyncio.sleep(retry_after)
|
69
|
+
self._global_event.set()
|
70
|
+
else:
|
71
|
+
await asyncio.sleep(retry_after)
|
72
|
+
|
73
|
+
async def _lift_global(self, delay: float) -> None:
|
74
|
+
await asyncio.sleep(delay)
|
75
|
+
self._global_event.set()
|
@@ -0,0 +1,65 @@
|
|
1
|
+
# disagreement/shard_manager.py
|
2
|
+
|
3
|
+
"""Sharding utilities for managing multiple gateway connections."""
|
4
|
+
|
5
|
+
from __future__ import annotations
|
6
|
+
|
7
|
+
import asyncio
|
8
|
+
from typing import List, TYPE_CHECKING
|
9
|
+
|
10
|
+
from .gateway import GatewayClient
|
11
|
+
|
12
|
+
if TYPE_CHECKING: # pragma: no cover - for type checking only
|
13
|
+
from .client import Client
|
14
|
+
|
15
|
+
|
16
|
+
class Shard:
|
17
|
+
"""Represents a single gateway shard."""
|
18
|
+
|
19
|
+
def __init__(self, shard_id: int, shard_count: int, gateway: GatewayClient) -> None:
|
20
|
+
self.id: int = shard_id
|
21
|
+
self.count: int = shard_count
|
22
|
+
self.gateway: GatewayClient = gateway
|
23
|
+
|
24
|
+
async def connect(self) -> None:
|
25
|
+
"""Connects this shard's gateway."""
|
26
|
+
await self.gateway.connect()
|
27
|
+
|
28
|
+
async def close(self) -> None:
|
29
|
+
"""Closes this shard's gateway."""
|
30
|
+
await self.gateway.close()
|
31
|
+
|
32
|
+
|
33
|
+
class ShardManager:
|
34
|
+
"""Manages multiple :class:`Shard` instances."""
|
35
|
+
|
36
|
+
def __init__(self, client: "Client", shard_count: int) -> None:
|
37
|
+
self.client: "Client" = client
|
38
|
+
self.shard_count: int = shard_count
|
39
|
+
self.shards: List[Shard] = []
|
40
|
+
|
41
|
+
def _create_shards(self) -> None:
|
42
|
+
if self.shards:
|
43
|
+
return
|
44
|
+
for shard_id in range(self.shard_count):
|
45
|
+
gateway = GatewayClient(
|
46
|
+
http_client=self.client._http,
|
47
|
+
event_dispatcher=self.client._event_dispatcher,
|
48
|
+
token=self.client.token,
|
49
|
+
intents=self.client.intents,
|
50
|
+
client_instance=self.client,
|
51
|
+
verbose=self.client.verbose,
|
52
|
+
shard_id=shard_id,
|
53
|
+
shard_count=self.shard_count,
|
54
|
+
)
|
55
|
+
self.shards.append(Shard(shard_id, self.shard_count, gateway))
|
56
|
+
|
57
|
+
async def start(self) -> None:
|
58
|
+
"""Starts all shards."""
|
59
|
+
self._create_shards()
|
60
|
+
await asyncio.gather(*(s.connect() for s in self.shards))
|
61
|
+
|
62
|
+
async def close(self) -> None:
|
63
|
+
"""Closes all shards."""
|
64
|
+
await asyncio.gather(*(s.close() for s in self.shards))
|
65
|
+
self.shards.clear()
|
disagreement/typing.py
ADDED
@@ -0,0 +1,42 @@
|
|
1
|
+
import asyncio
|
2
|
+
from contextlib import suppress
|
3
|
+
from typing import Optional, TYPE_CHECKING
|
4
|
+
|
5
|
+
from .errors import DisagreementException
|
6
|
+
|
7
|
+
if TYPE_CHECKING:
|
8
|
+
from .client import Client
|
9
|
+
|
10
|
+
if __name__ == "typing":
|
11
|
+
# For direct module execution testing
|
12
|
+
pass
|
13
|
+
|
14
|
+
|
15
|
+
class Typing:
|
16
|
+
"""Async context manager for Discord typing indicator."""
|
17
|
+
|
18
|
+
def __init__(self, client: "Client", channel_id: str) -> None:
|
19
|
+
self._client = client
|
20
|
+
self._channel_id = channel_id
|
21
|
+
self._task: Optional[asyncio.Task] = None
|
22
|
+
|
23
|
+
async def _run(self) -> None:
|
24
|
+
try:
|
25
|
+
while True:
|
26
|
+
await self._client._http.trigger_typing(self._channel_id)
|
27
|
+
await asyncio.sleep(5)
|
28
|
+
except asyncio.CancelledError:
|
29
|
+
pass
|
30
|
+
|
31
|
+
async def __aenter__(self) -> "Typing":
|
32
|
+
if self._client._closed:
|
33
|
+
raise DisagreementException("Client is closed.")
|
34
|
+
await self._client._http.trigger_typing(self._channel_id)
|
35
|
+
self._task = asyncio.create_task(self._run())
|
36
|
+
return self
|
37
|
+
|
38
|
+
async def __aexit__(self, exc_type, exc, tb) -> None:
|
39
|
+
if self._task:
|
40
|
+
self._task.cancel()
|
41
|
+
with suppress(asyncio.CancelledError):
|
42
|
+
await self._task
|
@@ -0,0 +1,17 @@
|
|
1
|
+
from .view import View
|
2
|
+
from .item import Item
|
3
|
+
from .button import Button, button
|
4
|
+
from .select import Select, select
|
5
|
+
from .modal import Modal, TextInput, text_input
|
6
|
+
|
7
|
+
__all__ = [
|
8
|
+
"View",
|
9
|
+
"Item",
|
10
|
+
"Button",
|
11
|
+
"button",
|
12
|
+
"Select",
|
13
|
+
"select",
|
14
|
+
"Modal",
|
15
|
+
"TextInput",
|
16
|
+
"text_input",
|
17
|
+
]
|
@@ -0,0 +1,99 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import asyncio
|
4
|
+
from typing import Any, Callable, Coroutine, Optional, TYPE_CHECKING
|
5
|
+
|
6
|
+
from .item import Item
|
7
|
+
from ..enums import ComponentType, ButtonStyle
|
8
|
+
from ..models import PartialEmoji, to_partial_emoji
|
9
|
+
|
10
|
+
if TYPE_CHECKING:
|
11
|
+
from ..interactions import Interaction
|
12
|
+
|
13
|
+
|
14
|
+
class Button(Item):
|
15
|
+
"""Represents a button component in a View.
|
16
|
+
|
17
|
+
Args:
|
18
|
+
style (ButtonStyle): The style of the button.
|
19
|
+
label (Optional[str]): The text that appears on the button.
|
20
|
+
emoji (Optional[str | PartialEmoji]): The emoji that appears on the button.
|
21
|
+
custom_id (Optional[str]): The developer-defined identifier for the button.
|
22
|
+
url (Optional[str]): The URL for the button.
|
23
|
+
disabled (bool): Whether the button is disabled.
|
24
|
+
row (Optional[int]): The row the button should be placed in, from 0 to 4.
|
25
|
+
"""
|
26
|
+
|
27
|
+
def __init__(
|
28
|
+
self,
|
29
|
+
*,
|
30
|
+
style: ButtonStyle = ButtonStyle.SECONDARY,
|
31
|
+
label: Optional[str] = None,
|
32
|
+
emoji: Optional[str | PartialEmoji] = None,
|
33
|
+
custom_id: Optional[str] = None,
|
34
|
+
url: Optional[str] = None,
|
35
|
+
disabled: bool = False,
|
36
|
+
row: Optional[int] = None,
|
37
|
+
):
|
38
|
+
super().__init__(type=ComponentType.BUTTON)
|
39
|
+
if not label and not emoji:
|
40
|
+
raise ValueError("A button must have a label and/or an emoji.")
|
41
|
+
|
42
|
+
if url and custom_id:
|
43
|
+
raise ValueError("A button cannot have both a URL and a custom_id.")
|
44
|
+
|
45
|
+
self.style = style
|
46
|
+
self.label = label
|
47
|
+
self.emoji = to_partial_emoji(emoji)
|
48
|
+
self.custom_id = custom_id
|
49
|
+
self.url = url
|
50
|
+
self.disabled = disabled
|
51
|
+
self._row = row
|
52
|
+
|
53
|
+
def to_dict(self) -> dict[str, Any]:
|
54
|
+
"""Converts the button to a dictionary that can be sent to Discord."""
|
55
|
+
payload = {
|
56
|
+
"type": ComponentType.BUTTON.value,
|
57
|
+
"style": self.style.value,
|
58
|
+
"disabled": self.disabled,
|
59
|
+
}
|
60
|
+
if self.label:
|
61
|
+
payload["label"] = self.label
|
62
|
+
if self.emoji:
|
63
|
+
payload["emoji"] = self.emoji.to_dict()
|
64
|
+
if self.url:
|
65
|
+
payload["url"] = self.url
|
66
|
+
if self.custom_id:
|
67
|
+
payload["custom_id"] = self.custom_id
|
68
|
+
return payload
|
69
|
+
|
70
|
+
|
71
|
+
def button(
|
72
|
+
*,
|
73
|
+
label: Optional[str] = None,
|
74
|
+
custom_id: Optional[str] = None,
|
75
|
+
style: ButtonStyle = ButtonStyle.SECONDARY,
|
76
|
+
emoji: Optional[str | PartialEmoji] = None,
|
77
|
+
url: Optional[str] = None,
|
78
|
+
disabled: bool = False,
|
79
|
+
row: Optional[int] = None,
|
80
|
+
) -> Callable[[Callable[..., Coroutine[Any, Any, Any]]], Button]:
|
81
|
+
"""A decorator to create a button in a View."""
|
82
|
+
|
83
|
+
def decorator(func: Callable[..., Coroutine[Any, Any, Any]]) -> Button:
|
84
|
+
if not asyncio.iscoroutinefunction(func):
|
85
|
+
raise TypeError("Button callback must be a coroutine function.")
|
86
|
+
|
87
|
+
item = Button(
|
88
|
+
label=label,
|
89
|
+
custom_id=custom_id,
|
90
|
+
style=style,
|
91
|
+
emoji=emoji,
|
92
|
+
url=url,
|
93
|
+
disabled=disabled,
|
94
|
+
row=row,
|
95
|
+
)
|
96
|
+
item.callback = func
|
97
|
+
return item
|
98
|
+
|
99
|
+
return decorator
|
disagreement/ui/item.py
ADDED
@@ -0,0 +1,38 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Any, Callable, Coroutine, Optional, TYPE_CHECKING
|
4
|
+
|
5
|
+
from ..models import Component
|
6
|
+
|
7
|
+
if TYPE_CHECKING:
|
8
|
+
from .view import View
|
9
|
+
from ..interactions import Interaction
|
10
|
+
|
11
|
+
|
12
|
+
class Item(Component):
|
13
|
+
"""Represents a UI item that can be placed in a View.
|
14
|
+
|
15
|
+
This is a base class and is not meant to be used directly.
|
16
|
+
"""
|
17
|
+
|
18
|
+
def __init__(self, *args, **kwargs):
|
19
|
+
super().__init__(*args, **kwargs)
|
20
|
+
self._view: Optional[View] = None
|
21
|
+
self._row: Optional[int] = None
|
22
|
+
# This is the callback associated with this item.
|
23
|
+
self.callback: Optional[
|
24
|
+
Callable[["View", Interaction], Coroutine[Any, Any, Any]]
|
25
|
+
] = None
|
26
|
+
|
27
|
+
@property
|
28
|
+
def view(self) -> Optional[View]:
|
29
|
+
return self._view
|
30
|
+
|
31
|
+
@property
|
32
|
+
def row(self) -> Optional[int]:
|
33
|
+
return self._row
|
34
|
+
|
35
|
+
def _refresh_from_data(self, data: dict[str, Any]) -> None:
|
36
|
+
# This is used to update the item's state from incoming interaction data.
|
37
|
+
# For example, a button's disabled state could be updated here.
|
38
|
+
pass
|
disagreement/ui/modal.py
ADDED
@@ -0,0 +1,132 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Any, Callable, Coroutine, Optional, List, TYPE_CHECKING
|
4
|
+
import asyncio
|
5
|
+
|
6
|
+
from .item import Item
|
7
|
+
from .view import View
|
8
|
+
from ..enums import ComponentType, TextInputStyle
|
9
|
+
from ..models import ActionRow
|
10
|
+
|
11
|
+
if TYPE_CHECKING: # pragma: no cover - for type hints only
|
12
|
+
from ..interactions import Interaction
|
13
|
+
|
14
|
+
|
15
|
+
class TextInput(Item):
|
16
|
+
"""Represents a text input component inside a modal."""
|
17
|
+
|
18
|
+
def __init__(
|
19
|
+
self,
|
20
|
+
*,
|
21
|
+
label: str,
|
22
|
+
custom_id: Optional[str] = None,
|
23
|
+
style: TextInputStyle = TextInputStyle.SHORT,
|
24
|
+
placeholder: Optional[str] = None,
|
25
|
+
required: bool = True,
|
26
|
+
min_length: Optional[int] = None,
|
27
|
+
max_length: Optional[int] = None,
|
28
|
+
row: Optional[int] = None,
|
29
|
+
) -> None:
|
30
|
+
super().__init__(type=ComponentType.TEXT_INPUT)
|
31
|
+
self.label = label
|
32
|
+
self.custom_id = custom_id
|
33
|
+
self.style = style
|
34
|
+
self.placeholder = placeholder
|
35
|
+
self.required = required
|
36
|
+
self.min_length = min_length
|
37
|
+
self.max_length = max_length
|
38
|
+
self._row = row
|
39
|
+
|
40
|
+
def to_dict(self) -> dict[str, Any]:
|
41
|
+
payload = {
|
42
|
+
"type": ComponentType.TEXT_INPUT.value,
|
43
|
+
"style": self.style.value,
|
44
|
+
"label": self.label,
|
45
|
+
"required": self.required,
|
46
|
+
}
|
47
|
+
if self.custom_id:
|
48
|
+
payload["custom_id"] = self.custom_id
|
49
|
+
if self.placeholder:
|
50
|
+
payload["placeholder"] = self.placeholder
|
51
|
+
if self.min_length is not None:
|
52
|
+
payload["min_length"] = self.min_length
|
53
|
+
if self.max_length is not None:
|
54
|
+
payload["max_length"] = self.max_length
|
55
|
+
return payload
|
56
|
+
|
57
|
+
|
58
|
+
def text_input(
|
59
|
+
*,
|
60
|
+
label: str,
|
61
|
+
custom_id: Optional[str] = None,
|
62
|
+
style: TextInputStyle = TextInputStyle.SHORT,
|
63
|
+
placeholder: Optional[str] = None,
|
64
|
+
required: bool = True,
|
65
|
+
min_length: Optional[int] = None,
|
66
|
+
max_length: Optional[int] = None,
|
67
|
+
row: Optional[int] = None,
|
68
|
+
) -> Callable[[Callable[..., Coroutine[Any, Any, Any]]], TextInput]:
|
69
|
+
"""Decorator to define a text input callback inside a :class:`Modal`."""
|
70
|
+
|
71
|
+
def decorator(func: Callable[..., Coroutine[Any, Any, Any]]) -> TextInput:
|
72
|
+
if not asyncio.iscoroutinefunction(func):
|
73
|
+
raise TypeError("TextInput callback must be a coroutine function.")
|
74
|
+
|
75
|
+
item = TextInput(
|
76
|
+
label=label,
|
77
|
+
custom_id=custom_id,
|
78
|
+
style=style,
|
79
|
+
placeholder=placeholder,
|
80
|
+
required=required,
|
81
|
+
min_length=min_length,
|
82
|
+
max_length=max_length,
|
83
|
+
row=row,
|
84
|
+
)
|
85
|
+
item.callback = func
|
86
|
+
return item
|
87
|
+
|
88
|
+
return decorator
|
89
|
+
|
90
|
+
|
91
|
+
class Modal:
|
92
|
+
"""Represents a modal dialog."""
|
93
|
+
|
94
|
+
def __init__(self, *, title: str, custom_id: str) -> None:
|
95
|
+
self.title = title
|
96
|
+
self.custom_id = custom_id
|
97
|
+
self._children: List[TextInput] = []
|
98
|
+
|
99
|
+
for item in self.__class__.__dict__.values():
|
100
|
+
if isinstance(item, TextInput):
|
101
|
+
self.add_item(item)
|
102
|
+
|
103
|
+
@property
|
104
|
+
def children(self) -> List[TextInput]:
|
105
|
+
return self._children
|
106
|
+
|
107
|
+
def add_item(self, item: TextInput) -> None:
|
108
|
+
if not isinstance(item, TextInput):
|
109
|
+
raise TypeError("Only TextInput items can be added to a Modal.")
|
110
|
+
if len(self._children) >= 5:
|
111
|
+
raise ValueError("A modal can only have up to 5 text inputs.")
|
112
|
+
item._view = None # Not part of a view but reuse item base
|
113
|
+
self._children.append(item)
|
114
|
+
|
115
|
+
def to_components(self) -> List[ActionRow]:
|
116
|
+
rows: List[ActionRow] = []
|
117
|
+
for child in self.children:
|
118
|
+
row = ActionRow(components=[child])
|
119
|
+
rows.append(row)
|
120
|
+
return rows
|
121
|
+
|
122
|
+
def to_dict(self) -> dict[str, Any]:
|
123
|
+
return {
|
124
|
+
"title": self.title,
|
125
|
+
"custom_id": self.custom_id,
|
126
|
+
"components": [r.to_dict() for r in self.to_components()],
|
127
|
+
}
|
128
|
+
|
129
|
+
async def callback(
|
130
|
+
self, interaction: Interaction
|
131
|
+
) -> None: # pragma: no cover - default
|
132
|
+
pass
|