huace-aigc-auth-client 1.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- huace_aigc_auth_client/__init__.py +101 -0
- huace_aigc_auth_client/legacy_adapter.py +625 -0
- huace_aigc_auth_client/sdk.py +726 -0
- huace_aigc_auth_client/webhook.py +128 -0
- huace_aigc_auth_client-1.1.7.dist-info/METADATA +797 -0
- huace_aigc_auth_client-1.1.7.dist-info/RECORD +9 -0
- huace_aigc_auth_client-1.1.7.dist-info/WHEEL +5 -0
- huace_aigc_auth_client-1.1.7.dist-info/licenses/LICENSE +22 -0
- huace_aigc_auth_client-1.1.7.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,726 @@
|
|
|
1
|
+
"""
|
|
2
|
+
AIGC Auth Python SDK
|
|
3
|
+
|
|
4
|
+
提供以下功能:
|
|
5
|
+
1. Token 验证
|
|
6
|
+
2. 获取用户信息
|
|
7
|
+
3. 权限检查
|
|
8
|
+
4. FastAPI/Flask 请求拦截中间件
|
|
9
|
+
5. 旧系统接入支持(用户同步)
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import os
|
|
13
|
+
import time
|
|
14
|
+
import hashlib
|
|
15
|
+
import requests
|
|
16
|
+
import logging
|
|
17
|
+
from functools import wraps
|
|
18
|
+
from typing import Optional, List, Dict, Any, Callable, Tuple
|
|
19
|
+
from dataclasses import dataclass
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
# 尝试加载 .env 文件
|
|
24
|
+
try:
|
|
25
|
+
from dotenv import load_dotenv
|
|
26
|
+
load_dotenv()
|
|
27
|
+
except ImportError:
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class UserInfo:
|
|
33
|
+
"""用户信息"""
|
|
34
|
+
id: int
|
|
35
|
+
username: str
|
|
36
|
+
nickname: Optional[str] = None
|
|
37
|
+
email: Optional[str] = None
|
|
38
|
+
phone: Optional[str] = None
|
|
39
|
+
avatar: Optional[str] = None
|
|
40
|
+
roles: List[str] = None
|
|
41
|
+
permissions: List[str] = None
|
|
42
|
+
department: Optional[str] = None
|
|
43
|
+
company: Optional[str] = None
|
|
44
|
+
is_admin: Optional[bool] = None
|
|
45
|
+
status: Optional[str] = None
|
|
46
|
+
|
|
47
|
+
def __post_init__(self):
|
|
48
|
+
if self.roles is None:
|
|
49
|
+
self.roles = []
|
|
50
|
+
if self.permissions is None:
|
|
51
|
+
self.permissions = []
|
|
52
|
+
|
|
53
|
+
def has_role(self, role: str) -> bool:
|
|
54
|
+
"""检查是否拥有指定角色"""
|
|
55
|
+
return role in self.roles
|
|
56
|
+
|
|
57
|
+
def has_permission(self, permission: str) -> bool:
|
|
58
|
+
"""检查是否拥有指定权限"""
|
|
59
|
+
return permission in self.permissions
|
|
60
|
+
|
|
61
|
+
def has_any_permission(self, permissions: List[str]) -> bool:
|
|
62
|
+
"""检查是否拥有任意一个权限"""
|
|
63
|
+
return any(p in self.permissions for p in permissions)
|
|
64
|
+
|
|
65
|
+
def has_all_permissions(self, permissions: List[str]) -> bool:
|
|
66
|
+
"""检查是否拥有所有权限"""
|
|
67
|
+
return all(p in self.permissions for p in permissions)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@dataclass
|
|
71
|
+
class TokenVerifyResult:
|
|
72
|
+
"""Token 验证结果"""
|
|
73
|
+
valid: bool
|
|
74
|
+
user_id: Optional[str] = None
|
|
75
|
+
username: Optional[str] = None
|
|
76
|
+
expires_at: Optional[str] = None
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class AigcAuthError(Exception):
|
|
80
|
+
"""AIGC Auth SDK 异常"""
|
|
81
|
+
def __init__(self, code: int, message: str):
|
|
82
|
+
self.code = code
|
|
83
|
+
self.message = message
|
|
84
|
+
super().__init__(f"[{code}] {message}")
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class AigcAuthClient:
|
|
88
|
+
"""
|
|
89
|
+
AIGC Auth 客户端
|
|
90
|
+
|
|
91
|
+
使用方法:
|
|
92
|
+
client = AigcAuthClient(
|
|
93
|
+
app_id="your_app_id",
|
|
94
|
+
app_secret="your_app_secret",
|
|
95
|
+
base_url="https://aigc-auth.huacemedia.com/aigc-auth/api/v1"
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
# 验证 token
|
|
99
|
+
result = client.verify_token(token)
|
|
100
|
+
|
|
101
|
+
# 获取用户信息
|
|
102
|
+
user = client.get_user_info(token)
|
|
103
|
+
|
|
104
|
+
# 检查权限
|
|
105
|
+
results = client.check_permissions(token, ["user:read", "user:write"])
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
def __init__(
|
|
109
|
+
self,
|
|
110
|
+
app_id: Optional[str] = None,
|
|
111
|
+
app_secret: Optional[str] = None,
|
|
112
|
+
base_url: Optional[str] = None,
|
|
113
|
+
timeout: int = 30,
|
|
114
|
+
cache_ttl: int = 300 # 缓存有效期(秒),默认 5 分钟
|
|
115
|
+
):
|
|
116
|
+
"""
|
|
117
|
+
初始化客户端
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
app_id: 应用 ID,可从环境变量 AIGC_AUTH_APP_ID 读取
|
|
121
|
+
app_secret: 应用密钥,可从环境变量 AIGC_AUTH_APP_SECRET 读取
|
|
122
|
+
base_url: API 基础 URL,可从环境变量 AIGC_AUTH_BASE_URL 读取
|
|
123
|
+
timeout: 请求超时时间(秒)
|
|
124
|
+
cache_ttl: 缓存有效期(秒),默认 300 秒(5 分钟)
|
|
125
|
+
"""
|
|
126
|
+
self.app_id = app_id or os.getenv("AIGC_AUTH_APP_ID")
|
|
127
|
+
self.app_secret = app_secret or os.getenv("AIGC_AUTH_APP_SECRET")
|
|
128
|
+
self.base_url = (
|
|
129
|
+
base_url or
|
|
130
|
+
os.getenv("AIGC_AUTH_BASE_URL") or
|
|
131
|
+
"https://aigc-auth.huacemedia.com/aigc-auth/api/v1"
|
|
132
|
+
)
|
|
133
|
+
self.timeout = timeout
|
|
134
|
+
self.cache_ttl = cache_ttl
|
|
135
|
+
|
|
136
|
+
# 缓存存储: {cache_key: (data, timestamp)}
|
|
137
|
+
self._cache: Dict[str, Tuple[Dict, float]] = {}
|
|
138
|
+
|
|
139
|
+
if not self.app_id or not self.app_secret:
|
|
140
|
+
raise ValueError(
|
|
141
|
+
"必须提供 app_id 和 app_secret,"
|
|
142
|
+
"可通过参数传入或设置环境变量 AIGC_AUTH_APP_ID 和 AIGC_AUTH_APP_SECRET"
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
def _get_headers(self) -> Dict[str, str]:
|
|
146
|
+
"""获取请求头"""
|
|
147
|
+
return {
|
|
148
|
+
"Content-Type": "application/json",
|
|
149
|
+
"X-App-ID": self.app_id,
|
|
150
|
+
"X-App-Secret": self.app_secret
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
def _generate_cache_key(self, token: str, url: str, method: str, extra_data: Dict = None) -> str:
|
|
154
|
+
"""
|
|
155
|
+
生成缓存键
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
token: 用户 token
|
|
159
|
+
url: 请求 URL
|
|
160
|
+
method: 请求方法
|
|
161
|
+
extra_data: 额外的请求参数(会被排序后拼接到 key 中)
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
str: 缓存键(使用 hash 以节省内存)
|
|
165
|
+
"""
|
|
166
|
+
key_string = f"{token}:{url}:{method}"
|
|
167
|
+
|
|
168
|
+
# 如果有额外参数,将其排序后拼接到 key 中
|
|
169
|
+
if extra_data:
|
|
170
|
+
# 对参数进行排序并转换为字符串
|
|
171
|
+
import json
|
|
172
|
+
sorted_data = json.dumps(extra_data, sort_keys=True, ensure_ascii=False)
|
|
173
|
+
key_string = f"{key_string}:{sorted_data}"
|
|
174
|
+
|
|
175
|
+
return hashlib.md5(key_string.encode()).hexdigest()
|
|
176
|
+
|
|
177
|
+
def _get_from_cache(self, cache_key: str) -> Optional[Dict]:
|
|
178
|
+
"""
|
|
179
|
+
从缓存中获取数据
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
cache_key: 缓存键
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
Optional[Dict]: 缓存的数据,如果缓存不存在或已过期则返回 None
|
|
186
|
+
"""
|
|
187
|
+
if cache_key not in self._cache:
|
|
188
|
+
return None
|
|
189
|
+
|
|
190
|
+
data, timestamp = self._cache[cache_key]
|
|
191
|
+
current_time = time.time()
|
|
192
|
+
|
|
193
|
+
# 检查缓存是否过期
|
|
194
|
+
if current_time - timestamp > self.cache_ttl:
|
|
195
|
+
# 清理过期缓存
|
|
196
|
+
del self._cache[cache_key]
|
|
197
|
+
return None
|
|
198
|
+
|
|
199
|
+
return data
|
|
200
|
+
|
|
201
|
+
def _set_cache(self, cache_key: str, data: Dict):
|
|
202
|
+
"""
|
|
203
|
+
设置缓存
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
cache_key: 缓存键
|
|
207
|
+
data: 要缓存的数据
|
|
208
|
+
"""
|
|
209
|
+
self._cache[cache_key] = (data, time.time())
|
|
210
|
+
|
|
211
|
+
# 简单的缓存清理:如果缓存数量过多,清理所有过期的缓存
|
|
212
|
+
if len(self._cache) > 1000:
|
|
213
|
+
self._clean_expired_cache()
|
|
214
|
+
|
|
215
|
+
def _clean_expired_cache(self):
|
|
216
|
+
"""清理所有过期的缓存"""
|
|
217
|
+
current_time = time.time()
|
|
218
|
+
expired_keys = [
|
|
219
|
+
key for key, (_, timestamp) in self._cache.items()
|
|
220
|
+
if current_time - timestamp > self.cache_ttl
|
|
221
|
+
]
|
|
222
|
+
for key in expired_keys:
|
|
223
|
+
del self._cache[key]
|
|
224
|
+
|
|
225
|
+
def clear_cache(self):
|
|
226
|
+
"""清空所有缓存"""
|
|
227
|
+
self._cache.clear()
|
|
228
|
+
|
|
229
|
+
def _request(self, method: str, endpoint: str, data: Dict = None, token: str = None, headers: Dict[str, str] = None) -> Dict:
|
|
230
|
+
"""
|
|
231
|
+
发送请求
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
method: 请求方法
|
|
235
|
+
endpoint: 端点路径
|
|
236
|
+
data: 请求数据
|
|
237
|
+
token: 用户 token(用于缓存键)
|
|
238
|
+
headers: 自定义请求头(会与默认 headers 合并,自定义的优先)
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
Dict: 响应数据
|
|
242
|
+
"""
|
|
243
|
+
url = f"{self.base_url}/sdk{endpoint}"
|
|
244
|
+
|
|
245
|
+
# 如果提供了 token,尝试从缓存中获取
|
|
246
|
+
if token:
|
|
247
|
+
# 从 data 中提取 token 之外的参数作为额外的缓存键信息
|
|
248
|
+
extra_data = None
|
|
249
|
+
if data:
|
|
250
|
+
extra_data = {k: v for k, v in data.items() if k != 'token'}
|
|
251
|
+
|
|
252
|
+
cache_key = self._generate_cache_key(token, url, method, extra_data)
|
|
253
|
+
cached_data = self._get_from_cache(cache_key)
|
|
254
|
+
if cached_data is not None:
|
|
255
|
+
return cached_data
|
|
256
|
+
|
|
257
|
+
try:
|
|
258
|
+
# 合并 headers:默认 headers + 自定义 headers(自定义的优先)
|
|
259
|
+
request_headers = self._get_headers()
|
|
260
|
+
if headers:
|
|
261
|
+
request_headers.update(headers)
|
|
262
|
+
|
|
263
|
+
response = requests.request(
|
|
264
|
+
method=method,
|
|
265
|
+
url=url,
|
|
266
|
+
json=data,
|
|
267
|
+
headers=request_headers,
|
|
268
|
+
timeout=self.timeout
|
|
269
|
+
)
|
|
270
|
+
response.raise_for_status()
|
|
271
|
+
result = response.json()
|
|
272
|
+
|
|
273
|
+
if result.get("code") != 0:
|
|
274
|
+
logger.error(f"AigcAuthClient 请求错误: {result}")
|
|
275
|
+
raise AigcAuthError(
|
|
276
|
+
result.get("code", -1),
|
|
277
|
+
result.get("message", "未知错误")
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
response_data = result.get("data", {})
|
|
281
|
+
|
|
282
|
+
# 如果请求成功且提供了 token,缓存响应数据
|
|
283
|
+
if token:
|
|
284
|
+
self._set_cache(cache_key, response_data)
|
|
285
|
+
|
|
286
|
+
return response_data
|
|
287
|
+
|
|
288
|
+
except requests.exceptions.RequestException as e:
|
|
289
|
+
logger.error(f"AigcAuthClient 请求失败: {str(e)}")
|
|
290
|
+
raise AigcAuthError(-1, f"请求失败: {str(e)}")
|
|
291
|
+
|
|
292
|
+
def verify_token(self, token: str) -> TokenVerifyResult:
|
|
293
|
+
"""
|
|
294
|
+
验证 Token
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
token: 用户的 access_token
|
|
298
|
+
|
|
299
|
+
Returns:
|
|
300
|
+
TokenVerifyResult: 验证结果
|
|
301
|
+
"""
|
|
302
|
+
data = self._request("POST", "/token/verify", {"token": token}, token=token)
|
|
303
|
+
return TokenVerifyResult(
|
|
304
|
+
valid=data.get("valid", False),
|
|
305
|
+
user_id=data.get("userId"),
|
|
306
|
+
username=data.get("username"),
|
|
307
|
+
expires_at=data.get("expiresAt")
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
def get_user_info(self, token: str) -> UserInfo:
|
|
311
|
+
"""
|
|
312
|
+
获取用户信息
|
|
313
|
+
|
|
314
|
+
Args:
|
|
315
|
+
token: 用户的 access_token
|
|
316
|
+
|
|
317
|
+
Returns:
|
|
318
|
+
UserInfo: 用户信息
|
|
319
|
+
|
|
320
|
+
Raises:
|
|
321
|
+
AigcAuthError: 当 token 无效或用户不存在时
|
|
322
|
+
"""
|
|
323
|
+
data = self._request("POST", "/user/info", {"token": token}, token=token)
|
|
324
|
+
return UserInfo(
|
|
325
|
+
id=data.get("id"),
|
|
326
|
+
username=data.get("username"),
|
|
327
|
+
nickname=data.get("nickname"),
|
|
328
|
+
email=data.get("email"),
|
|
329
|
+
phone=data.get("phone"),
|
|
330
|
+
avatar=data.get("avatar"),
|
|
331
|
+
roles=data.get("roles", []),
|
|
332
|
+
permissions=data.get("permissions", []),
|
|
333
|
+
department=data.get("department"),
|
|
334
|
+
company=data.get("company"),
|
|
335
|
+
is_admin=data.get("is_admin"),
|
|
336
|
+
status=data.get("status")
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
def check_permissions(
|
|
340
|
+
self,
|
|
341
|
+
token: str,
|
|
342
|
+
permission_codes: List[str]
|
|
343
|
+
) -> Dict[str, bool]:
|
|
344
|
+
"""
|
|
345
|
+
批量检查权限
|
|
346
|
+
|
|
347
|
+
Args:
|
|
348
|
+
token: 用户的 access_token
|
|
349
|
+
permission_codes: 权限代码列表
|
|
350
|
+
|
|
351
|
+
Returns:
|
|
352
|
+
Dict[str, bool]: 权限检查结果,key 为权限代码,value 为是否拥有
|
|
353
|
+
"""
|
|
354
|
+
data = self._request("POST", "/permission/check", {
|
|
355
|
+
"token": token,
|
|
356
|
+
"permissionCodes": permission_codes
|
|
357
|
+
}, token=token)
|
|
358
|
+
return data.get("results", {})
|
|
359
|
+
|
|
360
|
+
def get_user_info_from_header(self, authorization: str) -> Optional[UserInfo]:
|
|
361
|
+
"""
|
|
362
|
+
从 Authorization header 获取用户信息
|
|
363
|
+
|
|
364
|
+
Args:
|
|
365
|
+
authorization: Authorization header 的值,格式为 "Bearer {token}"
|
|
366
|
+
|
|
367
|
+
Returns:
|
|
368
|
+
UserInfo: 用户信息,如果验证失败返回 None
|
|
369
|
+
"""
|
|
370
|
+
if not authorization:
|
|
371
|
+
return None
|
|
372
|
+
|
|
373
|
+
if not authorization.startswith("Bearer "):
|
|
374
|
+
return None
|
|
375
|
+
|
|
376
|
+
token = authorization[7:] # 移除 "Bearer " 前缀
|
|
377
|
+
|
|
378
|
+
try:
|
|
379
|
+
return self.get_user_info(token)
|
|
380
|
+
except AigcAuthError:
|
|
381
|
+
return None
|
|
382
|
+
|
|
383
|
+
# ============ 用户同步相关方法 ============
|
|
384
|
+
|
|
385
|
+
def sync_user_to_auth(self, user_data: Dict[str, Any], headers: Dict[str, str] = None) -> Dict[str, Any]:
|
|
386
|
+
"""
|
|
387
|
+
同步用户到 aigc-auth(用于旧系统初始化同步)
|
|
388
|
+
|
|
389
|
+
Args:
|
|
390
|
+
user_data: 用户数据,必须包含 username,以及 password 或 password_hashed 二选一
|
|
391
|
+
- password: 明文密码
|
|
392
|
+
- password_hashed: 已加密的密码哈希(bcrypt 格式)
|
|
393
|
+
headers: 自定义请求头(可选)
|
|
394
|
+
|
|
395
|
+
Returns:
|
|
396
|
+
Dict: 同步结果
|
|
397
|
+
- success: bool 是否成功
|
|
398
|
+
- created: bool 是否新建(False 表示已存在)
|
|
399
|
+
- user_id: int 用户ID
|
|
400
|
+
- message: str 消息
|
|
401
|
+
"""
|
|
402
|
+
return self._request("POST", "/sync/user", user_data, headers=headers)
|
|
403
|
+
|
|
404
|
+
def batch_sync_users_to_auth(self, users: List[Dict[str, Any]], headers: Dict[str, str] = None) -> Dict[str, Any]:
|
|
405
|
+
"""
|
|
406
|
+
批量同步用户到 aigc-auth
|
|
407
|
+
|
|
408
|
+
Args:
|
|
409
|
+
users: 用户数据列表,每个用户必须包含 username,以及 password 或 password_hashed 二选一
|
|
410
|
+
headers: 自定义请求头(可选)
|
|
411
|
+
|
|
412
|
+
Returns:
|
|
413
|
+
Dict: 批量同步结果
|
|
414
|
+
- total: int 总数
|
|
415
|
+
- success: int 成功数
|
|
416
|
+
- failed: int 失败数
|
|
417
|
+
- skipped: int 跳过数(已存在)
|
|
418
|
+
- errors: List[Dict] 错误详情
|
|
419
|
+
"""
|
|
420
|
+
return self._request("POST", "/sync/batch", {"users": users}, headers=headers)
|
|
421
|
+
|
|
422
|
+
def register_webhook(self, webhook_url: str, events: List[str], secret: Optional[str] = None) -> Dict[str, Any]:
|
|
423
|
+
"""
|
|
424
|
+
注册 webhook 接收增量用户
|
|
425
|
+
|
|
426
|
+
Args:
|
|
427
|
+
webhook_url: webhook 接收地址
|
|
428
|
+
events: 订阅的事件列表,如 ["user.created", "user.updated"]
|
|
429
|
+
secret: webhook 签名密钥
|
|
430
|
+
|
|
431
|
+
Returns:
|
|
432
|
+
Dict: 注册结果
|
|
433
|
+
- webhook_id: str webhook ID
|
|
434
|
+
- success: bool 是否成功
|
|
435
|
+
"""
|
|
436
|
+
return self._request("POST", "/webhook/register", {
|
|
437
|
+
"url": webhook_url,
|
|
438
|
+
"events": events,
|
|
439
|
+
"secret": secret
|
|
440
|
+
})
|
|
441
|
+
|
|
442
|
+
def unregister_webhook(self, webhook_id: str) -> Dict[str, Any]:
|
|
443
|
+
"""
|
|
444
|
+
注销 webhook
|
|
445
|
+
|
|
446
|
+
Args:
|
|
447
|
+
webhook_id: webhook ID
|
|
448
|
+
|
|
449
|
+
Returns:
|
|
450
|
+
Dict: 注销结果
|
|
451
|
+
"""
|
|
452
|
+
return self._request("POST", "/webhook/unregister", {"webhookId": webhook_id})
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
def require_auth(
|
|
456
|
+
client: AigcAuthClient,
|
|
457
|
+
permissions: List[str] = None,
|
|
458
|
+
any_permission: bool = False
|
|
459
|
+
):
|
|
460
|
+
"""
|
|
461
|
+
FastAPI 路由装饰器,要求用户登录
|
|
462
|
+
|
|
463
|
+
Args:
|
|
464
|
+
client: AigcAuthClient 实例
|
|
465
|
+
permissions: 需要的权限列表(可选)
|
|
466
|
+
any_permission: 是否只需要任意一个权限,默认需要全部
|
|
467
|
+
|
|
468
|
+
使用方法:
|
|
469
|
+
@app.get("/protected")
|
|
470
|
+
@require_auth(client)
|
|
471
|
+
def protected_route(user_info: UserInfo):
|
|
472
|
+
return {"user": user_info.username}
|
|
473
|
+
|
|
474
|
+
@app.get("/admin")
|
|
475
|
+
@require_auth(client, permissions=["admin:access"])
|
|
476
|
+
def admin_route(user_info: UserInfo):
|
|
477
|
+
return {"admin": True}
|
|
478
|
+
"""
|
|
479
|
+
def decorator(func: Callable):
|
|
480
|
+
@wraps(func)
|
|
481
|
+
def wrapper(*args, **kwargs):
|
|
482
|
+
# 尝试从 FastAPI 获取 request
|
|
483
|
+
request = kwargs.get("request")
|
|
484
|
+
if request is None:
|
|
485
|
+
for arg in args:
|
|
486
|
+
if hasattr(arg, "headers"):
|
|
487
|
+
request = arg
|
|
488
|
+
break
|
|
489
|
+
|
|
490
|
+
if request is None:
|
|
491
|
+
raise AigcAuthError(401, "无法获取请求对象")
|
|
492
|
+
|
|
493
|
+
authorization = request.headers.get("Authorization")
|
|
494
|
+
user_info = client.get_user_info_from_header(authorization)
|
|
495
|
+
|
|
496
|
+
if user_info is None:
|
|
497
|
+
raise AigcAuthError(401, "未登录或 Token 已过期")
|
|
498
|
+
|
|
499
|
+
# 检查权限
|
|
500
|
+
if permissions:
|
|
501
|
+
if any_permission:
|
|
502
|
+
if not user_info.has_any_permission(permissions):
|
|
503
|
+
raise AigcAuthError(403, "权限不足")
|
|
504
|
+
else:
|
|
505
|
+
if not user_info.has_all_permissions(permissions):
|
|
506
|
+
raise AigcAuthError(403, "权限不足")
|
|
507
|
+
|
|
508
|
+
# 将用户信息注入到 kwargs
|
|
509
|
+
kwargs["user_info"] = user_info
|
|
510
|
+
return func(*args, **kwargs)
|
|
511
|
+
|
|
512
|
+
return wrapper
|
|
513
|
+
return decorator
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
class AuthMiddleware:
|
|
517
|
+
"""
|
|
518
|
+
通用认证中间件
|
|
519
|
+
|
|
520
|
+
支持 FastAPI 和 Flask
|
|
521
|
+
|
|
522
|
+
FastAPI 使用方法:
|
|
523
|
+
from fastapi import FastAPI, Request
|
|
524
|
+
from sdk import AigcAuthClient, AuthMiddleware
|
|
525
|
+
|
|
526
|
+
app = FastAPI()
|
|
527
|
+
client = AigcAuthClient(app_id="xxx", app_secret="xxx")
|
|
528
|
+
auth_middleware = AuthMiddleware(client)
|
|
529
|
+
|
|
530
|
+
@app.middleware("http")
|
|
531
|
+
async def auth_middleware_handler(request: Request, call_next):
|
|
532
|
+
return await auth_middleware.fastapi_middleware(request, call_next)
|
|
533
|
+
|
|
534
|
+
Flask 使用方法:
|
|
535
|
+
from flask import Flask
|
|
536
|
+
from sdk import AigcAuthClient, AuthMiddleware
|
|
537
|
+
|
|
538
|
+
app = Flask(__name__)
|
|
539
|
+
client = AigcAuthClient(app_id="xxx", app_secret="xxx")
|
|
540
|
+
auth_middleware = AuthMiddleware(client)
|
|
541
|
+
|
|
542
|
+
@app.before_request
|
|
543
|
+
def before_request():
|
|
544
|
+
return auth_middleware.flask_before_request()
|
|
545
|
+
"""
|
|
546
|
+
|
|
547
|
+
def __init__(
|
|
548
|
+
self,
|
|
549
|
+
client: AigcAuthClient,
|
|
550
|
+
exclude_paths: List[str] = None,
|
|
551
|
+
exclude_prefixes: List[str] = None
|
|
552
|
+
):
|
|
553
|
+
"""
|
|
554
|
+
初始化中间件
|
|
555
|
+
|
|
556
|
+
Args:
|
|
557
|
+
client: AigcAuthClient 实例
|
|
558
|
+
exclude_paths: 排除的路径列表(精确匹配)
|
|
559
|
+
exclude_prefixes: 排除的路径前缀列表
|
|
560
|
+
"""
|
|
561
|
+
self.client = client
|
|
562
|
+
self.exclude_paths = exclude_paths or []
|
|
563
|
+
self.exclude_prefixes = exclude_prefixes or []
|
|
564
|
+
|
|
565
|
+
def _should_skip(self, path: str) -> bool:
|
|
566
|
+
"""检查是否应该跳过验证"""
|
|
567
|
+
if path in self.exclude_paths:
|
|
568
|
+
return True
|
|
569
|
+
for prefix in self.exclude_prefixes:
|
|
570
|
+
if path.startswith(prefix):
|
|
571
|
+
return True
|
|
572
|
+
return False
|
|
573
|
+
|
|
574
|
+
def _extract_token(self, authorization: str) -> Optional[str]:
|
|
575
|
+
"""从 Authorization header 提取 token"""
|
|
576
|
+
if not authorization:
|
|
577
|
+
return None
|
|
578
|
+
if not authorization.startswith("Bearer "):
|
|
579
|
+
return None
|
|
580
|
+
return authorization[7:]
|
|
581
|
+
|
|
582
|
+
async def fastapi_middleware(self, request, call_next, user_info_callback: Callable = None):
|
|
583
|
+
"""
|
|
584
|
+
FastAPI 中间件
|
|
585
|
+
|
|
586
|
+
使用方法:
|
|
587
|
+
@app.middleware("http")
|
|
588
|
+
async def auth(request: Request, call_next):
|
|
589
|
+
return await auth_middleware.fastapi_middleware(request, call_next)
|
|
590
|
+
"""
|
|
591
|
+
from fastapi.responses import JSONResponse
|
|
592
|
+
|
|
593
|
+
path = request.url.path
|
|
594
|
+
|
|
595
|
+
# 检查是否跳过
|
|
596
|
+
if self._should_skip(path):
|
|
597
|
+
return await call_next(request)
|
|
598
|
+
|
|
599
|
+
# 获取 Authorization header
|
|
600
|
+
authorization = request.headers.get("Authorization")
|
|
601
|
+
token = self._extract_token(authorization)
|
|
602
|
+
|
|
603
|
+
if not token:
|
|
604
|
+
logger.warning("AuthMiddleware未提供认证信息")
|
|
605
|
+
return JSONResponse(
|
|
606
|
+
status_code=401,
|
|
607
|
+
content={"code": 401, "message": "未提供认证信息", "data": None}
|
|
608
|
+
)
|
|
609
|
+
|
|
610
|
+
# 验证 token
|
|
611
|
+
try:
|
|
612
|
+
user_info = self.client.get_user_info(token)
|
|
613
|
+
# 将用户信息存储到 request.state
|
|
614
|
+
request.state.user_info = user_info
|
|
615
|
+
# 处理代理头部,确保重定向(如果有)使用正确的协议
|
|
616
|
+
forwarded_proto = request.headers.get("x-forwarded-proto")
|
|
617
|
+
if forwarded_proto:
|
|
618
|
+
request.scope["scheme"] = forwarded_proto
|
|
619
|
+
if user_info_callback:
|
|
620
|
+
await user_info_callback(request, user_info)
|
|
621
|
+
except AigcAuthError as e:
|
|
622
|
+
logger.error(f"AuthMiddleware认证失败: {e.message}")
|
|
623
|
+
return JSONResponse(
|
|
624
|
+
status_code=401,
|
|
625
|
+
content={"code": e.code, "message": e.message, "data": None}
|
|
626
|
+
)
|
|
627
|
+
return await call_next(request)
|
|
628
|
+
|
|
629
|
+
def flask_before_request(self, user_info_callback: Callable = None):
|
|
630
|
+
"""
|
|
631
|
+
Flask before_request 处理器
|
|
632
|
+
|
|
633
|
+
使用方法:
|
|
634
|
+
@app.before_request
|
|
635
|
+
def before_request():
|
|
636
|
+
return auth_middleware.flask_before_request(user_info_callback=user_info_callback)
|
|
637
|
+
"""
|
|
638
|
+
from flask import request, jsonify, g
|
|
639
|
+
|
|
640
|
+
path = request.path
|
|
641
|
+
|
|
642
|
+
# 检查是否跳过
|
|
643
|
+
if self._should_skip(path):
|
|
644
|
+
return None
|
|
645
|
+
|
|
646
|
+
# 获取 Authorization header
|
|
647
|
+
authorization = request.headers.get("Authorization")
|
|
648
|
+
token = self._extract_token(authorization)
|
|
649
|
+
|
|
650
|
+
if not token:
|
|
651
|
+
logger.warning("AuthMiddleware未提供认证信息")
|
|
652
|
+
return jsonify({
|
|
653
|
+
"code": 401,
|
|
654
|
+
"message": "未提供认证信息",
|
|
655
|
+
"data": None
|
|
656
|
+
}), 401
|
|
657
|
+
|
|
658
|
+
# 验证 token
|
|
659
|
+
try:
|
|
660
|
+
user_info = self.client.get_user_info(token)
|
|
661
|
+
# 将用户信息存储到 flask.g
|
|
662
|
+
g.user_info = user_info
|
|
663
|
+
if user_info_callback:
|
|
664
|
+
user_info_callback(request, user_info)
|
|
665
|
+
except AigcAuthError as e:
|
|
666
|
+
logger.error(f"AuthMiddleware认证失败: {e.message}")
|
|
667
|
+
return jsonify({
|
|
668
|
+
"code": e.code,
|
|
669
|
+
"message": e.message,
|
|
670
|
+
"data": None
|
|
671
|
+
}), 401
|
|
672
|
+
|
|
673
|
+
return None
|
|
674
|
+
|
|
675
|
+
def get_current_user_fastapi(self, request) -> Optional[UserInfo]:
|
|
676
|
+
"""
|
|
677
|
+
FastAPI 中获取当前用户
|
|
678
|
+
|
|
679
|
+
Args:
|
|
680
|
+
request: FastAPI Request 对象
|
|
681
|
+
|
|
682
|
+
Returns:
|
|
683
|
+
UserInfo: 用户信息,如果未登录返回 None
|
|
684
|
+
"""
|
|
685
|
+
return getattr(request.state, "user_info", None)
|
|
686
|
+
|
|
687
|
+
def get_current_user_flask(self) -> Optional[UserInfo]:
|
|
688
|
+
"""
|
|
689
|
+
Flask 中获取当前用户
|
|
690
|
+
|
|
691
|
+
Returns:
|
|
692
|
+
UserInfo: 用户信息,如果未登录返回 None
|
|
693
|
+
"""
|
|
694
|
+
from flask import g
|
|
695
|
+
return getattr(g, "user_info", None)
|
|
696
|
+
|
|
697
|
+
|
|
698
|
+
# 便捷函数:创建 FastAPI 依赖
|
|
699
|
+
def create_fastapi_auth_dependency(client: AigcAuthClient):
|
|
700
|
+
"""
|
|
701
|
+
创建 FastAPI 认证依赖
|
|
702
|
+
|
|
703
|
+
使用方法:
|
|
704
|
+
from fastapi import Depends
|
|
705
|
+
from sdk import AigcAuthClient, create_fastapi_auth_dependency
|
|
706
|
+
|
|
707
|
+
client = AigcAuthClient(app_id="xxx", app_secret="xxx")
|
|
708
|
+
get_current_user = create_fastapi_auth_dependency(client)
|
|
709
|
+
|
|
710
|
+
@app.get("/me")
|
|
711
|
+
async def get_me(user: UserInfo = Depends(get_current_user)):
|
|
712
|
+
return {"username": user.username}
|
|
713
|
+
"""
|
|
714
|
+
from fastapi import Request, HTTPException
|
|
715
|
+
|
|
716
|
+
async def get_current_user(request: Request) -> UserInfo:
|
|
717
|
+
authorization = request.headers.get("Authorization")
|
|
718
|
+
user_info = client.get_user_info_from_header(authorization)
|
|
719
|
+
|
|
720
|
+
if user_info is None:
|
|
721
|
+
logger.warning("FastAPI依赖未登录或Token已过期")
|
|
722
|
+
raise HTTPException(status_code=401, detail="未登录或 Token 已过期")
|
|
723
|
+
|
|
724
|
+
return user_info
|
|
725
|
+
|
|
726
|
+
return get_current_user
|