ForcomeBot 2.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
src/api/websocket.py ADDED
@@ -0,0 +1,280 @@
1
+ """WebSocket推送模块 - 实时状态更新
2
+
3
+ 功能:
4
+ - WebSocketManager 连接管理
5
+ - 日志实时推送
6
+ - 状态变更推送
7
+ """
8
+ import asyncio
9
+ import logging
10
+ from typing import Set, Optional, Dict, Any, TYPE_CHECKING
11
+
12
+ from fastapi import WebSocket, WebSocketDisconnect
13
+
14
+ if TYPE_CHECKING:
15
+ from ..core.log_collector import LogCollector
16
+ from ..clients.langbot import LangBotClient
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class WebSocketManager:
22
+ """WebSocket连接管理器"""
23
+
24
+ def __init__(self, max_connections: int = 100):
25
+ """初始化WebSocket管理器
26
+
27
+ Args:
28
+ max_connections: 最大连接数
29
+ """
30
+ self._connections: Set[WebSocket] = set()
31
+ self._max_connections = max_connections
32
+ self._lock = asyncio.Lock()
33
+
34
+ async def connect(self, websocket: WebSocket) -> bool:
35
+ """接受WebSocket连接
36
+
37
+ Args:
38
+ websocket: WebSocket连接
39
+
40
+ Returns:
41
+ 是否成功连接
42
+ """
43
+ async with self._lock:
44
+ if len(self._connections) >= self._max_connections:
45
+ logger.warning(f"WebSocket连接数已达上限: {self._max_connections}")
46
+ await websocket.close(code=1013, reason="连接数已达上限")
47
+ return False
48
+
49
+ await websocket.accept()
50
+ self._connections.add(websocket)
51
+ logger.info(f"WebSocket连接已建立,当前连接数: {len(self._connections)}")
52
+ return True
53
+
54
+ def disconnect(self, websocket: WebSocket):
55
+ """断开WebSocket连接
56
+
57
+ Args:
58
+ websocket: WebSocket连接
59
+ """
60
+ self._connections.discard(websocket)
61
+ logger.info(f"WebSocket连接已断开,当前连接数: {len(self._connections)}")
62
+
63
+ async def broadcast(self, message: Dict[str, Any]):
64
+ """广播消息给所有连接
65
+
66
+ Args:
67
+ message: 要广播的消息
68
+ """
69
+ if not self._connections:
70
+ return
71
+
72
+ dead_connections = set()
73
+
74
+ for ws in self._connections.copy():
75
+ try:
76
+ await ws.send_json(message)
77
+ except Exception as e:
78
+ logger.debug(f"发送WebSocket消息失败: {e}")
79
+ dead_connections.add(ws)
80
+
81
+ # 清理失效连接
82
+ for ws in dead_connections:
83
+ self._connections.discard(ws)
84
+
85
+ async def send_to(self, websocket: WebSocket, message: Dict[str, Any]) -> bool:
86
+ """发送消息给指定连接
87
+
88
+ Args:
89
+ websocket: 目标WebSocket连接
90
+ message: 要发送的消息
91
+
92
+ Returns:
93
+ 是否发送成功
94
+ """
95
+ try:
96
+ await websocket.send_json(message)
97
+ return True
98
+ except Exception as e:
99
+ logger.debug(f"发送WebSocket消息失败: {e}")
100
+ self._connections.discard(websocket)
101
+ return False
102
+
103
+ @property
104
+ def connection_count(self) -> int:
105
+ """获取当前连接数"""
106
+ return len(self._connections)
107
+
108
+
109
+ # 全局WebSocket管理器实例
110
+ ws_manager = WebSocketManager()
111
+
112
+ # 全局引用(由main.py设置)
113
+ _log_collector: Optional["LogCollector"] = None
114
+ _langbot_client: Optional["LangBotClient"] = None
115
+
116
+
117
+ def set_websocket_dependencies(
118
+ log_collector: "LogCollector",
119
+ langbot_client: "LangBotClient"
120
+ ):
121
+ """设置WebSocket依赖
122
+
123
+ Args:
124
+ log_collector: 日志收集器
125
+ langbot_client: LangBot客户端
126
+ """
127
+ global _log_collector, _langbot_client
128
+ _log_collector = log_collector
129
+ _langbot_client = langbot_client
130
+ logger.info("WebSocket依赖已设置")
131
+
132
+
133
+ async def websocket_endpoint(websocket: WebSocket):
134
+ """WebSocket端点处理函数
135
+
136
+ 处理WebSocket连接,推送日志更新和状态变更
137
+ """
138
+ # 尝试连接
139
+ if not await ws_manager.connect(websocket):
140
+ return
141
+
142
+ # 订阅日志更新
143
+ log_queue = None
144
+ if _log_collector:
145
+ log_queue = _log_collector.subscribe()
146
+
147
+ # 启动状态推送任务
148
+ status_task = asyncio.create_task(_push_status_updates(websocket))
149
+
150
+ try:
151
+ # 发送初始状态
152
+ await _send_initial_status(websocket)
153
+
154
+ # 主循环:推送日志更新
155
+ while True:
156
+ if log_queue:
157
+ try:
158
+ # 等待日志更新(带超时,以便检查连接状态)
159
+ log_entry = await asyncio.wait_for(
160
+ log_queue.get(),
161
+ timeout=30.0
162
+ )
163
+ await ws_manager.send_to(websocket, {
164
+ "type": "log",
165
+ "data": log_entry
166
+ })
167
+ except asyncio.TimeoutError:
168
+ # 发送心跳
169
+ try:
170
+ await websocket.send_json({"type": "ping"})
171
+ except:
172
+ break
173
+ else:
174
+ # 没有日志收集器,只发送心跳
175
+ await asyncio.sleep(30)
176
+ try:
177
+ await websocket.send_json({"type": "ping"})
178
+ except:
179
+ break
180
+
181
+ except WebSocketDisconnect:
182
+ logger.debug("WebSocket客户端断开连接")
183
+ except Exception as e:
184
+ logger.error(f"WebSocket处理异常: {e}")
185
+ finally:
186
+ # 取消状态推送任务
187
+ status_task.cancel()
188
+ try:
189
+ await status_task
190
+ except asyncio.CancelledError:
191
+ pass
192
+
193
+ # 取消日志订阅
194
+ if log_queue and _log_collector:
195
+ _log_collector.unsubscribe(log_queue)
196
+
197
+ # 断开连接
198
+ ws_manager.disconnect(websocket)
199
+
200
+
201
+ async def _send_initial_status(websocket: WebSocket):
202
+ """发送初始状态
203
+
204
+ Args:
205
+ websocket: WebSocket连接
206
+ """
207
+ status = _get_current_status()
208
+ await ws_manager.send_to(websocket, {
209
+ "type": "status",
210
+ "data": status
211
+ })
212
+
213
+
214
+ async def _push_status_updates(websocket: WebSocket):
215
+ """定期推送状态更新
216
+
217
+ Args:
218
+ websocket: WebSocket连接
219
+ """
220
+ last_status = None
221
+
222
+ try:
223
+ while True:
224
+ await asyncio.sleep(5) # 每5秒检查一次状态变化
225
+
226
+ current_status = _get_current_status()
227
+
228
+ # 只在状态变化时推送
229
+ if current_status != last_status:
230
+ success = await ws_manager.send_to(websocket, {
231
+ "type": "status",
232
+ "data": current_status
233
+ })
234
+ if not success:
235
+ break
236
+ last_status = current_status
237
+
238
+ except asyncio.CancelledError:
239
+ pass
240
+ except Exception as e:
241
+ logger.debug(f"状态推送任务异常: {e}")
242
+
243
+
244
+ def _get_current_status() -> Dict[str, Any]:
245
+ """获取当前状态
246
+
247
+ Returns:
248
+ 状态字典
249
+ """
250
+ return {
251
+ "langbot_connected": _langbot_client.is_connected if _langbot_client else False,
252
+ "langbot_reconnecting": _langbot_client.is_reconnecting if _langbot_client else False,
253
+ "websocket_connections": ws_manager.connection_count
254
+ }
255
+
256
+
257
+ async def broadcast_status_change(status_type: str, data: Dict[str, Any]):
258
+ """广播状态变更
259
+
260
+ Args:
261
+ status_type: 状态类型(如 langbot_connected, config_updated)
262
+ data: 状态数据
263
+ """
264
+ await ws_manager.broadcast({
265
+ "type": "status_change",
266
+ "status_type": status_type,
267
+ "data": data
268
+ })
269
+
270
+
271
+ async def broadcast_log(log_entry: Dict[str, Any]):
272
+ """广播日志条目
273
+
274
+ Args:
275
+ log_entry: 日志条目
276
+ """
277
+ await ws_manager.broadcast({
278
+ "type": "log",
279
+ "data": log_entry
280
+ })
src/auth/__init__.py ADDED
@@ -0,0 +1,33 @@
1
+ """认证模块"""
2
+ from .database import Database, get_db, init_database
3
+ from .models import User, OperationLog
4
+ from .jwt_handler import JWTHandler
5
+ from .dingtalk import DingTalkClient, init_dingtalk_client, get_dingtalk_client
6
+ from .middleware import (
7
+ AuthMiddleware,
8
+ get_current_user,
9
+ get_optional_user,
10
+ init_auth,
11
+ is_auth_enabled,
12
+ log_operation
13
+ )
14
+ from .routes import router as auth_router
15
+
16
+ __all__ = [
17
+ "Database",
18
+ "get_db",
19
+ "init_database",
20
+ "User",
21
+ "OperationLog",
22
+ "JWTHandler",
23
+ "DingTalkClient",
24
+ "init_dingtalk_client",
25
+ "get_dingtalk_client",
26
+ "AuthMiddleware",
27
+ "get_current_user",
28
+ "get_optional_user",
29
+ "init_auth",
30
+ "is_auth_enabled",
31
+ "log_operation",
32
+ "auth_router",
33
+ ]
src/auth/database.py ADDED
@@ -0,0 +1,87 @@
1
+ """数据库连接管理"""
2
+ import logging
3
+ from pathlib import Path
4
+ from typing import AsyncGenerator
5
+
6
+ from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
7
+ from sqlalchemy.orm import DeclarativeBase
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class Base(DeclarativeBase):
13
+ """SQLAlchemy 基类"""
14
+ pass
15
+
16
+
17
+ class Database:
18
+ """数据库管理器"""
19
+
20
+ def __init__(self, db_path: str = "data/app.db"):
21
+ self.db_path = db_path
22
+ self.engine = None
23
+ self.session_factory = None
24
+
25
+ async def init(self):
26
+ """初始化数据库连接"""
27
+ # 确保目录存在
28
+ Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
29
+
30
+ # 创建异步引擎
31
+ self.engine = create_async_engine(
32
+ f"sqlite+aiosqlite:///{self.db_path}",
33
+ echo=False,
34
+ future=True
35
+ )
36
+
37
+ # 创建会话工厂
38
+ self.session_factory = async_sessionmaker(
39
+ self.engine,
40
+ class_=AsyncSession,
41
+ expire_on_commit=False
42
+ )
43
+
44
+ # 创建表
45
+ async with self.engine.begin() as conn:
46
+ from .models import User, OperationLog # 导入模型以注册
47
+ await conn.run_sync(Base.metadata.create_all)
48
+
49
+ logger.info(f"数据库初始化完成: {self.db_path}")
50
+
51
+ async def close(self):
52
+ """关闭数据库连接"""
53
+ if self.engine:
54
+ await self.engine.dispose()
55
+ logger.info("数据库连接已关闭")
56
+
57
+ async def get_session(self) -> AsyncGenerator[AsyncSession, None]:
58
+ """获取数据库会话"""
59
+ if not self.session_factory:
60
+ raise RuntimeError("数据库未初始化")
61
+
62
+ async with self.session_factory() as session:
63
+ try:
64
+ yield session
65
+ await session.commit()
66
+ except Exception:
67
+ await session.rollback()
68
+ raise
69
+
70
+
71
+ # 全局数据库实例
72
+ _database: Database = None
73
+
74
+
75
+ async def init_database(db_path: str = "data/app.db") -> Database:
76
+ """初始化全局数据库实例"""
77
+ global _database
78
+ _database = Database(db_path)
79
+ await _database.init()
80
+ return _database
81
+
82
+
83
+ def get_db() -> Database:
84
+ """获取全局数据库实例"""
85
+ if _database is None:
86
+ raise RuntimeError("数据库未初始化,请先调用 init_database()")
87
+ return _database