maim-message 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- maim_message-0.2.0/.gitignore +21 -0
- maim_message-0.2.0/LICENSE +21 -0
- maim_message-0.2.0/PKG-INFO +20 -0
- maim_message-0.2.0/README.md +258 -0
- maim_message-0.2.0/api.py +91 -0
- maim_message-0.2.0/pyproject.toml +34 -0
- maim_message-0.2.0/setup.cfg +4 -0
- maim_message-0.2.0/setup.py +17 -0
- maim_message-0.2.0/src/maim_message/__init__.py +30 -0
- maim_message-0.2.0/src/maim_message/api.py +325 -0
- maim_message-0.2.0/src/maim_message/message_base.py +257 -0
- maim_message-0.2.0/src/maim_message/router.py +211 -0
- maim_message-0.2.0/src/maim_message.egg-info/PKG-INFO +20 -0
- maim_message-0.2.0/src/maim_message.egg-info/SOURCES.txt +15 -0
- maim_message-0.2.0/src/maim_message.egg-info/dependency_links.txt +1 -0
- maim_message-0.2.0/src/maim_message.egg-info/requires.txt +5 -0
- maim_message-0.2.0/src/maim_message.egg-info/top_level.txt +1 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 tcmofashi
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: maim_message
|
|
3
|
+
Version: 0.2.0
|
|
4
|
+
Summary: A message handling library
|
|
5
|
+
Home-page: https://github.com/MaiM-with-u/maim_message
|
|
6
|
+
Author: tcmofashi
|
|
7
|
+
Author-email: mofashiforzbx@qq.com
|
|
8
|
+
Requires-Python: >=3.9
|
|
9
|
+
Description-Content-Type: text/markdown
|
|
10
|
+
License-File: LICENSE
|
|
11
|
+
Requires-Dist: fastapi>=0.70.0
|
|
12
|
+
Requires-Dist: uvicorn>=0.15.0
|
|
13
|
+
Requires-Dist: aiohttp>=3.8.0
|
|
14
|
+
Requires-Dist: pydantic>=1.9.0
|
|
15
|
+
Requires-Dist: websockets>=10.0
|
|
16
|
+
Dynamic: author-email
|
|
17
|
+
Dynamic: description-content-type
|
|
18
|
+
Dynamic: home-page
|
|
19
|
+
Dynamic: license-file
|
|
20
|
+
Dynamic: requires-python
|
|
@@ -0,0 +1,258 @@
|
|
|
1
|
+
# Maim Message
|
|
2
|
+
|
|
3
|
+
一个用于定义maimcore消息通用接口的Python库。
|
|
4
|
+
|
|
5
|
+
## 安装
|
|
6
|
+
|
|
7
|
+
```bash
|
|
8
|
+
git clone https://github.com/MaiM-with-u/maim_message
|
|
9
|
+
cd maim_message
|
|
10
|
+
pip install -e .
|
|
11
|
+
```
|
|
12
|
+
|
|
13
|
+
## 使用方法
|
|
14
|
+
|
|
15
|
+
```python
|
|
16
|
+
from maim_message import MessageBase
|
|
17
|
+
```
|
|
18
|
+
|
|
19
|
+
## 介绍
|
|
20
|
+
maim_message是从maimbot项目衍生出来的各个组件之间的消息定义和交换的库,为了实现多平台和方便开发者开发,maimbot从与nonebot耦合发展到maimcore阶段,maimcore只暴露自己的websocket接口,别的组件可以通过maim_message提供的MessageClient类与maimcore连接,也可以构造MessageServer接受MessageClient连接。
|
|
21
|
+
|
|
22
|
+
在这种通信协议下,可以产生多种组件类型,例如最简单的nonebot_adapter与maimcore之间的通信,nonebot作为MessageClient,连接作为MessageServer的maimcore,从而实现一个消息客户端的开发。另外还支持类似代理或中间件的插件形式,例如让nonebot连接到插件的MessageServer,插件再通过MessageClient连接到maimcore。
|
|
23
|
+
|
|
24
|
+
消息的构造使用MessageBase,MessageBase提供了序列化和反序列化的方法,是消息通信的基本格式,使用时可以直接构造MessageBase,也可以继承MessageBase后构造。
|
|
25
|
+
|
|
26
|
+
消息的基本内容由maim_message的Seg定义,Seg有type和data两个属性,除了特殊类型seglist之外,type并无限制,但maimcore只能处理类型为text,image,emoji,seglist的Seg,Seg支持嵌套,seglist类型的Seg的data定义为一个Seg列表,便于含有多种类型组合的消息,以及含有嵌套的消息可以直接递归的解析成Seg。
|
|
27
|
+
|
|
28
|
+
目前maimcore可以处理的Seg类型以及定义如下:
|
|
29
|
+
```python
|
|
30
|
+
Seg(
|
|
31
|
+
"seglist",
|
|
32
|
+
[
|
|
33
|
+
Seg("text", "111(raw text)"),
|
|
34
|
+
Seg("emoji", "base64(无头base64)"),
|
|
35
|
+
Seg("image", "base64(无头base64)"),
|
|
36
|
+
],
|
|
37
|
+
)
|
|
38
|
+
```
|
|
39
|
+
|
|
40
|
+
## 构造一个合法消息
|
|
41
|
+
#### 消息构造一图流
|
|
42
|
+
<img src="./doc/img/maim_message_messagebase.png" alt="由于图画的太好我就不打文本了" />
|
|
43
|
+
|
|
44
|
+
#### 消息构造代码参考
|
|
45
|
+
```python
|
|
46
|
+
from maim_message import (
|
|
47
|
+
BaseMessageInfo,
|
|
48
|
+
UserInfo,
|
|
49
|
+
GroupInfo,
|
|
50
|
+
FormatInfo,
|
|
51
|
+
TemplateInfo,
|
|
52
|
+
MessageBase,
|
|
53
|
+
Seg
|
|
54
|
+
)
|
|
55
|
+
import asyncio
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def construct_message(platform):
|
|
59
|
+
# 构造消息
|
|
60
|
+
user_info = UserInfo(
|
|
61
|
+
# 必填
|
|
62
|
+
platform=platform,
|
|
63
|
+
user_id='12348765',
|
|
64
|
+
# 选填
|
|
65
|
+
user_nickname="maimai",
|
|
66
|
+
user_cardname="mai god",
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
group_info = GroupInfo(
|
|
70
|
+
# 必填
|
|
71
|
+
platform=platform, # platform请务必保持一致
|
|
72
|
+
group_id='12345678',
|
|
73
|
+
# 选填
|
|
74
|
+
group_name="aaabbb",
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
format_info = FormatInfo(
|
|
78
|
+
# 消息内容中包含的Seg的type列表
|
|
79
|
+
content_format=["text", "image", "emoji", "at", "reply", "voice"],
|
|
80
|
+
# 消息发出后,期望最终的消息中包含的消息类型,可以帮助某些plugin判断是否向消息中添加某些消息类型
|
|
81
|
+
accept_format=["text", "image", "emoji", "reply"],
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
# 暂时不启用,可置None
|
|
85
|
+
template_info_custom = TemplateInfo(
|
|
86
|
+
template_items={
|
|
87
|
+
"detailed_text": "[{user_nickname}({user_nickname})]{user_cardname}: {processed_text}",
|
|
88
|
+
"main_prompt_template": "...",
|
|
89
|
+
},
|
|
90
|
+
template_name="qq123_default",
|
|
91
|
+
template_default=False,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
template_info_default = TemplateInfo(template_default=True)
|
|
95
|
+
|
|
96
|
+
message_info = BaseMessageInfo(
|
|
97
|
+
# 必填
|
|
98
|
+
platform=platform,
|
|
99
|
+
message_id="12345678", # 只会在reply和撤回消息等功能下启用,且可以不保证unique
|
|
100
|
+
time=1234567.001, # 时间戳
|
|
101
|
+
group_info=group_info,
|
|
102
|
+
user_info=user_info,
|
|
103
|
+
# 选填和暂未启用
|
|
104
|
+
format_info=format_info,
|
|
105
|
+
template_info=None,
|
|
106
|
+
additional_config={
|
|
107
|
+
"maimcore_reply_probability_gain": 0.5 # 回复概率增益
|
|
108
|
+
},
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
message_segment = Seg(
|
|
112
|
+
"seglist",
|
|
113
|
+
[
|
|
114
|
+
Seg("text", "111(raw text)"),
|
|
115
|
+
Seg("emoji", "base64(raw base64)"),
|
|
116
|
+
Seg("image", "base64(raw base64)"),
|
|
117
|
+
Seg("at", "111222333(qq number)"),
|
|
118
|
+
Seg("reply", "123456(message id)"),
|
|
119
|
+
Seg("voice", "wip"),
|
|
120
|
+
],
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
raw_message = "可有可无"
|
|
124
|
+
|
|
125
|
+
message = MessageBase(
|
|
126
|
+
# 必填
|
|
127
|
+
message_info=message_info,
|
|
128
|
+
message_segment=message_segment,
|
|
129
|
+
# 选填
|
|
130
|
+
raw_message=raw_message,
|
|
131
|
+
)
|
|
132
|
+
return message
|
|
133
|
+
```
|
|
134
|
+
|
|
135
|
+
## 简要构造一个消息客户端
|
|
136
|
+
涉及到标准消息的构建与客户端的建立,maim_message提供了一个Router类,可用于管理一个客户端程序处理多种不同平台的数据时建立的多个MessageClient,可参考如下。
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
```python
|
|
140
|
+
from maim_message import (
|
|
141
|
+
Router,
|
|
142
|
+
RouteConfig,
|
|
143
|
+
TargetConfig,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
# 配置路由config
|
|
147
|
+
# 从RouteConfig类构建route_config实例
|
|
148
|
+
route_config = RouteConfig(
|
|
149
|
+
#根据TargetConfig类构建一个合法的route_config
|
|
150
|
+
route_config={
|
|
151
|
+
"platform1": TargetConfig(
|
|
152
|
+
url="ws://127.0.0.1:19000/ws",
|
|
153
|
+
token=None, # 如果需要token验证则在这里设置
|
|
154
|
+
),
|
|
155
|
+
# 可配置多个平台连接
|
|
156
|
+
"platform2": TargetConfig(
|
|
157
|
+
url="ws://127.0.0.1:19000/ws",
|
|
158
|
+
token="your_auth_token_here", # 示例:带认证token的连接
|
|
159
|
+
),
|
|
160
|
+
"platform3": TargetConfig(
|
|
161
|
+
url="ws://127.0.0.1:19000/ws",
|
|
162
|
+
token=None,
|
|
163
|
+
),
|
|
164
|
+
}
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# 使用刚刚构建的route_config,从类Router创建路由器实例router
|
|
168
|
+
router = Router(route_config)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
async def main():
|
|
172
|
+
# 使用实例router的方法注册消息处理器
|
|
173
|
+
router.register_class_handler(message_handler) #message_handler示例见下方
|
|
174
|
+
|
|
175
|
+
try:
|
|
176
|
+
# 启动路由器(会自动连接所有配置的平台)
|
|
177
|
+
router_task = asyncio.create_task(router.run())
|
|
178
|
+
|
|
179
|
+
# 等待连接建立
|
|
180
|
+
await asyncio.sleep(2)
|
|
181
|
+
|
|
182
|
+
# 使用router.send_message()方法发送消息
|
|
183
|
+
await router.send_message(construct_message("test"))#接受的参数为MessageBase
|
|
184
|
+
|
|
185
|
+
# 保持运行直到被中断
|
|
186
|
+
await router_task
|
|
187
|
+
|
|
188
|
+
finally:
|
|
189
|
+
print("正在关闭连接...")
|
|
190
|
+
await router.stop()
|
|
191
|
+
print("已关闭所有连接")
|
|
192
|
+
|
|
193
|
+
async def message_handler(message):
|
|
194
|
+
"""
|
|
195
|
+
一个作为示例的消息处理函数
|
|
196
|
+
从mmc发来的消息将会进入此函数
|
|
197
|
+
你需要解析消息,并且向指定平台捏造合适的消息发送
|
|
198
|
+
如将mmc的MessageBase消息转换为onebotV11协议消息发送到QQ
|
|
199
|
+
或者根据其他协议发送到其他平台
|
|
200
|
+
"""
|
|
201
|
+
print(f"收到消息: {message}")
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
if __name__ == "__main__":
|
|
205
|
+
try:
|
|
206
|
+
asyncio.run(main())
|
|
207
|
+
except KeyboardInterrupt:
|
|
208
|
+
pass # 让asyncio.run处理清理工作
|
|
209
|
+
|
|
210
|
+
```
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
## 构造一个maimcore plugin
|
|
214
|
+
实际上只是模仿了maimcore的结构,真正的plugins应该继续向下游发送消息。
|
|
215
|
+
```python
|
|
216
|
+
from maim_message import MessageBase, Seg, MessageServer
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
async def process_seg(seg: Seg):
|
|
220
|
+
"""处理消息段的递归函数"""
|
|
221
|
+
if seg.type == "seglist":
|
|
222
|
+
seglist = seg.data
|
|
223
|
+
for single_seg in seglist:
|
|
224
|
+
await process_seg(single_seg)
|
|
225
|
+
# 实际内容处理逻辑
|
|
226
|
+
if seg.type == "voice":
|
|
227
|
+
seg.type = "text"
|
|
228
|
+
seg.data = "[音频]"
|
|
229
|
+
elif seg.type == "at":
|
|
230
|
+
seg.type = "text"
|
|
231
|
+
seg.data = "[@某人]"
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
async def handle_message(message_data):
|
|
235
|
+
"""消息处理函数"""
|
|
236
|
+
message = MessageBase.from_dict(message_data)
|
|
237
|
+
await process_seg(message.message_segment)
|
|
238
|
+
|
|
239
|
+
# 将处理后的消息广播给所有连接的客户端
|
|
240
|
+
await server.send_message(message)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
if __name__ == "__main__":
|
|
244
|
+
# 创建服务器实例
|
|
245
|
+
server = MessageServer(host="0.0.0.0", port=19000)
|
|
246
|
+
|
|
247
|
+
# 注册消息处理器
|
|
248
|
+
server.register_message_handler(handle_message)
|
|
249
|
+
|
|
250
|
+
# 运行服务器
|
|
251
|
+
server.run_sync()
|
|
252
|
+
|
|
253
|
+
```
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
## 许可证
|
|
257
|
+
|
|
258
|
+
MIT
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
from fastapi import FastAPI, HTTPException
|
|
2
|
+
from pydantic import BaseModel
|
|
3
|
+
import json
|
|
4
|
+
from typing import Optional, Dict, Any, Callable, List
|
|
5
|
+
import aiohttp
|
|
6
|
+
import asyncio
|
|
7
|
+
import uvicorn
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BaseMessageAPI:
|
|
11
|
+
def __init__(self, host: str = "0.0.0.0", port: int = 18000):
|
|
12
|
+
self.app = FastAPI()
|
|
13
|
+
self.host = host
|
|
14
|
+
self.port = port
|
|
15
|
+
self.message_handlers: List[Callable] = []
|
|
16
|
+
self._setup_routes()
|
|
17
|
+
self._running = False
|
|
18
|
+
|
|
19
|
+
def _setup_routes(self):
|
|
20
|
+
"""设置基础路由"""
|
|
21
|
+
|
|
22
|
+
@self.app.post("/api/message")
|
|
23
|
+
async def handle_message(message: Dict[str, Any]):
|
|
24
|
+
# try:
|
|
25
|
+
for handler in self.message_handlers:
|
|
26
|
+
await handler(message)
|
|
27
|
+
return {"status": "success"}
|
|
28
|
+
# except Exception as e:
|
|
29
|
+
# raise HTTPException(status_code=500, detail=str(e)) from e
|
|
30
|
+
|
|
31
|
+
def register_message_handler(self, handler: Callable):
|
|
32
|
+
"""注册消息处理函数"""
|
|
33
|
+
self.message_handlers.append(handler)
|
|
34
|
+
|
|
35
|
+
async def send_message(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
36
|
+
"""发送消息到指定端点"""
|
|
37
|
+
async with aiohttp.ClientSession() as session:
|
|
38
|
+
try:
|
|
39
|
+
async with session.post(
|
|
40
|
+
url, json=data, headers={"Content-Type": "application/json"}
|
|
41
|
+
) as response:
|
|
42
|
+
return await response.json()
|
|
43
|
+
except Exception as e:
|
|
44
|
+
# logger.error(f"发送消息失败: {str(e)}")
|
|
45
|
+
pass
|
|
46
|
+
|
|
47
|
+
def run_sync(self):
|
|
48
|
+
"""同步方式运行服务器"""
|
|
49
|
+
uvicorn.run(self.app, host=self.host, port=self.port)
|
|
50
|
+
|
|
51
|
+
async def run(self):
|
|
52
|
+
"""异步方式运行服务器"""
|
|
53
|
+
config = uvicorn.Config(
|
|
54
|
+
self.app, host=self.host, port=self.port, loop="asyncio"
|
|
55
|
+
)
|
|
56
|
+
self.server = uvicorn.Server(config)
|
|
57
|
+
try:
|
|
58
|
+
await self.server.serve()
|
|
59
|
+
except KeyboardInterrupt as e:
|
|
60
|
+
self.stop()
|
|
61
|
+
raise KeyboardInterrupt from e
|
|
62
|
+
|
|
63
|
+
async def start_server(self):
|
|
64
|
+
"""启动服务器的异步方法"""
|
|
65
|
+
if not self._running:
|
|
66
|
+
self._running = True
|
|
67
|
+
await self.run()
|
|
68
|
+
|
|
69
|
+
async def stop(self):
|
|
70
|
+
"""停止服务器"""
|
|
71
|
+
if hasattr(self, "server"):
|
|
72
|
+
self._running = False
|
|
73
|
+
# 正确关闭 uvicorn 服务器
|
|
74
|
+
self.server.should_exit = True
|
|
75
|
+
await self.server.shutdown()
|
|
76
|
+
# 等待服务器完全停止
|
|
77
|
+
if hasattr(self.server, "started") and self.server.started:
|
|
78
|
+
await self.server.main_loop()
|
|
79
|
+
# 清理处理程序
|
|
80
|
+
self.message_handlers.clear()
|
|
81
|
+
|
|
82
|
+
def start(self):
|
|
83
|
+
"""启动服务器的便捷方法"""
|
|
84
|
+
loop = asyncio.new_event_loop()
|
|
85
|
+
asyncio.set_event_loop(loop)
|
|
86
|
+
try:
|
|
87
|
+
loop.run_until_complete(self.start_server())
|
|
88
|
+
except KeyboardInterrupt:
|
|
89
|
+
pass
|
|
90
|
+
finally:
|
|
91
|
+
loop.close()
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "maim_message"
|
|
3
|
+
version = "0.2.0"
|
|
4
|
+
description = "A message handling library"
|
|
5
|
+
authors = [
|
|
6
|
+
{name = "tcmofashi"}
|
|
7
|
+
]
|
|
8
|
+
requires-python = ">=3.9"
|
|
9
|
+
dependencies = [
|
|
10
|
+
"fastapi>=0.70.0",
|
|
11
|
+
"uvicorn>=0.15.0",
|
|
12
|
+
"aiohttp>=3.8.0",
|
|
13
|
+
"pydantic>=1.9.0",
|
|
14
|
+
"websockets>=10.0",
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
[build-system]
|
|
18
|
+
requires = ["setuptools>=45", "wheel", "setuptools-scm"]
|
|
19
|
+
build-backend = "setuptools.build_meta"
|
|
20
|
+
|
|
21
|
+
[tool.black]
|
|
22
|
+
line-length = 100
|
|
23
|
+
target-version = ['py39']
|
|
24
|
+
include = '\.pyi?$'
|
|
25
|
+
|
|
26
|
+
[tool.isort]
|
|
27
|
+
profile = "black"
|
|
28
|
+
multi_line_output = 3
|
|
29
|
+
line_length = 100
|
|
30
|
+
|
|
31
|
+
[tool.mypy]
|
|
32
|
+
python_version = "3.9"
|
|
33
|
+
strict = true
|
|
34
|
+
ignore_missing_imports = true
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
import setuptools
|
|
2
|
+
|
|
3
|
+
with open("README.md", "r", encoding="utf-8") as fh:
|
|
4
|
+
long_description = fh.read()
|
|
5
|
+
|
|
6
|
+
setuptools.setup(
|
|
7
|
+
name="maim_message",
|
|
8
|
+
version="0.2.0",
|
|
9
|
+
author="tcmofashi",
|
|
10
|
+
url="https://github.com/MaiM-with-u/maim_message",
|
|
11
|
+
author_email="mofashiforzbx@qq.com",
|
|
12
|
+
description="A message handling library for maimcore",
|
|
13
|
+
long_description_content_type="text/markdown",
|
|
14
|
+
package_dir={"": "src"},
|
|
15
|
+
packages=setuptools.find_packages(where="src"),
|
|
16
|
+
python_requires=">=3.9",
|
|
17
|
+
)
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""Maim Message - A message handling library"""
|
|
2
|
+
|
|
3
|
+
__version__ = "0.1.0"
|
|
4
|
+
|
|
5
|
+
from .api import MessageClient, MessageServer
|
|
6
|
+
from .router import Router, RouteConfig, TargetConfig
|
|
7
|
+
from .message_base import (
|
|
8
|
+
Seg,
|
|
9
|
+
GroupInfo,
|
|
10
|
+
UserInfo,
|
|
11
|
+
FormatInfo,
|
|
12
|
+
TemplateInfo,
|
|
13
|
+
BaseMessageInfo,
|
|
14
|
+
MessageBase,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"MessageClient",
|
|
19
|
+
"MessageServer",
|
|
20
|
+
"Router",
|
|
21
|
+
"RouteConfig",
|
|
22
|
+
"TargetConfig",
|
|
23
|
+
"Seg",
|
|
24
|
+
"GroupInfo",
|
|
25
|
+
"UserInfo",
|
|
26
|
+
"FormatInfo",
|
|
27
|
+
"TemplateInfo",
|
|
28
|
+
"BaseMessageInfo",
|
|
29
|
+
"MessageBase",
|
|
30
|
+
]
|
|
@@ -0,0 +1,325 @@
|
|
|
1
|
+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
|
2
|
+
from typing import Dict, Any, Callable, List, Set, Optional
|
|
3
|
+
import aiohttp
|
|
4
|
+
import asyncio
|
|
5
|
+
import uvicorn
|
|
6
|
+
import json
|
|
7
|
+
from .message_base import MessageBase
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BaseMessageHandler:
|
|
11
|
+
"""消息处理基类"""
|
|
12
|
+
|
|
13
|
+
def __init__(self):
|
|
14
|
+
self.message_handlers: List[Callable] = []
|
|
15
|
+
self.background_tasks = set()
|
|
16
|
+
|
|
17
|
+
def register_message_handler(self, handler: Callable):
|
|
18
|
+
"""注册消息处理函数"""
|
|
19
|
+
self.message_handlers.append(handler)
|
|
20
|
+
|
|
21
|
+
async def process_message(self, message: Dict[str, Any]):
|
|
22
|
+
"""处理单条消息"""
|
|
23
|
+
tasks = []
|
|
24
|
+
for handler in self.message_handlers:
|
|
25
|
+
try:
|
|
26
|
+
tasks.append(handler(message))
|
|
27
|
+
except Exception as e:
|
|
28
|
+
raise RuntimeError(str(e)) from e
|
|
29
|
+
if tasks:
|
|
30
|
+
await asyncio.gather(*tasks, return_exceptions=True)
|
|
31
|
+
|
|
32
|
+
async def _handle_message(self, message: Dict[str, Any]):
|
|
33
|
+
"""后台处理单个消息"""
|
|
34
|
+
try:
|
|
35
|
+
await self.process_message(message)
|
|
36
|
+
except Exception as e:
|
|
37
|
+
raise RuntimeError(str(e)) from e
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class MessageServer(BaseMessageHandler):
|
|
41
|
+
"""WebSocket服务端"""
|
|
42
|
+
|
|
43
|
+
_class_handlers: List[Callable] = [] # 类级别的消息处理器
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
host: str = "0.0.0.0",
|
|
48
|
+
port: int = 18000,
|
|
49
|
+
enable_token=False,
|
|
50
|
+
app: Optional[FastAPI] = None,
|
|
51
|
+
path: str = "/ws",
|
|
52
|
+
):
|
|
53
|
+
super().__init__()
|
|
54
|
+
# 将类级别的处理器添加到实例处理器中
|
|
55
|
+
self.message_handlers.extend(self._class_handlers)
|
|
56
|
+
self.host = host
|
|
57
|
+
self.port = port
|
|
58
|
+
self.path = path
|
|
59
|
+
self.app = app or FastAPI()
|
|
60
|
+
self.own_app = app is None # 标记是否使用自己创建的app
|
|
61
|
+
self.active_websockets: Set[WebSocket] = set()
|
|
62
|
+
self.platform_websockets: Dict[str, WebSocket] = {} # 平台到websocket的映射
|
|
63
|
+
self.valid_tokens: Set[str] = set()
|
|
64
|
+
self.enable_token = enable_token
|
|
65
|
+
self._setup_routes()
|
|
66
|
+
self._running = False
|
|
67
|
+
|
|
68
|
+
@classmethod
|
|
69
|
+
def register_class_handler(cls, handler: Callable):
|
|
70
|
+
"""注册类级别的消息处理器"""
|
|
71
|
+
if handler not in cls._class_handlers:
|
|
72
|
+
cls._class_handlers.append(handler)
|
|
73
|
+
|
|
74
|
+
def register_message_handler(self, handler: Callable):
|
|
75
|
+
"""注册实例级别的消息处理器"""
|
|
76
|
+
if handler not in self.message_handlers:
|
|
77
|
+
self.message_handlers.append(handler)
|
|
78
|
+
|
|
79
|
+
async def verify_token(self, token: str) -> bool:
|
|
80
|
+
if not self.enable_token:
|
|
81
|
+
return True
|
|
82
|
+
return token in self.valid_tokens
|
|
83
|
+
|
|
84
|
+
def add_valid_token(self, token: str):
|
|
85
|
+
self.valid_tokens.add(token)
|
|
86
|
+
|
|
87
|
+
def remove_valid_token(self, token: str):
|
|
88
|
+
self.valid_tokens.discard(token)
|
|
89
|
+
|
|
90
|
+
def _setup_routes(self):
|
|
91
|
+
"""设置WebSocket路由"""
|
|
92
|
+
|
|
93
|
+
# 使用传入的path作为WebSocket endpoint
|
|
94
|
+
@self.app.websocket(self.path)
|
|
95
|
+
async def websocket_endpoint(websocket: WebSocket):
|
|
96
|
+
headers = dict(websocket.headers)
|
|
97
|
+
token = headers.get("authorization")
|
|
98
|
+
platform = headers.get("platform", "default")
|
|
99
|
+
if self.enable_token:
|
|
100
|
+
if not token or not await self.verify_token(token):
|
|
101
|
+
await websocket.close(code=1008, reason="Invalid or missing token")
|
|
102
|
+
return
|
|
103
|
+
|
|
104
|
+
await websocket.accept()
|
|
105
|
+
self.active_websockets.add(websocket)
|
|
106
|
+
|
|
107
|
+
# 添加到platform映射
|
|
108
|
+
if platform not in self.platform_websockets:
|
|
109
|
+
self.platform_websockets[platform] = websocket
|
|
110
|
+
|
|
111
|
+
try:
|
|
112
|
+
while True:
|
|
113
|
+
message = await websocket.receive_json()
|
|
114
|
+
asyncio.create_task(self._handle_message(message))
|
|
115
|
+
except WebSocketDisconnect:
|
|
116
|
+
self._remove_websocket(websocket, platform)
|
|
117
|
+
except Exception as e:
|
|
118
|
+
self._remove_websocket(websocket, platform)
|
|
119
|
+
raise RuntimeError(str(e)) from e
|
|
120
|
+
finally:
|
|
121
|
+
self._remove_websocket(websocket, platform)
|
|
122
|
+
|
|
123
|
+
def _remove_websocket(self, websocket: WebSocket, platform: str):
|
|
124
|
+
"""从所有集合中移除websocket"""
|
|
125
|
+
if websocket in self.active_websockets:
|
|
126
|
+
self.active_websockets.remove(websocket)
|
|
127
|
+
if platform in self.platform_websockets:
|
|
128
|
+
if self.platform_websockets[platform] == websocket:
|
|
129
|
+
del self.platform_websockets[platform]
|
|
130
|
+
|
|
131
|
+
async def broadcast_message(self, message: Dict[str, Any]):
|
|
132
|
+
disconnected = set()
|
|
133
|
+
for websocket in self.active_websockets:
|
|
134
|
+
try:
|
|
135
|
+
await websocket.send_json(message)
|
|
136
|
+
except Exception:
|
|
137
|
+
disconnected.add(websocket)
|
|
138
|
+
for websocket in disconnected:
|
|
139
|
+
self.active_websockets.remove(websocket)
|
|
140
|
+
|
|
141
|
+
async def broadcast_to_platform(self, platform: str, message: Dict[str, Any]):
|
|
142
|
+
"""向指定平台的所有WebSocket客户端广播消息"""
|
|
143
|
+
if platform not in self.platform_websockets:
|
|
144
|
+
return
|
|
145
|
+
|
|
146
|
+
disconnected = set()
|
|
147
|
+
try:
|
|
148
|
+
await self.platform_websockets[platform].send_json(message)
|
|
149
|
+
except Exception:
|
|
150
|
+
disconnected.add(self.platform_websockets[platform])
|
|
151
|
+
|
|
152
|
+
# 清理断开的连接
|
|
153
|
+
for websocket in disconnected:
|
|
154
|
+
self._remove_websocket(websocket, platform)
|
|
155
|
+
|
|
156
|
+
async def send_message(self, message: MessageBase):
|
|
157
|
+
await self.broadcast_to_platform(
|
|
158
|
+
message.message_info.platform, message.to_dict()
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
def run_sync(self):
|
|
162
|
+
"""同步方式运行服务器"""
|
|
163
|
+
if not self.own_app:
|
|
164
|
+
raise RuntimeError("当使用外部FastAPI实例时,请使用该实例的运行方法")
|
|
165
|
+
uvicorn.run(self.app, host=self.host, port=self.port)
|
|
166
|
+
|
|
167
|
+
async def run(self):
|
|
168
|
+
"""异步方式运行服务器"""
|
|
169
|
+
self._running = True
|
|
170
|
+
try:
|
|
171
|
+
if self.own_app:
|
|
172
|
+
# 如果使用自己的 FastAPI 实例,运行 uvicorn 服务器
|
|
173
|
+
config = uvicorn.Config(
|
|
174
|
+
self.app, host=self.host, port=self.port, loop="asyncio"
|
|
175
|
+
)
|
|
176
|
+
self.server = uvicorn.Server(config)
|
|
177
|
+
await self.server.serve()
|
|
178
|
+
else:
|
|
179
|
+
# 如果使用外部 FastAPI 实例,保持运行状态以处理消息
|
|
180
|
+
while self._running:
|
|
181
|
+
await asyncio.sleep(1)
|
|
182
|
+
except KeyboardInterrupt:
|
|
183
|
+
await self.stop()
|
|
184
|
+
raise
|
|
185
|
+
except Exception as e:
|
|
186
|
+
await self.stop()
|
|
187
|
+
raise RuntimeError(f"服务器运行错误: {str(e)}") from e
|
|
188
|
+
finally:
|
|
189
|
+
await self.stop()
|
|
190
|
+
|
|
191
|
+
async def start_server(self):
|
|
192
|
+
"""启动服务器的异步方法"""
|
|
193
|
+
if not self._running:
|
|
194
|
+
self._running = True
|
|
195
|
+
await self.run()
|
|
196
|
+
|
|
197
|
+
async def stop(self):
|
|
198
|
+
"""停止服务器"""
|
|
199
|
+
# 清理platform映射
|
|
200
|
+
self.platform_websockets.clear()
|
|
201
|
+
|
|
202
|
+
# 取消所有后台任务
|
|
203
|
+
for task in self.background_tasks:
|
|
204
|
+
task.cancel()
|
|
205
|
+
# 等待所有任务完成
|
|
206
|
+
await asyncio.gather(*self.background_tasks, return_exceptions=True)
|
|
207
|
+
self.background_tasks.clear()
|
|
208
|
+
|
|
209
|
+
# 关闭所有WebSocket连接
|
|
210
|
+
for websocket in self.active_websockets:
|
|
211
|
+
await websocket.close()
|
|
212
|
+
self.active_websockets.clear()
|
|
213
|
+
|
|
214
|
+
if hasattr(self, "server") and self.own_app:
|
|
215
|
+
self._running = False
|
|
216
|
+
# 正确关闭 uvicorn 服务器
|
|
217
|
+
self.server.should_exit = True
|
|
218
|
+
await self.server.shutdown()
|
|
219
|
+
# 等待服务器完全停止
|
|
220
|
+
if hasattr(self.server, "started") and self.server.started:
|
|
221
|
+
await self.server.main_loop()
|
|
222
|
+
# 清理处理程序
|
|
223
|
+
self.message_handlers.clear()
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
class MessageClient(BaseMessageHandler):
|
|
227
|
+
"""WebSocket客户端"""
|
|
228
|
+
|
|
229
|
+
_class_handlers: List[Callable] = [] # 类级别的消息处理器
|
|
230
|
+
|
|
231
|
+
def __init__(self):
|
|
232
|
+
super().__init__()
|
|
233
|
+
self.message_handlers.extend(self._class_handlers)
|
|
234
|
+
self.platform = None
|
|
235
|
+
self.remote_ws = None
|
|
236
|
+
self.remote_ws_url = None
|
|
237
|
+
self.remote_ws_token = None
|
|
238
|
+
self.remote_ws_connected = False
|
|
239
|
+
self.remote_reconnect_interval = 5
|
|
240
|
+
self._running = False
|
|
241
|
+
self.retry_count = 0
|
|
242
|
+
|
|
243
|
+
@classmethod
|
|
244
|
+
def register_class_handler(cls, handler: Callable):
|
|
245
|
+
"""注册类级别的消息处理器"""
|
|
246
|
+
if handler not in cls._class_handlers:
|
|
247
|
+
cls._class_handlers.append(handler)
|
|
248
|
+
|
|
249
|
+
def register_message_handler(self, handler: Callable):
|
|
250
|
+
"""注册实例级别的消息处理器"""
|
|
251
|
+
if handler not in self.message_handlers:
|
|
252
|
+
self.message_handlers.append(handler)
|
|
253
|
+
|
|
254
|
+
async def connect(self, url: str, platform: str, token: Optional[str] = None):
|
|
255
|
+
"""设置连接参数"""
|
|
256
|
+
self.remote_ws_url = url
|
|
257
|
+
self.remote_ws_token = token
|
|
258
|
+
self.platform = platform
|
|
259
|
+
self._running = True
|
|
260
|
+
|
|
261
|
+
async def run(self):
|
|
262
|
+
"""维持websocket连接和消息处理"""
|
|
263
|
+
self.retry_count = 0
|
|
264
|
+
headers = {"platform": self.platform}
|
|
265
|
+
if self.remote_ws_token:
|
|
266
|
+
headers["Authorization"] = str(self.remote_ws_token)
|
|
267
|
+
|
|
268
|
+
while self._running:
|
|
269
|
+
try:
|
|
270
|
+
print(f"正在连接到 {self.remote_ws_url}")
|
|
271
|
+
async with aiohttp.ClientSession() as session:
|
|
272
|
+
ws = await session.ws_connect(self.remote_ws_url, headers=headers)
|
|
273
|
+
self.remote_ws = ws
|
|
274
|
+
self.remote_ws_connected = True
|
|
275
|
+
print(f"已连接到 {self.remote_ws_url}")
|
|
276
|
+
self.retry_count = 0
|
|
277
|
+
|
|
278
|
+
try:
|
|
279
|
+
async for msg in ws:
|
|
280
|
+
if not self._running:
|
|
281
|
+
break
|
|
282
|
+
if msg.type == aiohttp.WSMsgType.TEXT:
|
|
283
|
+
try:
|
|
284
|
+
message = msg.json()
|
|
285
|
+
asyncio.create_task(self._handle_message(message))
|
|
286
|
+
except json.JSONDecodeError as e:
|
|
287
|
+
print(f"收到无效的JSON消息: {e}")
|
|
288
|
+
elif msg.type in (
|
|
289
|
+
aiohttp.WSMsgType.CLOSED,
|
|
290
|
+
aiohttp.WSMsgType.ERROR,
|
|
291
|
+
):
|
|
292
|
+
break
|
|
293
|
+
finally:
|
|
294
|
+
await ws.close()
|
|
295
|
+
|
|
296
|
+
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
|
|
297
|
+
print(f"连接失败 ({self.retry_count}): {e}")
|
|
298
|
+
self.retry_count += 1
|
|
299
|
+
except asyncio.CancelledError:
|
|
300
|
+
break
|
|
301
|
+
finally:
|
|
302
|
+
self.remote_ws_connected = False
|
|
303
|
+
self.remote_ws = None
|
|
304
|
+
|
|
305
|
+
if self._running:
|
|
306
|
+
await asyncio.sleep(self.remote_reconnect_interval)
|
|
307
|
+
|
|
308
|
+
async def stop(self):
|
|
309
|
+
"""停止客户端"""
|
|
310
|
+
self._running = False
|
|
311
|
+
if self.remote_ws and not self.remote_ws.closed:
|
|
312
|
+
await self.remote_ws.close()
|
|
313
|
+
self.remote_ws_connected = False
|
|
314
|
+
self.remote_ws = None
|
|
315
|
+
|
|
316
|
+
async def send_message(self, message: Dict[str, Any]) -> bool:
|
|
317
|
+
"""发送消息到服务器"""
|
|
318
|
+
if not self.remote_ws_connected:
|
|
319
|
+
raise RuntimeError("未连接到服务器")
|
|
320
|
+
try:
|
|
321
|
+
await self.remote_ws.send_json(message)
|
|
322
|
+
return True
|
|
323
|
+
except Exception as e:
|
|
324
|
+
self.remote_ws_connected = False
|
|
325
|
+
raise RuntimeError(f"发送消息失败: {e}") from e
|
|
@@ -0,0 +1,257 @@
|
|
|
1
|
+
from dataclasses import dataclass, asdict
|
|
2
|
+
from typing import List, Optional, Union, Dict
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
@dataclass
|
|
6
|
+
class Seg:
|
|
7
|
+
"""消息片段类,用于表示消息的不同部分
|
|
8
|
+
|
|
9
|
+
Attributes:
|
|
10
|
+
type: 片段类型,可以是 'text'、'image'、'seglist' 等
|
|
11
|
+
data: 片段的具体内容
|
|
12
|
+
- 对于 text 类型,data 是字符串
|
|
13
|
+
- 对于 image 类型,data 是 base64 字符串
|
|
14
|
+
- 对于 seglist 类型,data 是 Seg 列表
|
|
15
|
+
translated_data: 经过翻译处理的数据(可选)
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
type: str
|
|
19
|
+
data: Union[str, List["Seg"]]
|
|
20
|
+
|
|
21
|
+
# def __init__(self, type: str, data: Union[str, List['Seg']],):
|
|
22
|
+
# """初始化实例,确保字典和属性同步"""
|
|
23
|
+
# # 先初始化字典
|
|
24
|
+
# self.type = type
|
|
25
|
+
# self.data = data
|
|
26
|
+
|
|
27
|
+
@classmethod
|
|
28
|
+
def from_dict(cls, data: Dict) -> "Seg":
|
|
29
|
+
"""从字典创建Seg实例"""
|
|
30
|
+
type = data.get("type")
|
|
31
|
+
data = data.get("data")
|
|
32
|
+
if type == "seglist":
|
|
33
|
+
data = [Seg.from_dict(seg) for seg in data]
|
|
34
|
+
return cls(type=type, data=data)
|
|
35
|
+
|
|
36
|
+
def to_dict(self) -> Dict:
|
|
37
|
+
"""转换为字典格式"""
|
|
38
|
+
result = {"type": self.type}
|
|
39
|
+
if self.type == "seglist":
|
|
40
|
+
result["data"] = [seg.to_dict() for seg in self.data]
|
|
41
|
+
else:
|
|
42
|
+
result["data"] = self.data
|
|
43
|
+
return result
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class GroupInfo:
|
|
48
|
+
"""群组信息类"""
|
|
49
|
+
|
|
50
|
+
platform: Optional[str] = None
|
|
51
|
+
group_id: Optional[str] = None
|
|
52
|
+
group_name: Optional[str] = None # 群名称
|
|
53
|
+
|
|
54
|
+
def to_dict(self) -> Dict:
|
|
55
|
+
"""转换为字典格式"""
|
|
56
|
+
return {k: v for k, v in asdict(self).items() if v is not None}
|
|
57
|
+
|
|
58
|
+
@classmethod
|
|
59
|
+
def from_dict(cls, data: Dict) -> "GroupInfo":
|
|
60
|
+
"""从字典创建GroupInfo实例
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
data: 包含必要字段的字典
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
GroupInfo: 新的实例
|
|
67
|
+
"""
|
|
68
|
+
if data.get("group_id") is None:
|
|
69
|
+
return None
|
|
70
|
+
return cls(
|
|
71
|
+
platform=data.get("platform"),
|
|
72
|
+
group_id=data.get("group_id"),
|
|
73
|
+
group_name=data.get("group_name", None),
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@dataclass
|
|
78
|
+
class UserInfo:
|
|
79
|
+
"""用户信息类"""
|
|
80
|
+
|
|
81
|
+
platform: Optional[str] = None
|
|
82
|
+
user_id: Optional[str] = None
|
|
83
|
+
user_nickname: Optional[str] = None # 用户昵称
|
|
84
|
+
user_cardname: Optional[str] = None # 用户群昵称
|
|
85
|
+
|
|
86
|
+
def to_dict(self) -> Dict:
|
|
87
|
+
"""转换为字典格式"""
|
|
88
|
+
return {k: v for k, v in asdict(self).items() if v is not None}
|
|
89
|
+
|
|
90
|
+
@classmethod
|
|
91
|
+
def from_dict(cls, data: Dict) -> "UserInfo":
|
|
92
|
+
"""从字典创建UserInfo实例
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
data: 包含必要字段的字典
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
UserInfo: 新的实例
|
|
99
|
+
"""
|
|
100
|
+
return cls(
|
|
101
|
+
platform=data.get("platform"),
|
|
102
|
+
user_id=data.get("user_id"),
|
|
103
|
+
user_nickname=data.get("user_nickname", None),
|
|
104
|
+
user_cardname=data.get("user_cardname", None),
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
@dataclass
|
|
109
|
+
class FormatInfo:
|
|
110
|
+
"""格式信息类"""
|
|
111
|
+
|
|
112
|
+
"""
|
|
113
|
+
目前maimcore可接受的格式为text,image,emoji
|
|
114
|
+
可发送的格式为text,emoji,reply
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
content_format: Optional[List["str"]] = None
|
|
118
|
+
accept_format: Optional[List["str"]] = None
|
|
119
|
+
|
|
120
|
+
def to_dict(self) -> Dict:
|
|
121
|
+
"""转换为字典格式"""
|
|
122
|
+
return {k: v for k, v in asdict(self).items() if v is not None}
|
|
123
|
+
|
|
124
|
+
@classmethod
|
|
125
|
+
def from_dict(cls, data: Dict) -> "FormatInfo":
|
|
126
|
+
"""从字典创建FormatInfo实例
|
|
127
|
+
Args:
|
|
128
|
+
data: 包含必要字段的字典
|
|
129
|
+
Returns:
|
|
130
|
+
FormatInfo: 新的实例
|
|
131
|
+
"""
|
|
132
|
+
return cls(
|
|
133
|
+
content_format=data.get("content_format"),
|
|
134
|
+
accept_format=data.get("accept_format"),
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
@dataclass
|
|
139
|
+
class TemplateInfo:
|
|
140
|
+
"""模板信息类"""
|
|
141
|
+
|
|
142
|
+
template_items: Optional[Dict[str, str]] = None
|
|
143
|
+
template_name: Optional[Dict[str, str]] = None
|
|
144
|
+
template_default: bool = True
|
|
145
|
+
|
|
146
|
+
def to_dict(self) -> Dict:
|
|
147
|
+
"""转换为字典格式"""
|
|
148
|
+
return {k: v for k, v in asdict(self).items() if v is not None}
|
|
149
|
+
|
|
150
|
+
@classmethod
|
|
151
|
+
def from_dict(cls, data: Dict) -> "TemplateInfo":
|
|
152
|
+
"""从字典创建TemplateInfo实例
|
|
153
|
+
Args:
|
|
154
|
+
data: 包含必要字段的字典
|
|
155
|
+
Returns:
|
|
156
|
+
TemplateInfo: 新的实例
|
|
157
|
+
"""
|
|
158
|
+
return cls(
|
|
159
|
+
template_items=data.get("template_items"),
|
|
160
|
+
template_name=data.get("template_name"),
|
|
161
|
+
template_default=data.get("template_default", True),
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
@dataclass
|
|
166
|
+
class BaseMessageInfo:
|
|
167
|
+
"""消息信息类"""
|
|
168
|
+
|
|
169
|
+
platform: Optional[str] = None
|
|
170
|
+
message_id: Optional[str] = None
|
|
171
|
+
time: Optional[float] = None
|
|
172
|
+
group_info: Optional[GroupInfo] = None
|
|
173
|
+
user_info: Optional[UserInfo] = None
|
|
174
|
+
format_info: Optional[FormatInfo] = None
|
|
175
|
+
template_info: Optional[TemplateInfo] = None
|
|
176
|
+
additional_config: Optional[dict] = None
|
|
177
|
+
|
|
178
|
+
def to_dict(self) -> Dict:
|
|
179
|
+
"""转换为字典格式"""
|
|
180
|
+
result = {}
|
|
181
|
+
for field, value in asdict(self).items():
|
|
182
|
+
if value is not None:
|
|
183
|
+
if isinstance(value, (GroupInfo, UserInfo, FormatInfo, TemplateInfo)):
|
|
184
|
+
result[field] = value.to_dict()
|
|
185
|
+
else:
|
|
186
|
+
result[field] = value
|
|
187
|
+
return result
|
|
188
|
+
|
|
189
|
+
@classmethod
|
|
190
|
+
def from_dict(cls, data: Dict) -> "BaseMessageInfo":
|
|
191
|
+
"""从字典创建BaseMessageInfo实例
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
data: 包含必要字段的字典
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
BaseMessageInfo: 新的实例
|
|
198
|
+
"""
|
|
199
|
+
group_info = GroupInfo.from_dict(data.get("group_info", {}))
|
|
200
|
+
user_info = UserInfo.from_dict(data.get("user_info", {}))
|
|
201
|
+
format_info = FormatInfo.from_dict(data.get("format_info", {}))
|
|
202
|
+
template_info = TemplateInfo.from_dict(data.get("template_info", {}))
|
|
203
|
+
return cls(
|
|
204
|
+
platform=data.get("platform"),
|
|
205
|
+
message_id=data.get("message_id"),
|
|
206
|
+
time=data.get("time"),
|
|
207
|
+
additional_config=data.get("additional_config", None),
|
|
208
|
+
group_info=group_info,
|
|
209
|
+
user_info=user_info,
|
|
210
|
+
format_info=format_info,
|
|
211
|
+
template_info=template_info,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
@dataclass
|
|
216
|
+
class MessageBase:
|
|
217
|
+
"""消息类"""
|
|
218
|
+
|
|
219
|
+
message_info: BaseMessageInfo
|
|
220
|
+
message_segment: Seg
|
|
221
|
+
raw_message: Optional[str] = None # 原始消息,包含未解析的cq码
|
|
222
|
+
|
|
223
|
+
def to_dict(self) -> Dict:
|
|
224
|
+
"""转换为字典格式
|
|
225
|
+
|
|
226
|
+
Returns:
|
|
227
|
+
Dict: 包含所有非None字段的字典,其中:
|
|
228
|
+
- message_info: 转换为字典格式
|
|
229
|
+
- message_segment: 转换为字典格式
|
|
230
|
+
- raw_message: 如果存在则包含
|
|
231
|
+
"""
|
|
232
|
+
result = {
|
|
233
|
+
"message_info": self.message_info.to_dict(),
|
|
234
|
+
"message_segment": self.message_segment.to_dict(),
|
|
235
|
+
}
|
|
236
|
+
if self.raw_message is not None:
|
|
237
|
+
result["raw_message"] = self.raw_message
|
|
238
|
+
return result
|
|
239
|
+
|
|
240
|
+
@classmethod
|
|
241
|
+
def from_dict(cls, data: Dict) -> "MessageBase":
|
|
242
|
+
"""从字典创建MessageBase实例
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
data: 包含必要字段的字典
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
MessageBase: 新的实例
|
|
249
|
+
"""
|
|
250
|
+
message_info = BaseMessageInfo.from_dict(data.get("message_info", {}))
|
|
251
|
+
message_segment = Seg.from_dict(data.get("message_segment", {}))
|
|
252
|
+
raw_message = data.get("raw_message", None)
|
|
253
|
+
return cls(
|
|
254
|
+
message_info=message_info,
|
|
255
|
+
message_segment=message_segment,
|
|
256
|
+
raw_message=raw_message,
|
|
257
|
+
)
|
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
from typing import Optional, Dict, Any, Callable, List, Set
|
|
2
|
+
from dataclasses import dataclass, asdict
|
|
3
|
+
from .message_base import MessageBase
|
|
4
|
+
from .api import MessageClient
|
|
5
|
+
from fastapi import WebSocket
|
|
6
|
+
import asyncio
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class TargetConfig:
|
|
11
|
+
url: str = None
|
|
12
|
+
token: Optional[str] = None
|
|
13
|
+
|
|
14
|
+
def to_dict(self) -> Dict:
|
|
15
|
+
return asdict(self)
|
|
16
|
+
|
|
17
|
+
@classmethod
|
|
18
|
+
def from_dict(cls, data: Dict) -> "TargetConfig":
|
|
19
|
+
return cls(url=data.get("url"), token=data.get("token"))
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class RouteConfig:
|
|
24
|
+
route_config: Dict[str, TargetConfig] = None
|
|
25
|
+
|
|
26
|
+
def to_dict(self) -> Dict:
|
|
27
|
+
return asdict(self)
|
|
28
|
+
|
|
29
|
+
@classmethod
|
|
30
|
+
def from_dict(cls, data: Dict) -> "RouteConfig":
|
|
31
|
+
route_config = data.get("route_config")
|
|
32
|
+
for k in route_config.keys():
|
|
33
|
+
route_config[k] = TargetConfig.from_dict(route_config[k])
|
|
34
|
+
return cls(route_config=route_config)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Router:
|
|
38
|
+
def __init__(self, config: RouteConfig):
|
|
39
|
+
self.config = config
|
|
40
|
+
self.clients: Dict[str, MessageClient] = {}
|
|
41
|
+
self._running = False
|
|
42
|
+
self._client_tasks: Dict[str, asyncio.Task] = {}
|
|
43
|
+
self._monitor_task = None
|
|
44
|
+
|
|
45
|
+
async def _monitor_connections(self):
|
|
46
|
+
"""监控所有客户端连接状态"""
|
|
47
|
+
await asyncio.sleep(3) # 等待初始连接建立
|
|
48
|
+
while self._running:
|
|
49
|
+
for platform in list(self.clients.keys()):
|
|
50
|
+
client = self.clients[platform]
|
|
51
|
+
if not client.remote_ws_connected:
|
|
52
|
+
print(f"检测到平台 {platform} 连接断开,尝试重连...")
|
|
53
|
+
await self._reconnect_platform(platform)
|
|
54
|
+
await asyncio.sleep(5) # 每5秒检查一次
|
|
55
|
+
|
|
56
|
+
async def _reconnect_platform(self, platform: str):
|
|
57
|
+
"""重新连接指定平台"""
|
|
58
|
+
if platform in self._client_tasks:
|
|
59
|
+
task = self._client_tasks[platform]
|
|
60
|
+
if not task.done():
|
|
61
|
+
task.cancel()
|
|
62
|
+
await asyncio.gather(task, return_exceptions=True)
|
|
63
|
+
del self._client_tasks[platform]
|
|
64
|
+
|
|
65
|
+
if platform in self.clients:
|
|
66
|
+
await self.clients[platform].stop()
|
|
67
|
+
del self.clients[platform]
|
|
68
|
+
|
|
69
|
+
await self.connect(platform)
|
|
70
|
+
|
|
71
|
+
async def add_platform(self, platform: str, config: TargetConfig):
|
|
72
|
+
"""动态添加新平台"""
|
|
73
|
+
self.config.route_config[platform] = config
|
|
74
|
+
if self._running:
|
|
75
|
+
await self.connect(platform)
|
|
76
|
+
|
|
77
|
+
async def remove_platform(self, platform: str):
|
|
78
|
+
"""动态移除平台"""
|
|
79
|
+
if platform in self.config.route_config:
|
|
80
|
+
del self.config.route_config[platform]
|
|
81
|
+
|
|
82
|
+
if platform in self._client_tasks:
|
|
83
|
+
task = self._client_tasks[platform]
|
|
84
|
+
if not task.done():
|
|
85
|
+
task.cancel()
|
|
86
|
+
await asyncio.gather(task, return_exceptions=True)
|
|
87
|
+
del self._client_tasks[platform]
|
|
88
|
+
|
|
89
|
+
if platform in self.clients:
|
|
90
|
+
await self.clients[platform].stop()
|
|
91
|
+
del self.clients[platform]
|
|
92
|
+
|
|
93
|
+
async def connect(self, platform: str):
|
|
94
|
+
"""连接指定平台"""
|
|
95
|
+
if platform not in self.config.route_config:
|
|
96
|
+
raise ValueError(f"未找到平台配置: {platform}")
|
|
97
|
+
|
|
98
|
+
config = self.config.route_config[platform]
|
|
99
|
+
client = MessageClient()
|
|
100
|
+
await client.connect(config.url, platform, config.token)
|
|
101
|
+
self.clients[platform] = client
|
|
102
|
+
|
|
103
|
+
if self._running:
|
|
104
|
+
self._client_tasks[platform] = asyncio.create_task(client.run())
|
|
105
|
+
|
|
106
|
+
async def run(self):
|
|
107
|
+
"""运行所有客户端连接"""
|
|
108
|
+
self._running = True
|
|
109
|
+
try:
|
|
110
|
+
# 初始化所有平台的连接
|
|
111
|
+
for platform in self.config.route_config:
|
|
112
|
+
if platform not in self.clients:
|
|
113
|
+
await self.connect(platform)
|
|
114
|
+
|
|
115
|
+
# 启动连接监控
|
|
116
|
+
self._monitor_task = asyncio.create_task(self._monitor_connections())
|
|
117
|
+
|
|
118
|
+
# 等待运行状态改变
|
|
119
|
+
while self._running:
|
|
120
|
+
await asyncio.sleep(1)
|
|
121
|
+
|
|
122
|
+
except asyncio.CancelledError:
|
|
123
|
+
await self.stop()
|
|
124
|
+
finally:
|
|
125
|
+
if self._monitor_task:
|
|
126
|
+
self._monitor_task.cancel()
|
|
127
|
+
await asyncio.gather(self._monitor_task, return_exceptions=True)
|
|
128
|
+
|
|
129
|
+
async def stop(self):
|
|
130
|
+
"""停止所有客户端"""
|
|
131
|
+
self._running = False
|
|
132
|
+
|
|
133
|
+
# 取消监控任务
|
|
134
|
+
if self._monitor_task and not self._monitor_task.done():
|
|
135
|
+
self._monitor_task.cancel()
|
|
136
|
+
await asyncio.gather(self._monitor_task, return_exceptions=True)
|
|
137
|
+
|
|
138
|
+
# 先取消所有后台任务
|
|
139
|
+
for task in self._client_tasks.values():
|
|
140
|
+
if not task.done():
|
|
141
|
+
task.cancel()
|
|
142
|
+
|
|
143
|
+
# 等待任务取消完成
|
|
144
|
+
if self._client_tasks:
|
|
145
|
+
await asyncio.gather(*self._client_tasks.values(), return_exceptions=True)
|
|
146
|
+
self._client_tasks.clear()
|
|
147
|
+
|
|
148
|
+
# 然后停止所有客户端
|
|
149
|
+
stop_tasks = []
|
|
150
|
+
for client in self.clients.values():
|
|
151
|
+
stop_tasks.append(client.stop())
|
|
152
|
+
if stop_tasks:
|
|
153
|
+
await asyncio.gather(*stop_tasks, return_exceptions=True)
|
|
154
|
+
|
|
155
|
+
self.clients.clear()
|
|
156
|
+
|
|
157
|
+
def register_class_handler(self, handler):
|
|
158
|
+
MessageClient.register_class_handler(handler)
|
|
159
|
+
|
|
160
|
+
def get_target_url(self, message: MessageBase):
|
|
161
|
+
platform = message.message_info.platform
|
|
162
|
+
if platform in self.config.route_config.keys():
|
|
163
|
+
return self.config.route_config[platform].url
|
|
164
|
+
else:
|
|
165
|
+
return None
|
|
166
|
+
|
|
167
|
+
async def send_message(self, message: MessageBase):
|
|
168
|
+
url = self.get_target_url(message)
|
|
169
|
+
platform = message.message_info.platform
|
|
170
|
+
if url is None:
|
|
171
|
+
raise ValueError(f"不存在该平台url配置: {platform}")
|
|
172
|
+
if platform not in self.clients.keys():
|
|
173
|
+
client = MessageClient()
|
|
174
|
+
await client.connect(
|
|
175
|
+
url, platform, self.config.route_config[platform].token
|
|
176
|
+
)
|
|
177
|
+
self.clients[platform] = client
|
|
178
|
+
await self.clients[platform].send_message(message.to_dict())
|
|
179
|
+
|
|
180
|
+
async def _adjust_connections(self, new_config: RouteConfig):
|
|
181
|
+
"""根据新配置调整连接"""
|
|
182
|
+
# 获取新旧配置的平台集合
|
|
183
|
+
old_platforms = set(self.config.route_config.keys())
|
|
184
|
+
new_platforms = set(new_config.route_config.keys())
|
|
185
|
+
|
|
186
|
+
# 需要移除的平台
|
|
187
|
+
for platform in old_platforms - new_platforms:
|
|
188
|
+
await self.remove_platform(platform)
|
|
189
|
+
|
|
190
|
+
# 需要更新或添加的平台
|
|
191
|
+
for platform in new_platforms:
|
|
192
|
+
new_target = new_config.route_config[platform]
|
|
193
|
+
if platform in self.config.route_config:
|
|
194
|
+
old_target = self.config.route_config[platform]
|
|
195
|
+
# 如果配置发生变化,需要重新连接
|
|
196
|
+
if (
|
|
197
|
+
new_target.url != old_target.url
|
|
198
|
+
or new_target.token != old_target.token
|
|
199
|
+
):
|
|
200
|
+
await self.remove_platform(platform)
|
|
201
|
+
await self.add_platform(platform, new_target)
|
|
202
|
+
else:
|
|
203
|
+
# 新增平台
|
|
204
|
+
await self.add_platform(platform, new_target)
|
|
205
|
+
|
|
206
|
+
async def update_config(self, config_data: Dict):
|
|
207
|
+
"""更新路由配置并动态调整连接"""
|
|
208
|
+
new_config = RouteConfig.from_dict(config_data)
|
|
209
|
+
if self._running:
|
|
210
|
+
await self._adjust_connections(new_config)
|
|
211
|
+
self.config = new_config
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: maim_message
|
|
3
|
+
Version: 0.2.0
|
|
4
|
+
Summary: A message handling library
|
|
5
|
+
Home-page: https://github.com/MaiM-with-u/maim_message
|
|
6
|
+
Author: tcmofashi
|
|
7
|
+
Author-email: mofashiforzbx@qq.com
|
|
8
|
+
Requires-Python: >=3.9
|
|
9
|
+
Description-Content-Type: text/markdown
|
|
10
|
+
License-File: LICENSE
|
|
11
|
+
Requires-Dist: fastapi>=0.70.0
|
|
12
|
+
Requires-Dist: uvicorn>=0.15.0
|
|
13
|
+
Requires-Dist: aiohttp>=3.8.0
|
|
14
|
+
Requires-Dist: pydantic>=1.9.0
|
|
15
|
+
Requires-Dist: websockets>=10.0
|
|
16
|
+
Dynamic: author-email
|
|
17
|
+
Dynamic: description-content-type
|
|
18
|
+
Dynamic: home-page
|
|
19
|
+
Dynamic: license-file
|
|
20
|
+
Dynamic: requires-python
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
.gitignore
|
|
2
|
+
LICENSE
|
|
3
|
+
README.md
|
|
4
|
+
api.py
|
|
5
|
+
pyproject.toml
|
|
6
|
+
setup.py
|
|
7
|
+
src/maim_message/__init__.py
|
|
8
|
+
src/maim_message/api.py
|
|
9
|
+
src/maim_message/message_base.py
|
|
10
|
+
src/maim_message/router.py
|
|
11
|
+
src/maim_message.egg-info/PKG-INFO
|
|
12
|
+
src/maim_message.egg-info/SOURCES.txt
|
|
13
|
+
src/maim_message.egg-info/dependency_links.txt
|
|
14
|
+
src/maim_message.egg-info/requires.txt
|
|
15
|
+
src/maim_message.egg-info/top_level.txt
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
maim_message
|