skyplatform-iam 1.0.0__tar.gz → 1.0.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {skyplatform_iam-1.0.0 → skyplatform_iam-1.0.3}/PKG-INFO +2 -3
- {skyplatform_iam-1.0.0 → skyplatform_iam-1.0.3}/pyproject.toml +2 -3
- {skyplatform_iam-1.0.0 → skyplatform_iam-1.0.3}/skyplatform_iam/__init__.py +12 -1
- skyplatform_iam-1.0.3/skyplatform_iam/config.py +132 -0
- {skyplatform_iam-1.0.0 → skyplatform_iam-1.0.3}/skyplatform_iam/connect_agenterra_iam.py +82 -12
- skyplatform_iam-1.0.3/skyplatform_iam/middleware.py +418 -0
- skyplatform_iam-1.0.0/skyplatform_iam/auth_middleware.py +0 -173
- skyplatform_iam-1.0.0/skyplatform_iam/config.py +0 -68
- skyplatform_iam-1.0.0/skyplatform_iam/middleware.py +0 -186
- {skyplatform_iam-1.0.0 → skyplatform_iam-1.0.3}/README.md +0 -0
- {skyplatform_iam-1.0.0 → skyplatform_iam-1.0.3}/skyplatform_iam/exceptions.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: skyplatform-iam
|
|
3
|
-
Version: 1.0.
|
|
3
|
+
Version: 1.0.3
|
|
4
4
|
Summary: SkyPlatform IAM认证SDK,提供FastAPI中间件和认证路由
|
|
5
5
|
Project-URL: Homepage, https://github.com/xinmayoujiang12621/agenterra_iam
|
|
6
6
|
Project-URL: Documentation, https://skyplatform-iam.readthedocs.io/
|
|
@@ -24,10 +24,9 @@ Classifier: Programming Language :: Python :: 3.12
|
|
|
24
24
|
Classifier: Topic :: Internet :: WWW/HTTP :: HTTP Servers
|
|
25
25
|
Classifier: Topic :: Security
|
|
26
26
|
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
27
|
-
Requires-Python: >=3.
|
|
27
|
+
Requires-Python: >=3.9
|
|
28
28
|
Requires-Dist: fastapi>=0.68.0
|
|
29
29
|
Requires-Dist: pydantic>=1.8.0
|
|
30
|
-
Requires-Dist: python-dotenv>=0.19.0
|
|
31
30
|
Requires-Dist: requests>=2.25.0
|
|
32
31
|
Requires-Dist: starlette>=0.14.0
|
|
33
32
|
Provides-Extra: dev
|
|
@@ -4,13 +4,13 @@ build-backend = "hatchling.build"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "skyplatform-iam"
|
|
7
|
-
version = "1.0.
|
|
7
|
+
version = "1.0.3"
|
|
8
8
|
authors = [
|
|
9
9
|
{ name="x9", email="xuanxienanxunmobao@gmail.com" },
|
|
10
10
|
]
|
|
11
11
|
description = "SkyPlatform IAM认证SDK,提供FastAPI中间件和认证路由"
|
|
12
12
|
readme = "README.md"
|
|
13
|
-
requires-python = ">=3.
|
|
13
|
+
requires-python = ">=3.9"
|
|
14
14
|
license = { text = "MIT" }
|
|
15
15
|
keywords = ["fastapi", "authentication", "middleware", "iam", "skyplatform"]
|
|
16
16
|
classifiers = [
|
|
@@ -34,7 +34,6 @@ dependencies = [
|
|
|
34
34
|
"fastapi>=0.68.0",
|
|
35
35
|
"pydantic>=1.8.0",
|
|
36
36
|
"requests>=2.25.0",
|
|
37
|
-
"python-dotenv>=0.19.0",
|
|
38
37
|
"starlette>=0.14.0",
|
|
39
38
|
]
|
|
40
39
|
|
|
@@ -4,7 +4,7 @@ SkyPlatform IAM SDK
|
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
6
|
from .config import AuthConfig
|
|
7
|
-
from .middleware import AuthMiddleware
|
|
7
|
+
from .middleware import AuthMiddleware, AuthService, setup_auth_middleware, get_current_user, get_optional_user
|
|
8
8
|
from .connect_agenterra_iam import ConnectAgenterraIam
|
|
9
9
|
from .exceptions import (
|
|
10
10
|
SkyPlatformAuthException,
|
|
@@ -28,6 +28,10 @@ __all__ = [
|
|
|
28
28
|
|
|
29
29
|
# 中间件
|
|
30
30
|
"AuthMiddleware",
|
|
31
|
+
"AuthService",
|
|
32
|
+
"setup_auth_middleware",
|
|
33
|
+
"get_current_user",
|
|
34
|
+
"get_optional_user",
|
|
31
35
|
|
|
32
36
|
# 客户端
|
|
33
37
|
"ConnectAgenterraIam",
|
|
@@ -84,10 +88,17 @@ def setup_auth(app, config: AuthConfig = None):
|
|
|
84
88
|
Note:
|
|
85
89
|
此函数只设置认证中间件,不包含预制路由。
|
|
86
90
|
客户端应用需要根据业务需求自己实现认证相关的API接口。
|
|
91
|
+
建议传入完整的AuthConfig对象以避免环境变量配置问题。
|
|
87
92
|
"""
|
|
88
93
|
if config is None:
|
|
89
94
|
config = AuthConfig.from_env()
|
|
90
95
|
|
|
96
|
+
# 验证配置的完整性
|
|
97
|
+
config.validate_config()
|
|
98
|
+
|
|
99
|
+
# 初始化全局认证服务
|
|
100
|
+
setup_auth_middleware(config)
|
|
101
|
+
|
|
91
102
|
# 添加中间件
|
|
92
103
|
middleware = AuthMiddleware(app=app, config=config)
|
|
93
104
|
app.add_middleware(AuthMiddleware, config=config)
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
"""
|
|
2
|
+
SkyPlatform IAM SDK 配置模块
|
|
3
|
+
"""
|
|
4
|
+
import os
|
|
5
|
+
import fnmatch
|
|
6
|
+
from typing import Optional, List
|
|
7
|
+
from pydantic import BaseModel, Field
|
|
8
|
+
from dotenv import load_dotenv
|
|
9
|
+
|
|
10
|
+
# 加载环境变量
|
|
11
|
+
load_dotenv()
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class AuthConfig(BaseModel):
|
|
15
|
+
"""
|
|
16
|
+
认证配置类
|
|
17
|
+
支持环境变量和代码配置
|
|
18
|
+
"""
|
|
19
|
+
# IAM服务配置
|
|
20
|
+
agenterra_iam_host: str
|
|
21
|
+
server_name: str
|
|
22
|
+
access_key: str
|
|
23
|
+
|
|
24
|
+
# Token配置
|
|
25
|
+
token_header: str = "Authorization"
|
|
26
|
+
token_prefix: str = "Bearer "
|
|
27
|
+
|
|
28
|
+
# 错误处理配置
|
|
29
|
+
enable_debug: bool = False
|
|
30
|
+
|
|
31
|
+
# 白名单路径配置(实例变量)
|
|
32
|
+
whitelist_paths: List[str] = Field(default_factory=list)
|
|
33
|
+
|
|
34
|
+
class Config:
|
|
35
|
+
env_prefix = "AGENTERRA_"
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
def from_env(cls) -> "AuthConfig":
|
|
39
|
+
"""
|
|
40
|
+
从环境变量创建配置
|
|
41
|
+
"""
|
|
42
|
+
return cls(
|
|
43
|
+
agenterra_iam_host=os.environ.get('AGENTERRA_IAM_HOST', ''),
|
|
44
|
+
server_name=os.environ.get('AGENTERRA_SERVER_NAME', ''),
|
|
45
|
+
access_key=os.environ.get('AGENTERRA_ACCESS_KEY', ''),
|
|
46
|
+
enable_debug=os.environ.get('AGENTERRA_ENABLE_DEBUG', 'false').lower() == 'true',
|
|
47
|
+
whitelist_paths=[] # 初始化空的白名单路径列表
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
def validate_config(self) -> bool:
|
|
51
|
+
"""
|
|
52
|
+
验证配置是否完整
|
|
53
|
+
"""
|
|
54
|
+
required_fields = ['agenterra_iam_host', 'server_name', 'access_key']
|
|
55
|
+
for field in required_fields:
|
|
56
|
+
if not getattr(self, field):
|
|
57
|
+
raise ValueError(f"配置项 {field} 不能为空")
|
|
58
|
+
return True
|
|
59
|
+
|
|
60
|
+
def _normalize_path(self, path: str) -> str:
|
|
61
|
+
"""
|
|
62
|
+
标准化路径格式
|
|
63
|
+
"""
|
|
64
|
+
if not path:
|
|
65
|
+
return path
|
|
66
|
+
|
|
67
|
+
# 确保路径以 / 开头
|
|
68
|
+
if not path.startswith('/'):
|
|
69
|
+
path = '/' + path
|
|
70
|
+
|
|
71
|
+
# 移除重复的斜杠
|
|
72
|
+
while '//' in path:
|
|
73
|
+
path = path.replace('//', '/')
|
|
74
|
+
|
|
75
|
+
return path
|
|
76
|
+
|
|
77
|
+
def add_whitelist_path(self, path: str) -> None:
|
|
78
|
+
"""
|
|
79
|
+
添加白名单路径
|
|
80
|
+
"""
|
|
81
|
+
if not path:
|
|
82
|
+
return
|
|
83
|
+
|
|
84
|
+
normalized_path = self._normalize_path(path)
|
|
85
|
+
if normalized_path not in self.whitelist_paths:
|
|
86
|
+
self.whitelist_paths.append(normalized_path)
|
|
87
|
+
|
|
88
|
+
def add_whitelist_paths(self, paths: List[str]) -> None:
|
|
89
|
+
"""
|
|
90
|
+
批量添加白名单路径
|
|
91
|
+
"""
|
|
92
|
+
for path in paths:
|
|
93
|
+
self.add_whitelist_path(path)
|
|
94
|
+
|
|
95
|
+
def remove_whitelist_path(self, path: str) -> None:
|
|
96
|
+
"""
|
|
97
|
+
移除白名单路径
|
|
98
|
+
"""
|
|
99
|
+
if not path:
|
|
100
|
+
return
|
|
101
|
+
|
|
102
|
+
normalized_path = self._normalize_path(path)
|
|
103
|
+
if normalized_path in self.whitelist_paths:
|
|
104
|
+
self.whitelist_paths.remove(normalized_path)
|
|
105
|
+
|
|
106
|
+
def clear_whitelist_paths(self) -> None:
|
|
107
|
+
"""
|
|
108
|
+
清空所有白名单路径
|
|
109
|
+
"""
|
|
110
|
+
self.whitelist_paths.clear()
|
|
111
|
+
|
|
112
|
+
def get_whitelist_paths(self) -> List[str]:
|
|
113
|
+
"""
|
|
114
|
+
获取所有白名单路径
|
|
115
|
+
"""
|
|
116
|
+
return self.whitelist_paths.copy()
|
|
117
|
+
|
|
118
|
+
def is_path_whitelisted(self, path: str) -> bool:
|
|
119
|
+
"""
|
|
120
|
+
检查路径是否在白名单中(支持通配符匹配)
|
|
121
|
+
"""
|
|
122
|
+
if not path:
|
|
123
|
+
return False
|
|
124
|
+
|
|
125
|
+
normalized_path = self._normalize_path(path)
|
|
126
|
+
|
|
127
|
+
for whitelist_path in self.whitelist_paths:
|
|
128
|
+
# 支持通配符匹配
|
|
129
|
+
if fnmatch.fnmatch(normalized_path, whitelist_path):
|
|
130
|
+
return True
|
|
131
|
+
|
|
132
|
+
return False
|
|
@@ -1,13 +1,9 @@
|
|
|
1
|
-
import os
|
|
2
1
|
import requests
|
|
3
2
|
import logging
|
|
4
3
|
import traceback
|
|
5
4
|
import copy
|
|
6
|
-
from dotenv import load_dotenv
|
|
7
5
|
from enum import Enum
|
|
8
|
-
|
|
9
|
-
# 加载环境变量
|
|
10
|
-
load_dotenv()
|
|
6
|
+
from fastapi import HTTPException, status
|
|
11
7
|
|
|
12
8
|
|
|
13
9
|
class CredentialTypeEnum(str, Enum):
|
|
@@ -19,14 +15,31 @@ class CredentialTypeEnum(str, Enum):
|
|
|
19
15
|
|
|
20
16
|
|
|
21
17
|
class ConnectAgenterraIam(object):
|
|
22
|
-
|
|
18
|
+
_instance = None
|
|
19
|
+
_initialized = False
|
|
20
|
+
|
|
21
|
+
def __new__(cls, config=None, logger_name="skyplatform_iam", log_level=logging.INFO):
|
|
22
|
+
"""
|
|
23
|
+
单例模式实现
|
|
24
|
+
确保整个应用中只有一个ConnectAgenterraIam实例
|
|
25
|
+
"""
|
|
26
|
+
if cls._instance is None:
|
|
27
|
+
cls._instance = super(ConnectAgenterraIam, cls).__new__(cls)
|
|
28
|
+
return cls._instance
|
|
29
|
+
|
|
30
|
+
def __init__(self, config=None, logger_name="skyplatform_iam", log_level=logging.INFO):
|
|
23
31
|
"""
|
|
24
32
|
初始化AgenterraIAM连接器
|
|
25
33
|
|
|
26
34
|
参数:
|
|
35
|
+
- config: AuthConfig配置对象,如果为None则从环境变量读取
|
|
27
36
|
- logger_name: 日志记录器名称
|
|
28
37
|
- log_level: 日志级别
|
|
29
38
|
"""
|
|
39
|
+
# 防止重复初始化
|
|
40
|
+
if self._initialized:
|
|
41
|
+
return
|
|
42
|
+
|
|
30
43
|
# 配置日志记录器
|
|
31
44
|
self.logger = logging.getLogger(logger_name)
|
|
32
45
|
if not self.logger.handlers:
|
|
@@ -38,10 +51,22 @@ class ConnectAgenterraIam(object):
|
|
|
38
51
|
self.logger.addHandler(handler)
|
|
39
52
|
self.logger.setLevel(log_level)
|
|
40
53
|
|
|
41
|
-
#
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
54
|
+
# 必须传入config参数,不再支持从环境变量读取
|
|
55
|
+
if config is None:
|
|
56
|
+
raise ValueError("必须传入AuthConfig配置对象,不再支持从环境变量读取配置")
|
|
57
|
+
|
|
58
|
+
self.agenterra_iam_host = config.agenterra_iam_host
|
|
59
|
+
self.server_name = config.server_name
|
|
60
|
+
self.access_key = config.access_key
|
|
61
|
+
self.logger.info("使用传入的AuthConfig配置")
|
|
62
|
+
|
|
63
|
+
# 验证必要的配置
|
|
64
|
+
if not self.agenterra_iam_host:
|
|
65
|
+
self.logger.warning("AGENTERRA_IAM_HOST 配置未设置")
|
|
66
|
+
if not self.server_name:
|
|
67
|
+
self.logger.warning("AGENTERRA_SERVER_NAME 配置未设置")
|
|
68
|
+
if not self.access_key:
|
|
69
|
+
self.logger.warning("AGENTERRA_ACCESS_KEY 配置未设置")
|
|
45
70
|
|
|
46
71
|
self.logger.info(f"初始化AgenterraIAM连接器 - Host: {self.agenterra_iam_host}, Server: {self._mask_sensitive(self.server_name)}")
|
|
47
72
|
|
|
@@ -54,6 +79,48 @@ class ConnectAgenterraIam(object):
|
|
|
54
79
|
"server_name": self.server_name,
|
|
55
80
|
"access_key": self.access_key
|
|
56
81
|
}
|
|
82
|
+
|
|
83
|
+
# 标记为已初始化
|
|
84
|
+
self._initialized = True
|
|
85
|
+
|
|
86
|
+
def reload_config(self, config):
|
|
87
|
+
"""
|
|
88
|
+
重新加载配置
|
|
89
|
+
用于在运行时更新配置
|
|
90
|
+
|
|
91
|
+
参数:
|
|
92
|
+
- config: AuthConfig配置对象
|
|
93
|
+
"""
|
|
94
|
+
if config is None:
|
|
95
|
+
raise ValueError("必须传入AuthConfig配置对象")
|
|
96
|
+
|
|
97
|
+
self.logger.info("重新加载配置")
|
|
98
|
+
|
|
99
|
+
# 更新配置
|
|
100
|
+
self.agenterra_iam_host = config.agenterra_iam_host
|
|
101
|
+
self.server_name = config.server_name
|
|
102
|
+
self.access_key = config.access_key
|
|
103
|
+
|
|
104
|
+
# 验证必要的配置
|
|
105
|
+
if not self.agenterra_iam_host:
|
|
106
|
+
self.logger.warning("AGENTERRA_IAM_HOST 配置未设置")
|
|
107
|
+
if not self.server_name:
|
|
108
|
+
self.logger.warning("AGENTERRA_SERVER_NAME 配置未设置")
|
|
109
|
+
if not self.access_key:
|
|
110
|
+
self.logger.warning("AGENTERRA_ACCESS_KEY 配置未设置")
|
|
111
|
+
|
|
112
|
+
# 更新headers和body
|
|
113
|
+
self.headers = {
|
|
114
|
+
"Content-Type": "application/json",
|
|
115
|
+
"SERVER-AK": self.server_name,
|
|
116
|
+
"SERVER-SK": self.access_key
|
|
117
|
+
}
|
|
118
|
+
self.body = {
|
|
119
|
+
"server_name": self.server_name,
|
|
120
|
+
"access_key": self.access_key
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
self.logger.info(f"配置重新加载完成 - Host: {self.agenterra_iam_host}, Server: {self._mask_sensitive(self.server_name)}")
|
|
57
124
|
|
|
58
125
|
def _mask_sensitive(self, value, mask_char="*", show_chars=4):
|
|
59
126
|
"""
|
|
@@ -430,6 +497,11 @@ class ConnectAgenterraIam(object):
|
|
|
430
497
|
"server_sk": server_sk,
|
|
431
498
|
}
|
|
432
499
|
uri = "/api/v2/service/verify"
|
|
500
|
+
|
|
501
|
+
# 检查agenterra_iam_host是否为None
|
|
502
|
+
if self.agenterra_iam_host is None:
|
|
503
|
+
raise ValueError("AGENTERRA_IAM_HOST 配置未设置或为空,请确保传入正确的AuthConfig对象")
|
|
504
|
+
|
|
433
505
|
url = self.agenterra_iam_host + uri
|
|
434
506
|
|
|
435
507
|
# 记录请求信息
|
|
@@ -463,7 +535,6 @@ class ConnectAgenterraIam(object):
|
|
|
463
535
|
else:
|
|
464
536
|
# token有效但无权限,抛出403异常
|
|
465
537
|
self.logger.warning(f"[{method_name}] token有效但用户无权限访问API: {api}")
|
|
466
|
-
from fastapi import HTTPException, status
|
|
467
538
|
raise HTTPException(
|
|
468
539
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
469
540
|
detail=result.get("message", "用户无权限访问此API")
|
|
@@ -475,7 +546,6 @@ class ConnectAgenterraIam(object):
|
|
|
475
546
|
result = response.json()
|
|
476
547
|
# 处理403响应
|
|
477
548
|
self.logger.warning(f"[{method_name}] 收到403响应 - {result.get('message', '用户无权限访问此API')}")
|
|
478
|
-
from fastapi import HTTPException, status
|
|
479
549
|
raise HTTPException(
|
|
480
550
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
481
551
|
detail=result.get("message", "用户无权限访问此API")
|
|
@@ -0,0 +1,418 @@
|
|
|
1
|
+
"""
|
|
2
|
+
SkyPlatform IAM SDK 中间件模块
|
|
3
|
+
"""
|
|
4
|
+
import logging
|
|
5
|
+
from typing import Optional, Callable, Dict, Any
|
|
6
|
+
from fastapi import Request, Response, HTTPException, status
|
|
7
|
+
from fastapi.responses import JSONResponse
|
|
8
|
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
9
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
|
10
|
+
import jwt
|
|
11
|
+
|
|
12
|
+
from .config import AuthConfig
|
|
13
|
+
from .connect_agenterra_iam import ConnectAgenterraIam
|
|
14
|
+
from .exceptions import (
|
|
15
|
+
AuthenticationError,
|
|
16
|
+
AuthorizationError,
|
|
17
|
+
ConfigurationError
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class AuthMiddleware(BaseHTTPMiddleware):
|
|
24
|
+
"""
|
|
25
|
+
认证中间件
|
|
26
|
+
自动拦截请求进行Token验证和权限检查
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
app,
|
|
32
|
+
config: AuthConfig,
|
|
33
|
+
skip_validation: Optional[Callable[[Request], bool]] = None
|
|
34
|
+
):
|
|
35
|
+
"""
|
|
36
|
+
初始化认证中间件
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
app: FastAPI应用实例
|
|
40
|
+
config: 认证配置
|
|
41
|
+
skip_validation: 自定义跳过验证的函数
|
|
42
|
+
"""
|
|
43
|
+
super().__init__(app)
|
|
44
|
+
self.config = config
|
|
45
|
+
self.iam_client = ConnectAgenterraIam(config=config)
|
|
46
|
+
self.skip_validation = skip_validation
|
|
47
|
+
|
|
48
|
+
# 验证配置
|
|
49
|
+
try:
|
|
50
|
+
self.config.validate_config()
|
|
51
|
+
except ValueError as e:
|
|
52
|
+
raise ConfigurationError(str(e))
|
|
53
|
+
|
|
54
|
+
def is_path_whitelisted(self, path: str) -> bool:
|
|
55
|
+
"""
|
|
56
|
+
检查路径是否在本地白名单中
|
|
57
|
+
"""
|
|
58
|
+
if not self.config:
|
|
59
|
+
return False
|
|
60
|
+
return self.config.is_path_whitelisted(path)
|
|
61
|
+
|
|
62
|
+
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
|
63
|
+
"""
|
|
64
|
+
中间件主要处理逻辑
|
|
65
|
+
"""
|
|
66
|
+
try:
|
|
67
|
+
# 获取请求路径
|
|
68
|
+
api_path = request.url.path
|
|
69
|
+
|
|
70
|
+
# 首先检查路径是否在本地白名单中
|
|
71
|
+
if self.is_path_whitelisted(api_path):
|
|
72
|
+
logger.info(f"路径 {api_path} 在本地白名单中,跳过认证直接允许访问")
|
|
73
|
+
# 设置白名单标识
|
|
74
|
+
request.state.user = None
|
|
75
|
+
request.state.authenticated = False
|
|
76
|
+
request.state.is_whitelist = True
|
|
77
|
+
# 直接调用下一个处理器
|
|
78
|
+
response = await call_next(request)
|
|
79
|
+
return response
|
|
80
|
+
|
|
81
|
+
# 提取Token(可能为空,白名单接口不需要token)
|
|
82
|
+
token = self._extract_token(request)
|
|
83
|
+
|
|
84
|
+
# 验证Token和权限(即使token为空也要调用IAM验证,因为可能是白名单接口)
|
|
85
|
+
user_info = await self._verify_token_and_permission(request, token)
|
|
86
|
+
if not user_info:
|
|
87
|
+
return self._create_error_response(
|
|
88
|
+
status_code=401,
|
|
89
|
+
message="Token验证失败",
|
|
90
|
+
detail="提供的Token无效或已过期"
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
# 检查是否为白名单接口
|
|
94
|
+
if user_info.get('is_whitelist', False):
|
|
95
|
+
# 白名单接口,允许访问但不设置用户信息
|
|
96
|
+
request.state.user = None
|
|
97
|
+
request.state.authenticated = False
|
|
98
|
+
request.state.is_whitelist = True
|
|
99
|
+
else:
|
|
100
|
+
# 正常认证接口,设置用户信息
|
|
101
|
+
request.state.user = user_info
|
|
102
|
+
request.state.authenticated = True
|
|
103
|
+
request.state.is_whitelist = False
|
|
104
|
+
|
|
105
|
+
# 继续处理请求
|
|
106
|
+
response = await call_next(request)
|
|
107
|
+
return response
|
|
108
|
+
|
|
109
|
+
except HTTPException as e:
|
|
110
|
+
return self._create_error_response(
|
|
111
|
+
status_code=e.status_code,
|
|
112
|
+
message=str(e.detail),
|
|
113
|
+
detail=getattr(e, 'detail', None)
|
|
114
|
+
)
|
|
115
|
+
except AuthenticationError as e:
|
|
116
|
+
return self._create_error_response(
|
|
117
|
+
status_code=e.status_code,
|
|
118
|
+
message=e.message,
|
|
119
|
+
detail=e.detail
|
|
120
|
+
)
|
|
121
|
+
except AuthorizationError as e:
|
|
122
|
+
return self._create_error_response(
|
|
123
|
+
status_code=e.status_code,
|
|
124
|
+
message=e.message,
|
|
125
|
+
detail=e.detail
|
|
126
|
+
)
|
|
127
|
+
except Exception as e:
|
|
128
|
+
logger.error(f"认证中间件处理异常: {str(e)}")
|
|
129
|
+
if self.config.enable_debug:
|
|
130
|
+
logger.exception("详细异常信息:")
|
|
131
|
+
|
|
132
|
+
return self._create_error_response(
|
|
133
|
+
status_code=500,
|
|
134
|
+
message="内部服务器错误",
|
|
135
|
+
detail=str(e) if self.config.enable_debug else None
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
def _extract_token(self, request: Request) -> Optional[str]:
|
|
139
|
+
"""
|
|
140
|
+
从请求中提取Token
|
|
141
|
+
"""
|
|
142
|
+
# 从Authorization头提取
|
|
143
|
+
auth_header = request.headers.get(self.config.token_header)
|
|
144
|
+
if auth_header and auth_header.startswith(self.config.token_prefix):
|
|
145
|
+
return auth_header[len(self.config.token_prefix):].strip()
|
|
146
|
+
|
|
147
|
+
# 从查询参数提取(备选方案)
|
|
148
|
+
token = request.query_params.get("token")
|
|
149
|
+
if token:
|
|
150
|
+
return token
|
|
151
|
+
|
|
152
|
+
return None
|
|
153
|
+
|
|
154
|
+
async def _verify_token_and_permission(self, request: Request, token: Optional[str]) -> Optional[Dict[str, Any]]:
|
|
155
|
+
"""
|
|
156
|
+
验证Token和权限
|
|
157
|
+
"""
|
|
158
|
+
try:
|
|
159
|
+
# 获取请求信息
|
|
160
|
+
api_path = request.url.path
|
|
161
|
+
method = request.method
|
|
162
|
+
|
|
163
|
+
# 从请求头获取服务认证信息(可选)
|
|
164
|
+
server_ak = request.headers.get("SERVER-AK", "")
|
|
165
|
+
server_sk = request.headers.get("SERVER-SK", "")
|
|
166
|
+
|
|
167
|
+
# 调用IAM验证接口(即使token为空也要调用,因为可能是白名单接口)
|
|
168
|
+
user_info = self.iam_client.verify_token(
|
|
169
|
+
token=token or "", # 如果token为None,传递空字符串
|
|
170
|
+
api=api_path,
|
|
171
|
+
method=method,
|
|
172
|
+
server_ak=server_ak,
|
|
173
|
+
server_sk=server_sk
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
return user_info
|
|
177
|
+
|
|
178
|
+
except HTTPException:
|
|
179
|
+
# 重新抛出HTTP异常
|
|
180
|
+
raise
|
|
181
|
+
except Exception as e:
|
|
182
|
+
logger.error(f"Token验证异常: {str(e)}")
|
|
183
|
+
if self.config.enable_debug:
|
|
184
|
+
logger.exception("详细异常信息:")
|
|
185
|
+
return None
|
|
186
|
+
|
|
187
|
+
def _create_error_response(
|
|
188
|
+
self,
|
|
189
|
+
status_code: int,
|
|
190
|
+
message: str,
|
|
191
|
+
detail: Optional[str] = None
|
|
192
|
+
) -> JSONResponse:
|
|
193
|
+
"""
|
|
194
|
+
创建错误响应
|
|
195
|
+
"""
|
|
196
|
+
error_data = {
|
|
197
|
+
"success": False,
|
|
198
|
+
"message": message,
|
|
199
|
+
"status_code": status_code
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
if detail:
|
|
203
|
+
error_data["detail"] = detail
|
|
204
|
+
|
|
205
|
+
return JSONResponse(
|
|
206
|
+
status_code=status_code,
|
|
207
|
+
content=error_data
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
class AuthService:
|
|
212
|
+
"""
|
|
213
|
+
认证服务类
|
|
214
|
+
提供依赖注入式的认证功能
|
|
215
|
+
"""
|
|
216
|
+
|
|
217
|
+
def __init__(self, auth_config: AuthConfig):
|
|
218
|
+
if auth_config is None:
|
|
219
|
+
raise ValueError("auth_config参数不能为None,必须传入AuthConfig配置对象")
|
|
220
|
+
self.security = HTTPBearer(auto_error=False)
|
|
221
|
+
self.iam_client = ConnectAgenterraIam(config=auth_config)
|
|
222
|
+
self.auth_config = auth_config
|
|
223
|
+
|
|
224
|
+
def is_path_whitelisted(self, path: str) -> bool:
|
|
225
|
+
"""
|
|
226
|
+
检查路径是否在白名单中
|
|
227
|
+
"""
|
|
228
|
+
if not self.auth_config:
|
|
229
|
+
return False
|
|
230
|
+
return self.auth_config.is_path_whitelisted(path)
|
|
231
|
+
|
|
232
|
+
async def verify_token(self, request: Request):
|
|
233
|
+
"""验证token和权限"""
|
|
234
|
+
# 通过token, server_ak, server_sk判断是否有权限
|
|
235
|
+
api_path = request.url.path
|
|
236
|
+
|
|
237
|
+
# 首先检查路径是否在白名单中
|
|
238
|
+
if self.is_path_whitelisted(api_path):
|
|
239
|
+
logger.info(f"路径 {api_path} 在白名单中,跳过IAM鉴权")
|
|
240
|
+
return True
|
|
241
|
+
|
|
242
|
+
credentials: HTTPAuthorizationCredentials = await self.security(request)
|
|
243
|
+
method = request.method
|
|
244
|
+
|
|
245
|
+
server_ak = request.headers.get("SERVER-AK", "")
|
|
246
|
+
server_sk = request.headers.get("SERVER-SK", "")
|
|
247
|
+
|
|
248
|
+
token = ""
|
|
249
|
+
if credentials is not None:
|
|
250
|
+
token = credentials.credentials
|
|
251
|
+
user_info_by_iam = self.iam_client.verify_token(token, api_path, method, server_ak, server_sk)
|
|
252
|
+
if user_info_by_iam:
|
|
253
|
+
return True
|
|
254
|
+
return False
|
|
255
|
+
|
|
256
|
+
async def get_current_user(self, request: Request) -> Optional[Dict]:
|
|
257
|
+
"""获取当前用户信息"""
|
|
258
|
+
try:
|
|
259
|
+
# 直接调用verify_token方法进行token验证
|
|
260
|
+
if not await self.verify_token(request):
|
|
261
|
+
return None
|
|
262
|
+
|
|
263
|
+
# 获取token用于后续用户信息获取
|
|
264
|
+
credentials: HTTPAuthorizationCredentials = await self.security(request)
|
|
265
|
+
if not credentials:
|
|
266
|
+
return None
|
|
267
|
+
|
|
268
|
+
token = credentials.credentials
|
|
269
|
+
|
|
270
|
+
# 直接解析JWT token获取payload
|
|
271
|
+
payload = self.decode_jwt_token(token)
|
|
272
|
+
if not payload:
|
|
273
|
+
logger.error("JWT token解析失败")
|
|
274
|
+
return None
|
|
275
|
+
|
|
276
|
+
# 从payload中提取用户信息
|
|
277
|
+
iam_user_id = payload.get("sub") # JWT标准中用户ID存储在sub字段
|
|
278
|
+
username = None
|
|
279
|
+
|
|
280
|
+
# 解析新的凭证信息结构
|
|
281
|
+
all_credentials = payload.get("all_credentials", [])
|
|
282
|
+
total_credentials = payload.get("total_credentials", 0)
|
|
283
|
+
|
|
284
|
+
# 从all_credentials中提取username(向后兼容)
|
|
285
|
+
for cred in all_credentials:
|
|
286
|
+
if cred.get("type") == "username":
|
|
287
|
+
username = cred.get("value")
|
|
288
|
+
break
|
|
289
|
+
|
|
290
|
+
# 向后兼容性:如果没有all_credentials,尝试从payload的其他字段构建
|
|
291
|
+
if not all_credentials:
|
|
292
|
+
credentials_list = []
|
|
293
|
+
# 检查payload中是否有直接的username字段
|
|
294
|
+
if payload.get("username"):
|
|
295
|
+
username = payload.get("username")
|
|
296
|
+
credentials_list.append({"type": "username", "value": username})
|
|
297
|
+
if payload.get("email"):
|
|
298
|
+
credentials_list.append({"type": "email", "value": payload.get("email")})
|
|
299
|
+
if payload.get("phone"):
|
|
300
|
+
credentials_list.append({"type": "phone", "value": payload.get("phone")})
|
|
301
|
+
all_credentials = credentials_list
|
|
302
|
+
total_credentials = len(credentials_list)
|
|
303
|
+
|
|
304
|
+
if not username:
|
|
305
|
+
return None
|
|
306
|
+
|
|
307
|
+
# 构建用户信息字典
|
|
308
|
+
user_info = {
|
|
309
|
+
"id": iam_user_id,
|
|
310
|
+
"username": username,
|
|
311
|
+
"all_credentials": all_credentials,
|
|
312
|
+
"total_credentials": total_credentials,
|
|
313
|
+
"microservice": payload.get("microservice") # 添加微服务信息
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
# 向后兼容:添加传统字段映射
|
|
317
|
+
for cred in all_credentials:
|
|
318
|
+
if cred.get("type") == "email":
|
|
319
|
+
user_info["email"] = cred.get("value")
|
|
320
|
+
elif cred.get("type") == "phone":
|
|
321
|
+
user_info["phone"] = cred.get("value")
|
|
322
|
+
elif cred.get("type") == "username" and not user_info.get("username"):
|
|
323
|
+
user_info["username"] = cred.get("value")
|
|
324
|
+
|
|
325
|
+
# 统计凭证类型分布
|
|
326
|
+
cred_types = [cred.get("type") for cred in all_credentials]
|
|
327
|
+
cred_type_count = {cred_type: cred_types.count(cred_type) for cred_type in set(cred_types)}
|
|
328
|
+
|
|
329
|
+
logger.info(
|
|
330
|
+
f"用户认证成功: user_id={iam_user_id}, username={username}, 凭证数量={total_credentials}, 凭证类型分布={cred_type_count}")
|
|
331
|
+
logger.debug(f"JWT payload: {payload}")
|
|
332
|
+
|
|
333
|
+
# 将用户信息添加到请求状态中
|
|
334
|
+
request.state.user = user_info
|
|
335
|
+
return user_info
|
|
336
|
+
|
|
337
|
+
except HTTPException as e:
|
|
338
|
+
logger.error(f"获取当前用户信息失败: {str(e)}")
|
|
339
|
+
# 重新抛出HTTP异常(403权限不足)
|
|
340
|
+
return None
|
|
341
|
+
except Exception as e:
|
|
342
|
+
logger.error(f"获取当前用户信息失败: {str(e)}")
|
|
343
|
+
return None
|
|
344
|
+
|
|
345
|
+
async def require_auth(self, request: Request) -> Dict:
|
|
346
|
+
"""要求用户必须登录"""
|
|
347
|
+
try:
|
|
348
|
+
user_info = await self.get_current_user(request)
|
|
349
|
+
if not user_info:
|
|
350
|
+
raise HTTPException(
|
|
351
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
352
|
+
detail="需要登录认证",
|
|
353
|
+
headers={"WWW-Authenticate": "Bearer"},
|
|
354
|
+
)
|
|
355
|
+
return user_info
|
|
356
|
+
except HTTPException:
|
|
357
|
+
# 重新抛出HTTP异常(可能是403权限不足或401未认证)
|
|
358
|
+
raise
|
|
359
|
+
|
|
360
|
+
async def optional_auth(self, request: Request) -> Optional[Dict]:
|
|
361
|
+
"""可选的用户认证(不强制要求登录)"""
|
|
362
|
+
try:
|
|
363
|
+
return await self.get_current_user(request)
|
|
364
|
+
except HTTPException:
|
|
365
|
+
# 对于可选认证,如果是403权限不足,仍然抛出异常
|
|
366
|
+
# 如果是401未认证,返回None
|
|
367
|
+
raise
|
|
368
|
+
|
|
369
|
+
def decode_jwt_token(self, token: str) -> Optional[Dict]:
|
|
370
|
+
"""直接解析JWT token获取payload"""
|
|
371
|
+
try:
|
|
372
|
+
# 不验证签名,只解析payload(因为token已经通过verify_token验证过)
|
|
373
|
+
decoded_payload = jwt.decode(token, options={"verify_signature": False})
|
|
374
|
+
logger.debug(f"JWT token解析成功: {decoded_payload}")
|
|
375
|
+
return decoded_payload
|
|
376
|
+
except jwt.InvalidTokenError as e:
|
|
377
|
+
logger.error(f"JWT token解析失败: {str(e)}")
|
|
378
|
+
return None
|
|
379
|
+
except Exception as e:
|
|
380
|
+
logger.error(f"JWT token解析异常: {str(e)}")
|
|
381
|
+
return None
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
# 全局认证服务实例(延迟初始化)
|
|
385
|
+
auth_service = None
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
def setup_auth_middleware(auth_config: AuthConfig) -> None:
|
|
389
|
+
"""
|
|
390
|
+
设置认证中间件配置
|
|
391
|
+
|
|
392
|
+
Args:
|
|
393
|
+
auth_config: 认证配置实例,包含白名单路径等配置
|
|
394
|
+
"""
|
|
395
|
+
global auth_service
|
|
396
|
+
auth_service = AuthService(auth_config)
|
|
397
|
+
logger.info(f"认证中间件已配置,白名单路径数量: {len(auth_config.get_whitelist_paths())}")
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
# 便捷的依赖函数
|
|
401
|
+
async def get_current_user(request: Request) -> Dict:
|
|
402
|
+
"""获取当前用户的依赖函数"""
|
|
403
|
+
if auth_service is None:
|
|
404
|
+
raise HTTPException(
|
|
405
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
406
|
+
detail="认证服务未初始化,请先调用setup_auth_middleware函数进行配置"
|
|
407
|
+
)
|
|
408
|
+
return await auth_service.require_auth(request)
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
async def get_optional_user(request: Request) -> Optional[Dict]:
|
|
412
|
+
"""获取可选当前用户的依赖函数"""
|
|
413
|
+
if auth_service is None:
|
|
414
|
+
raise HTTPException(
|
|
415
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
416
|
+
detail="认证服务未初始化,请先调用setup_auth_middleware函数进行配置"
|
|
417
|
+
)
|
|
418
|
+
return await auth_service.optional_auth(request)
|
|
@@ -1,173 +0,0 @@
|
|
|
1
|
-
from fastapi import Request, HTTPException, status
|
|
2
|
-
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
3
|
-
from typing import Optional, Dict
|
|
4
|
-
import jwt
|
|
5
|
-
|
|
6
|
-
from .connect_agenterra_iam import ConnectAgenterraIam
|
|
7
|
-
import logging
|
|
8
|
-
|
|
9
|
-
logger = logging.getLogger(__name__)
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class AuthMiddleware:
|
|
13
|
-
def __init__(self):
|
|
14
|
-
self.security = HTTPBearer(auto_error=False)
|
|
15
|
-
self.iam_client = ConnectAgenterraIam()
|
|
16
|
-
|
|
17
|
-
async def verify_token(self, request: Request):
|
|
18
|
-
# 通过token, server_ak, server_sk判断是否有权限
|
|
19
|
-
credentials: HTTPAuthorizationCredentials = await self.security(request)
|
|
20
|
-
api_path = request.url.path
|
|
21
|
-
method = request.method
|
|
22
|
-
|
|
23
|
-
server_ak = request.headers.get("SERVER-AK", "")
|
|
24
|
-
server_sk = request.headers.get("SERVER-SK", "")
|
|
25
|
-
|
|
26
|
-
token = ""
|
|
27
|
-
if credentials is not None:
|
|
28
|
-
token = credentials.credentials
|
|
29
|
-
user_info_by_iam = self.iam_client.verify_token(token, api_path, method, server_ak, server_sk)
|
|
30
|
-
if user_info_by_iam:
|
|
31
|
-
return True
|
|
32
|
-
return False
|
|
33
|
-
|
|
34
|
-
async def get_current_user(self, request: Request) -> Optional[Dict]:
|
|
35
|
-
"""获取当前用户信息"""
|
|
36
|
-
try:
|
|
37
|
-
# 直接调用verify_token方法进行token验证
|
|
38
|
-
if not await self.verify_token(request):
|
|
39
|
-
return None
|
|
40
|
-
|
|
41
|
-
# 获取token用于后续用户信息获取
|
|
42
|
-
credentials: HTTPAuthorizationCredentials = await self.security(request)
|
|
43
|
-
token = credentials.credentials
|
|
44
|
-
|
|
45
|
-
# 直接解析JWT token获取payload
|
|
46
|
-
payload = self.decode_jwt_token(token)
|
|
47
|
-
if not payload:
|
|
48
|
-
logger.error("JWT token解析失败")
|
|
49
|
-
return None
|
|
50
|
-
|
|
51
|
-
# 从payload中提取用户信息
|
|
52
|
-
iam_user_id = payload.get("sub") # JWT标准中用户ID存储在sub字段
|
|
53
|
-
username = None
|
|
54
|
-
|
|
55
|
-
# 解析新的凭证信息结构
|
|
56
|
-
all_credentials = payload.get("all_credentials", [])
|
|
57
|
-
total_credentials = payload.get("total_credentials", 0)
|
|
58
|
-
|
|
59
|
-
# 从all_credentials中提取username(向后兼容)
|
|
60
|
-
for cred in all_credentials:
|
|
61
|
-
if cred.get("type") == "username":
|
|
62
|
-
username = cred.get("value")
|
|
63
|
-
break
|
|
64
|
-
|
|
65
|
-
# 向后兼容性:如果没有all_credentials,尝试从payload的其他字段构建
|
|
66
|
-
if not all_credentials:
|
|
67
|
-
credentials_list = []
|
|
68
|
-
# 检查payload中是否有直接的username字段
|
|
69
|
-
if payload.get("username"):
|
|
70
|
-
username = payload.get("username")
|
|
71
|
-
credentials_list.append({"type": "username", "value": username})
|
|
72
|
-
if payload.get("email"):
|
|
73
|
-
credentials_list.append({"type": "email", "value": payload.get("email")})
|
|
74
|
-
if payload.get("phone"):
|
|
75
|
-
credentials_list.append({"type": "phone", "value": payload.get("phone")})
|
|
76
|
-
all_credentials = credentials_list
|
|
77
|
-
total_credentials = len(credentials_list)
|
|
78
|
-
|
|
79
|
-
if not username:
|
|
80
|
-
return None
|
|
81
|
-
|
|
82
|
-
# 构建用户信息字典
|
|
83
|
-
user_info = {
|
|
84
|
-
"id": iam_user_id,
|
|
85
|
-
"username": username,
|
|
86
|
-
"all_credentials": all_credentials,
|
|
87
|
-
"total_credentials": total_credentials,
|
|
88
|
-
"microservice": payload.get("microservice") # 添加微服务信息
|
|
89
|
-
}
|
|
90
|
-
|
|
91
|
-
# 向后兼容:添加传统字段映射
|
|
92
|
-
for cred in all_credentials:
|
|
93
|
-
if cred.get("type") == "email":
|
|
94
|
-
user_info["email"] = cred.get("value")
|
|
95
|
-
elif cred.get("type") == "phone":
|
|
96
|
-
user_info["phone"] = cred.get("value")
|
|
97
|
-
elif cred.get("type") == "username" and not user_info.get("username"):
|
|
98
|
-
user_info["username"] = cred.get("value")
|
|
99
|
-
|
|
100
|
-
# 统计凭证类型分布
|
|
101
|
-
cred_types = [cred.get("type") for cred in all_credentials]
|
|
102
|
-
cred_type_count = {cred_type: cred_types.count(cred_type) for cred_type in set(cred_types)}
|
|
103
|
-
|
|
104
|
-
logger.info(
|
|
105
|
-
f"用户认证成功: user_id={iam_user_id}, username={username}, 凭证数量={total_credentials}, 凭证类型分布={cred_type_count}")
|
|
106
|
-
logger.debug(f"JWT payload: {payload}")
|
|
107
|
-
|
|
108
|
-
# 将用户信息添加到请求状态中
|
|
109
|
-
request.state.user = user_info
|
|
110
|
-
return user_info
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
except HTTPException as e:
|
|
114
|
-
print(403)
|
|
115
|
-
logger.error(f"获取当前用户信息失败: {str(e)}")
|
|
116
|
-
# 重新抛出HTTP异常(403权限不足)
|
|
117
|
-
return None
|
|
118
|
-
except Exception as e:
|
|
119
|
-
logger.error(f"获取当前用户信息失败: {str(e)}")
|
|
120
|
-
return None
|
|
121
|
-
|
|
122
|
-
async def require_auth(self, request: Request) -> Dict:
|
|
123
|
-
"""要求用户必须登录"""
|
|
124
|
-
try:
|
|
125
|
-
user_info = await self.get_current_user(request)
|
|
126
|
-
if not user_info:
|
|
127
|
-
raise HTTPException(
|
|
128
|
-
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
129
|
-
detail="需要登录认证",
|
|
130
|
-
headers={"WWW-Authenticate": "Bearer"},
|
|
131
|
-
)
|
|
132
|
-
return user_info
|
|
133
|
-
except HTTPException:
|
|
134
|
-
# 重新抛出HTTP异常(可能是403权限不足或401未认证)
|
|
135
|
-
raise
|
|
136
|
-
|
|
137
|
-
async def optional_auth(self, request: Request) -> Optional[Dict]:
|
|
138
|
-
"""可选的用户认证(不强制要求登录)"""
|
|
139
|
-
try:
|
|
140
|
-
return await self.get_current_user(request)
|
|
141
|
-
except HTTPException:
|
|
142
|
-
# 对于可选认证,如果是403权限不足,仍然抛出异常
|
|
143
|
-
# 如果是401未认证,返回None
|
|
144
|
-
raise
|
|
145
|
-
|
|
146
|
-
def decode_jwt_token(self, token: str) -> Optional[Dict]:
|
|
147
|
-
"""直接解析JWT token获取payload"""
|
|
148
|
-
try:
|
|
149
|
-
# 不验证签名,只解析payload(因为token已经通过verify_token验证过)
|
|
150
|
-
decoded_payload = jwt.decode(token, options={"verify_signature": False})
|
|
151
|
-
logger.debug(f"JWT token解析成功: {decoded_payload}")
|
|
152
|
-
return decoded_payload
|
|
153
|
-
except jwt.InvalidTokenError as e:
|
|
154
|
-
logger.error(f"JWT token解析失败: {str(e)}")
|
|
155
|
-
return None
|
|
156
|
-
except Exception as e:
|
|
157
|
-
logger.error(f"JWT token解析异常: {str(e)}")
|
|
158
|
-
return None
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
# 创建全局认证中间件实例
|
|
162
|
-
auth_middleware = AuthMiddleware()
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
# 便捷的依赖函数
|
|
166
|
-
async def get_current_user(request: Request) -> Dict:
|
|
167
|
-
"""获取当前用户的依赖函数"""
|
|
168
|
-
return await auth_middleware.require_auth(request)
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
async def get_optional_user(request: Request) -> Optional[Dict]:
|
|
172
|
-
"""获取可选当前用户的依赖函数"""
|
|
173
|
-
return await auth_middleware.optional_auth(request)
|
|
@@ -1,68 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
SkyPlatform IAM SDK 配置模块
|
|
3
|
-
"""
|
|
4
|
-
import os
|
|
5
|
-
from typing import Optional, List
|
|
6
|
-
from pydantic import BaseModel
|
|
7
|
-
from dotenv import load_dotenv
|
|
8
|
-
|
|
9
|
-
# 加载环境变量
|
|
10
|
-
load_dotenv()
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class AuthConfig(BaseModel):
|
|
14
|
-
"""
|
|
15
|
-
认证配置类
|
|
16
|
-
支持环境变量和代码配置
|
|
17
|
-
"""
|
|
18
|
-
# IAM服务配置
|
|
19
|
-
agenterra_iam_host: str
|
|
20
|
-
server_name: str
|
|
21
|
-
access_key: str
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
# Token配置
|
|
25
|
-
token_header: str = "Authorization"
|
|
26
|
-
token_prefix: str = "Bearer "
|
|
27
|
-
|
|
28
|
-
# 错误处理配置
|
|
29
|
-
enable_debug: bool = False
|
|
30
|
-
|
|
31
|
-
class Config:
|
|
32
|
-
env_prefix = "AGENTERRA_"
|
|
33
|
-
|
|
34
|
-
@classmethod
|
|
35
|
-
def from_env(cls) -> "AuthConfig":
|
|
36
|
-
"""
|
|
37
|
-
从环境变量创建配置
|
|
38
|
-
"""
|
|
39
|
-
return cls(
|
|
40
|
-
agenterra_iam_host=os.environ.get('AGENTERRA_IAM_HOST', ''),
|
|
41
|
-
server_name=os.environ.get('AGENTERRA_SERVER_NAME', ''),
|
|
42
|
-
access_key=os.environ.get('AGENTERRA_ACCESS_KEY', ''),
|
|
43
|
-
enable_debug=os.environ.get('AGENTERRA_ENABLE_DEBUG', 'false').lower() == 'true'
|
|
44
|
-
)
|
|
45
|
-
|
|
46
|
-
def validate_config(self) -> bool:
|
|
47
|
-
"""
|
|
48
|
-
验证配置是否完整
|
|
49
|
-
"""
|
|
50
|
-
required_fields = ['agenterra_iam_host', 'server_name', 'access_key']
|
|
51
|
-
for field in required_fields:
|
|
52
|
-
if not getattr(self, field):
|
|
53
|
-
raise ValueError(f"配置项 {field} 不能为空")
|
|
54
|
-
return True
|
|
55
|
-
|
|
56
|
-
def add_whitelist_path(self, path: str) -> None:
|
|
57
|
-
"""
|
|
58
|
-
添加白名单路径
|
|
59
|
-
"""
|
|
60
|
-
if path not in self.whitelist_paths:
|
|
61
|
-
self.whitelist_paths.append(path)
|
|
62
|
-
|
|
63
|
-
def remove_whitelist_path(self, path: str) -> None:
|
|
64
|
-
"""
|
|
65
|
-
移除白名单路径
|
|
66
|
-
"""
|
|
67
|
-
if path in self.whitelist_paths:
|
|
68
|
-
self.whitelist_paths.remove(path)
|
|
@@ -1,186 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
SkyPlatform IAM SDK 中间件模块
|
|
3
|
-
"""
|
|
4
|
-
import logging
|
|
5
|
-
from typing import Optional, Callable, Dict, Any
|
|
6
|
-
from fastapi import Request, Response, HTTPException
|
|
7
|
-
from fastapi.responses import JSONResponse
|
|
8
|
-
from starlette.middleware.base import BaseHTTPMiddleware
|
|
9
|
-
|
|
10
|
-
from .config import AuthConfig
|
|
11
|
-
from .connect_agenterra_iam import ConnectAgenterraIam
|
|
12
|
-
from .exceptions import (
|
|
13
|
-
AuthenticationError,
|
|
14
|
-
AuthorizationError,
|
|
15
|
-
ConfigurationError
|
|
16
|
-
)
|
|
17
|
-
|
|
18
|
-
logger = logging.getLogger(__name__)
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class AuthMiddleware(BaseHTTPMiddleware):
|
|
22
|
-
"""
|
|
23
|
-
认证中间件
|
|
24
|
-
自动拦截请求进行Token验证和权限检查
|
|
25
|
-
"""
|
|
26
|
-
|
|
27
|
-
def __init__(
|
|
28
|
-
self,
|
|
29
|
-
app,
|
|
30
|
-
config: AuthConfig,
|
|
31
|
-
skip_validation: Optional[Callable[[Request], bool]] = None
|
|
32
|
-
):
|
|
33
|
-
"""
|
|
34
|
-
初始化认证中间件
|
|
35
|
-
|
|
36
|
-
Args:
|
|
37
|
-
app: FastAPI应用实例
|
|
38
|
-
config: 认证配置
|
|
39
|
-
skip_validation: 自定义跳过验证的函数
|
|
40
|
-
"""
|
|
41
|
-
super().__init__(app)
|
|
42
|
-
self.config = config
|
|
43
|
-
self.iam_client = ConnectAgenterraIam()
|
|
44
|
-
self.skip_validation = skip_validation
|
|
45
|
-
|
|
46
|
-
# 验证配置
|
|
47
|
-
try:
|
|
48
|
-
self.config.validate_config()
|
|
49
|
-
except ValueError as e:
|
|
50
|
-
raise ConfigurationError(str(e))
|
|
51
|
-
|
|
52
|
-
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
|
53
|
-
"""
|
|
54
|
-
中间件主要处理逻辑
|
|
55
|
-
"""
|
|
56
|
-
try:
|
|
57
|
-
|
|
58
|
-
# 提取Token(可能为空,白名单接口不需要token)
|
|
59
|
-
token = self._extract_token(request)
|
|
60
|
-
|
|
61
|
-
# 验证Token和权限(即使token为空也要调用IAM验证,因为可能是白名单接口)
|
|
62
|
-
user_info = await self._verify_token_and_permission(request, token)
|
|
63
|
-
if not user_info:
|
|
64
|
-
return self._create_error_response(
|
|
65
|
-
status_code=401,
|
|
66
|
-
message="Token验证失败",
|
|
67
|
-
detail="提供的Token无效或已过期"
|
|
68
|
-
)
|
|
69
|
-
|
|
70
|
-
# 检查是否为白名单接口
|
|
71
|
-
if user_info.get('is_whitelist', False):
|
|
72
|
-
# 白名单接口,允许访问但不设置用户信息
|
|
73
|
-
request.state.user = None
|
|
74
|
-
request.state.authenticated = False
|
|
75
|
-
request.state.is_whitelist = True
|
|
76
|
-
else:
|
|
77
|
-
# 正常认证接口,设置用户信息
|
|
78
|
-
request.state.user = user_info
|
|
79
|
-
request.state.authenticated = True
|
|
80
|
-
request.state.is_whitelist = False
|
|
81
|
-
|
|
82
|
-
# 继续处理请求
|
|
83
|
-
response = await call_next(request)
|
|
84
|
-
return response
|
|
85
|
-
|
|
86
|
-
except HTTPException as e:
|
|
87
|
-
# FastAPI HTTPException直接返回
|
|
88
|
-
return self._create_error_response(
|
|
89
|
-
status_code=e.status_code,
|
|
90
|
-
message=str(e.detail),
|
|
91
|
-
detail=getattr(e, 'detail', None)
|
|
92
|
-
)
|
|
93
|
-
except AuthenticationError as e:
|
|
94
|
-
return self._create_error_response(
|
|
95
|
-
status_code=e.status_code,
|
|
96
|
-
message=e.message,
|
|
97
|
-
detail=e.detail
|
|
98
|
-
)
|
|
99
|
-
except AuthorizationError as e:
|
|
100
|
-
return self._create_error_response(
|
|
101
|
-
status_code=e.status_code,
|
|
102
|
-
message=e.message,
|
|
103
|
-
detail=e.detail
|
|
104
|
-
)
|
|
105
|
-
except Exception as e:
|
|
106
|
-
logger.error(f"认证中间件处理异常: {str(e)}")
|
|
107
|
-
if self.config.enable_debug:
|
|
108
|
-
logger.exception("详细异常信息:")
|
|
109
|
-
|
|
110
|
-
return self._create_error_response(
|
|
111
|
-
status_code=500,
|
|
112
|
-
message="内部服务器错误",
|
|
113
|
-
detail=str(e) if self.config.enable_debug else None
|
|
114
|
-
)
|
|
115
|
-
|
|
116
|
-
def _extract_token(self, request: Request) -> Optional[str]:
|
|
117
|
-
"""
|
|
118
|
-
从请求中提取Token
|
|
119
|
-
"""
|
|
120
|
-
# 从Authorization头提取
|
|
121
|
-
auth_header = request.headers.get(self.config.token_header)
|
|
122
|
-
if auth_header and auth_header.startswith(self.config.token_prefix):
|
|
123
|
-
return auth_header[len(self.config.token_prefix):].strip()
|
|
124
|
-
|
|
125
|
-
# 从查询参数提取(备选方案)
|
|
126
|
-
token = request.query_params.get("token")
|
|
127
|
-
if token:
|
|
128
|
-
return token
|
|
129
|
-
|
|
130
|
-
return None
|
|
131
|
-
|
|
132
|
-
async def _verify_token_and_permission(self, request: Request, token: Optional[str]) -> Optional[Dict[str, Any]]:
|
|
133
|
-
"""
|
|
134
|
-
验证Token和权限
|
|
135
|
-
"""
|
|
136
|
-
try:
|
|
137
|
-
# 获取请求信息
|
|
138
|
-
api_path = request.url.path
|
|
139
|
-
method = request.method
|
|
140
|
-
|
|
141
|
-
# 从请求头获取服务认证信息(可选)
|
|
142
|
-
server_ak = request.headers.get("SERVER-AK", "")
|
|
143
|
-
server_sk = request.headers.get("SERVER-SK", "")
|
|
144
|
-
|
|
145
|
-
# 调用IAM验证接口(即使token为空也要调用,因为可能是白名单接口)
|
|
146
|
-
user_info = self.iam_client.verify_token(
|
|
147
|
-
token=token or "", # 如果token为None,传递空字符串
|
|
148
|
-
api=api_path,
|
|
149
|
-
method=method,
|
|
150
|
-
server_ak=server_ak,
|
|
151
|
-
server_sk=server_sk
|
|
152
|
-
)
|
|
153
|
-
|
|
154
|
-
return user_info
|
|
155
|
-
|
|
156
|
-
except HTTPException:
|
|
157
|
-
# 重新抛出HTTP异常
|
|
158
|
-
raise
|
|
159
|
-
except Exception as e:
|
|
160
|
-
logger.error(f"Token验证异常: {str(e)}")
|
|
161
|
-
if self.config.enable_debug:
|
|
162
|
-
logger.exception("详细异常信息:")
|
|
163
|
-
return None
|
|
164
|
-
|
|
165
|
-
def _create_error_response(
|
|
166
|
-
self,
|
|
167
|
-
status_code: int,
|
|
168
|
-
message: str,
|
|
169
|
-
detail: Optional[str] = None
|
|
170
|
-
) -> JSONResponse:
|
|
171
|
-
"""
|
|
172
|
-
创建错误响应
|
|
173
|
-
"""
|
|
174
|
-
error_data = {
|
|
175
|
-
"success": False,
|
|
176
|
-
"message": message,
|
|
177
|
-
"status_code": status_code
|
|
178
|
-
}
|
|
179
|
-
|
|
180
|
-
if detail:
|
|
181
|
-
error_data["detail"] = detail
|
|
182
|
-
|
|
183
|
-
return JSONResponse(
|
|
184
|
-
status_code=status_code,
|
|
185
|
-
content=error_data
|
|
186
|
-
)
|
|
File without changes
|
|
File without changes
|