satori-python 0.11.4__tar.gz → 0.12.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. {satori_python-0.11.4 → satori_python-0.12.0}/PKG-INFO +2 -2
  2. {satori_python-0.11.4 → satori_python-0.12.0}/pyproject.toml +5 -5
  3. {satori_python-0.11.4 → satori_python-0.12.0}/src/satori/__init__.py +1 -1
  4. {satori_python-0.11.4 → satori_python-0.12.0}/src/satori/client/__init__.py +70 -27
  5. {satori_python-0.11.4 → satori_python-0.12.0}/src/satori/client/account.py +11 -1
  6. {satori_python-0.11.4 → satori_python-0.12.0}/src/satori/client/account.pyi +28 -4
  7. {satori_python-0.11.4 → satori_python-0.12.0}/src/satori/client/network/base.py +1 -1
  8. satori_python-0.12.0/src/satori/client/network/util.py +32 -0
  9. {satori_python-0.11.4 → satori_python-0.12.0}/src/satori/client/network/webhook.py +13 -2
  10. {satori_python-0.11.4 → satori_python-0.12.0}/src/satori/client/network/websocket.py +2 -2
  11. {satori_python-0.11.4 → satori_python-0.12.0}/src/satori/client/session.py +58 -36
  12. {satori_python-0.11.4 → satori_python-0.12.0}/src/satori/element.py +10 -5
  13. {satori_python-0.11.4 → satori_python-0.12.0}/src/satori/model.py +79 -101
  14. {satori_python-0.11.4 → satori_python-0.12.0}/src/satori/parser.py +2 -0
  15. {satori_python-0.11.4 → satori_python-0.12.0}/src/satori/server/__init__.py +31 -178
  16. satori_python-0.12.0/src/satori/server/adapter.py +45 -0
  17. {satori_python-0.11.4 → satori_python-0.12.0}/src/satori/server/model.py +6 -5
  18. satori_python-0.12.0/src/satori/server/route.py +396 -0
  19. satori_python-0.11.4/src/satori/server/adapter.py +0 -34
  20. satori_python-0.11.4/src/satori/server/route.py +0 -240
  21. {satori_python-0.11.4 → satori_python-0.12.0}/LICENSE +0 -0
  22. {satori_python-0.11.4 → satori_python-0.12.0}/README.md +0 -0
  23. {satori_python-0.11.4 → satori_python-0.12.0}/src/satori/client/network/__init__.py +0 -0
  24. {satori_python-0.11.4 → satori_python-0.12.0}/src/satori/config.py +0 -0
  25. {satori_python-0.11.4 → satori_python-0.12.0}/src/satori/const.py +0 -0
  26. {satori_python-0.11.4 → satori_python-0.12.0}/src/satori/event.py +0 -0
  27. {satori_python-0.11.4 → satori_python-0.12.0}/src/satori/exception.py +0 -0
  28. {satori_python-0.11.4 → satori_python-0.12.0}/src/satori/server/conection.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: satori-python
3
- Version: 0.11.4
3
+ Version: 0.12.0
4
4
  Summary: Satori Protocol SDK for python
5
5
  Home-page: https://github.com/RF-Tar-Railt/satori-python
6
6
  Author-Email: RF-Tar-Railt <rf_tar_railt@qq.com>
@@ -20,7 +20,7 @@ Requires-Python: >=3.9
20
20
  Requires-Dist: aiohttp>=3.9.3
21
21
  Requires-Dist: loguru>=0.7.2
22
22
  Requires-Dist: launart>=0.8.2
23
- Requires-Dist: typing-extensions>=4.10.0
23
+ Requires-Dist: typing-extensions>=4.7.0
24
24
  Requires-Dist: graia-amnesia>=0.9.0
25
25
  Requires-Dist: starlette>=0.37.2
26
26
  Requires-Dist: uvicorn[standard]>=0.28.0
@@ -9,7 +9,7 @@ dependencies = [
9
9
  "aiohttp>=3.9.3",
10
10
  "loguru>=0.7.2",
11
11
  "launart>=0.8.2",
12
- "typing-extensions>=4.10.0",
12
+ "typing-extensions>=4.7.0",
13
13
  "graia-amnesia>=0.9.0",
14
14
  "starlette>=0.37.2",
15
15
  "uvicorn[standard]>=0.28.0",
@@ -28,7 +28,7 @@ classifiers = [
28
28
  "Programming Language :: Python :: 3.12",
29
29
  "Operating System :: OS Independent",
30
30
  ]
31
- version = "0.11.4"
31
+ version = "0.12.0"
32
32
 
33
33
  [project.license]
34
34
  text = "MIT"
@@ -46,9 +46,9 @@ build-backend = "mina.backend"
46
46
  [tool.pdm.dev-dependencies]
47
47
  dev = [
48
48
  "isort>=5.13.2",
49
- "black>=24.2.0",
50
- "ruff>=0.3.2",
51
- "pre-commit>=3.6.2",
49
+ "black>=24.4.0",
50
+ "ruff>=0.4.1",
51
+ "pre-commit>=3.7.0",
52
52
  "fix-future-annotations>=0.5.0",
53
53
  "mina-build<0.6,>=0.5.1",
54
54
  "pdm-mina>=0.3.2",
@@ -39,4 +39,4 @@ from .model import MessageObject as MessageObject
39
39
  from .model import Role as Role
40
40
  from .model import User as User
41
41
 
42
- __version__ = "0.11.4"
42
+ __version__ = "0.12.0"
@@ -5,13 +5,14 @@ import functools
5
5
  import signal
6
6
  import threading
7
7
  from functools import wraps
8
- from typing import Any, Awaitable, Callable, Iterable, Literal, TypeVar, overload
8
+ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterable, Literal, TypeVar, overload
9
9
 
10
10
  from creart import it
11
+ from graia.amnesia.builtins.aiohttp import AiohttpClientService
11
12
  from launart import Launart, Service, any_completed
12
13
  from loguru import logger
13
14
 
14
- from satori import event
15
+ from satori import event as events
15
16
  from satori.config import Config, WebhookInfo, WebsocketsInfo
16
17
  from satori.const import EventType
17
18
  from satori.model import Event, LoginStatus
@@ -33,7 +34,7 @@ MAPPING: dict[type[Config], type[BaseNetwork]] = {
33
34
 
34
35
  class App(Service):
35
36
  id = "satori-python.client"
36
- required: set[str] = set()
37
+ required: set[str] = {"http.client/aiohttp"}
37
38
  stages: set[str] = {"preparing", "blocking", "cleanup"}
38
39
 
39
40
  accounts: dict[str, Account]
@@ -69,8 +70,8 @@ class App(Service):
69
70
 
70
71
  @overload
71
72
  def register_on(self, event_type: Literal[EventType.FRIEND_REQUEST]) -> Callable[
72
- [Callable[[Account, event.UserEvent], Awaitable[Any]]],
73
- Callable[[Account, event.UserEvent], Awaitable[Any]],
73
+ [Callable[[Account, events.UserEvent], Awaitable[Any]]],
74
+ Callable[[Account, events.UserEvent], Awaitable[Any]],
74
75
  ]: ...
75
76
 
76
77
  @overload
@@ -80,8 +81,8 @@ class App(Service):
80
81
  EventType.GUILD_ADDED, EventType.GUILD_REMOVED, EventType.GUILD_REQUEST, EventType.GUILD_UPDATED
81
82
  ],
82
83
  ) -> Callable[
83
- [Callable[[Account, event.GuildEvent], Awaitable[Any]]],
84
- Callable[[Account, event.GuildEvent], Awaitable[Any]],
84
+ [Callable[[Account, events.GuildEvent], Awaitable[Any]]],
85
+ Callable[[Account, events.GuildEvent], Awaitable[Any]],
85
86
  ]: ...
86
87
 
87
88
  @overload
@@ -94,8 +95,8 @@ class App(Service):
94
95
  EventType.GUILD_MEMBER_REQUEST,
95
96
  ],
96
97
  ) -> Callable[
97
- [Callable[[Account, event.GuildMemberEvent], Awaitable[Any]]],
98
- Callable[[Account, event.GuildMemberEvent], Awaitable[Any]],
98
+ [Callable[[Account, events.GuildMemberEvent], Awaitable[Any]]],
99
+ Callable[[Account, events.GuildMemberEvent], Awaitable[Any]],
99
100
  ]: ...
100
101
 
101
102
  @overload
@@ -105,16 +106,16 @@ class App(Service):
105
106
  EventType.GUILD_ROLE_CREATED, EventType.GUILD_ROLE_DELETED, EventType.GUILD_ROLE_UPDATED
106
107
  ],
107
108
  ) -> Callable[
108
- [Callable[[Account, event.GuildRoleEvent], Awaitable[Any]]],
109
- Callable[[Account, event.GuildRoleEvent], Awaitable[Any]],
109
+ [Callable[[Account, events.GuildRoleEvent], Awaitable[Any]]],
110
+ Callable[[Account, events.GuildRoleEvent], Awaitable[Any]],
110
111
  ]: ...
111
112
 
112
113
  @overload
113
114
  def register_on(
114
115
  self, event_type: Literal[EventType.LOGIN_ADDED, EventType.LOGIN_REMOVED, EventType.LOGIN_UPDATED]
115
116
  ) -> Callable[
116
- [Callable[[Account, event.LoginEvent], Awaitable[Any]]],
117
- Callable[[Account, event.LoginEvent], Awaitable[Any]],
117
+ [Callable[[Account, events.LoginEvent], Awaitable[Any]]],
118
+ Callable[[Account, events.LoginEvent], Awaitable[Any]],
118
119
  ]: ...
119
120
 
120
121
  @overload
@@ -122,34 +123,34 @@ class App(Service):
122
123
  self,
123
124
  event_type: Literal[EventType.MESSAGE_CREATED, EventType.MESSAGE_DELETED, EventType.MESSAGE_UPDATED],
124
125
  ) -> Callable[
125
- [Callable[[Account, event.MessageEvent], Awaitable[Any]]],
126
- Callable[[Account, event.MessageEvent], Awaitable[Any]],
126
+ [Callable[[Account, events.MessageEvent], Awaitable[Any]]],
127
+ Callable[[Account, events.MessageEvent], Awaitable[Any]],
127
128
  ]: ...
128
129
 
129
130
  @overload
130
131
  def register_on(
131
132
  self, event_type: Literal[EventType.REACTION_ADDED, EventType.REACTION_REMOVED]
132
133
  ) -> Callable[
133
- [Callable[[Account, event.ReactionEvent], Awaitable[Any]]],
134
- Callable[[Account, event.ReactionEvent], Awaitable[Any]],
134
+ [Callable[[Account, events.ReactionEvent], Awaitable[Any]]],
135
+ Callable[[Account, events.ReactionEvent], Awaitable[Any]],
135
136
  ]: ...
136
137
 
137
138
  @overload
138
139
  def register_on(self, event_type: Literal[EventType.INTERACTION_BUTTON]) -> Callable[
139
- [Callable[[Account, event.ButtonInteractionEvent], Awaitable[Any]]],
140
- Callable[[Account, event.ButtonInteractionEvent], Awaitable[Any]],
140
+ [Callable[[Account, events.ButtonInteractionEvent], Awaitable[Any]]],
141
+ Callable[[Account, events.ButtonInteractionEvent], Awaitable[Any]],
141
142
  ]: ...
142
143
 
143
144
  @overload
144
145
  def register_on(self, event_type: Literal[EventType.INTERACTION_COMMAND]) -> Callable[
145
- [Callable[[Account, event.ArgvInteractionEvent | event.MessageEvent], Awaitable[Any]]],
146
- Callable[[Account, event.ArgvInteractionEvent | event.MessageEvent], Awaitable[Any]],
146
+ [Callable[[Account, events.ArgvInteractionEvent | events.MessageEvent], Awaitable[Any]]],
147
+ Callable[[Account, events.ArgvInteractionEvent | events.MessageEvent], Awaitable[Any]],
147
148
  ]: ...
148
149
 
149
150
  @overload
150
151
  def register_on(self, event_type: Literal[EventType.INTERNAL]) -> Callable[
151
- [Callable[[Account, event.InternalEvent], Awaitable[Any]]],
152
- Callable[[Account, event.InternalEvent], Awaitable[Any]],
152
+ [Callable[[Account, events.InternalEvent], Awaitable[Any]]],
153
+ Callable[[Account, events.InternalEvent], Awaitable[Any]],
153
154
  ]: ...
154
155
 
155
156
  @overload
@@ -180,16 +181,56 @@ class App(Service):
180
181
  if self.lifecycle_callbacks:
181
182
  await asyncio.gather(*(callback(account, state) for callback in self.lifecycle_callbacks))
182
183
 
183
- async def post(self, event: Event):
184
+ async def post(self, event: Event, conn: BaseNetwork):
184
185
  if not self.event_callbacks:
185
186
  return
186
187
  identity = f"{event.platform}/{event.self_id}"
187
188
  if identity not in self.accounts:
188
- logger.warning(f"Received event for unknown account: {event}")
189
- return
190
- account = self.accounts[identity]
189
+ if event.type == EventType.LOGIN_ADDED:
190
+ if TYPE_CHECKING:
191
+ assert isinstance(event, events.LoginEvent)
192
+ account = Account(
193
+ event.platform,
194
+ event.self_id,
195
+ event.login,
196
+ conn.config,
197
+ )
198
+ logger.info(f"account added: {account}")
199
+ (
200
+ account.connected.set()
201
+ if event.login.status == LoginStatus.ONLINE
202
+ else account.connected.clear()
203
+ )
204
+ self.accounts[identity] = account
205
+ conn.accounts[identity] = account
206
+ await self.account_update(account, LoginStatus.ONLINE)
207
+ await self.account_update(account, LoginStatus.CONNECT)
208
+ else:
209
+ logger.warning(f"Received event for unknown account: {event}")
210
+ return
211
+ else:
212
+ account = self.accounts[identity]
213
+ if event.type == EventType.LOGIN_UPDATED:
214
+ if TYPE_CHECKING:
215
+ assert isinstance(event, events.LoginEvent)
216
+ logger.info(f"account updated: {account}")
217
+ (
218
+ account.connected.set()
219
+ if event.login.status in (LoginStatus.ONLINE, LoginStatus.CONNECT)
220
+ else account.connected.clear()
221
+ )
222
+
191
223
  await asyncio.gather(*(callback(account, event) for callback in self.event_callbacks))
192
224
 
225
+ if event.type == EventType.LOGIN_REMOVED:
226
+ if TYPE_CHECKING:
227
+ assert isinstance(event, events.LoginEvent)
228
+ logger.info(f"account removed: {account}")
229
+ account.connected.clear()
230
+ await self.account_update(account, LoginStatus.DISCONNECT)
231
+ del self.accounts[identity]
232
+ del conn.accounts[identity]
233
+
193
234
  async def launch(self, manager: Launart):
194
235
  for conn in self.connections:
195
236
  manager.add_component(conn)
@@ -217,6 +258,7 @@ class App(Service):
217
258
  ):
218
259
  if manager is None:
219
260
  manager = it(Launart)
261
+ manager.add_component(AiohttpClientService())
220
262
  manager.add_component(self)
221
263
  manager.launch_blocking(loop=loop, stop_signal=stop_signal)
222
264
 
@@ -227,6 +269,7 @@ class App(Service):
227
269
  ):
228
270
  if manager is None:
229
271
  manager = it(Launart)
272
+ manager.add_component(AiohttpClientService())
230
273
  manager.add_component(self)
231
274
  handled_signals: dict[signal.Signals, Any] = {}
232
275
  launch_task = asyncio.create_task(manager.launch(), name="amnesia-launch")
@@ -6,6 +6,8 @@ from typing import TypeVar
6
6
 
7
7
  from yarl import URL
8
8
 
9
+ from satori.model import Login
10
+
9
11
  from .session import Session
10
12
 
11
13
  TS = TypeVar("TS", bound="Session")
@@ -28,9 +30,17 @@ class ApiInfo:
28
30
 
29
31
 
30
32
  class Account:
31
- def __init__(self, platform: str, self_id: str, config: ApiInfo, session_cls: type[Session] = Session):
33
+ def __init__(
34
+ self,
35
+ platform: str,
36
+ self_id: str,
37
+ self_info: Login,
38
+ config: ApiInfo,
39
+ session_cls: type[Session] = Session,
40
+ ):
32
41
  self.platform = platform
33
42
  self.self_id = self_id
43
+ self.self_info = self_info
34
44
  self.config = config
35
45
  self.session = session_cls(self) # type: ignore
36
46
  self.connected = asyncio.Event()
@@ -4,7 +4,20 @@ from typing import Any, Iterable, Protocol, TypeVar, overload
4
4
  from yarl import URL
5
5
 
6
6
  from satori.element import Element
7
- from satori.model import Channel, Event, Guild, Login, Member, MessageObject, PageResult, Role, User
7
+ from satori.model import (
8
+ Channel,
9
+ Direction,
10
+ Event,
11
+ Guild,
12
+ Login,
13
+ Member,
14
+ MessageObject,
15
+ Order,
16
+ PageDequeResult,
17
+ PageResult,
18
+ Role,
19
+ User,
20
+ )
8
21
 
9
22
  from .session import Session
10
23
 
@@ -24,11 +37,14 @@ class ApiInfo(Api):
24
37
  class Account:
25
38
  platform: str
26
39
  self_id: str
40
+ self_info: Login
27
41
  config: Api
28
42
  session: Session
29
43
  connected: asyncio.Event
30
44
 
31
- def __init__(self, platform: str, self_id: str, config: Api, session_cls: type[Session] = Session): ...
45
+ def __init__(
46
+ self, platform: str, self_id: str, self_info: Login, config: Api, session_cls: type[Session] = Session
47
+ ): ...
32
48
  @property
33
49
  def identity(self) -> str: ...
34
50
  @overload
@@ -100,8 +116,13 @@ class Account:
100
116
  content: str,
101
117
  ) -> None: ...
102
118
  async def message_list(
103
- self, *, channel_id: str, next_token: str | None = None
104
- ) -> PageResult[MessageObject]: ...
119
+ self,
120
+ channel_id: str,
121
+ next_token: str | None = None,
122
+ direction: Direction = "before",
123
+ limit: int = 50,
124
+ order: Order = "asc",
125
+ ) -> PageDequeResult[MessageObject]: ...
105
126
  async def channel_get(self, *, channel_id: str) -> Channel: ...
106
127
  async def channel_list(self, *, guild_id: str, next_token: str | None = None) -> PageResult[Channel]: ...
107
128
  async def channel_create(self, *, guild_id: str, data: Channel) -> Channel: ...
@@ -112,6 +133,7 @@ class Account:
112
133
  data: Channel,
113
134
  ) -> None: ...
114
135
  async def channel_delete(self, *, channel_id: str) -> None: ...
136
+ async def channel_mute(self, *, channel_id: str, duration: float = 0) -> None: ...
115
137
  async def user_channel_create(self, *, user_id: str, guild_id: str | None = None) -> Channel: ...
116
138
  async def guild_get(self, *, guild_id: str) -> Guild: ...
117
139
  async def guild_list(self, *, next_token: str | None = None) -> PageResult[Guild]: ...
@@ -121,6 +143,7 @@ class Account:
121
143
  ) -> PageResult[Member]: ...
122
144
  async def guild_member_get(self, *, guild_id: str, user_id: str) -> Member: ...
123
145
  async def guild_member_kick(self, *, guild_id: str, user_id: str, permanent: bool = False) -> None: ...
146
+ async def guild_member_mute(self, *, guild_id: str, user_id: str, duration: float = 0) -> None: ...
124
147
  async def guild_member_approve(self, *, request_id: str, approve: bool, comment: str) -> None: ...
125
148
  async def guild_member_role_set(self, *, guild_id: str, user_id: str, role_id: str) -> None: ...
126
149
  async def guild_member_role_unset(self, *, guild_id: str, user_id: str, role_id: str) -> None: ...
@@ -179,3 +202,4 @@ class Account:
179
202
  action: str,
180
203
  **kwargs,
181
204
  ) -> Any: ...
205
+ async def admin_login_list(self) -> list[Login]: ...
@@ -43,6 +43,6 @@ class BaseNetwork(Generic[TConfig], Service):
43
43
  logger.warning(f"Failed to parse event: {raw}\nCaused by {e!r}")
44
44
  else:
45
45
  self.sequence = event.id
46
- await self.app.post(event)
46
+ await self.app.post(event, self)
47
47
 
48
48
  return asyncio.create_task(event_parse_task(body))
@@ -0,0 +1,32 @@
1
+ import json
2
+
3
+ from aiohttp import ClientResponse
4
+
5
+ from satori.exception import (
6
+ ApiNotImplementedException,
7
+ BadRequestException,
8
+ ForbiddenException,
9
+ MethodNotAllowedException,
10
+ NotFoundException,
11
+ UnauthorizedException,
12
+ )
13
+
14
+
15
+ async def validate_response(resp: ClientResponse):
16
+ if 200 <= resp.status < 300:
17
+ return json.loads(content) if (content := await resp.text()) else {}
18
+ elif resp.status == 400:
19
+ raise BadRequestException(await resp.text())
20
+ elif resp.status == 401:
21
+ raise UnauthorizedException(await resp.text())
22
+ elif resp.status == 403:
23
+ raise ForbiddenException(await resp.text())
24
+ elif resp.status == 404:
25
+ raise NotFoundException(await resp.text())
26
+ elif resp.status == 405:
27
+ raise MethodNotAllowedException(await resp.text())
28
+ elif resp.status == 500:
29
+ raise ApiNotImplementedException(await resp.text())
30
+ else:
31
+ resp.raise_for_status()
32
+ return json.loads(content) if (content := await resp.text()) else {}
@@ -3,15 +3,17 @@ from __future__ import annotations
3
3
  import asyncio
4
4
 
5
5
  from aiohttp import web
6
+ from graia.amnesia.builtins.aiohttp import AiohttpClientService
6
7
  from launart.manager import Launart
7
8
  from launart.utilles import any_completed
8
9
  from loguru import logger
9
10
 
10
11
  from satori.config import WebhookInfo as WebhookInfo
11
- from satori.model import LoginStatus, Opcode
12
+ from satori.model import Login, LoginStatus, Opcode
12
13
 
13
14
  from ..account import Account
14
15
  from .base import BaseNetwork
16
+ from .util import validate_response
15
17
 
16
18
 
17
19
  class WebhookNetwork(BaseNetwork[WebhookInfo]):
@@ -40,12 +42,21 @@ class WebhookNetwork(BaseNetwork[WebhookInfo]):
40
42
  account.connected.set()
41
43
  account.config = self.config
42
44
  else:
43
- account = Account(platform, self_id, self.config)
45
+ assert self.manager
46
+ aio = self.manager.get_component(AiohttpClientService)
47
+ async with aio.session.post(self.config.api_base / "admin/login.list") as resp:
48
+ logins = [Login.parse(i) for i in await validate_response(resp)]
49
+ login = next(
50
+ (i for i in logins if i.self_id == self_id and i.platform == platform),
51
+ Login(LoginStatus.CONNECT, self_id=self_id, platform=platform),
52
+ )
53
+ account = Account(platform, self_id, login, self.config)
44
54
  logger.info(f"account registered: {account}")
45
55
  account.connected.set()
46
56
  self.app.accounts[identity] = account
47
57
  self.accounts[identity] = account
48
58
  await self.app.account_update(account, LoginStatus.ONLINE)
59
+ await self.app.account_update(account, LoginStatus.CONNECT)
49
60
  data = await req.json()
50
61
  op = data["op"]
51
62
  if op != Opcode.EVENT:
@@ -11,7 +11,7 @@ from launart.utilles import any_completed
11
11
  from loguru import logger
12
12
 
13
13
  from satori.config import WebsocketsInfo as WebsocketsInfo
14
- from satori.model import LoginStatus, Opcode
14
+ from satori.model import Login, LoginStatus, Opcode
15
15
 
16
16
  from ..account import Account
17
17
  from .base import BaseNetwork
@@ -99,7 +99,7 @@ class WsNetwork(BaseNetwork[WebsocketsInfo]):
99
99
  account.connected.clear()
100
100
  account.config = self.config
101
101
  else:
102
- account = Account(platform, self_id, self.config)
102
+ account = Account(platform, self_id, Login.parse(login), self.config)
103
103
  logger.info(f"account registered: {account}")
104
104
  (
105
105
  account.connected.set()
@@ -1,21 +1,28 @@
1
1
  from __future__ import annotations
2
2
 
3
- import json
4
3
  from typing import TYPE_CHECKING, Any, Iterable, List, cast
5
4
 
6
- import aiohttp
5
+ from graia.amnesia.builtins.aiohttp import AiohttpClientService
6
+ from launart import Launart
7
7
 
8
8
  from satori.const import Api
9
9
  from satori.element import Element
10
- from satori.exception import (
11
- ApiNotImplementedException,
12
- BadRequestException,
13
- ForbiddenException,
14
- MethodNotAllowedException,
15
- NotFoundException,
16
- UnauthorizedException,
10
+ from satori.model import (
11
+ Channel,
12
+ Direction,
13
+ Event,
14
+ Guild,
15
+ Login,
16
+ Member,
17
+ MessageObject,
18
+ Order,
19
+ PageDequeResult,
20
+ PageResult,
21
+ Role,
22
+ User,
17
23
  )
18
- from satori.model import Channel, Event, Guild, Login, Member, MessageObject, PageResult, Role, User
24
+
25
+ from .network.util import validate_response
19
26
 
20
27
  if TYPE_CHECKING:
21
28
  from .account import Account
@@ -33,29 +40,13 @@ class Session:
33
40
  "X-Platform": self.account.platform,
34
41
  "X-Self-ID": self.account.self_id,
35
42
  }
36
- async with aiohttp.ClientSession() as session:
37
- async with session.post(
38
- endpoint,
39
- json=params or {},
40
- headers=headers,
41
- ) as resp:
42
- if 200 <= resp.status < 300:
43
- return json.loads(content) if (content := await resp.text()) else {}
44
- elif resp.status == 400:
45
- raise BadRequestException(await resp.text())
46
- elif resp.status == 401:
47
- raise UnauthorizedException(await resp.text())
48
- elif resp.status == 403:
49
- raise ForbiddenException(await resp.text())
50
- elif resp.status == 404:
51
- raise NotFoundException(await resp.text())
52
- elif resp.status == 405:
53
- raise MethodNotAllowedException(await resp.text())
54
- elif resp.status == 500:
55
- raise ApiNotImplementedException(await resp.text())
56
- else:
57
- resp.raise_for_status()
58
- return json.loads(content) if (content := await resp.text()) else {}
43
+ aio = Launart.current().get_component(AiohttpClientService)
44
+ async with aio.session.post(
45
+ endpoint,
46
+ json=params or {},
47
+ headers=headers,
48
+ ) as resp:
49
+ return await validate_response(resp)
59
50
 
60
51
  async def send(
61
52
  self,
@@ -153,12 +144,27 @@ class Session:
153
144
  {"channel_id": channel_id, "message_id": message_id, "content": content},
154
145
  )
155
146
 
156
- async def message_list(self, channel_id: str, next_token: str | None = None) -> PageResult[MessageObject]:
147
+ async def message_list(
148
+ self,
149
+ channel_id: str,
150
+ next_token: str | None = None,
151
+ direction: Direction = "before",
152
+ limit: int = 50,
153
+ order: Order = "asc",
154
+ ) -> PageDequeResult[MessageObject]:
155
+ if not next_token and direction != "before":
156
+ raise ValueError("Invalid direction")
157
157
  res = await self.call_api(
158
158
  Api.MESSAGE_LIST,
159
- {"channel_id": channel_id, "next": next_token},
159
+ {
160
+ "channel_id": channel_id,
161
+ "next": next_token,
162
+ "direction": direction,
163
+ "limit": limit,
164
+ "order": order,
165
+ },
160
166
  )
161
- return PageResult.parse(res, MessageObject.parse)
167
+ return PageDequeResult.parse(res, MessageObject.parse)
162
168
 
163
169
  async def channel_get(self, channel_id: str) -> Channel:
164
170
  res = await self.call_api(
@@ -197,6 +203,12 @@ class Session:
197
203
  {"channel_id": channel_id},
198
204
  )
199
205
 
206
+ async def channel_mute(self, channel_id: str, duration: float = 0) -> None:
207
+ await self.call_api(
208
+ Api.CHANNEL_MUTE,
209
+ {"channel_id": channel_id, "duration": duration},
210
+ )
211
+
200
212
  async def user_channel_create(self, user_id: str, guild_id: str | None = None) -> Channel:
201
213
  data = {"user_id": user_id}
202
214
  if guild_id is not None:
@@ -247,6 +259,12 @@ class Session:
247
259
  {"guild_id": guild_id, "user_id": user_id, "permanent": permanent},
248
260
  )
249
261
 
262
+ async def guild_member_mute(self, guild_id: str, user_id: str, duration: float = 0) -> None:
263
+ await self.call_api(
264
+ Api.GUILD_MEMBER_MUTE,
265
+ {"guild_id": guild_id, "user_id": user_id, "duration": duration},
266
+ )
267
+
250
268
  async def guild_member_approve(self, request_id: str, approve: bool, comment: str) -> None:
251
269
  await self.call_api(
252
270
  Api.GUILD_MEMBER_APPROVE,
@@ -382,3 +400,7 @@ class Session:
382
400
  **kwargs,
383
401
  ) -> Any:
384
402
  return await self.call_api(f"internal/{action}", kwargs)
403
+
404
+ async def admin_login_list(self) -> list[Login]:
405
+ res = await self.call_api("admin/login.list")
406
+ return [Login.parse(i) for i in res]
@@ -18,13 +18,17 @@ class Element:
18
18
 
19
19
  __names__: ClassVar[Tuple[str, ...]]
20
20
 
21
+ @property
22
+ def children(self) -> List["Element"]:
23
+ return self._children
24
+
21
25
  @property
22
26
  def tag(self) -> str:
23
27
  return self.__class__.__name__.lower()
24
28
 
25
29
  @classmethod
26
30
  def unpack(cls, attrs: Dict[str, Any]):
27
- obj = cls(**{k: v for k, v in attrs.items() if k in cls.__names__})
31
+ obj = cls(**{k: v for k, v in attrs.items() if k in cls.__names__}) # type: ignore
28
32
  obj._attrs.update({k: v for k, v in attrs.items() if k not in cls.__names__})
29
33
  return obj
30
34
 
@@ -132,7 +136,7 @@ class Sharp(Element):
132
136
  name: Optional[str] = None
133
137
 
134
138
 
135
- @dataclass
139
+ @dataclass(repr=False)
136
140
  class Link(Element):
137
141
  """<a> 元素用于显示一个链接。"""
138
142
 
@@ -163,7 +167,7 @@ class Resource(Element):
163
167
  title: Optional[str] = None
164
168
  extra: InitVar[Optional[Dict[str, Any]]] = None
165
169
  cache: Optional[bool] = None
166
- timeout: Optional[str] = None
170
+ timeout: Optional[int] = None
167
171
 
168
172
  __names__ = ("src", "title")
169
173
 
@@ -178,7 +182,7 @@ class Resource(Element):
178
182
  poster: Optional[str] = None,
179
183
  extra: Optional[Dict[str, Any]] = None,
180
184
  cache: Optional[bool] = None,
181
- timeout: Optional[str] = None,
185
+ timeout: Optional[int] = None,
182
186
  **kwargs,
183
187
  ):
184
188
  data: Dict[str, Any] = {"extra": extra}
@@ -537,7 +541,7 @@ def transform(elements: List[RawElement]) -> List[Element]:
537
541
  tag = elem.tag()
538
542
  if tag in ELEMENT_TYPE_MAP:
539
543
  seg_cls = ELEMENT_TYPE_MAP[tag]
540
- msg.append(seg_cls.unpack(elem.attrs))
544
+ msg.append(seg_cls.unpack(elem.attrs)(*transform(elem.children)))
541
545
  elif tag in ("a", "link"):
542
546
  link = Link.unpack(elem.attrs)
543
547
  if elem.children:
@@ -547,6 +551,7 @@ def transform(elements: List[RawElement]) -> List[Element]:
547
551
  button = Button.unpack(elem.attrs)
548
552
  if elem.children:
549
553
  button(*transform(elem.children))
554
+ msg.append(button)
550
555
  elif tag in STYLE_TYPE_MAP:
551
556
  seg_cls = STYLE_TYPE_MAP[tag]
552
557
  msg.append(seg_cls.unpack(elem.attrs)(*transform(elem.children)))