agent-api-server 2.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent_api_server/__init__.py +0 -0
- agent_api_server/api/__init__.py +0 -0
- agent_api_server/api/v1/__init__.py +0 -0
- agent_api_server/api/v1/api.py +25 -0
- agent_api_server/api/v1/config.py +57 -0
- agent_api_server/api/v1/graph.py +59 -0
- agent_api_server/api/v1/schema.py +57 -0
- agent_api_server/api/v1/thread.py +563 -0
- agent_api_server/cache/__init__.py +0 -0
- agent_api_server/cache/redis_cache.py +385 -0
- agent_api_server/callback_handler.py +18 -0
- agent_api_server/client/css/styles.css +1202 -0
- agent_api_server/client/favicon.ico +0 -0
- agent_api_server/client/index.html +102 -0
- agent_api_server/client/js/app.js +1499 -0
- agent_api_server/client/js/index.umd.js +824 -0
- agent_api_server/config_center/config_center.py +239 -0
- agent_api_server/configs/__init__.py +3 -0
- agent_api_server/configs/config.py +163 -0
- agent_api_server/dynamic_llm/__init__.py +0 -0
- agent_api_server/dynamic_llm/dynamic_llm.py +331 -0
- agent_api_server/listener.py +530 -0
- agent_api_server/log/__init__.py +0 -0
- agent_api_server/log/formatters.py +122 -0
- agent_api_server/log/logging.json +50 -0
- agent_api_server/mcp_convert/__init__.py +0 -0
- agent_api_server/mcp_convert/mcp_convert.py +375 -0
- agent_api_server/memeory/__init__.py +0 -0
- agent_api_server/memeory/postgres.py +233 -0
- agent_api_server/register/__init__.py +0 -0
- agent_api_server/register/register.py +65 -0
- agent_api_server/service.py +354 -0
- agent_api_server/service_hub/service_hub.py +233 -0
- agent_api_server/service_hub/service_hub_test.py +700 -0
- agent_api_server/shared/__init__.py +0 -0
- agent_api_server/shared/ase.py +54 -0
- agent_api_server/shared/base_model.py +103 -0
- agent_api_server/shared/common.py +110 -0
- agent_api_server/shared/decode_token.py +107 -0
- agent_api_server/shared/detect_message.py +410 -0
- agent_api_server/shared/get_model_info.py +491 -0
- agent_api_server/shared/message.py +419 -0
- agent_api_server/shared/util_func.py +372 -0
- agent_api_server/sso_service/__init__.py +1 -0
- agent_api_server/sso_service/sdk/__init__.py +1 -0
- agent_api_server/sso_service/sdk/client.py +224 -0
- agent_api_server/sso_service/sdk/credential.py +11 -0
- agent_api_server/sso_service/sdk/encoding.py +22 -0
- agent_api_server/sso_service/sso_service.py +177 -0
- agent_api_server-2.1.7.dist-info/METADATA +130 -0
- agent_api_server-2.1.7.dist-info/RECORD +52 -0
- agent_api_server-2.1.7.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,372 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Dict
|
|
3
|
+
from importlib import import_module
|
|
4
|
+
from fastapi import HTTPException
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
import aiofiles
|
|
9
|
+
import asyncio
|
|
10
|
+
import sys
|
|
11
|
+
from typing import List, Any
|
|
12
|
+
from agent_api_server.memeory.postgres import AsyncPostgresCheckpointer
|
|
13
|
+
from langgraph.pregel import Pregel
|
|
14
|
+
|
|
15
|
+
from agent_api_server.shared.common import ConfigCategory
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
BASE_DIR = Path.cwd()
|
|
20
|
+
DATA_DIR = BASE_DIR / 'data'
|
|
21
|
+
DATA_DIR.mkdir(exist_ok=True, mode=0o700)
|
|
22
|
+
LLM_CFG_ENV = DATA_DIR /'.llm_cfg_env'
|
|
23
|
+
|
|
24
|
+
async def get_all_graph_names() -> List[str]:
|
|
25
|
+
"""
|
|
26
|
+
获取配置文件中定义的所有graph名称
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
List[str]: 所有可用的graph名称列表
|
|
30
|
+
|
|
31
|
+
Raises:
|
|
32
|
+
HTTPException: 如果配置文件加载失败
|
|
33
|
+
"""
|
|
34
|
+
try:
|
|
35
|
+
config = await load_graph_config()
|
|
36
|
+
if not config.get("graphs"):
|
|
37
|
+
return []
|
|
38
|
+
return list(config["graphs"].keys())
|
|
39
|
+
except Exception as e:
|
|
40
|
+
raise HTTPException(
|
|
41
|
+
status_code=500,
|
|
42
|
+
detail={
|
|
43
|
+
"error": "config_loading_failed",
|
|
44
|
+
"message": f"Cannot get graph names: {str(e)}"
|
|
45
|
+
}
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
async def load_graph(graph_name: str, config: Dict, with_checker: bool) -> tuple[str, Pregel, None] | tuple[
|
|
50
|
+
str, Pregel, AsyncPostgresCheckpointer]:
|
|
51
|
+
if not config.get("graphs") or graph_name not in config["graphs"]:
|
|
52
|
+
available = list(config["graphs"].keys()) if "graphs" in config else []
|
|
53
|
+
raise HTTPException(
|
|
54
|
+
status_code=404,
|
|
55
|
+
detail={
|
|
56
|
+
"error": "graph_not_found",
|
|
57
|
+
"available": available,
|
|
58
|
+
"message": f"Graph '{graph_name}' not found in config"
|
|
59
|
+
}
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
graph_path = config["graphs"][graph_name]
|
|
63
|
+
try:
|
|
64
|
+
module_path, attr_name = graph_path.rsplit(":", 1)
|
|
65
|
+
if not module_path or not attr_name:
|
|
66
|
+
raise ValueError("Empty module path or attribute name")
|
|
67
|
+
except ValueError as e:
|
|
68
|
+
raise HTTPException(
|
|
69
|
+
status_code=400,
|
|
70
|
+
detail={
|
|
71
|
+
"error": "invalid_path_format",
|
|
72
|
+
"expected": "path/to/module.py:attribute_name",
|
|
73
|
+
"actual": graph_path,
|
|
74
|
+
"message": str(e)
|
|
75
|
+
}
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
try:
|
|
79
|
+
module_path = module_path.replace("./", "").replace("\\", "/")
|
|
80
|
+
if module_path.endswith(".py"):
|
|
81
|
+
module_path = module_path[:-3]
|
|
82
|
+
import_path = module_path.replace("/", ".")
|
|
83
|
+
except Exception as e:
|
|
84
|
+
raise HTTPException(
|
|
85
|
+
status_code=400,
|
|
86
|
+
detail={
|
|
87
|
+
"error": "path_processing_failed",
|
|
88
|
+
"message": f"Cannot process path '{module_path}': {str(e)}"
|
|
89
|
+
}
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
try:
|
|
93
|
+
module_dir = str(Path(module_path).parent.resolve())
|
|
94
|
+
if module_dir not in sys.path:
|
|
95
|
+
sys.path.insert(0, module_dir)
|
|
96
|
+
|
|
97
|
+
module = await asyncio.to_thread(import_module, import_path)
|
|
98
|
+
except ModuleNotFoundError as e:
|
|
99
|
+
raise HTTPException(
|
|
100
|
+
status_code=500,
|
|
101
|
+
detail={
|
|
102
|
+
"error": "module_not_found",
|
|
103
|
+
"path": import_path,
|
|
104
|
+
"sys_path": str(sys.path),
|
|
105
|
+
"message": f"Module not found: {str(e)}"
|
|
106
|
+
}
|
|
107
|
+
)
|
|
108
|
+
except Exception as e:
|
|
109
|
+
raise HTTPException(
|
|
110
|
+
status_code=500,
|
|
111
|
+
detail={
|
|
112
|
+
"error": "module_import_error",
|
|
113
|
+
"path": import_path,
|
|
114
|
+
"message": f"Error importing module: {str(e)}"
|
|
115
|
+
}
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
# 5. 获取graph实例 - 更健壮的实例化
|
|
119
|
+
try:
|
|
120
|
+
graph_factory = getattr(module, attr_name)
|
|
121
|
+
|
|
122
|
+
# 处理可能是协程或工厂函数的情况
|
|
123
|
+
if asyncio.iscoroutinefunction(graph_factory):
|
|
124
|
+
graph = await graph_factory()
|
|
125
|
+
elif callable(graph_factory):
|
|
126
|
+
graph = graph_factory()
|
|
127
|
+
else:
|
|
128
|
+
graph = graph_factory
|
|
129
|
+
|
|
130
|
+
# 确保是 Pregel 实例
|
|
131
|
+
if not isinstance(graph, Pregel):
|
|
132
|
+
raise TypeError(f"Expected Pregel instance, got {type(graph)}")
|
|
133
|
+
|
|
134
|
+
if with_checker:
|
|
135
|
+
checkpointer = AsyncPostgresCheckpointer.get_worker_instance()
|
|
136
|
+
try:
|
|
137
|
+
graph.checkpointer = await checkpointer.checkpointer()
|
|
138
|
+
except Exception as e:
|
|
139
|
+
raise RuntimeError(f"Failed to initialize checkpointer: {str(e)}")
|
|
140
|
+
|
|
141
|
+
return graph_name, graph, checkpointer
|
|
142
|
+
|
|
143
|
+
return graph_name, graph,None
|
|
144
|
+
except AttributeError:
|
|
145
|
+
raise HTTPException(
|
|
146
|
+
status_code=500,
|
|
147
|
+
detail={
|
|
148
|
+
"error": "missing_attribute",
|
|
149
|
+
"module": import_path,
|
|
150
|
+
"attribute": attr_name,
|
|
151
|
+
"available": dir(module)
|
|
152
|
+
}
|
|
153
|
+
)
|
|
154
|
+
except Exception as e:
|
|
155
|
+
raise HTTPException(
|
|
156
|
+
status_code=500,
|
|
157
|
+
detail={
|
|
158
|
+
"error": "graph_initialization_failed",
|
|
159
|
+
"message": f"Failed to initialize graph: {str(e)}",
|
|
160
|
+
"type": type(e).__name__
|
|
161
|
+
}
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def resolve_config_path(file_name: str) -> Path:
|
|
166
|
+
search_paths = [
|
|
167
|
+
Path.cwd() / file_name,
|
|
168
|
+
Path(__file__).parent.parent / file_name,
|
|
169
|
+
]
|
|
170
|
+
|
|
171
|
+
for path in search_paths:
|
|
172
|
+
if path.exists():
|
|
173
|
+
return path
|
|
174
|
+
|
|
175
|
+
for parent in Path.cwd().parents:
|
|
176
|
+
candidate = parent / file_name
|
|
177
|
+
if candidate.exists():
|
|
178
|
+
return candidate
|
|
179
|
+
|
|
180
|
+
raise FileNotFoundError(f"Could not locate file {file_name}")
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
async def load_graph_config() -> Dict[str, str]:
|
|
184
|
+
try:
|
|
185
|
+
json_path = resolve_config_path("langgraph.json")
|
|
186
|
+
async with aiofiles.open(json_path, mode='r') as f:
|
|
187
|
+
content = await f.read()
|
|
188
|
+
config = json.loads(content)
|
|
189
|
+
logger.info(f"Successfully loaded config from {json_path}")
|
|
190
|
+
return config
|
|
191
|
+
except Exception as e:
|
|
192
|
+
raise HTTPException(
|
|
193
|
+
status_code=500,
|
|
194
|
+
detail=f"Config loading failed: {str(e)}"
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
async def parse_agent_config() -> List[Dict[str, Any]]:
|
|
199
|
+
"""
|
|
200
|
+
解析langgraph.json配置文件,返回所有Agent的标准化信息数组
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
List[Dict[str, Any]]: 每个Agent的信息字典组成的数组,结构为:
|
|
204
|
+
[{
|
|
205
|
+
"agent_name": str,
|
|
206
|
+
"agent_description": str,
|
|
207
|
+
"agent_api_version": str,
|
|
208
|
+
"agent_icon_url": str, # 预留字段
|
|
209
|
+
"agent_features": Dict[str, Any] # 预留字段
|
|
210
|
+
"agent_label": str, # Agent的标签
|
|
211
|
+
"has_site": bool # Agent是否要产生chatbot
|
|
212
|
+
"multilangs": Dict[str, Dict[str, str]] # 多语言标签,
|
|
213
|
+
"is_system_agent": bool # Agent是否要注册为系统级Agent,如果注册为系统级Agent则全局租户共用一个,否则分租户使用
|
|
214
|
+
}]
|
|
215
|
+
|
|
216
|
+
Raises:
|
|
217
|
+
HTTPException: 如果配置文件加载或解析失败,包含详细错误分类
|
|
218
|
+
"""
|
|
219
|
+
try:
|
|
220
|
+
config = await load_graph_config()
|
|
221
|
+
|
|
222
|
+
if not config.get("graphs") or not config.get("agent_description"):
|
|
223
|
+
raise ValueError("Missing required 'graphs' or 'agent_description' in config")
|
|
224
|
+
|
|
225
|
+
multilangs = config.get("multilangs", {})
|
|
226
|
+
default_multilang = {
|
|
227
|
+
"default": {
|
|
228
|
+
"en_US": "Default",
|
|
229
|
+
"zh_CN": "默认",
|
|
230
|
+
"zh_TW": "预设"
|
|
231
|
+
}
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
if multilangs:
|
|
235
|
+
multilangs = {**default_multilang, **multilangs}
|
|
236
|
+
else:
|
|
237
|
+
multilangs = default_multilang
|
|
238
|
+
|
|
239
|
+
agents = []
|
|
240
|
+
for agent_name in config["graphs"].keys():
|
|
241
|
+
agents.append({
|
|
242
|
+
"agent_name": agent_name,
|
|
243
|
+
"agent_description": config["agent_description"].get(agent_name, ""),
|
|
244
|
+
"agent_api_version": config.get("agent_api_version", "v0.0.1"),
|
|
245
|
+
"agent_icon_url": config.get("agent_icon_url", ""),
|
|
246
|
+
"agent_features": config.get("agent_features", {}),
|
|
247
|
+
"agent_labels": config.get("agent_labels", []),
|
|
248
|
+
"has_site": config.get("has_site", False),
|
|
249
|
+
"is_system_agent": config.get("is_system_agent", False),
|
|
250
|
+
"multilangs": multilangs
|
|
251
|
+
})
|
|
252
|
+
|
|
253
|
+
return agents
|
|
254
|
+
except HTTPException as e:
|
|
255
|
+
raise e
|
|
256
|
+
except Exception as e:
|
|
257
|
+
raise e
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def set_model_config(
|
|
261
|
+
tool_name: str,
|
|
262
|
+
model_type: str,
|
|
263
|
+
model_provider: str,
|
|
264
|
+
model_name: str,
|
|
265
|
+
credentials: dict,
|
|
266
|
+
agent_id: str,
|
|
267
|
+
ts_id: str = None,
|
|
268
|
+
) -> None:
|
|
269
|
+
category_map = {
|
|
270
|
+
'llm': (ConfigCategory.CHAT_PROVIDER.value, ConfigCategory.CHAT_MODEL.value,
|
|
271
|
+
ConfigCategory.CHAT_CREDENTIALS.value, ConfigCategory.AGENT_ID.value),
|
|
272
|
+
|
|
273
|
+
'text-embedding': (ConfigCategory.EMBEDDING_PROVIDER.value, ConfigCategory.EMBEDDING_MODEL.value,
|
|
274
|
+
ConfigCategory.EMBEDDING_CREDENTIALS.value, ConfigCategory.AGENT_ID.value),
|
|
275
|
+
|
|
276
|
+
'rerank': (ConfigCategory.RERANK_PROVIDER.value, ConfigCategory.RERANK_MODEL.value,
|
|
277
|
+
ConfigCategory.RERANK_CREDENTIALS.value, ConfigCategory.AGENT_ID.value)
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
try:
|
|
281
|
+
provider_suffix, model_suffix, credentials_key_suffix, agent_id_suffix = category_map[model_type]
|
|
282
|
+
except KeyError:
|
|
283
|
+
if logger:
|
|
284
|
+
warning_msg = f"Unsupported model type {model_type}"
|
|
285
|
+
if tool_name != "default":
|
|
286
|
+
warning_msg += f" for tool {tool_name}"
|
|
287
|
+
logger.warning(warning_msg)
|
|
288
|
+
return
|
|
289
|
+
|
|
290
|
+
base_prefix = "" if tool_name == "default" else f"{tool_name.upper()}_"
|
|
291
|
+
|
|
292
|
+
provider_key = f"{base_prefix}{provider_suffix}"
|
|
293
|
+
model_key = f"{base_prefix}{model_suffix}"
|
|
294
|
+
credentials_key = f"{base_prefix}{credentials_key_suffix}"
|
|
295
|
+
agent_id_key = f"{base_prefix}{agent_id_suffix}"
|
|
296
|
+
|
|
297
|
+
if ts_id:
|
|
298
|
+
provider_key = f"{provider_key}_{ts_id}"
|
|
299
|
+
model_key = f"{model_key}_{ts_id}"
|
|
300
|
+
credentials_key = f"{credentials_key}_{ts_id}"
|
|
301
|
+
agent_id_key = f"{agent_id_key}_{ts_id}"
|
|
302
|
+
|
|
303
|
+
new_vars = {
|
|
304
|
+
provider_key: model_provider,
|
|
305
|
+
model_key: model_name,
|
|
306
|
+
credentials_key: credentials,
|
|
307
|
+
agent_id_key: agent_id
|
|
308
|
+
}
|
|
309
|
+
LLM_INFO_PATH = str(LLM_CFG_ENV) + f"_{ts_id}"
|
|
310
|
+
update_env_file(LLM_INFO_PATH, new_vars)
|
|
311
|
+
|
|
312
|
+
logger.debug(f"\nset model config to {LLM_INFO_PATH} ok\n"
|
|
313
|
+
f"provider_key is {provider_key}, provider_value is {model_provider}\n"
|
|
314
|
+
f"model_key is {model_key}, model_value is {model_name}\n"
|
|
315
|
+
f"agent_id_key is {agent_id_key}, agent id value is {agent_id}\n"
|
|
316
|
+
f"credentials_key is {credentials_key}, credentials is {credentials}\n")
|
|
317
|
+
|
|
318
|
+
def get_env(ts_tenant: str):
|
|
319
|
+
envs = dict(os.environ)
|
|
320
|
+
suffix = f"_{ts_tenant}"
|
|
321
|
+
LLM_INFO_PATH = str(LLM_CFG_ENV) + f"_{ts_tenant}"
|
|
322
|
+
|
|
323
|
+
try:
|
|
324
|
+
env_path = resolve_config_path(LLM_INFO_PATH)
|
|
325
|
+
if env_path.exists():
|
|
326
|
+
with open(env_path, 'r') as f:
|
|
327
|
+
encrypted_lines = f.readlines()
|
|
328
|
+
|
|
329
|
+
ll_vars = {}
|
|
330
|
+
for line in encrypted_lines:
|
|
331
|
+
line = line.strip()
|
|
332
|
+
if not line or line.startswith('#'):
|
|
333
|
+
continue
|
|
334
|
+
try:
|
|
335
|
+
key, value = line.split('=', 1)
|
|
336
|
+
if key.endswith(suffix):
|
|
337
|
+
ll_vars[key] = value
|
|
338
|
+
except ValueError as e:
|
|
339
|
+
logger.warning(f"Skipping malformed line in {env_path}: {line}")
|
|
340
|
+
|
|
341
|
+
os.environ.update(ll_vars)
|
|
342
|
+
envs.update(ll_vars)
|
|
343
|
+
except FileNotFoundError:
|
|
344
|
+
pass
|
|
345
|
+
except Exception as e:
|
|
346
|
+
logger.error(f"Error decrypting environment: {str(e)}")
|
|
347
|
+
return envs
|
|
348
|
+
|
|
349
|
+
def update_env_file(env_path, new_vars):
|
|
350
|
+
existing_vars = {}
|
|
351
|
+
if os.path.exists(env_path):
|
|
352
|
+
with open(env_path, 'r') as f:
|
|
353
|
+
for line in f:
|
|
354
|
+
line = line.strip()
|
|
355
|
+
if not line or line.startswith('#'):
|
|
356
|
+
continue
|
|
357
|
+
try:
|
|
358
|
+
key, value = line.split('=', 1)
|
|
359
|
+
existing_vars[key] = value
|
|
360
|
+
except ValueError:
|
|
361
|
+
continue
|
|
362
|
+
|
|
363
|
+
ll_vars = {}
|
|
364
|
+
for key, value in new_vars.items():
|
|
365
|
+
ll_vars[key] = value
|
|
366
|
+
|
|
367
|
+
existing_vars.update(ll_vars)
|
|
368
|
+
with open(env_path, 'w') as f:
|
|
369
|
+
for key, value in existing_vars.items():
|
|
370
|
+
f.write(f"{key}={value}\n")
|
|
371
|
+
|
|
372
|
+
os.chmod(env_path, 0o600)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .sso_service import SSOConfig, SSOService
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .client import SSOClient, get_sso_client_id_and_secret
|
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
from urllib.parse import quote
|
|
5
|
+
from typing import List, Optional
|
|
6
|
+
|
|
7
|
+
import requests
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
CLIENT_CREDENTIALS = "client_credentials"
|
|
12
|
+
CURRENT_SERVICE_NAME = "KB-Insight"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ClientRequest:
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
datacenter: Optional[str] = None,
|
|
19
|
+
cluster: Optional[str] = None,
|
|
20
|
+
workspace: Optional[str] = None,
|
|
21
|
+
namespace: Optional[str] = None,
|
|
22
|
+
app_id: Optional[str] = None,
|
|
23
|
+
app_name: Optional[str] = None,
|
|
24
|
+
scopes: Optional[List[str]] = None,
|
|
25
|
+
service_name: Optional[str] = None,
|
|
26
|
+
redirect_url: Optional[str] = None,
|
|
27
|
+
**kwargs,
|
|
28
|
+
):
|
|
29
|
+
self.datacenter = datacenter
|
|
30
|
+
self.cluster = cluster
|
|
31
|
+
self.workspace = workspace
|
|
32
|
+
self.namespace = namespace
|
|
33
|
+
self.app_id = app_id
|
|
34
|
+
self.app_name = app_name
|
|
35
|
+
self.scopes = scopes if scopes else ["Admin", "Editor", "Viewer"]
|
|
36
|
+
self.service_name = service_name
|
|
37
|
+
self.redirect_url = redirect_url
|
|
38
|
+
|
|
39
|
+
def to_dict(self):
|
|
40
|
+
return {
|
|
41
|
+
"datacenter": self.datacenter,
|
|
42
|
+
"cluster": self.cluster,
|
|
43
|
+
"workspace": self.workspace,
|
|
44
|
+
"namespace": self.namespace,
|
|
45
|
+
"appId": self.app_id,
|
|
46
|
+
"appName": self.app_name,
|
|
47
|
+
"scopes": self.scopes,
|
|
48
|
+
"serviceName": self.service_name,
|
|
49
|
+
"redirectUrl": self.redirect_url,
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class ClientInfo(ClientRequest):
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
client_id: Optional[str] = None,
|
|
57
|
+
creation_time: Optional[int] = None,
|
|
58
|
+
last_modified_time: Optional[int] = None,
|
|
59
|
+
client_secret: Optional[str] = None,
|
|
60
|
+
client_type: Optional[str] = None,
|
|
61
|
+
**kwargs,
|
|
62
|
+
):
|
|
63
|
+
super().__init__(**kwargs)
|
|
64
|
+
self.client_id = client_id
|
|
65
|
+
self.creation_time = creation_time
|
|
66
|
+
self.last_modified_time = last_modified_time
|
|
67
|
+
self.client_secret = client_secret
|
|
68
|
+
self.client_type = client_type
|
|
69
|
+
|
|
70
|
+
@classmethod
|
|
71
|
+
def from_dict(cls, data: dict):
|
|
72
|
+
# 过滤掉 ClientRequest 不支持的字段
|
|
73
|
+
client_request_data = {
|
|
74
|
+
"datacenter": data.get("datacenter"),
|
|
75
|
+
"cluster": data.get("cluster"),
|
|
76
|
+
"workspace": data.get("workspace"),
|
|
77
|
+
"namespace": data.get("namespace"),
|
|
78
|
+
"app_id": data.get("appId"),
|
|
79
|
+
"app_name": data.get("appName"),
|
|
80
|
+
"scopes": data.get("scopes"),
|
|
81
|
+
"service_name": data.get("serviceName"),
|
|
82
|
+
"redirect_url": data.get("redirectUrl"),
|
|
83
|
+
}
|
|
84
|
+
return cls(
|
|
85
|
+
client_id=data.get("clientId"),
|
|
86
|
+
creation_time=data.get("creationTime"),
|
|
87
|
+
last_modified_time=data.get("lastModifiedTime"),
|
|
88
|
+
client_secret=data.get("clientSecret"),
|
|
89
|
+
client_type=data.get("clientType"),
|
|
90
|
+
**client_request_data, # 只传递 ClientRequest 支持的字段
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class SsoToken:
|
|
95
|
+
def __init__(
|
|
96
|
+
self,
|
|
97
|
+
access_token: Optional[str] = None,
|
|
98
|
+
token_type: Optional[str] = None,
|
|
99
|
+
expires_in: Optional[int] = None,
|
|
100
|
+
refresh_token: Optional[str] = None,
|
|
101
|
+
):
|
|
102
|
+
self.access_token = access_token
|
|
103
|
+
self.token_type = token_type
|
|
104
|
+
self.expires_in = expires_in
|
|
105
|
+
self.refresh_token = refresh_token
|
|
106
|
+
|
|
107
|
+
@classmethod
|
|
108
|
+
def from_dict(cls, data: dict):
|
|
109
|
+
return cls(
|
|
110
|
+
access_token=data.get("accessToken"),
|
|
111
|
+
token_type=data.get("tokenType"),
|
|
112
|
+
expires_in=data.get("expiresIn"),
|
|
113
|
+
refresh_token=data.get("refreshToken"),
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class SSOClient:
|
|
118
|
+
def __init__(self, service_name: Optional[str] = None):
|
|
119
|
+
self.client_info = ClientInfo(
|
|
120
|
+
datacenter=os.getenv("datacenter"),
|
|
121
|
+
cluster=os.getenv("cluster"),
|
|
122
|
+
workspace=os.getenv("workspace"),
|
|
123
|
+
namespace=os.getenv("namespace"),
|
|
124
|
+
app_id=os.getenv("appID") or os.getenv("appId"),
|
|
125
|
+
app_name=service_name or CURRENT_SERVICE_NAME,
|
|
126
|
+
service_name=service_name or CURRENT_SERVICE_NAME,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
def query_sso_client(self, address: str, req: ClientRequest) -> ClientInfo:
|
|
130
|
+
path = f"{address}/clients/{quote(req.app_name)}"
|
|
131
|
+
path += f"?namespace={req.namespace}&workspace={req.workspace}&cluster={req.cluster}&datacenter={req.datacenter}&serviceName={quote(req.service_name)}&appId={req.app_id}"
|
|
132
|
+
|
|
133
|
+
credential = self.new_srp_token_credential(req.service_name)
|
|
134
|
+
headers = {credential["header_name"]: credential["header_value"]}
|
|
135
|
+
|
|
136
|
+
response = requests.get(path, headers=headers)
|
|
137
|
+
response.raise_for_status()
|
|
138
|
+
return ClientInfo.from_dict(response.json())
|
|
139
|
+
|
|
140
|
+
def register_client(self, address: str, req: ClientRequest) -> ClientInfo:
|
|
141
|
+
path = f"{address}/clients"
|
|
142
|
+
body = json.dumps(req.to_dict())
|
|
143
|
+
|
|
144
|
+
logger.info({body})
|
|
145
|
+
|
|
146
|
+
credential = self.new_srp_token_credential(req.service_name)
|
|
147
|
+
headers = {
|
|
148
|
+
credential["header_name"]: credential["header_value"],
|
|
149
|
+
"Content-Type": "application/json"
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
response = requests.post(path, headers=headers, data=body)
|
|
153
|
+
response.raise_for_status()
|
|
154
|
+
return ClientInfo.from_dict(response.json())
|
|
155
|
+
|
|
156
|
+
def update_sso_client(self, address: str, client_id: str, req: ClientRequest) -> ClientInfo:
|
|
157
|
+
path = f"{address}/clients/{client_id}"
|
|
158
|
+
body = json.dumps(req.to_dict())
|
|
159
|
+
|
|
160
|
+
credential = self.new_srp_token_credential(req.service_name)
|
|
161
|
+
headers = {
|
|
162
|
+
credential["header_name"]: credential["header_value"],
|
|
163
|
+
"Content-Type": "application/json"
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
response = requests.put(path, headers=headers, data=body)
|
|
167
|
+
response.raise_for_status()
|
|
168
|
+
return ClientInfo.from_dict(response.json())
|
|
169
|
+
|
|
170
|
+
def query_sso_client_token(self, address: str) -> SsoToken:
|
|
171
|
+
path = f"{address}/oauth/token"
|
|
172
|
+
path += f"?grant_type={CLIENT_CREDENTIALS}&client_id={self.client_info.client_id}&client_secret={self.client_info.client_secret}&duration=eternal"
|
|
173
|
+
|
|
174
|
+
response = requests.post(path)
|
|
175
|
+
response.raise_for_status()
|
|
176
|
+
return SsoToken.from_dict(response.json())
|
|
177
|
+
|
|
178
|
+
def set_sso_client(self, address: str) -> None:
|
|
179
|
+
if not address.endswith("/v4.0"):
|
|
180
|
+
address += "/v4.0"
|
|
181
|
+
|
|
182
|
+
try:
|
|
183
|
+
rsp = self.query_sso_client(address, self.client_info)
|
|
184
|
+
except requests.exceptions.HTTPError as e:
|
|
185
|
+
if "Not Found" in str(e):
|
|
186
|
+
rsp = self.register_client(address, self.client_info)
|
|
187
|
+
else:
|
|
188
|
+
raise e
|
|
189
|
+
|
|
190
|
+
if rsp is None:
|
|
191
|
+
raise Exception("Failed to set SSO client")
|
|
192
|
+
|
|
193
|
+
if rsp.app_id != self.client_info.app_id:
|
|
194
|
+
try:
|
|
195
|
+
new_rsp = self.update_sso_client(address, rsp.client_id, self.client_info)
|
|
196
|
+
rsp = new_rsp
|
|
197
|
+
except requests.exceptions.HTTPError as e:
|
|
198
|
+
logger.error(f"Failed to update SSO client: {e}")
|
|
199
|
+
|
|
200
|
+
self.client_info = rsp
|
|
201
|
+
return None
|
|
202
|
+
|
|
203
|
+
@staticmethod
|
|
204
|
+
def new_srp_token_credential(service_name: str) -> dict:
|
|
205
|
+
from .credential import get_srp_token
|
|
206
|
+
return {
|
|
207
|
+
"header_name": "X-Auth-SRPToken",
|
|
208
|
+
"header_value": get_srp_token(service_name),
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def get_sso_client_id_and_secret(address: str, service: str) -> tuple:
|
|
213
|
+
sso = SSOClient(service)
|
|
214
|
+
logger.debug(f"{sso.client_info.__dict__}")
|
|
215
|
+
|
|
216
|
+
try:
|
|
217
|
+
sso.set_sso_client(address)
|
|
218
|
+
except Exception as e:
|
|
219
|
+
raise ValueError(f"get sso clientId and ClientSecret err: {e}")
|
|
220
|
+
|
|
221
|
+
logger.debug(
|
|
222
|
+
f"ServiceName: {sso.client_info.service_name}, ClientID: {sso.client_info.client_id}, ClientSecret: {sso.client_info.client_secret}")
|
|
223
|
+
|
|
224
|
+
return sso.client_info.client_id, sso.client_info.client_secret
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from .encoding import AesEncrypt, Base64UrlSafeEncode
|
|
3
|
+
|
|
4
|
+
SRP_TOKEN_KEY = "ssoisno12345678987654321"
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def get_srp_token(service_name: str) -> str:
|
|
8
|
+
cur_time = str(int(time.time()))
|
|
9
|
+
src = f"{cur_time}-{service_name}"
|
|
10
|
+
crypted = AesEncrypt(src, SRP_TOKEN_KEY)
|
|
11
|
+
return Base64UrlSafeEncode(crypted)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
|
3
|
+
from cryptography.hazmat.primitives import padding
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def AesEncrypt(src: str, key: str) -> bytes:
|
|
7
|
+
key_bytes = key.encode("utf-8")
|
|
8
|
+
src_bytes = src.encode("utf-8")
|
|
9
|
+
|
|
10
|
+
# Pad the source data
|
|
11
|
+
padder = padding.PKCS7(algorithms.AES.block_size).padder()
|
|
12
|
+
padded_data = padder.update(src_bytes) + padder.finalize()
|
|
13
|
+
|
|
14
|
+
# Encrypt
|
|
15
|
+
cipher = Cipher(algorithms.AES(key_bytes), modes.ECB())
|
|
16
|
+
encryptor = cipher.encryptor()
|
|
17
|
+
return encryptor.update(padded_data) + encryptor.finalize()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def Base64UrlSafeEncode(data: bytes) -> str:
|
|
21
|
+
encoded = base64.urlsafe_b64encode(data).decode("utf-8")
|
|
22
|
+
return encoded.rstrip("=")
|