satori-python 0.14.4__tar.gz → 0.15.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {satori_python-0.14.4 → satori_python-0.15.0}/PKG-INFO +2 -1
  2. {satori_python-0.14.4 → satori_python-0.15.0}/README.md +1 -0
  3. {satori_python-0.14.4 → satori_python-0.15.0}/pyproject.toml +2 -1
  4. {satori_python-0.14.4 → satori_python-0.15.0}/src/satori/__init__.py +1 -1
  5. {satori_python-0.14.4 → satori_python-0.15.0}/src/satori/client/__init__.py +23 -4
  6. {satori_python-0.14.4 → satori_python-0.15.0}/src/satori/client/account.py +2 -2
  7. {satori_python-0.14.4 → satori_python-0.15.0}/src/satori/client/account.pyi +10 -9
  8. {satori_python-0.14.4 → satori_python-0.15.0}/src/satori/client/network/base.py +9 -1
  9. {satori_python-0.14.4 → satori_python-0.15.0}/src/satori/client/network/webhook.py +15 -6
  10. {satori_python-0.14.4 → satori_python-0.15.0}/src/satori/client/network/websocket.py +10 -9
  11. {satori_python-0.14.4 → satori_python-0.15.0}/src/satori/client/protocol.py +22 -17
  12. {satori_python-0.14.4 → satori_python-0.15.0}/src/satori/element.py +1 -1
  13. {satori_python-0.14.4 → satori_python-0.15.0}/src/satori/model.py +98 -8
  14. {satori_python-0.14.4 → satori_python-0.15.0}/src/satori/server/__init__.py +44 -24
  15. {satori_python-0.14.4 → satori_python-0.15.0}/src/satori/server/adapter.py +4 -3
  16. {satori_python-0.14.4 → satori_python-0.15.0}/src/satori/server/model.py +3 -6
  17. {satori_python-0.14.4 → satori_python-0.15.0}/src/satori/server/route.py +2 -2
  18. satori_python-0.14.4/src/satori/server/deque.py → satori_python-0.15.0/src/satori/server/utils.py +5 -0
  19. {satori_python-0.14.4 → satori_python-0.15.0}/LICENSE +0 -0
  20. {satori_python-0.14.4 → satori_python-0.15.0}/src/satori/client/network/__init__.py +0 -0
  21. {satori_python-0.14.4 → satori_python-0.15.0}/src/satori/client/network/util.py +0 -0
  22. {satori_python-0.14.4 → satori_python-0.15.0}/src/satori/config.py +0 -0
  23. {satori_python-0.14.4 → satori_python-0.15.0}/src/satori/const.py +0 -0
  24. {satori_python-0.14.4 → satori_python-0.15.0}/src/satori/event.py +0 -0
  25. {satori_python-0.14.4 → satori_python-0.15.0}/src/satori/exception.py +0 -0
  26. {satori_python-0.14.4 → satori_python-0.15.0}/src/satori/parser.py +0 -0
  27. {satori_python-0.14.4 → satori_python-0.15.0}/src/satori/server/conection.py +0 -0
  28. {satori_python-0.14.4 → satori_python-0.15.0}/src/satori/server/formdata.py +0 -0
  29. {satori_python-0.14.4 → satori_python-0.15.0}/src/satori/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: satori-python
3
- Version: 0.14.4
3
+ Version: 0.15.0
4
4
  Summary: Satori Protocol SDK for python
5
5
  Home-page: https://github.com/RF-Tar-Railt/satori-python
6
6
  Author-Email: RF-Tar-Railt <rf_tar_railt@qq.com>
@@ -46,6 +46,7 @@ Description-Content-Type: text/markdown
46
46
  目前提供了 `satori` 协议实现的有:
47
47
 
48
48
  - [Chronocat](https://chronocat.vercel.app)
49
+ - [nekobox](https://github.com/wyapx/nekobox)
49
50
  - Koishi (搭配 `@koishijs/plugin-server`)
50
51
 
51
52
  ### 使用该 SDK 的框架
@@ -16,6 +16,7 @@
16
16
  目前提供了 `satori` 协议实现的有:
17
17
 
18
18
  - [Chronocat](https://chronocat.vercel.app)
19
+ - [nekobox](https://github.com/wyapx/nekobox)
19
20
  - Koishi (搭配 `@koishijs/plugin-server`)
20
21
 
21
22
  ### 使用该 SDK 的框架
@@ -29,7 +29,7 @@ classifiers = [
29
29
  "Programming Language :: Python :: 3.12",
30
30
  "Operating System :: OS Independent",
31
31
  ]
32
- version = "0.14.4"
32
+ version = "0.15.0"
33
33
 
34
34
  [project.license]
35
35
  text = "MIT"
@@ -41,6 +41,7 @@ repository = "https://github.com/RF-Tar-Railt/satori-python"
41
41
  [build-system]
42
42
  requires = [
43
43
  "mina-build<0.6,>=0.5.1",
44
+ "pdm-backend<2.4.0",
44
45
  ]
45
46
  build-backend = "mina.backend"
46
47
 
@@ -43,4 +43,4 @@ from .model import Role as Role
43
43
  from .model import Upload as Upload
44
44
  from .model import User as User
45
45
 
46
- __version__ = "0.14.4"
46
+ __version__ = "0.15.0"
@@ -185,14 +185,14 @@ class App(Service):
185
185
  async def post(self, event: Event, conn: BaseNetwork):
186
186
  if not self.event_callbacks:
187
187
  return
188
- identity = f"{event.platform}/{event.self_id}"
188
+ identity = f"{event.platform_}/{event.self_id_}"
189
189
  if identity not in self.accounts:
190
190
  if event.type == EventType.LOGIN_ADDED:
191
191
  if TYPE_CHECKING:
192
192
  assert isinstance(event, events.LoginEvent)
193
193
  account = Account(
194
- event.platform,
195
- event.self_id,
194
+ event.platform_,
195
+ event.self_id_,
196
196
  event.login,
197
197
  conn.config,
198
198
  )
@@ -204,8 +204,27 @@ class App(Service):
204
204
  )
205
205
  self.accounts[identity] = account
206
206
  conn.accounts[identity] = account
207
- await self.account_update(account, LoginStatus.ONLINE)
208
207
  await self.account_update(account, LoginStatus.CONNECT)
208
+ await self.account_update(account, LoginStatus.ONLINE)
209
+ elif event.type == EventType.LOGIN_UPDATED:
210
+ if TYPE_CHECKING:
211
+ assert isinstance(event, events.LoginEvent)
212
+ if event.login.status == LoginStatus.ONLINE:
213
+ account = Account(
214
+ event.platform_,
215
+ event.self_id_,
216
+ event.login,
217
+ conn.config,
218
+ )
219
+ logger.info(f"account added: {account}")
220
+ account.connected.set()
221
+ self.accounts[identity] = account
222
+ conn.accounts[identity] = account
223
+ await self.account_update(account, LoginStatus.CONNECT)
224
+ await self.account_update(account, LoginStatus.ONLINE)
225
+ else:
226
+ logger.warning(f"Received event for unknown account: {event}")
227
+ return
209
228
  else:
210
229
  logger.warning(f"Received event for unknown account: {event}")
211
230
  return
@@ -6,7 +6,7 @@ from typing import Generic, TypeVar
6
6
 
7
7
  from yarl import URL
8
8
 
9
- from satori.model import Login
9
+ from satori.model import LoginType
10
10
 
11
11
  from .protocol import ApiProtocol
12
12
 
@@ -35,7 +35,7 @@ class Account(Generic[TP]):
35
35
  self,
36
36
  platform: str,
37
37
  self_id: str,
38
- self_info: Login,
38
+ self_info: LoginType,
39
39
  config: ApiInfo,
40
40
  protocol_cls: type[TP] = ApiProtocol,
41
41
  ):
@@ -10,9 +10,10 @@ from satori.model import (
10
10
  Direction,
11
11
  Event,
12
12
  Guild,
13
- Login,
13
+ LoginType,
14
14
  Member,
15
15
  MessageObject,
16
+ MessageReceipt,
16
17
  Order,
17
18
  PageDequeResult,
18
19
  PageResult,
@@ -40,7 +41,7 @@ class ApiInfo(Api):
40
41
  class Account(Generic[TP]):
41
42
  platform: str
42
43
  self_id: str
43
- self_info: Login
44
+ self_info: LoginType
44
45
  config: Api
45
46
  protocol: TP
46
47
  connected: asyncio.Event
@@ -49,7 +50,7 @@ class Account(Generic[TP]):
49
50
  self,
50
51
  platform: str,
51
52
  self_id: str,
52
- self_info: Login,
53
+ self_info: LoginType,
53
54
  config: Api,
54
55
  protocol_cls: type[TP] = ApiProtocol,
55
56
  ): ...
@@ -71,7 +72,7 @@ class Account(Generic[TP]):
71
72
  - 链接开头出现在 self_info.proxy_urls 中的某一项
72
73
  """
73
74
 
74
- async def send(self, event: Event, message: str | Iterable[str | Element]) -> list[MessageObject]:
75
+ async def send(self, event: Event, message: str | Iterable[str | Element]) -> list[MessageReceipt]:
75
76
  """发送消息。返回一个 `MessageObject` 对象构成的数组。
76
77
 
77
78
  Args:
@@ -87,7 +88,7 @@ class Account(Generic[TP]):
87
88
 
88
89
  async def send_message(
89
90
  self, channel: str | Channel, message: str | Iterable[str | Element]
90
- ) -> list[MessageObject]:
91
+ ) -> list[MessageReceipt]:
91
92
  """发送消息。返回一个 `MessageObject` 对象构成的数组。
92
93
 
93
94
  Args:
@@ -100,7 +101,7 @@ class Account(Generic[TP]):
100
101
 
101
102
  async def send_private_message(
102
103
  self, user: str | User, message: str | Iterable[str | Element]
103
- ) -> list[MessageObject]:
104
+ ) -> list[MessageReceipt]:
104
105
  """发送私聊消息。返回一个 `MessageObject` 对象构成的数组。
105
106
 
106
107
  Args:
@@ -125,7 +126,7 @@ class Account(Generic[TP]):
125
126
  None: 该方法无返回值
126
127
  """
127
128
 
128
- async def message_create(self, channel_id: str, content: str) -> list[MessageObject]:
129
+ async def message_create(self, channel_id: str, content: str) -> list[MessageReceipt]:
129
130
  """发送消息。返回一个 `MessageObject` 对象构成的数组。
130
131
 
131
132
  Args:
@@ -495,7 +496,7 @@ class Account(Generic[TP]):
495
496
  PageResult[User]: `User` 的分页列表
496
497
  """
497
498
 
498
- async def login_get(self) -> Login:
499
+ async def login_get(self) -> LoginType:
499
500
  """获取当前登录信息。返回一个 `Login` 对象。
500
501
 
501
502
  Returns:
@@ -542,7 +543,7 @@ class Account(Generic[TP]):
542
543
  **kwargs: 参数
543
544
  """
544
545
 
545
- async def admin_login_list(self) -> list[Login]:
546
+ async def admin_login_list(self) -> list[LoginType]:
546
547
  """获取登录信息列表。返回一个 `Login` 对象构成的数组。
547
548
 
548
549
  Returns:
@@ -40,8 +40,16 @@ class BaseNetwork(Generic[TConfig], Service):
40
40
  try:
41
41
  event = Event.parse(raw)
42
42
  except Exception as e:
43
- logger.warning(f"Failed to parse event: {raw}\nCaused by {e!r}")
43
+ if (
44
+ "self_id" in raw
45
+ or ("login" in raw and "self_id" in raw["login"])
46
+ or ("login" in raw and "user" in raw["login"] and "self_id" in raw["login"]["user"])
47
+ ):
48
+ logger.warning(f"Failed to parse event: {raw}\nCaused by {e!r}")
49
+ else:
50
+ logger.trace(f"Failed to parse event: {raw}\nCaused by {e!r}")
44
51
  else:
52
+ logger.trace(f"Received event: {event}")
45
53
  self.sequence = event.id
46
54
  await self.app.post(event, self)
47
55
 
@@ -9,7 +9,7 @@ from launart.utilles import any_completed
9
9
  from loguru import logger
10
10
 
11
11
  from satori.config import WebhookInfo as WebhookInfo
12
- from satori.model import Login, LoginStatus, Opcode
12
+ from satori.model import Login, LoginPreview, LoginStatus, Opcode
13
13
 
14
14
  from ..account import Account
15
15
  from .base import BaseNetwork
@@ -33,8 +33,14 @@ class WebhookNetwork(BaseNetwork[WebhookInfo]):
33
33
  token = auth.split(" ", 1)[1]
34
34
  if self.config.token and self.config.token != token:
35
35
  return web.Response(status=401)
36
- platform = header["X-Platform"]
37
- self_id = header["X-Self-ID"]
36
+ if "X-Platform" in header and "X-Self-ID" in header:
37
+ platform = header["X-Platform"]
38
+ self_id = header["X-Self-ID"]
39
+ elif "Satori-Platform" in header and "Satori-Login-ID" in header:
40
+ platform = header["Satori-Platform"]
41
+ self_id = header["Satori-Login-ID"]
42
+ else:
43
+ return web.Response(status=400)
38
44
  identity = f"{platform}/{self_id}"
39
45
  if identity in self.app.accounts:
40
46
  account = self.app.accounts[identity]
@@ -45,9 +51,12 @@ class WebhookNetwork(BaseNetwork[WebhookInfo]):
45
51
  assert self.manager
46
52
  aio = self.manager.get_component(AiohttpClientService)
47
53
  async with aio.session.post(self.config.api_base / "admin/login.list") as resp:
48
- logins = [Login.parse(i) for i in await validate_response(resp)]
54
+ logins = [
55
+ LoginPreview.parse(i) if "user" in i else Login.parse(i)
56
+ for i in await validate_response(resp)
57
+ ]
49
58
  login = next(
50
- (i for i in logins if i.self_id == self_id and i.platform == platform),
59
+ (i for i in logins if i.id == self_id and i.platform == platform),
51
60
  Login(LoginStatus.CONNECT, self_id=self_id, platform=platform),
52
61
  )
53
62
  account = Account(platform, self_id, login, self.config)
@@ -55,8 +64,8 @@ class WebhookNetwork(BaseNetwork[WebhookInfo]):
55
64
  account.connected.set()
56
65
  self.app.accounts[identity] = account
57
66
  self.accounts[identity] = account
58
- await self.app.account_update(account, LoginStatus.ONLINE)
59
67
  await self.app.account_update(account, LoginStatus.CONNECT)
68
+ await self.app.account_update(account, LoginStatus.ONLINE)
60
69
  data = await req.json()
61
70
  op = data["op"]
62
71
  if op != Opcode.EVENT:
@@ -11,7 +11,7 @@ from launart.utilles import any_completed
11
11
  from loguru import logger
12
12
 
13
13
  from satori.config import WebsocketsInfo as WebsocketsInfo
14
- from satori.model import Login, LoginStatus, Opcode
14
+ from satori.model import Login, LoginPreview, LoginStatus, Opcode
15
15
 
16
16
  from ..account import Account
17
17
  from .base import BaseNetwork
@@ -86,31 +86,32 @@ class WsNetwork(BaseNetwork[WebsocketsInfo]):
86
86
  logger.error(f"Received unexpected payload: {data}")
87
87
  return False
88
88
  for login in data["body"]["logins"]:
89
- if "self_id" not in login:
89
+ obj = LoginPreview.parse(login) if "user" in login else Login.parse(login)
90
+ if obj.id is None:
90
91
  continue
91
- platform = login.get("platform", "satori")
92
- self_id = login["self_id"]
92
+ platform = obj.platform or "satori"
93
+ self_id = obj.id
93
94
  identity = f"{platform}/{self_id}"
94
95
  if identity in self.app.accounts:
95
96
  account = self.app.accounts[identity]
96
97
  self.accounts[identity] = account
97
- if login["status"] == LoginStatus.ONLINE:
98
+ if not obj.status or obj.status == LoginStatus.ONLINE:
98
99
  account.connected.set()
99
100
  else:
100
101
  account.connected.clear()
101
102
  account.config = self.config
102
103
  else:
103
- account = Account(platform, self_id, Login.parse(login), self.config)
104
+ account = Account(platform, self_id, obj, self.config)
104
105
  logger.info(f"account registered: {account}")
105
106
  (
106
107
  account.connected.set()
107
- if login["status"] == LoginStatus.ONLINE
108
+ if not obj.status or obj.status == LoginStatus.ONLINE
108
109
  else account.connected.clear()
109
110
  )
110
111
  self.app.accounts[identity] = account
111
112
  self.accounts[identity] = account
112
- await self.app.account_update(account, LoginStatus.ONLINE)
113
- await self.app.account_update(account, LoginStatus.CONNECT)
113
+ await self.app.account_update(account, LoginStatus.CONNECT)
114
+ await self.app.account_update(account, LoginStatus.ONLINE)
114
115
  if not self.accounts:
115
116
  logger.warning(f"No account available for {self.config}")
116
117
  return False
@@ -15,8 +15,11 @@ from satori.model import (
15
15
  Event,
16
16
  Guild,
17
17
  Login,
18
+ LoginPreview,
19
+ LoginType,
18
20
  Member,
19
21
  MessageObject,
22
+ MessageReceipt,
20
23
  Order,
21
24
  PageDequeResult,
22
25
  PageResult,
@@ -50,6 +53,8 @@ class ApiProtocol:
50
53
  "Authorization": f"Bearer {self.account.config.token or ''}",
51
54
  "X-Platform": self.account.platform,
52
55
  "X-Self-ID": self.account.self_id,
56
+ "Satori-Platform": self.account.platform,
57
+ "Satori-Login-ID": self.account.self_id,
53
58
  }
54
59
  aio = Launart.current().get_component(AiohttpClientService)
55
60
  if multipart:
@@ -75,15 +80,15 @@ class ApiProtocol:
75
80
  ) as resp:
76
81
  return await validate_response(resp)
77
82
 
78
- async def send(self, event: Event, message: str | Iterable[str | Element]) -> list[MessageObject]:
79
- """发送消息。返回一个 `MessageObject` 对象构成的数组。
83
+ async def send(self, event: Event, message: str | Iterable[str | Element]) -> list[MessageReceipt]:
84
+ """发送消息。返回一个 `MessageReceipt` 对象构成的数组。
80
85
 
81
86
  Args:
82
87
  event (Event): 当前事件(上下文)
83
88
  message (str | Iterable[str | Element]): 要发送的消息
84
89
 
85
90
  Returns:
86
- list[MessageObject]: `MessageObject` 对象构成的数组
91
+ list[MessageReceipt]: `MessageReceipt` 对象构成的数组
87
92
 
88
93
  Raises:
89
94
  RuntimeError: 传入的事件缺少 `channel` 对象
@@ -94,15 +99,15 @@ class ApiProtocol:
94
99
 
95
100
  async def send_message(
96
101
  self, channel: str | Channel, message: str | Iterable[str | Element]
97
- ) -> list[MessageObject]:
98
- """发送消息。返回一个 `MessageObject` 对象构成的数组。
102
+ ) -> list[MessageReceipt]:
103
+ """发送消息。返回一个 `MessageReceipt` 对象构成的数组。
99
104
 
100
105
  Args:
101
106
  channel (str | Channel): 要发送的频道 ID
102
107
  message (str | Iterable[str | Element]): 要发送的消息
103
108
 
104
109
  Returns:
105
- list[MessageObject]: `MessageObject` 对象构成的数组
110
+ list[MessageReceipt]: `MessageReceipt` 对象构成的数组
106
111
  """
107
112
  channel_id = channel.id if isinstance(channel, Channel) else channel
108
113
  msg = message if isinstance(message, str) else "".join(str(i) for i in message)
@@ -110,15 +115,15 @@ class ApiProtocol:
110
115
 
111
116
  async def send_private_message(
112
117
  self, user: str | User, message: str | Iterable[str | Element]
113
- ) -> list[MessageObject]:
114
- """发送私聊消息。返回一个 `MessageObject` 对象构成的数组。
118
+ ) -> list[MessageReceipt]:
119
+ """发送私聊消息。返回一个 `MessageReceipt` 对象构成的数组。
115
120
 
116
121
  Args:
117
122
  user (str | User): 要发送的用户 ID
118
123
  message (str | Iterable[str | Element]): 要发送的消息
119
124
 
120
125
  Returns:
121
- list[MessageObject]: `MessageObject` 对象构成的数组
126
+ list[MessageReceipt]: `MessageReceipt` 对象构成的数组
122
127
  """
123
128
  user_id = user.id if isinstance(user, User) else user
124
129
  channel = await self.user_channel_create(user_id=user_id)
@@ -145,22 +150,22 @@ class ApiProtocol:
145
150
  content=msg,
146
151
  )
147
152
 
148
- async def message_create(self, channel_id: str, content: str) -> list[MessageObject]:
149
- """发送消息。返回一个 `MessageObject` 对象构成的数组。
153
+ async def message_create(self, channel_id: str, content: str) -> list[MessageReceipt]:
154
+ """发送消息。返回一个 `MessageReceipt` 对象构成的数组。
150
155
 
151
156
  Args:
152
157
  channel_id (str): 频道 ID
153
158
  content (str): 消息内容
154
159
 
155
160
  Returns:
156
- list[MessageObject]: `MessageObject` 对象构成的数组
161
+ list[MessageReceipt]: `MessageReceipt` 对象构成的数组
157
162
  """
158
163
  res = await self.call_api(
159
164
  Api.MESSAGE_CREATE,
160
165
  {"channel_id": channel_id, "content": content},
161
166
  )
162
167
  res = cast("list[dict]", res)
163
- return [MessageObject.parse(i) for i in res]
168
+ return [MessageReceipt.parse(i) for i in res]
164
169
 
165
170
  async def message_get(self, channel_id: str, message_id: str) -> MessageObject:
166
171
  """获取特定消息。返回一个 `MessageObject` 对象。
@@ -672,14 +677,14 @@ class ApiProtocol:
672
677
  )
673
678
  return PageResult.parse(res, User.parse)
674
679
 
675
- async def login_get(self) -> Login:
680
+ async def login_get(self) -> LoginType:
676
681
  """获取当前登录信息。返回一个 `Login` 对象。
677
682
 
678
683
  Returns:
679
684
  Login: `Login` 对象
680
685
  """
681
686
  res = await self.call_api(Api.LOGIN_GET, {})
682
- return Login.parse(res)
687
+ return LoginPreview.parse(res) if "user" in res else Login.parse(res)
683
688
 
684
689
  async def user_get(self, user_id: str) -> User:
685
690
  """获取用户信息。返回一个 `User` 对象。
@@ -730,14 +735,14 @@ class ApiProtocol:
730
735
  """
731
736
  return await self.call_api(f"internal/{action}", kwargs)
732
737
 
733
- async def admin_login_list(self) -> list[Login]:
738
+ async def admin_login_list(self) -> list[LoginType]:
734
739
  """获取登录信息列表。返回一个 `Login` 对象构成的数组。
735
740
 
736
741
  Returns:
737
742
  list[Login]: `Login` 对象构成的数组
738
743
  """
739
744
  res = await self.call_api("admin/login.list")
740
- return [Login.parse(i) for i in res]
745
+ return [LoginPreview.parse(i) if "user" in i else Login.parse(i) for i in res]
741
746
 
742
747
  @overload
743
748
  async def upload_create(self, *uploads: Upload) -> list[str]: ...
@@ -44,7 +44,7 @@ class Element:
44
44
  raise TypeError(f.name, attr)
45
45
  setattr(self, f.name, attr.lower() == "true")
46
46
  else:
47
- setattr(self, f.name, _type(attr))
47
+ setattr(self, f.name, _type(attr)) # type: ignore
48
48
  self._attrs[f.name] = getattr(self, f.name)
49
49
  self._attrs = {k: v for k, v in self._attrs.items() if v is not None}
50
50
 
@@ -160,6 +160,39 @@ class Login(ModelBase):
160
160
  res["platform"] = self.platform
161
161
  return res
162
162
 
163
+ @property
164
+ def id(self) -> Optional[str]:
165
+ return self.self_id or (self.user.id if self.user else None)
166
+
167
+
168
+ @dataclass
169
+ class LoginPreview(ModelBase):
170
+ user: User
171
+ platform: str
172
+ status: Optional[LoginStatus] = None
173
+ features: list[str] = field(default_factory=list)
174
+ proxy_urls: list[str] = field(default_factory=list)
175
+
176
+ __converter__ = {"user": User.parse, "status": LoginStatus}
177
+
178
+ def dump(self):
179
+ res: dict[str, Any] = {
180
+ "user": self.user.dump(),
181
+ "platform": self.platform,
182
+ "features": self.features,
183
+ "proxy_urls": self.proxy_urls,
184
+ }
185
+ if self.status:
186
+ res["status"] = self.status.value
187
+ return res
188
+
189
+ @property
190
+ def id(self) -> str:
191
+ return self.user.id
192
+
193
+
194
+ LoginType = Union[Login, LoginPreview]
195
+
163
196
 
164
197
  @dataclass
165
198
  class ArgvInteraction(ModelBase):
@@ -260,18 +293,49 @@ class MessageObject(ModelBase):
260
293
  return res
261
294
 
262
295
 
296
+ @dataclass
297
+ class MessageReceipt(ModelBase):
298
+ id: str
299
+ content: Optional[str] = None
300
+
301
+ @classmethod
302
+ def from_elements(
303
+ cls,
304
+ id: str,
305
+ content: Optional[list[Element]] = None,
306
+ ):
307
+ return cls(id, "".join(str(i) for i in content) if content else None)
308
+
309
+ @property
310
+ def message(self) -> Optional[list[Element]]:
311
+ return transform(parse(self.content)) if self.content else None
312
+
313
+ @classmethod
314
+ def parse(cls, raw: dict):
315
+ if "elements" in raw and "content" not in raw:
316
+ content = [RawElement(*item.values()) for item in raw["elements"]]
317
+ raw["content"] = "".join(str(i) for i in content)
318
+ return super().parse(raw)
319
+
320
+ def dump(self):
321
+ res = {"id": self.id}
322
+ if self.content:
323
+ res["content"] = self.content
324
+ return res
325
+
326
+
263
327
  @dataclass
264
328
  class Event(ModelBase):
265
329
  id: int
266
330
  type: str
267
- platform: str
268
- self_id: str
269
331
  timestamp: datetime
332
+ platform: Optional[str] = None
333
+ self_id: Optional[str] = None
270
334
  argv: Optional[ArgvInteraction] = None
271
335
  button: Optional[ButtonInteraction] = None
272
336
  channel: Optional[Channel] = None
273
337
  guild: Optional[Guild] = None
274
- login: Optional[Login] = None
338
+ login: Optional[LoginType] = None
275
339
  member: Optional[Member] = None
276
340
  message: Optional[MessageObject] = None
277
341
  operator: Optional[User] = None
@@ -287,7 +351,13 @@ class Event(ModelBase):
287
351
  "button": ButtonInteraction.parse,
288
352
  "channel": Channel.parse,
289
353
  "guild": Guild.parse,
290
- "login": Login.parse,
354
+ "login": lambda raw: (
355
+ LoginPreview.parse(
356
+ raw if raw["user"] else {**raw, "user": {"id": raw["self_id"]}} if "self_id" in raw else raw
357
+ )
358
+ if "user" in raw
359
+ else Login.parse(raw)
360
+ ),
291
361
  "member": Member.parse,
292
362
  "message": MessageObject.parse,
293
363
  "operator": User.parse,
@@ -295,12 +365,32 @@ class Event(ModelBase):
295
365
  "user": User.parse,
296
366
  }
297
367
 
368
+ @property
369
+ def platform_(self):
370
+ if self.platform:
371
+ return self.platform
372
+ if self.login and self.login.platform:
373
+ return self.login.platform
374
+ raise ValueError("platform not found")
375
+
376
+ @property
377
+ def self_id_(self):
378
+ if self.self_id:
379
+ return self.self_id
380
+ if self.login and self.login.id:
381
+ return self.login.id
382
+ raise ValueError("self_id not found")
383
+
384
+ def __post_init__(self):
385
+ _ = self.platform_
386
+ _ = self.self_id_
387
+
298
388
  def dump(self):
299
389
  res = {
300
390
  "id": self.id,
301
391
  "type": self.type,
302
- "platform": self.platform,
303
- "self_id": self.self_id,
392
+ "platform": self.platform_,
393
+ "self_id": self.self_id_,
304
394
  "timestamp": int(self.timestamp.timestamp() * 1000),
305
395
  }
306
396
  if self.argv:
@@ -341,7 +431,7 @@ class PageResult(ModelBase, Generic[T]):
341
431
  @classmethod
342
432
  def parse(cls, raw: dict, parser: Optional[Callable[[dict], T]] = None) -> "PageResult[T]":
343
433
  data = [(parser or ModelBase.parse)(item) for item in raw["data"]]
344
- return cls(data, raw.get("next"))
434
+ return cls(data, raw.get("next")) # type: ignore
345
435
 
346
436
  def dump(self):
347
437
  res: dict = {"data": [item.dump() for item in self.data]}
@@ -357,7 +447,7 @@ class PageDequeResult(PageResult[T]):
357
447
  @classmethod
358
448
  def parse(cls, raw: dict, parser: Optional[Callable[[dict], T]] = None) -> "PageDequeResult[T]":
359
449
  data = [(parser or ModelBase.parse)(item) for item in raw["data"]]
360
- return cls(data, raw.get("next"), raw.get("prev"))
450
+ return cls(data, raw.get("next"), raw.get("prev")) # type: ignore
361
451
 
362
452
  def dump(self):
363
453
  res: dict = {"data": [item.dump() for item in self.data]}
@@ -22,9 +22,10 @@ from loguru import logger
22
22
  from starlette.applications import Starlette
23
23
  from starlette.datastructures import FormData as FormData
24
24
  from starlette.requests import Request as StarletteRequest
25
- from starlette.responses import JSONResponse, Response, StreamingResponse
25
+ from starlette.responses import FileResponse, JSONResponse, Response, StreamingResponse
26
26
  from starlette.routing import Route, WebSocketRoute
27
- from starlette.websockets import WebSocket
27
+ from starlette.staticfiles import StaticFiles
28
+ from starlette.websockets import WebSocket, WebSocketDisconnect
28
29
  from yarl import URL
29
30
 
30
31
  from satori.config import WebhookInfo
@@ -33,13 +34,13 @@ from satori.model import Event, ModelBase, Opcode
33
34
 
34
35
  from .adapter import Adapter as Adapter
35
36
  from .conection import WebsocketConnection
36
- from .deque import Deque
37
37
  from .formdata import parse_content_disposition as parse_content_disposition
38
38
  from .model import Provider as Provider
39
39
  from .model import Request as Request
40
40
  from .model import Router as Router
41
41
  from .route import RouteCall as RouteCall
42
42
  from .route import RouterMixin as RouterMixin
43
+ from .utils import Deque
43
44
 
44
45
 
45
46
  async def _request_handler(method: str, request: StarletteRequest, func: RouteCall):
@@ -103,11 +104,11 @@ class Server(Service, RouterMixin):
103
104
  self.routes = {}
104
105
  self.webhooks = webhooks or []
105
106
  self._tempdir = TemporaryDirectory()
106
- self.proxy_url_mapping = {}
107
107
  self._sequence = 0
108
108
  self._event_cache = Deque(maxlen=100)
109
109
  self.stream_threshold = stream_threshold
110
110
  self.stream_chunk_size = stream_chunk_size
111
+ self.resources: dict[str, Path] = {}
111
112
  super().__init__()
112
113
 
113
114
  def apply(self, item: Provider | Router | Adapter):
@@ -115,15 +116,17 @@ class Server(Service, RouterMixin):
115
116
  item.ensure_server(self)
116
117
  self._adapters.append(item)
117
118
  self.providers.append(item)
118
- self.proxy_url_mapping[item.id] = item.proxy_urls()
119
119
  elif isinstance(item, Provider):
120
120
  self.providers.append(item)
121
- self.proxy_url_mapping[item.id] = item.proxy_urls()
122
121
  elif isinstance(item, Router):
123
122
  self.routers.append(item)
124
123
  else:
125
124
  raise TypeError(f"Unknown config type: {item}")
126
125
 
126
+ def mount(self, route_path: str, file: Path):
127
+ """在指定路径挂载静态文件"""
128
+ self.resources[route_path] = file
129
+
127
130
  async def event_callback(self, event: Event):
128
131
  event.id = self._sequence
129
132
  self._event_cache.append(event)
@@ -131,6 +134,8 @@ class Server(Service, RouterMixin):
131
134
  for connection in self.connections:
132
135
  try:
133
136
  await connection.send({"op": Opcode.EVENT, "body": event.dump()})
137
+ except WebSocketDisconnect:
138
+ break
134
139
  except Exception as e:
135
140
  print_exc()
136
141
  logger.error(e)
@@ -141,8 +146,10 @@ class Server(Service, RouterMixin):
141
146
  headers={
142
147
  "Content-Type": "application/json",
143
148
  "Authorization": f"Bearer {hook.token or ''}",
144
- "X-Platform": event.platform,
145
- "X-Self-ID": event.self_id,
149
+ "X-Platform": event.platform_,
150
+ "Satori-Platform": event.platform_,
151
+ "X-Self-ID": event.self_id_,
152
+ "Satori-Login-ID": event.self_id_,
146
153
  },
147
154
  json={"op": Opcode.EVENT, "body": event.dump()},
148
155
  ) as resp:
@@ -191,19 +198,22 @@ class Server(Service, RouterMixin):
191
198
  async def admin_login_list_handler(self, request: StarletteRequest):
192
199
  logins = []
193
200
  for provider in self.providers:
194
- logins.extend(await provider.get_logins())
201
+ _logins = await provider.get_logins()
202
+ for _login in _logins:
203
+ _login.proxy_urls.extend(provider.proxy_urls())
204
+ logins.extend(_logins)
195
205
  return JSONResponse(content=[lo.dump() for lo in logins])
196
206
 
197
207
  async def http_server_handler(self, request: StarletteRequest):
198
208
  if not self._adapters and not self.routes:
199
209
  return Response(status_code=404, content=request.path_params["method"])
200
210
  method = request.path_params["method"]
201
- if "X-Platform" not in request.headers:
202
- return Response(status_code=401, content="Missing X-Platform header")
203
- platform = request.headers["X-Platform"]
204
- if "X-Self-ID" not in request.headers:
205
- return Response(status_code=401, content="Missing X-Self-ID header")
206
- self_id = request.headers["X-Self-ID"]
211
+ if "X-Platform" not in request.headers and "Satori-Platform" not in request.headers:
212
+ return Response(status_code=401, content="Missing header X-Platform or Satori-Platform")
213
+ platform: str = request.headers.get("X-Platform") or request.headers.get("Satori-Platform") # type: ignore
214
+ if "X-Self-ID" not in request.headers and "Satori-Login-ID" not in request.headers:
215
+ return Response(status_code=401, content="Missing header X-Self-ID or Satori-Login-ID")
216
+ self_id: str = request.headers.get("X-Self-ID") or request.headers.get("Satori-Login-ID") # type: ignore
207
217
 
208
218
  for _router in self._adapters:
209
219
  if method not in _router.routes:
@@ -223,6 +233,8 @@ class Server(Service, RouterMixin):
223
233
  url = request.path_params["upload_url"]
224
234
  try:
225
235
  content = await self.download(url)
236
+ if isinstance(content, Path):
237
+ return FileResponse(path=content)
226
238
  # if content size > stream_limit, use streaming response
227
239
  if len(content) > self.stream_threshold:
228
240
 
@@ -243,24 +255,30 @@ class Server(Service, RouterMixin):
243
255
  return Response(status_code=500, content=repr(e))
244
256
 
245
257
  async def download(self, url: str):
246
- pr = urllib.parse.urlparse(url.replace(":/", "://", 1).replace(":///", "://", 1))
258
+ url = url.replace(":/", "://", 1).replace(":///", "://", 1)
259
+ pr = urllib.parse.urlparse(url)
247
260
  if pr.scheme == "upload":
248
261
  if pr.netloc == "temp":
249
262
  _, inst, filename = pr.path.split("/", 2)
250
263
  if inst == f"{self.id}:{id(self)}":
251
264
  file = Path(self._tempdir.name) / filename
252
265
  if file.exists():
253
- return file.read_bytes()
266
+ return file
254
267
  raise FileNotFoundError(f"{filename} not found")
255
- platform = pr.netloc
256
- _, self_id, path = pr.path.split("/", 2)
257
- for provider in self.providers:
268
+ for provider in self.providers:
269
+ if pr.scheme == "upload":
270
+ platform = pr.netloc
271
+ _, self_id, path = pr.path.split("/", 2)
258
272
  if provider.ensure(platform, self_id):
259
273
  return await provider.download_uploaded(platform, self_id, path)
260
- for provider in self.providers:
261
- for proxy_url_pf in self.proxy_url_mapping[provider.id]:
262
- if url.startswith(proxy_url_pf):
263
- return await provider.download_proxied(proxy_url_pf, url)
274
+
275
+ for proxy_url_pf in provider.proxy_urls():
276
+ if not url.startswith(proxy_url_pf):
277
+ continue
278
+ resp = await provider.download_proxied(proxy_url_pf, url)
279
+ if resp is None:
280
+ continue
281
+ return resp
264
282
  raise ValueError(f"Unknown proxy url: {url}")
265
283
 
266
284
  def get_local_file(self, url: str):
@@ -322,6 +340,8 @@ class Server(Service, RouterMixin):
322
340
  ),
323
341
  ]
324
342
  )
343
+ for path, file in self.resources.items():
344
+ app.mount(path, StaticFiles(directory=file.parent, html=file.suffix == ".html"))
325
345
  asgi_service.middleware.mounts[""] = app # type: ignore
326
346
 
327
347
  async def event_task(_provider: Provider):
@@ -4,8 +4,9 @@ from typing import TYPE_CHECKING, Optional
4
4
 
5
5
  from launart import Service
6
6
 
7
- from ..model import Event, Login
7
+ from ..model import Event, LoginType
8
8
  from .route import RouterMixin
9
+ from .utils import ctx
9
10
 
10
11
  if TYPE_CHECKING:
11
12
  from . import Server
@@ -34,11 +35,11 @@ class Adapter(Service, RouterMixin):
34
35
  async def download_uploaded(self, platform: str, self_id: str, path: str) -> bytes: ...
35
36
 
36
37
  async def download_proxied(self, prefix: str, url: str) -> bytes:
37
- async with self.server.session.get(url) as resp:
38
+ async with self.server.session.get(url, ssl=ctx) as resp:
38
39
  return await resp.read()
39
40
 
40
41
  @abstractmethod
41
- async def get_logins(self) -> list[Login]: ...
42
+ async def get_logins(self) -> list[LoginType]: ...
42
43
 
43
44
  def __init__(self):
44
45
  super().__init__()
@@ -3,7 +3,7 @@ from dataclasses import dataclass
3
3
  from typing import TYPE_CHECKING, Any, Generic, Optional, Protocol, TypeVar, Union, runtime_checkable
4
4
 
5
5
  from satori.const import Api
6
- from satori.model import Event, Login
6
+ from satori.model import Event, Login, LoginPreview, LoginType
7
7
 
8
8
  if TYPE_CHECKING:
9
9
  from .route import RouteCall
@@ -22,14 +22,11 @@ class Request(Generic[TP]):
22
22
 
23
23
  @runtime_checkable
24
24
  class Provider(Protocol):
25
- @property
26
- def id(self) -> str: ...
27
-
28
25
  def publisher(self) -> AsyncIterator[Event]: ...
29
26
 
30
27
  def authenticate(self, token: Optional[str]) -> bool: ...
31
28
 
32
- async def get_logins(self) -> list[Login]: ...
29
+ async def get_logins(self) -> Union[list[Login], list[LoginPreview], list[LoginType]]: ...
33
30
 
34
31
  @staticmethod
35
32
  def proxy_urls() -> list[str]: ...
@@ -38,7 +35,7 @@ class Provider(Protocol):
38
35
 
39
36
  async def download_uploaded(self, platform: str, self_id: str, path: str) -> bytes: ...
40
37
 
41
- async def download_proxied(self, prefix: str, url: str) -> bytes: ...
38
+ async def download_proxied(self, prefix: str, url: str) -> Optional[bytes]: ...
42
39
 
43
40
 
44
41
  @runtime_checkable
@@ -8,7 +8,7 @@ from satori.model import (
8
8
  Channel,
9
9
  Direction,
10
10
  Guild,
11
- Login,
11
+ LoginType,
12
12
  Member,
13
13
  MessageObject,
14
14
  ModelBase,
@@ -241,7 +241,7 @@ class ReactionListParam(TypedDict):
241
241
 
242
242
 
243
243
  REACTION_LIST: TypeAlias = RouteCall[ReactionListParam, Union[PageResult[User], dict[str, Any]]]
244
- LOGIN_GET: TypeAlias = RouteCall[Any, Union[Login, dict[str, Any]]]
244
+ LOGIN_GET: TypeAlias = RouteCall[Any, Union[LoginType, dict[str, Any]]]
245
245
 
246
246
 
247
247
  class UserGetParam(TypedDict):
@@ -1,3 +1,4 @@
1
+ import ssl
1
2
  from collections import deque
2
3
 
3
4
 
@@ -23,6 +24,10 @@ class Deque:
23
24
  return list(self.data)[i + 1 - self.offset :]
24
25
 
25
26
 
27
+ ctx = ssl.create_default_context()
28
+ ctx.set_ciphers("DEFAULT")
29
+
30
+
26
31
  if __name__ == "__main__":
27
32
  d = Deque(3)
28
33
  d.append(0)
File without changes