satori-python-server 0.11.5__tar.gz → 0.13.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,6 +13,7 @@ dependencies = [
13
13
  "graia-amnesia",
14
14
  "starlette",
15
15
  "uvicorn[standard]",
16
+ "python-multipart",
16
17
  ]
17
18
  description = "Satori Protocol SDK for python, specify server part"
18
19
  license = {text = "MIT"}
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: satori-python-server
3
- Version: 0.11.5
3
+ Version: 0.13.0
4
4
  Summary: Satori Protocol SDK for python, specify server part
5
5
  Home-page: https://github.com/RF-Tar-Railt/satori-python
6
6
  Author-Email: RF-Tar-Railt <rf_tar_railt@qq.com>
@@ -20,8 +20,9 @@ Requires-Python: >=3.8
20
20
  Requires-Dist: aiohttp>=3.9.3
21
21
  Requires-Dist: launart>=0.8.2
22
22
  Requires-Dist: graia-amnesia>=0.9.0
23
- Requires-Dist: starlette>=0.37.2
23
+ Requires-Dist: starlette[python-multipart]>=0.37.2
24
24
  Requires-Dist: uvicorn[standard]>=0.28.0
25
+ Requires-Dist: python-multipart>=0.0.9
25
26
  Requires-Dist: satori-python-core>=0.11.4
26
27
  Description-Content-Type: text/markdown
27
28
 
@@ -45,6 +46,10 @@ Description-Content-Type: text/markdown
45
46
  - [Chronocat](https://chronocat.vercel.app)
46
47
  - Koishi (搭配 `@koishijs/plugin-server`)
47
48
 
49
+ ### 使用该 SDK 的框架
50
+
51
+ - [`Entari`](https://github.com/ArcletProject/Entari)
52
+
48
53
  ## 安装
49
54
 
50
55
  安装完整体:
@@ -72,14 +77,15 @@ pip install satori-python-server
72
77
  客户端:
73
78
 
74
79
  ```python
75
- from satori import Event, WebsocketsInfo
80
+ from satori import EventType, WebsocketsInfo
81
+ from satori.event import MessageEvent
76
82
  from satori.client import Account, App
77
83
 
78
84
  app = App(WebsocketsInfo(port=5140))
79
85
 
80
- @app.register
81
- async def on_message(account: Account, event: Event):
82
- if event.user and event.user.id == "xxxxxxxxxxx":
86
+ @app.register_on(EventType.MESSAGE_CREATED)
87
+ async def on_message(account: Account, event: MessageEvent):
88
+ if event.user.id == "xxxxxxxxxxx":
83
89
  await account.send(event, "Hello, World!")
84
90
 
85
91
  app.run()
@@ -18,6 +18,10 @@
18
18
  - [Chronocat](https://chronocat.vercel.app)
19
19
  - Koishi (搭配 `@koishijs/plugin-server`)
20
20
 
21
+ ### 使用该 SDK 的框架
22
+
23
+ - [`Entari`](https://github.com/ArcletProject/Entari)
24
+
21
25
  ## 安装
22
26
 
23
27
  安装完整体:
@@ -45,14 +49,15 @@ pip install satori-python-server
45
49
  客户端:
46
50
 
47
51
  ```python
48
- from satori import Event, WebsocketsInfo
52
+ from satori import EventType, WebsocketsInfo
53
+ from satori.event import MessageEvent
49
54
  from satori.client import Account, App
50
55
 
51
56
  app = App(WebsocketsInfo(port=5140))
52
57
 
53
- @app.register
54
- async def on_message(account: Account, event: Event):
55
- if event.user and event.user.id == "xxxxxxxxxxx":
58
+ @app.register_on(EventType.MESSAGE_CREATED)
59
+ async def on_message(account: Account, event: MessageEvent):
60
+ if event.user.id == "xxxxxxxxxxx":
56
61
  await account.send(event, "Hello, World!")
57
62
 
58
63
  app.run()
@@ -8,8 +8,9 @@ dependencies = [
8
8
  "aiohttp>=3.9.3",
9
9
  "launart>=0.8.2",
10
10
  "graia-amnesia>=0.9.0",
11
- "starlette>=0.37.2",
11
+ "starlette[python-multipart]>=0.37.2",
12
12
  "uvicorn[standard]>=0.28.0",
13
+ "python-multipart>=0.0.9",
13
14
  "satori-python-core >= 0.11.4",
14
15
  ]
15
16
  description = "Satori Protocol SDK for python, specify server part"
@@ -26,7 +27,7 @@ classifiers = [
26
27
  "Programming Language :: Python :: 3.12",
27
28
  "Operating System :: OS Independent",
28
29
  ]
29
- version = "0.11.5"
30
+ version = "0.13.0"
30
31
 
31
32
  [project.license]
32
33
  text = "MIT"
@@ -0,0 +1,338 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import functools
5
+ import mimetypes
6
+ import secrets
7
+ import signal
8
+ import threading
9
+ import urllib.parse
10
+ from contextlib import suppress
11
+ from pathlib import Path
12
+ from tempfile import TemporaryDirectory
13
+ from traceback import print_exc
14
+ from typing import Any, Iterable, cast
15
+
16
+ import aiohttp
17
+ from creart import it
18
+ from graia.amnesia.builtins.asgi import UvicornASGIService
19
+ from launart import Launart, Service, any_completed
20
+ from loguru import logger
21
+ from starlette.applications import Starlette
22
+ from starlette.datastructures import FormData
23
+ from starlette.requests import Request as StarletteRequest
24
+ from starlette.responses import JSONResponse, Response
25
+ from starlette.routing import Route, WebSocketRoute
26
+ from starlette.websockets import WebSocket
27
+ from yarl import URL
28
+
29
+ from satori.config import WebhookInfo
30
+ from satori.const import Api
31
+ from satori.model import Event, ModelBase, Opcode
32
+
33
+ from .adapter import Adapter as Adapter
34
+ from .conection import WebsocketConnection
35
+ from .formdata import parse_content_disposition
36
+ from .model import Provider as Provider
37
+ from .model import Request as Request
38
+ from .model import Router as Router
39
+ from .route import RouteCall as RouteCall
40
+ from .route import RouterMixin as RouterMixin
41
+
42
+
43
+ async def _request_handler(method: str, request: StarletteRequest, func: RouteCall):
44
+ if method == Api.UPLOAD_CREATE.value:
45
+ async with request.form() as form:
46
+ res = await func(
47
+ Request(
48
+ cast(dict, request.headers.mutablecopy()),
49
+ method,
50
+ form,
51
+ )
52
+ )
53
+ return JSONResponse(content=res)
54
+ res = await func(
55
+ Request(
56
+ cast(dict, request.headers.mutablecopy()),
57
+ method,
58
+ await request.json(),
59
+ )
60
+ )
61
+ if isinstance(res, ModelBase):
62
+ return JSONResponse(content=res.dump())
63
+ if res and isinstance(res, list) and isinstance(res[0], ModelBase):
64
+ return JSONResponse(content=[_.dump() for _ in res]) # type: ignore
65
+ return res if isinstance(res, Response) else JSONResponse(content=res)
66
+
67
+
68
+ class Server(Service, RouterMixin):
69
+ id = "satori-python.server"
70
+ required: set[str] = {"asgi.service/uvicorn"}
71
+ stages: set[str] = {"preparing", "blocking", "cleanup"}
72
+
73
+ version: str
74
+ providers: list[Provider]
75
+ routers: list[Router]
76
+ _adapters: list[Adapter]
77
+ connections: list[WebsocketConnection]
78
+
79
+ def __init__(
80
+ self,
81
+ host: str = "127.0.0.1",
82
+ port: int = 5140,
83
+ path: str = "",
84
+ version: str = "v1",
85
+ webhooks: list[WebhookInfo] | None = None,
86
+ ):
87
+ self.connections = []
88
+ manager = it(Launart)
89
+ manager.add_component(UvicornASGIService(host, port))
90
+ self.version = version
91
+ self.path = path
92
+ if self.path and not self.path.startswith("/"):
93
+ self.path = f"/{self.path}"
94
+ self._adapters = []
95
+ self.providers = []
96
+ self.routers = []
97
+ self.routes = {}
98
+ self.webhooks = webhooks or []
99
+ self.session = aiohttp.ClientSession()
100
+ self._tempdir = TemporaryDirectory()
101
+ self.proxy_url_mapping = {}
102
+ super().__init__()
103
+
104
+ def apply(self, item: Provider | Router | Adapter):
105
+ if isinstance(item, Adapter):
106
+ self._adapters.append(item)
107
+ self.providers.append(item)
108
+ for proxy_url_pf in item.proxy_urls():
109
+ self.proxy_url_mapping[proxy_url_pf] = item
110
+ elif isinstance(item, Provider):
111
+ self.providers.append(item)
112
+ for proxy_url_pf in item.proxy_urls():
113
+ self.proxy_url_mapping[proxy_url_pf] = item
114
+ elif isinstance(item, Router):
115
+ self.routers.append(item)
116
+ else:
117
+ raise TypeError(f"Unknown config type: {item}")
118
+
119
+ async def event_callback(self, event: Event):
120
+ for connection in self.connections:
121
+ try:
122
+ await connection.send({"op": Opcode.EVENT, "body": event.dump()})
123
+ except Exception as e:
124
+ print_exc()
125
+ logger.error(e)
126
+ for hook in self.webhooks:
127
+ try:
128
+ async with self.session.post(
129
+ URL(f"http://{hook.identity}"),
130
+ headers={
131
+ "Content-Type": "application/json",
132
+ "Authorization": f"Bearer {hook.token or ''}",
133
+ "X-Platform": event.platform,
134
+ "X-Self-ID": event.self_id,
135
+ },
136
+ json={"op": Opcode.EVENT, "body": event.dump()},
137
+ ) as resp:
138
+ resp.raise_for_status()
139
+ except Exception as e:
140
+ print_exc()
141
+ logger.error(e)
142
+
143
+ async def websocket_server_handler(self, ws: WebSocket):
144
+ await ws.accept()
145
+ connection = WebsocketConnection(ws)
146
+ identity = await ws.receive_json()
147
+ if not isinstance(identity, dict) or identity.get("op") != Opcode.IDENTIFY:
148
+ return await ws.close(code=3000, reason="Unauthorized")
149
+ token = identity["body"]["token"]
150
+ logins = []
151
+ for provider in self.providers:
152
+ if not provider.authenticate(token):
153
+ return await ws.close(code=3000, reason="Unauthorized")
154
+ logins.extend(await provider.get_logins())
155
+ await connection.send({"op": Opcode.READY, "body": {"logins": [lo.dump() for lo in logins]}})
156
+ self.connections.append(connection)
157
+
158
+ try:
159
+ await any_completed(connection.heartbeat(), connection.close_signal.wait())
160
+ finally:
161
+ self.connections.remove(connection)
162
+
163
+ async def admin_login_list_handler(self, request: StarletteRequest):
164
+ logins = []
165
+ for provider in self.providers:
166
+ logins.extend(await provider.get_logins())
167
+ return JSONResponse(content=[lo.dump() for lo in logins])
168
+
169
+ async def http_server_handler(self, request: StarletteRequest):
170
+ if not self._adapters and not self.routes:
171
+ return Response(status_code=404, content=request.path_params["method"])
172
+ method = request.path_params["method"]
173
+ if "X-Platform" not in request.headers:
174
+ return Response(status_code=401, content="Missing X-Platform header")
175
+ platform = request.headers["X-Platform"]
176
+ if "X-Self-ID" not in request.headers:
177
+ return Response(status_code=401, content="Missing X-Self-ID header")
178
+ self_id = request.headers["X-Self-ID"]
179
+
180
+ for _router in self._adapters:
181
+ if method not in _router.routes:
182
+ continue
183
+ if not _router.ensure(platform, self_id):
184
+ continue
185
+ return await _request_handler(method, request, _router.routes[method])
186
+ if method in self.routes:
187
+ return await _request_handler(method, request, self.routes[method])
188
+ for _router in self.routers:
189
+ if method not in _router.routes:
190
+ continue
191
+ return await _request_handler(method, request, _router.routes[method])
192
+ return Response(status_code=404, content=method)
193
+
194
+ async def proxy_url_handler(self, request: StarletteRequest):
195
+ url = request.path_params["upload_url"]
196
+ try:
197
+ return Response(content=await self.download(url))
198
+ except FileNotFoundError as e404:
199
+ return Response(status_code=404, content=str(e404))
200
+ except ValueError as e403:
201
+ return Response(status_code=403, content=str(e403))
202
+ except Exception as e:
203
+ return Response(status_code=400, content=str(e))
204
+
205
+ async def download(self, url: str):
206
+ pr = urllib.parse.urlparse(url.replace(":/", "://", 1).replace(":///", "://", 1))
207
+ if pr.scheme == "upload":
208
+ if pr.netloc == "temp":
209
+ _, inst, filename = pr.path.split("/", 2)
210
+ if inst == f"{self.id}:{id(self)}":
211
+ file = Path(self._tempdir.name) / filename
212
+ if file.exists():
213
+ return file.read_bytes()
214
+ raise FileNotFoundError(f"{filename} not found")
215
+ platform = pr.netloc
216
+ _, self_id, path = pr.path.split("/", 2)
217
+ for provider in self.providers:
218
+ if provider.ensure(platform, self_id):
219
+ return await provider.download_uploaded(platform, self_id, path)
220
+ for proxy_url_pf, provider in self.proxy_url_mapping.items():
221
+ if url.startswith(proxy_url_pf):
222
+ async with self.session.get(url) as resp:
223
+ return await resp.read()
224
+ raise ValueError(f"Unknown proxy url: {url}")
225
+
226
+ def get_local_file(self, url: str):
227
+ url = url.split("/")[-1]
228
+ file = Path(self._tempdir.name) / url
229
+ if file.exists():
230
+ return file.read_bytes()
231
+
232
+ async def _default_upload_create_handler(self, request: Request[FormData]):
233
+ res = {}
234
+ root = Path(self._tempdir.name)
235
+ for _, data in request.params.items():
236
+ if isinstance(data, str):
237
+ continue
238
+ ext = data.headers["content-type"]
239
+ disp = parse_content_disposition(data.headers["content-disposition"])
240
+ fid = secrets.token_urlsafe(16)
241
+ if "filename" in disp:
242
+ filename = f"{fid}-{disp['filename']}"
243
+ else:
244
+ filename = f"{fid}-{disp['name']}{mimetypes.guess_extension(ext) or '.png'}"
245
+ file = root / filename
246
+ with file.resolve().open("wb+") as f:
247
+ f.write(await data.read())
248
+
249
+ res[disp["name"]] = f"upload://temp/{self.id}:{id(self)}/{filename}"
250
+
251
+ loop = asyncio.get_running_loop()
252
+ loop.call_later(600, file.unlink, True)
253
+ return res
254
+
255
+ async def launch(self, manager: Launart):
256
+ for _adapter in self._adapters:
257
+ manager.add_component(_adapter)
258
+
259
+ if Api.UPLOAD_CREATE.value not in self.routes and not self._adapters:
260
+ self.routes[Api.UPLOAD_CREATE.value] = self._default_upload_create_handler
261
+
262
+ async with self.stage("preparing"):
263
+ asgi_service = manager.get_component(UvicornASGIService)
264
+ app = Starlette(
265
+ routes=[
266
+ WebSocketRoute(f"{self.path}/{self.version}/events", self.websocket_server_handler),
267
+ Route(
268
+ f"{self.path}/{self.version}/admin/login.list",
269
+ self.admin_login_list_handler,
270
+ methods=["POST"],
271
+ ),
272
+ Route(
273
+ f"{self.path}/{self.version}/proxy/{{upload_url:path}}",
274
+ self.proxy_url_handler,
275
+ methods=["GET"],
276
+ ),
277
+ Route(
278
+ f"{self.path}/{self.version}/{{method:path}}",
279
+ self.http_server_handler,
280
+ methods=["POST"],
281
+ ),
282
+ ]
283
+ )
284
+ asgi_service.middleware.mounts[""] = app # type: ignore
285
+
286
+ async def event_task(_provider: Provider):
287
+ async for event in _provider.publisher():
288
+ await self.event_callback(event)
289
+
290
+ async with self.stage("blocking"):
291
+ await any_completed(
292
+ manager.status.wait_for_sigexit(),
293
+ *(event_task(provider) for provider in self.providers),
294
+ *(_adapter.status.wait_for("blocking-completed") for _adapter in self._adapters),
295
+ )
296
+
297
+ async with self.stage("cleanup"):
298
+ with suppress(KeyError):
299
+ del asgi_service.middleware.mounts[""]
300
+ await self.session.close()
301
+ self._tempdir.cleanup()
302
+
303
+ def run(
304
+ self,
305
+ manager: Launart | None = None,
306
+ *,
307
+ loop: asyncio.AbstractEventLoop | None = None,
308
+ stop_signal: Iterable[signal.Signals] = (signal.SIGINT,),
309
+ ):
310
+ if manager is None:
311
+ manager = it(Launart)
312
+ manager.add_component(self)
313
+ manager.launch_blocking(loop=loop, stop_signal=stop_signal)
314
+
315
+ async def run_async(
316
+ self,
317
+ manager: Launart | None = None,
318
+ stop_signal: Iterable[signal.Signals] = (signal.SIGINT,),
319
+ ):
320
+ if manager is None:
321
+ manager = it(Launart)
322
+ manager.add_component(self)
323
+ handled_signals: dict[signal.Signals, Any] = {}
324
+ launch_task = asyncio.create_task(manager.launch(), name="amnesia-launch")
325
+ signal_handler = functools.partial(manager._on_sys_signal, main_task=launch_task)
326
+ if threading.current_thread() is threading.main_thread(): # pragma: worst case
327
+ try:
328
+ for sig in stop_signal:
329
+ handled_signals[sig] = signal.getsignal(sig)
330
+ signal.signal(sig, signal_handler)
331
+ except ValueError: # pragma: no cover
332
+ # `signal.signal` may throw if `threading.main_thread` does
333
+ # not support signals
334
+ handled_signals.clear()
335
+ await launch_task
336
+ for sig, handler in handled_signals.items():
337
+ if signal.getsignal(sig) is signal_handler:
338
+ signal.signal(sig, handler)
@@ -1,13 +1,13 @@
1
1
  from abc import abstractmethod
2
- from typing import Any, AsyncIterator, Dict, List
2
+ from typing import AsyncIterator, List
3
3
 
4
4
  from launart import Service
5
5
 
6
6
  from ..model import Event, Login
7
- from .model import Request
7
+ from .route import RouterMixin
8
8
 
9
9
 
10
- class Adapter(Service):
10
+ class Adapter(Service, RouterMixin):
11
11
  @abstractmethod
12
12
  def get_platform(self) -> str: ...
13
13
 
@@ -15,19 +15,24 @@ class Adapter(Service):
15
15
  def publisher(self) -> AsyncIterator[Event]: ...
16
16
 
17
17
  @abstractmethod
18
- def validate_headers(self, headers: Dict[str, Any]) -> bool: ...
18
+ def ensure(self, platform: str, self_id: str) -> bool: ...
19
19
 
20
20
  @abstractmethod
21
21
  def authenticate(self, token: str) -> bool: ...
22
22
 
23
- @abstractmethod
24
- async def get_logins(self) -> List[Login]: ...
23
+ @staticmethod
24
+ def proxy_urls() -> List[str]:
25
+ return []
25
26
 
26
27
  @abstractmethod
27
- async def call_api(self, request: Request[Any]) -> Any: ...
28
+ async def download_uploaded(self, platform: str, self_id: str, path: str) -> bytes: ...
28
29
 
29
30
  @abstractmethod
30
- async def call_internal_api(self, request: Request[Any]) -> Any: ...
31
+ async def get_logins(self) -> List[Login]: ...
32
+
33
+ def __init__(self):
34
+ super().__init__()
35
+ self.routes = {}
31
36
 
32
37
  @property
33
38
  def id(self):
@@ -0,0 +1,13 @@
1
+ import re
2
+
3
+
4
+ def parse_content_disposition(header_value):
5
+ match = re.match(r"""form-data; (?P<parameters>.+)""", header_value)
6
+ if match:
7
+ parameters = match.groupdict()["parameters"]
8
+ parsed_data = {}
9
+ for param in parameters.split(";"):
10
+ key, value = param.strip().split("=")
11
+ parsed_data[key.strip('"')] = value.strip('"')
12
+ return parsed_data
13
+ raise ValueError(header_value)
@@ -1,11 +1,14 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import Any, AsyncIterator, Generic, Protocol, TypeVar, Union, runtime_checkable
4
+ from typing import TYPE_CHECKING, Any, AsyncIterator, Generic, Protocol, TypeVar, Union, runtime_checkable
5
5
 
6
6
  from satori.const import Api
7
7
  from satori.model import Event, Login
8
8
 
9
+ if TYPE_CHECKING:
10
+ from .route import RouteCall
11
+
9
12
  JsonType = Union[list, dict, str, int, bool, float, None]
10
13
  TA = TypeVar("TA", str, Api)
11
14
  TP = TypeVar("TP")
@@ -26,11 +29,14 @@ class Provider(Protocol):
26
29
 
27
30
  async def get_logins(self) -> list[Login]: ...
28
31
 
32
+ @staticmethod
33
+ def proxy_urls() -> list[str]: ...
29
34
 
30
- @runtime_checkable
31
- class Router(Protocol):
32
- def validate_headers(self, headers: dict[str, Any]) -> bool: ...
35
+ def ensure(self, platform: str, self_id: str) -> bool: ...
33
36
 
34
- async def call_api(self, request: Request[Any]) -> Any: ...
37
+ async def download_uploaded(self, platform: str, self_id: str, path: str) -> bytes: ...
35
38
 
36
- async def call_internal_api(self, request: Request[Any]) -> Any: ...
39
+
40
+ @runtime_checkable
41
+ class Router(Protocol):
42
+ routes: dict[str, RouteCall[Any, Any]]