skyplatform-iam 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- skyplatform_iam/__init__.py +11 -2
- skyplatform_iam/connect_agenterra_iam.py +82 -12
- skyplatform_iam/middleware.py +235 -3
- {skyplatform_iam-1.0.1.dist-info → skyplatform_iam-1.0.3.dist-info}/METADATA +1 -2
- skyplatform_iam-1.0.3.dist-info/RECORD +8 -0
- skyplatform_iam/auth_middleware.py +0 -201
- skyplatform_iam-1.0.1.dist-info/RECORD +0 -9
- {skyplatform_iam-1.0.1.dist-info → skyplatform_iam-1.0.3.dist-info}/WHEEL +0 -0
skyplatform_iam/__init__.py
CHANGED
|
@@ -4,9 +4,8 @@ SkyPlatform IAM SDK
|
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
6
|
from .config import AuthConfig
|
|
7
|
-
from .middleware import AuthMiddleware
|
|
7
|
+
from .middleware import AuthMiddleware, AuthService, setup_auth_middleware, get_current_user, get_optional_user
|
|
8
8
|
from .connect_agenterra_iam import ConnectAgenterraIam
|
|
9
|
-
from .auth_middleware import setup_auth_middleware
|
|
10
9
|
from .exceptions import (
|
|
11
10
|
SkyPlatformAuthException,
|
|
12
11
|
AuthenticationError,
|
|
@@ -29,7 +28,10 @@ __all__ = [
|
|
|
29
28
|
|
|
30
29
|
# 中间件
|
|
31
30
|
"AuthMiddleware",
|
|
31
|
+
"AuthService",
|
|
32
32
|
"setup_auth_middleware",
|
|
33
|
+
"get_current_user",
|
|
34
|
+
"get_optional_user",
|
|
33
35
|
|
|
34
36
|
# 客户端
|
|
35
37
|
"ConnectAgenterraIam",
|
|
@@ -86,10 +88,17 @@ def setup_auth(app, config: AuthConfig = None):
|
|
|
86
88
|
Note:
|
|
87
89
|
此函数只设置认证中间件,不包含预制路由。
|
|
88
90
|
客户端应用需要根据业务需求自己实现认证相关的API接口。
|
|
91
|
+
建议传入完整的AuthConfig对象以避免环境变量配置问题。
|
|
89
92
|
"""
|
|
90
93
|
if config is None:
|
|
91
94
|
config = AuthConfig.from_env()
|
|
92
95
|
|
|
96
|
+
# 验证配置的完整性
|
|
97
|
+
config.validate_config()
|
|
98
|
+
|
|
99
|
+
# 初始化全局认证服务
|
|
100
|
+
setup_auth_middleware(config)
|
|
101
|
+
|
|
93
102
|
# 添加中间件
|
|
94
103
|
middleware = AuthMiddleware(app=app, config=config)
|
|
95
104
|
app.add_middleware(AuthMiddleware, config=config)
|
|
@@ -1,13 +1,9 @@
|
|
|
1
|
-
import os
|
|
2
1
|
import requests
|
|
3
2
|
import logging
|
|
4
3
|
import traceback
|
|
5
4
|
import copy
|
|
6
|
-
from dotenv import load_dotenv
|
|
7
5
|
from enum import Enum
|
|
8
|
-
|
|
9
|
-
# 加载环境变量
|
|
10
|
-
load_dotenv()
|
|
6
|
+
from fastapi import HTTPException, status
|
|
11
7
|
|
|
12
8
|
|
|
13
9
|
class CredentialTypeEnum(str, Enum):
|
|
@@ -19,14 +15,31 @@ class CredentialTypeEnum(str, Enum):
|
|
|
19
15
|
|
|
20
16
|
|
|
21
17
|
class ConnectAgenterraIam(object):
|
|
22
|
-
|
|
18
|
+
_instance = None
|
|
19
|
+
_initialized = False
|
|
20
|
+
|
|
21
|
+
def __new__(cls, config=None, logger_name="skyplatform_iam", log_level=logging.INFO):
|
|
22
|
+
"""
|
|
23
|
+
单例模式实现
|
|
24
|
+
确保整个应用中只有一个ConnectAgenterraIam实例
|
|
25
|
+
"""
|
|
26
|
+
if cls._instance is None:
|
|
27
|
+
cls._instance = super(ConnectAgenterraIam, cls).__new__(cls)
|
|
28
|
+
return cls._instance
|
|
29
|
+
|
|
30
|
+
def __init__(self, config=None, logger_name="skyplatform_iam", log_level=logging.INFO):
|
|
23
31
|
"""
|
|
24
32
|
初始化AgenterraIAM连接器
|
|
25
33
|
|
|
26
34
|
参数:
|
|
35
|
+
- config: AuthConfig配置对象,如果为None则从环境变量读取
|
|
27
36
|
- logger_name: 日志记录器名称
|
|
28
37
|
- log_level: 日志级别
|
|
29
38
|
"""
|
|
39
|
+
# 防止重复初始化
|
|
40
|
+
if self._initialized:
|
|
41
|
+
return
|
|
42
|
+
|
|
30
43
|
# 配置日志记录器
|
|
31
44
|
self.logger = logging.getLogger(logger_name)
|
|
32
45
|
if not self.logger.handlers:
|
|
@@ -38,10 +51,22 @@ class ConnectAgenterraIam(object):
|
|
|
38
51
|
self.logger.addHandler(handler)
|
|
39
52
|
self.logger.setLevel(log_level)
|
|
40
53
|
|
|
41
|
-
#
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
54
|
+
# 必须传入config参数,不再支持从环境变量读取
|
|
55
|
+
if config is None:
|
|
56
|
+
raise ValueError("必须传入AuthConfig配置对象,不再支持从环境变量读取配置")
|
|
57
|
+
|
|
58
|
+
self.agenterra_iam_host = config.agenterra_iam_host
|
|
59
|
+
self.server_name = config.server_name
|
|
60
|
+
self.access_key = config.access_key
|
|
61
|
+
self.logger.info("使用传入的AuthConfig配置")
|
|
62
|
+
|
|
63
|
+
# 验证必要的配置
|
|
64
|
+
if not self.agenterra_iam_host:
|
|
65
|
+
self.logger.warning("AGENTERRA_IAM_HOST 配置未设置")
|
|
66
|
+
if not self.server_name:
|
|
67
|
+
self.logger.warning("AGENTERRA_SERVER_NAME 配置未设置")
|
|
68
|
+
if not self.access_key:
|
|
69
|
+
self.logger.warning("AGENTERRA_ACCESS_KEY 配置未设置")
|
|
45
70
|
|
|
46
71
|
self.logger.info(f"初始化AgenterraIAM连接器 - Host: {self.agenterra_iam_host}, Server: {self._mask_sensitive(self.server_name)}")
|
|
47
72
|
|
|
@@ -54,6 +79,48 @@ class ConnectAgenterraIam(object):
|
|
|
54
79
|
"server_name": self.server_name,
|
|
55
80
|
"access_key": self.access_key
|
|
56
81
|
}
|
|
82
|
+
|
|
83
|
+
# 标记为已初始化
|
|
84
|
+
self._initialized = True
|
|
85
|
+
|
|
86
|
+
def reload_config(self, config):
|
|
87
|
+
"""
|
|
88
|
+
重新加载配置
|
|
89
|
+
用于在运行时更新配置
|
|
90
|
+
|
|
91
|
+
参数:
|
|
92
|
+
- config: AuthConfig配置对象
|
|
93
|
+
"""
|
|
94
|
+
if config is None:
|
|
95
|
+
raise ValueError("必须传入AuthConfig配置对象")
|
|
96
|
+
|
|
97
|
+
self.logger.info("重新加载配置")
|
|
98
|
+
|
|
99
|
+
# 更新配置
|
|
100
|
+
self.agenterra_iam_host = config.agenterra_iam_host
|
|
101
|
+
self.server_name = config.server_name
|
|
102
|
+
self.access_key = config.access_key
|
|
103
|
+
|
|
104
|
+
# 验证必要的配置
|
|
105
|
+
if not self.agenterra_iam_host:
|
|
106
|
+
self.logger.warning("AGENTERRA_IAM_HOST 配置未设置")
|
|
107
|
+
if not self.server_name:
|
|
108
|
+
self.logger.warning("AGENTERRA_SERVER_NAME 配置未设置")
|
|
109
|
+
if not self.access_key:
|
|
110
|
+
self.logger.warning("AGENTERRA_ACCESS_KEY 配置未设置")
|
|
111
|
+
|
|
112
|
+
# 更新headers和body
|
|
113
|
+
self.headers = {
|
|
114
|
+
"Content-Type": "application/json",
|
|
115
|
+
"SERVER-AK": self.server_name,
|
|
116
|
+
"SERVER-SK": self.access_key
|
|
117
|
+
}
|
|
118
|
+
self.body = {
|
|
119
|
+
"server_name": self.server_name,
|
|
120
|
+
"access_key": self.access_key
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
self.logger.info(f"配置重新加载完成 - Host: {self.agenterra_iam_host}, Server: {self._mask_sensitive(self.server_name)}")
|
|
57
124
|
|
|
58
125
|
def _mask_sensitive(self, value, mask_char="*", show_chars=4):
|
|
59
126
|
"""
|
|
@@ -430,6 +497,11 @@ class ConnectAgenterraIam(object):
|
|
|
430
497
|
"server_sk": server_sk,
|
|
431
498
|
}
|
|
432
499
|
uri = "/api/v2/service/verify"
|
|
500
|
+
|
|
501
|
+
# 检查agenterra_iam_host是否为None
|
|
502
|
+
if self.agenterra_iam_host is None:
|
|
503
|
+
raise ValueError("AGENTERRA_IAM_HOST 配置未设置或为空,请确保传入正确的AuthConfig对象")
|
|
504
|
+
|
|
433
505
|
url = self.agenterra_iam_host + uri
|
|
434
506
|
|
|
435
507
|
# 记录请求信息
|
|
@@ -463,7 +535,6 @@ class ConnectAgenterraIam(object):
|
|
|
463
535
|
else:
|
|
464
536
|
# token有效但无权限,抛出403异常
|
|
465
537
|
self.logger.warning(f"[{method_name}] token有效但用户无权限访问API: {api}")
|
|
466
|
-
from fastapi import HTTPException, status
|
|
467
538
|
raise HTTPException(
|
|
468
539
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
469
540
|
detail=result.get("message", "用户无权限访问此API")
|
|
@@ -475,7 +546,6 @@ class ConnectAgenterraIam(object):
|
|
|
475
546
|
result = response.json()
|
|
476
547
|
# 处理403响应
|
|
477
548
|
self.logger.warning(f"[{method_name}] 收到403响应 - {result.get('message', '用户无权限访问此API')}")
|
|
478
|
-
from fastapi import HTTPException, status
|
|
479
549
|
raise HTTPException(
|
|
480
550
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
481
551
|
detail=result.get("message", "用户无权限访问此API")
|
skyplatform_iam/middleware.py
CHANGED
|
@@ -3,9 +3,11 @@ SkyPlatform IAM SDK 中间件模块
|
|
|
3
3
|
"""
|
|
4
4
|
import logging
|
|
5
5
|
from typing import Optional, Callable, Dict, Any
|
|
6
|
-
from fastapi import Request, Response, HTTPException
|
|
6
|
+
from fastapi import Request, Response, HTTPException, status
|
|
7
7
|
from fastapi.responses import JSONResponse
|
|
8
|
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
8
9
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
10
|
+
import jwt
|
|
9
11
|
|
|
10
12
|
from .config import AuthConfig
|
|
11
13
|
from .connect_agenterra_iam import ConnectAgenterraIam
|
|
@@ -40,7 +42,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|
|
40
42
|
"""
|
|
41
43
|
super().__init__(app)
|
|
42
44
|
self.config = config
|
|
43
|
-
self.iam_client = ConnectAgenterraIam()
|
|
45
|
+
self.iam_client = ConnectAgenterraIam(config=config)
|
|
44
46
|
self.skip_validation = skip_validation
|
|
45
47
|
|
|
46
48
|
# 验证配置
|
|
@@ -49,11 +51,32 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|
|
49
51
|
except ValueError as e:
|
|
50
52
|
raise ConfigurationError(str(e))
|
|
51
53
|
|
|
54
|
+
def is_path_whitelisted(self, path: str) -> bool:
|
|
55
|
+
"""
|
|
56
|
+
检查路径是否在本地白名单中
|
|
57
|
+
"""
|
|
58
|
+
if not self.config:
|
|
59
|
+
return False
|
|
60
|
+
return self.config.is_path_whitelisted(path)
|
|
61
|
+
|
|
52
62
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
|
53
63
|
"""
|
|
54
64
|
中间件主要处理逻辑
|
|
55
65
|
"""
|
|
56
66
|
try:
|
|
67
|
+
# 获取请求路径
|
|
68
|
+
api_path = request.url.path
|
|
69
|
+
|
|
70
|
+
# 首先检查路径是否在本地白名单中
|
|
71
|
+
if self.is_path_whitelisted(api_path):
|
|
72
|
+
logger.info(f"路径 {api_path} 在本地白名单中,跳过认证直接允许访问")
|
|
73
|
+
# 设置白名单标识
|
|
74
|
+
request.state.user = None
|
|
75
|
+
request.state.authenticated = False
|
|
76
|
+
request.state.is_whitelist = True
|
|
77
|
+
# 直接调用下一个处理器
|
|
78
|
+
response = await call_next(request)
|
|
79
|
+
return response
|
|
57
80
|
|
|
58
81
|
# 提取Token(可能为空,白名单接口不需要token)
|
|
59
82
|
token = self._extract_token(request)
|
|
@@ -84,7 +107,6 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|
|
84
107
|
return response
|
|
85
108
|
|
|
86
109
|
except HTTPException as e:
|
|
87
|
-
# FastAPI HTTPException直接返回
|
|
88
110
|
return self._create_error_response(
|
|
89
111
|
status_code=e.status_code,
|
|
90
112
|
message=str(e.detail),
|
|
@@ -184,3 +206,213 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|
|
184
206
|
status_code=status_code,
|
|
185
207
|
content=error_data
|
|
186
208
|
)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
class AuthService:
|
|
212
|
+
"""
|
|
213
|
+
认证服务类
|
|
214
|
+
提供依赖注入式的认证功能
|
|
215
|
+
"""
|
|
216
|
+
|
|
217
|
+
def __init__(self, auth_config: AuthConfig):
|
|
218
|
+
if auth_config is None:
|
|
219
|
+
raise ValueError("auth_config参数不能为None,必须传入AuthConfig配置对象")
|
|
220
|
+
self.security = HTTPBearer(auto_error=False)
|
|
221
|
+
self.iam_client = ConnectAgenterraIam(config=auth_config)
|
|
222
|
+
self.auth_config = auth_config
|
|
223
|
+
|
|
224
|
+
def is_path_whitelisted(self, path: str) -> bool:
|
|
225
|
+
"""
|
|
226
|
+
检查路径是否在白名单中
|
|
227
|
+
"""
|
|
228
|
+
if not self.auth_config:
|
|
229
|
+
return False
|
|
230
|
+
return self.auth_config.is_path_whitelisted(path)
|
|
231
|
+
|
|
232
|
+
async def verify_token(self, request: Request):
|
|
233
|
+
"""验证token和权限"""
|
|
234
|
+
# 通过token, server_ak, server_sk判断是否有权限
|
|
235
|
+
api_path = request.url.path
|
|
236
|
+
|
|
237
|
+
# 首先检查路径是否在白名单中
|
|
238
|
+
if self.is_path_whitelisted(api_path):
|
|
239
|
+
logger.info(f"路径 {api_path} 在白名单中,跳过IAM鉴权")
|
|
240
|
+
return True
|
|
241
|
+
|
|
242
|
+
credentials: HTTPAuthorizationCredentials = await self.security(request)
|
|
243
|
+
method = request.method
|
|
244
|
+
|
|
245
|
+
server_ak = request.headers.get("SERVER-AK", "")
|
|
246
|
+
server_sk = request.headers.get("SERVER-SK", "")
|
|
247
|
+
|
|
248
|
+
token = ""
|
|
249
|
+
if credentials is not None:
|
|
250
|
+
token = credentials.credentials
|
|
251
|
+
user_info_by_iam = self.iam_client.verify_token(token, api_path, method, server_ak, server_sk)
|
|
252
|
+
if user_info_by_iam:
|
|
253
|
+
return True
|
|
254
|
+
return False
|
|
255
|
+
|
|
256
|
+
async def get_current_user(self, request: Request) -> Optional[Dict]:
|
|
257
|
+
"""获取当前用户信息"""
|
|
258
|
+
try:
|
|
259
|
+
# 直接调用verify_token方法进行token验证
|
|
260
|
+
if not await self.verify_token(request):
|
|
261
|
+
return None
|
|
262
|
+
|
|
263
|
+
# 获取token用于后续用户信息获取
|
|
264
|
+
credentials: HTTPAuthorizationCredentials = await self.security(request)
|
|
265
|
+
if not credentials:
|
|
266
|
+
return None
|
|
267
|
+
|
|
268
|
+
token = credentials.credentials
|
|
269
|
+
|
|
270
|
+
# 直接解析JWT token获取payload
|
|
271
|
+
payload = self.decode_jwt_token(token)
|
|
272
|
+
if not payload:
|
|
273
|
+
logger.error("JWT token解析失败")
|
|
274
|
+
return None
|
|
275
|
+
|
|
276
|
+
# 从payload中提取用户信息
|
|
277
|
+
iam_user_id = payload.get("sub") # JWT标准中用户ID存储在sub字段
|
|
278
|
+
username = None
|
|
279
|
+
|
|
280
|
+
# 解析新的凭证信息结构
|
|
281
|
+
all_credentials = payload.get("all_credentials", [])
|
|
282
|
+
total_credentials = payload.get("total_credentials", 0)
|
|
283
|
+
|
|
284
|
+
# 从all_credentials中提取username(向后兼容)
|
|
285
|
+
for cred in all_credentials:
|
|
286
|
+
if cred.get("type") == "username":
|
|
287
|
+
username = cred.get("value")
|
|
288
|
+
break
|
|
289
|
+
|
|
290
|
+
# 向后兼容性:如果没有all_credentials,尝试从payload的其他字段构建
|
|
291
|
+
if not all_credentials:
|
|
292
|
+
credentials_list = []
|
|
293
|
+
# 检查payload中是否有直接的username字段
|
|
294
|
+
if payload.get("username"):
|
|
295
|
+
username = payload.get("username")
|
|
296
|
+
credentials_list.append({"type": "username", "value": username})
|
|
297
|
+
if payload.get("email"):
|
|
298
|
+
credentials_list.append({"type": "email", "value": payload.get("email")})
|
|
299
|
+
if payload.get("phone"):
|
|
300
|
+
credentials_list.append({"type": "phone", "value": payload.get("phone")})
|
|
301
|
+
all_credentials = credentials_list
|
|
302
|
+
total_credentials = len(credentials_list)
|
|
303
|
+
|
|
304
|
+
if not username:
|
|
305
|
+
return None
|
|
306
|
+
|
|
307
|
+
# 构建用户信息字典
|
|
308
|
+
user_info = {
|
|
309
|
+
"id": iam_user_id,
|
|
310
|
+
"username": username,
|
|
311
|
+
"all_credentials": all_credentials,
|
|
312
|
+
"total_credentials": total_credentials,
|
|
313
|
+
"microservice": payload.get("microservice") # 添加微服务信息
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
# 向后兼容:添加传统字段映射
|
|
317
|
+
for cred in all_credentials:
|
|
318
|
+
if cred.get("type") == "email":
|
|
319
|
+
user_info["email"] = cred.get("value")
|
|
320
|
+
elif cred.get("type") == "phone":
|
|
321
|
+
user_info["phone"] = cred.get("value")
|
|
322
|
+
elif cred.get("type") == "username" and not user_info.get("username"):
|
|
323
|
+
user_info["username"] = cred.get("value")
|
|
324
|
+
|
|
325
|
+
# 统计凭证类型分布
|
|
326
|
+
cred_types = [cred.get("type") for cred in all_credentials]
|
|
327
|
+
cred_type_count = {cred_type: cred_types.count(cred_type) for cred_type in set(cred_types)}
|
|
328
|
+
|
|
329
|
+
logger.info(
|
|
330
|
+
f"用户认证成功: user_id={iam_user_id}, username={username}, 凭证数量={total_credentials}, 凭证类型分布={cred_type_count}")
|
|
331
|
+
logger.debug(f"JWT payload: {payload}")
|
|
332
|
+
|
|
333
|
+
# 将用户信息添加到请求状态中
|
|
334
|
+
request.state.user = user_info
|
|
335
|
+
return user_info
|
|
336
|
+
|
|
337
|
+
except HTTPException as e:
|
|
338
|
+
logger.error(f"获取当前用户信息失败: {str(e)}")
|
|
339
|
+
# 重新抛出HTTP异常(403权限不足)
|
|
340
|
+
return None
|
|
341
|
+
except Exception as e:
|
|
342
|
+
logger.error(f"获取当前用户信息失败: {str(e)}")
|
|
343
|
+
return None
|
|
344
|
+
|
|
345
|
+
async def require_auth(self, request: Request) -> Dict:
|
|
346
|
+
"""要求用户必须登录"""
|
|
347
|
+
try:
|
|
348
|
+
user_info = await self.get_current_user(request)
|
|
349
|
+
if not user_info:
|
|
350
|
+
raise HTTPException(
|
|
351
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
352
|
+
detail="需要登录认证",
|
|
353
|
+
headers={"WWW-Authenticate": "Bearer"},
|
|
354
|
+
)
|
|
355
|
+
return user_info
|
|
356
|
+
except HTTPException:
|
|
357
|
+
# 重新抛出HTTP异常(可能是403权限不足或401未认证)
|
|
358
|
+
raise
|
|
359
|
+
|
|
360
|
+
async def optional_auth(self, request: Request) -> Optional[Dict]:
|
|
361
|
+
"""可选的用户认证(不强制要求登录)"""
|
|
362
|
+
try:
|
|
363
|
+
return await self.get_current_user(request)
|
|
364
|
+
except HTTPException:
|
|
365
|
+
# 对于可选认证,如果是403权限不足,仍然抛出异常
|
|
366
|
+
# 如果是401未认证,返回None
|
|
367
|
+
raise
|
|
368
|
+
|
|
369
|
+
def decode_jwt_token(self, token: str) -> Optional[Dict]:
|
|
370
|
+
"""直接解析JWT token获取payload"""
|
|
371
|
+
try:
|
|
372
|
+
# 不验证签名,只解析payload(因为token已经通过verify_token验证过)
|
|
373
|
+
decoded_payload = jwt.decode(token, options={"verify_signature": False})
|
|
374
|
+
logger.debug(f"JWT token解析成功: {decoded_payload}")
|
|
375
|
+
return decoded_payload
|
|
376
|
+
except jwt.InvalidTokenError as e:
|
|
377
|
+
logger.error(f"JWT token解析失败: {str(e)}")
|
|
378
|
+
return None
|
|
379
|
+
except Exception as e:
|
|
380
|
+
logger.error(f"JWT token解析异常: {str(e)}")
|
|
381
|
+
return None
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
# 全局认证服务实例(延迟初始化)
|
|
385
|
+
auth_service = None
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
def setup_auth_middleware(auth_config: AuthConfig) -> None:
|
|
389
|
+
"""
|
|
390
|
+
设置认证中间件配置
|
|
391
|
+
|
|
392
|
+
Args:
|
|
393
|
+
auth_config: 认证配置实例,包含白名单路径等配置
|
|
394
|
+
"""
|
|
395
|
+
global auth_service
|
|
396
|
+
auth_service = AuthService(auth_config)
|
|
397
|
+
logger.info(f"认证中间件已配置,白名单路径数量: {len(auth_config.get_whitelist_paths())}")
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
# 便捷的依赖函数
|
|
401
|
+
async def get_current_user(request: Request) -> Dict:
|
|
402
|
+
"""获取当前用户的依赖函数"""
|
|
403
|
+
if auth_service is None:
|
|
404
|
+
raise HTTPException(
|
|
405
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
406
|
+
detail="认证服务未初始化,请先调用setup_auth_middleware函数进行配置"
|
|
407
|
+
)
|
|
408
|
+
return await auth_service.require_auth(request)
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
async def get_optional_user(request: Request) -> Optional[Dict]:
|
|
412
|
+
"""获取可选当前用户的依赖函数"""
|
|
413
|
+
if auth_service is None:
|
|
414
|
+
raise HTTPException(
|
|
415
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
416
|
+
detail="认证服务未初始化,请先调用setup_auth_middleware函数进行配置"
|
|
417
|
+
)
|
|
418
|
+
return await auth_service.optional_auth(request)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: skyplatform-iam
|
|
3
|
-
Version: 1.0.
|
|
3
|
+
Version: 1.0.3
|
|
4
4
|
Summary: SkyPlatform IAM认证SDK,提供FastAPI中间件和认证路由
|
|
5
5
|
Project-URL: Homepage, https://github.com/xinmayoujiang12621/agenterra_iam
|
|
6
6
|
Project-URL: Documentation, https://skyplatform-iam.readthedocs.io/
|
|
@@ -27,7 +27,6 @@ Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
|
27
27
|
Requires-Python: >=3.9
|
|
28
28
|
Requires-Dist: fastapi>=0.68.0
|
|
29
29
|
Requires-Dist: pydantic>=1.8.0
|
|
30
|
-
Requires-Dist: python-dotenv>=0.19.0
|
|
31
30
|
Requires-Dist: requests>=2.25.0
|
|
32
31
|
Requires-Dist: starlette>=0.14.0
|
|
33
32
|
Provides-Extra: dev
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
skyplatform_iam/__init__.py,sha256=3I9OSLQS8-5CLwWobi2Zxuw1yw1Fro3ez9gd-HSGL_s,2835
|
|
2
|
+
skyplatform_iam/config.py,sha256=s4tctVpguKZv4O1Fhf7_Fo7zELNX6KYviMjkE1WPbQM,3715
|
|
3
|
+
skyplatform_iam/connect_agenterra_iam.py,sha256=uC4SoKRPHOaY_99o8TYfUDXHOvsFqxPiVAb5NYax0D0,34540
|
|
4
|
+
skyplatform_iam/exceptions.py,sha256=Rt55QIzVK1F_kn6yzKQKKakD6PZDFdPLCGaCphKKms8,2166
|
|
5
|
+
skyplatform_iam/middleware.py,sha256=XNJxvjw3O55TW-ff_uORK-C9Wy4BTAfkcnNjy1SQkx0,15721
|
|
6
|
+
skyplatform_iam-1.0.3.dist-info/METADATA,sha256=y2Lby7o6Z2meqorX5s0PpxOfdlfWIc-gyQZ8YVWwgXw,6990
|
|
7
|
+
skyplatform_iam-1.0.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
8
|
+
skyplatform_iam-1.0.3.dist-info/RECORD,,
|
|
@@ -1,201 +0,0 @@
|
|
|
1
|
-
from fastapi import Request, HTTPException, status
|
|
2
|
-
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
3
|
-
from typing import Optional, Dict
|
|
4
|
-
import jwt
|
|
5
|
-
|
|
6
|
-
from .connect_agenterra_iam import ConnectAgenterraIam
|
|
7
|
-
from .config import AuthConfig
|
|
8
|
-
import logging
|
|
9
|
-
|
|
10
|
-
logger = logging.getLogger(__name__)
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class AuthMiddleware:
|
|
14
|
-
def __init__(self, auth_config: Optional[AuthConfig] = None):
|
|
15
|
-
self.security = HTTPBearer(auto_error=False)
|
|
16
|
-
self.iam_client = ConnectAgenterraIam()
|
|
17
|
-
self.auth_config = auth_config
|
|
18
|
-
|
|
19
|
-
def is_path_whitelisted(self, path: str) -> bool:
|
|
20
|
-
"""
|
|
21
|
-
检查路径是否在白名单中
|
|
22
|
-
"""
|
|
23
|
-
if not self.auth_config:
|
|
24
|
-
return False
|
|
25
|
-
return self.auth_config.is_path_whitelisted(path)
|
|
26
|
-
|
|
27
|
-
async def verify_token(self, request: Request):
|
|
28
|
-
# 通过token, server_ak, server_sk判断是否有权限
|
|
29
|
-
api_path = request.url.path
|
|
30
|
-
|
|
31
|
-
# 首先检查路径是否在白名单中
|
|
32
|
-
if self.is_path_whitelisted(api_path):
|
|
33
|
-
logger.info(f"路径 {api_path} 在白名单中,跳过IAM鉴权")
|
|
34
|
-
return True
|
|
35
|
-
|
|
36
|
-
credentials: HTTPAuthorizationCredentials = await self.security(request)
|
|
37
|
-
method = request.method
|
|
38
|
-
|
|
39
|
-
server_ak = request.headers.get("SERVER-AK", "")
|
|
40
|
-
server_sk = request.headers.get("SERVER-SK", "")
|
|
41
|
-
|
|
42
|
-
token = ""
|
|
43
|
-
if credentials is not None:
|
|
44
|
-
token = credentials.credentials
|
|
45
|
-
user_info_by_iam = self.iam_client.verify_token(token, api_path, method, server_ak, server_sk)
|
|
46
|
-
if user_info_by_iam:
|
|
47
|
-
return True
|
|
48
|
-
return False
|
|
49
|
-
|
|
50
|
-
async def get_current_user(self, request: Request) -> Optional[Dict]:
|
|
51
|
-
"""获取当前用户信息"""
|
|
52
|
-
try:
|
|
53
|
-
# 直接调用verify_token方法进行token验证
|
|
54
|
-
if not await self.verify_token(request):
|
|
55
|
-
return None
|
|
56
|
-
|
|
57
|
-
# 获取token用于后续用户信息获取
|
|
58
|
-
credentials: HTTPAuthorizationCredentials = await self.security(request)
|
|
59
|
-
token = credentials.credentials
|
|
60
|
-
|
|
61
|
-
# 直接解析JWT token获取payload
|
|
62
|
-
payload = self.decode_jwt_token(token)
|
|
63
|
-
if not payload:
|
|
64
|
-
logger.error("JWT token解析失败")
|
|
65
|
-
return None
|
|
66
|
-
|
|
67
|
-
# 从payload中提取用户信息
|
|
68
|
-
iam_user_id = payload.get("sub") # JWT标准中用户ID存储在sub字段
|
|
69
|
-
username = None
|
|
70
|
-
|
|
71
|
-
# 解析新的凭证信息结构
|
|
72
|
-
all_credentials = payload.get("all_credentials", [])
|
|
73
|
-
total_credentials = payload.get("total_credentials", 0)
|
|
74
|
-
|
|
75
|
-
# 从all_credentials中提取username(向后兼容)
|
|
76
|
-
for cred in all_credentials:
|
|
77
|
-
if cred.get("type") == "username":
|
|
78
|
-
username = cred.get("value")
|
|
79
|
-
break
|
|
80
|
-
|
|
81
|
-
# 向后兼容性:如果没有all_credentials,尝试从payload的其他字段构建
|
|
82
|
-
if not all_credentials:
|
|
83
|
-
credentials_list = []
|
|
84
|
-
# 检查payload中是否有直接的username字段
|
|
85
|
-
if payload.get("username"):
|
|
86
|
-
username = payload.get("username")
|
|
87
|
-
credentials_list.append({"type": "username", "value": username})
|
|
88
|
-
if payload.get("email"):
|
|
89
|
-
credentials_list.append({"type": "email", "value": payload.get("email")})
|
|
90
|
-
if payload.get("phone"):
|
|
91
|
-
credentials_list.append({"type": "phone", "value": payload.get("phone")})
|
|
92
|
-
all_credentials = credentials_list
|
|
93
|
-
total_credentials = len(credentials_list)
|
|
94
|
-
|
|
95
|
-
if not username:
|
|
96
|
-
return None
|
|
97
|
-
|
|
98
|
-
# 构建用户信息字典
|
|
99
|
-
user_info = {
|
|
100
|
-
"id": iam_user_id,
|
|
101
|
-
"username": username,
|
|
102
|
-
"all_credentials": all_credentials,
|
|
103
|
-
"total_credentials": total_credentials,
|
|
104
|
-
"microservice": payload.get("microservice") # 添加微服务信息
|
|
105
|
-
}
|
|
106
|
-
|
|
107
|
-
# 向后兼容:添加传统字段映射
|
|
108
|
-
for cred in all_credentials:
|
|
109
|
-
if cred.get("type") == "email":
|
|
110
|
-
user_info["email"] = cred.get("value")
|
|
111
|
-
elif cred.get("type") == "phone":
|
|
112
|
-
user_info["phone"] = cred.get("value")
|
|
113
|
-
elif cred.get("type") == "username" and not user_info.get("username"):
|
|
114
|
-
user_info["username"] = cred.get("value")
|
|
115
|
-
|
|
116
|
-
# 统计凭证类型分布
|
|
117
|
-
cred_types = [cred.get("type") for cred in all_credentials]
|
|
118
|
-
cred_type_count = {cred_type: cred_types.count(cred_type) for cred_type in set(cred_types)}
|
|
119
|
-
|
|
120
|
-
logger.info(
|
|
121
|
-
f"用户认证成功: user_id={iam_user_id}, username={username}, 凭证数量={total_credentials}, 凭证类型分布={cred_type_count}")
|
|
122
|
-
logger.debug(f"JWT payload: {payload}")
|
|
123
|
-
|
|
124
|
-
# 将用户信息添加到请求状态中
|
|
125
|
-
request.state.user = user_info
|
|
126
|
-
return user_info
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
except HTTPException as e:
|
|
130
|
-
print(403)
|
|
131
|
-
logger.error(f"获取当前用户信息失败: {str(e)}")
|
|
132
|
-
# 重新抛出HTTP异常(403权限不足)
|
|
133
|
-
return None
|
|
134
|
-
except Exception as e:
|
|
135
|
-
logger.error(f"获取当前用户信息失败: {str(e)}")
|
|
136
|
-
return None
|
|
137
|
-
|
|
138
|
-
async def require_auth(self, request: Request) -> Dict:
|
|
139
|
-
"""要求用户必须登录"""
|
|
140
|
-
try:
|
|
141
|
-
user_info = await self.get_current_user(request)
|
|
142
|
-
if not user_info:
|
|
143
|
-
raise HTTPException(
|
|
144
|
-
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
145
|
-
detail="需要登录认证",
|
|
146
|
-
headers={"WWW-Authenticate": "Bearer"},
|
|
147
|
-
)
|
|
148
|
-
return user_info
|
|
149
|
-
except HTTPException:
|
|
150
|
-
# 重新抛出HTTP异常(可能是403权限不足或401未认证)
|
|
151
|
-
raise
|
|
152
|
-
|
|
153
|
-
async def optional_auth(self, request: Request) -> Optional[Dict]:
|
|
154
|
-
"""可选的用户认证(不强制要求登录)"""
|
|
155
|
-
try:
|
|
156
|
-
return await self.get_current_user(request)
|
|
157
|
-
except HTTPException:
|
|
158
|
-
# 对于可选认证,如果是403权限不足,仍然抛出异常
|
|
159
|
-
# 如果是401未认证,返回None
|
|
160
|
-
raise
|
|
161
|
-
|
|
162
|
-
def decode_jwt_token(self, token: str) -> Optional[Dict]:
|
|
163
|
-
"""直接解析JWT token获取payload"""
|
|
164
|
-
try:
|
|
165
|
-
# 不验证签名,只解析payload(因为token已经通过verify_token验证过)
|
|
166
|
-
decoded_payload = jwt.decode(token, options={"verify_signature": False})
|
|
167
|
-
logger.debug(f"JWT token解析成功: {decoded_payload}")
|
|
168
|
-
return decoded_payload
|
|
169
|
-
except jwt.InvalidTokenError as e:
|
|
170
|
-
logger.error(f"JWT token解析失败: {str(e)}")
|
|
171
|
-
return None
|
|
172
|
-
except Exception as e:
|
|
173
|
-
logger.error(f"JWT token解析异常: {str(e)}")
|
|
174
|
-
return None
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
# 创建全局认证中间件实例
|
|
178
|
-
auth_middleware = AuthMiddleware()
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
def setup_auth_middleware(auth_config: AuthConfig) -> None:
|
|
182
|
-
"""
|
|
183
|
-
设置认证中间件配置
|
|
184
|
-
|
|
185
|
-
Args:
|
|
186
|
-
auth_config: 认证配置实例,包含白名单路径等配置
|
|
187
|
-
"""
|
|
188
|
-
global auth_middleware
|
|
189
|
-
auth_middleware = AuthMiddleware(auth_config)
|
|
190
|
-
logger.info(f"认证中间件已配置,白名单路径数量: {len(auth_config.get_whitelist_paths())}")
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
# 便捷的依赖函数
|
|
194
|
-
async def get_current_user(request: Request) -> Dict:
|
|
195
|
-
"""获取当前用户的依赖函数"""
|
|
196
|
-
return await auth_middleware.require_auth(request)
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
async def get_optional_user(request: Request) -> Optional[Dict]:
|
|
200
|
-
"""获取可选当前用户的依赖函数"""
|
|
201
|
-
return await auth_middleware.optional_auth(request)
|
|
@@ -1,9 +0,0 @@
|
|
|
1
|
-
skyplatform_iam/__init__.py,sha256=w_HFG7ddDO-fFsaGMZbBCkacD3VSk8o3ttvv7wluDIg,2516
|
|
2
|
-
skyplatform_iam/auth_middleware.py,sha256=1CvrjR-rFDAxn9YdwD9xWPEDnPmgDWUiELgX_7e4114,8119
|
|
3
|
-
skyplatform_iam/config.py,sha256=s4tctVpguKZv4O1Fhf7_Fo7zELNX6KYviMjkE1WPbQM,3715
|
|
4
|
-
skyplatform_iam/connect_agenterra_iam.py,sha256=kF4iWMhV-NoxCHgV7pyoClK9UliqC16n-E9V1aDPfKw,31843
|
|
5
|
-
skyplatform_iam/exceptions.py,sha256=Rt55QIzVK1F_kn6yzKQKKakD6PZDFdPLCGaCphKKms8,2166
|
|
6
|
-
skyplatform_iam/middleware.py,sha256=Yg-pX-wI8ROYCIHtwW7F1ABiFK5FSNlDJUgMc6fD_b4,6231
|
|
7
|
-
skyplatform_iam-1.0.1.dist-info/METADATA,sha256=cXkh0Utk6SnZkikCKU2dutHaxJ38YMisIEk1rP_vov8,7027
|
|
8
|
-
skyplatform_iam-1.0.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
9
|
-
skyplatform_iam-1.0.1.dist-info/RECORD,,
|
|
File without changes
|