agent-api-server 2.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. agent_api_server/__init__.py +0 -0
  2. agent_api_server/api/__init__.py +0 -0
  3. agent_api_server/api/v1/__init__.py +0 -0
  4. agent_api_server/api/v1/api.py +25 -0
  5. agent_api_server/api/v1/config.py +57 -0
  6. agent_api_server/api/v1/graph.py +59 -0
  7. agent_api_server/api/v1/schema.py +57 -0
  8. agent_api_server/api/v1/thread.py +563 -0
  9. agent_api_server/cache/__init__.py +0 -0
  10. agent_api_server/cache/redis_cache.py +385 -0
  11. agent_api_server/callback_handler.py +18 -0
  12. agent_api_server/client/css/styles.css +1202 -0
  13. agent_api_server/client/favicon.ico +0 -0
  14. agent_api_server/client/index.html +102 -0
  15. agent_api_server/client/js/app.js +1499 -0
  16. agent_api_server/client/js/index.umd.js +824 -0
  17. agent_api_server/config_center/config_center.py +239 -0
  18. agent_api_server/configs/__init__.py +3 -0
  19. agent_api_server/configs/config.py +163 -0
  20. agent_api_server/dynamic_llm/__init__.py +0 -0
  21. agent_api_server/dynamic_llm/dynamic_llm.py +331 -0
  22. agent_api_server/listener.py +530 -0
  23. agent_api_server/log/__init__.py +0 -0
  24. agent_api_server/log/formatters.py +122 -0
  25. agent_api_server/log/logging.json +50 -0
  26. agent_api_server/mcp_convert/__init__.py +0 -0
  27. agent_api_server/mcp_convert/mcp_convert.py +375 -0
  28. agent_api_server/memeory/__init__.py +0 -0
  29. agent_api_server/memeory/postgres.py +233 -0
  30. agent_api_server/register/__init__.py +0 -0
  31. agent_api_server/register/register.py +65 -0
  32. agent_api_server/service.py +354 -0
  33. agent_api_server/service_hub/service_hub.py +233 -0
  34. agent_api_server/service_hub/service_hub_test.py +700 -0
  35. agent_api_server/shared/__init__.py +0 -0
  36. agent_api_server/shared/ase.py +54 -0
  37. agent_api_server/shared/base_model.py +103 -0
  38. agent_api_server/shared/common.py +110 -0
  39. agent_api_server/shared/decode_token.py +107 -0
  40. agent_api_server/shared/detect_message.py +410 -0
  41. agent_api_server/shared/get_model_info.py +491 -0
  42. agent_api_server/shared/message.py +419 -0
  43. agent_api_server/shared/util_func.py +372 -0
  44. agent_api_server/sso_service/__init__.py +1 -0
  45. agent_api_server/sso_service/sdk/__init__.py +1 -0
  46. agent_api_server/sso_service/sdk/client.py +224 -0
  47. agent_api_server/sso_service/sdk/credential.py +11 -0
  48. agent_api_server/sso_service/sdk/encoding.py +22 -0
  49. agent_api_server/sso_service/sso_service.py +177 -0
  50. agent_api_server-2.1.7.dist-info/METADATA +130 -0
  51. agent_api_server-2.1.7.dist-info/RECORD +52 -0
  52. agent_api_server-2.1.7.dist-info/WHEEL +4 -0
@@ -0,0 +1,372 @@
1
+ import os
2
+ from typing import Dict
3
+ from importlib import import_module
4
+ from fastapi import HTTPException
5
+ from pathlib import Path
6
+ import json
7
+ import logging
8
+ import aiofiles
9
+ import asyncio
10
+ import sys
11
+ from typing import List, Any
12
+ from agent_api_server.memeory.postgres import AsyncPostgresCheckpointer
13
+ from langgraph.pregel import Pregel
14
+
15
+ from agent_api_server.shared.common import ConfigCategory
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ BASE_DIR = Path.cwd()
20
+ DATA_DIR = BASE_DIR / 'data'
21
+ DATA_DIR.mkdir(exist_ok=True, mode=0o700)
22
+ LLM_CFG_ENV = DATA_DIR /'.llm_cfg_env'
23
+
24
+ async def get_all_graph_names() -> List[str]:
25
+ """
26
+ 获取配置文件中定义的所有graph名称
27
+
28
+ Returns:
29
+ List[str]: 所有可用的graph名称列表
30
+
31
+ Raises:
32
+ HTTPException: 如果配置文件加载失败
33
+ """
34
+ try:
35
+ config = await load_graph_config()
36
+ if not config.get("graphs"):
37
+ return []
38
+ return list(config["graphs"].keys())
39
+ except Exception as e:
40
+ raise HTTPException(
41
+ status_code=500,
42
+ detail={
43
+ "error": "config_loading_failed",
44
+ "message": f"Cannot get graph names: {str(e)}"
45
+ }
46
+ )
47
+
48
+
49
+ async def load_graph(graph_name: str, config: Dict, with_checker: bool) -> tuple[str, Pregel, None] | tuple[
50
+ str, Pregel, AsyncPostgresCheckpointer]:
51
+ if not config.get("graphs") or graph_name not in config["graphs"]:
52
+ available = list(config["graphs"].keys()) if "graphs" in config else []
53
+ raise HTTPException(
54
+ status_code=404,
55
+ detail={
56
+ "error": "graph_not_found",
57
+ "available": available,
58
+ "message": f"Graph '{graph_name}' not found in config"
59
+ }
60
+ )
61
+
62
+ graph_path = config["graphs"][graph_name]
63
+ try:
64
+ module_path, attr_name = graph_path.rsplit(":", 1)
65
+ if not module_path or not attr_name:
66
+ raise ValueError("Empty module path or attribute name")
67
+ except ValueError as e:
68
+ raise HTTPException(
69
+ status_code=400,
70
+ detail={
71
+ "error": "invalid_path_format",
72
+ "expected": "path/to/module.py:attribute_name",
73
+ "actual": graph_path,
74
+ "message": str(e)
75
+ }
76
+ )
77
+
78
+ try:
79
+ module_path = module_path.replace("./", "").replace("\\", "/")
80
+ if module_path.endswith(".py"):
81
+ module_path = module_path[:-3]
82
+ import_path = module_path.replace("/", ".")
83
+ except Exception as e:
84
+ raise HTTPException(
85
+ status_code=400,
86
+ detail={
87
+ "error": "path_processing_failed",
88
+ "message": f"Cannot process path '{module_path}': {str(e)}"
89
+ }
90
+ )
91
+
92
+ try:
93
+ module_dir = str(Path(module_path).parent.resolve())
94
+ if module_dir not in sys.path:
95
+ sys.path.insert(0, module_dir)
96
+
97
+ module = await asyncio.to_thread(import_module, import_path)
98
+ except ModuleNotFoundError as e:
99
+ raise HTTPException(
100
+ status_code=500,
101
+ detail={
102
+ "error": "module_not_found",
103
+ "path": import_path,
104
+ "sys_path": str(sys.path),
105
+ "message": f"Module not found: {str(e)}"
106
+ }
107
+ )
108
+ except Exception as e:
109
+ raise HTTPException(
110
+ status_code=500,
111
+ detail={
112
+ "error": "module_import_error",
113
+ "path": import_path,
114
+ "message": f"Error importing module: {str(e)}"
115
+ }
116
+ )
117
+
118
+ # 5. 获取graph实例 - 更健壮的实例化
119
+ try:
120
+ graph_factory = getattr(module, attr_name)
121
+
122
+ # 处理可能是协程或工厂函数的情况
123
+ if asyncio.iscoroutinefunction(graph_factory):
124
+ graph = await graph_factory()
125
+ elif callable(graph_factory):
126
+ graph = graph_factory()
127
+ else:
128
+ graph = graph_factory
129
+
130
+ # 确保是 Pregel 实例
131
+ if not isinstance(graph, Pregel):
132
+ raise TypeError(f"Expected Pregel instance, got {type(graph)}")
133
+
134
+ if with_checker:
135
+ checkpointer = AsyncPostgresCheckpointer.get_worker_instance()
136
+ try:
137
+ graph.checkpointer = await checkpointer.checkpointer()
138
+ except Exception as e:
139
+ raise RuntimeError(f"Failed to initialize checkpointer: {str(e)}")
140
+
141
+ return graph_name, graph, checkpointer
142
+
143
+ return graph_name, graph,None
144
+ except AttributeError:
145
+ raise HTTPException(
146
+ status_code=500,
147
+ detail={
148
+ "error": "missing_attribute",
149
+ "module": import_path,
150
+ "attribute": attr_name,
151
+ "available": dir(module)
152
+ }
153
+ )
154
+ except Exception as e:
155
+ raise HTTPException(
156
+ status_code=500,
157
+ detail={
158
+ "error": "graph_initialization_failed",
159
+ "message": f"Failed to initialize graph: {str(e)}",
160
+ "type": type(e).__name__
161
+ }
162
+ )
163
+
164
+
165
+ def resolve_config_path(file_name: str) -> Path:
166
+ search_paths = [
167
+ Path.cwd() / file_name,
168
+ Path(__file__).parent.parent / file_name,
169
+ ]
170
+
171
+ for path in search_paths:
172
+ if path.exists():
173
+ return path
174
+
175
+ for parent in Path.cwd().parents:
176
+ candidate = parent / file_name
177
+ if candidate.exists():
178
+ return candidate
179
+
180
+ raise FileNotFoundError(f"Could not locate file {file_name}")
181
+
182
+
183
+ async def load_graph_config() -> Dict[str, str]:
184
+ try:
185
+ json_path = resolve_config_path("langgraph.json")
186
+ async with aiofiles.open(json_path, mode='r') as f:
187
+ content = await f.read()
188
+ config = json.loads(content)
189
+ logger.info(f"Successfully loaded config from {json_path}")
190
+ return config
191
+ except Exception as e:
192
+ raise HTTPException(
193
+ status_code=500,
194
+ detail=f"Config loading failed: {str(e)}"
195
+ )
196
+
197
+
198
+ async def parse_agent_config() -> List[Dict[str, Any]]:
199
+ """
200
+ 解析langgraph.json配置文件,返回所有Agent的标准化信息数组
201
+
202
+ Returns:
203
+ List[Dict[str, Any]]: 每个Agent的信息字典组成的数组,结构为:
204
+ [{
205
+ "agent_name": str,
206
+ "agent_description": str,
207
+ "agent_api_version": str,
208
+ "agent_icon_url": str, # 预留字段
209
+ "agent_features": Dict[str, Any] # 预留字段
210
+ "agent_label": str, # Agent的标签
211
+ "has_site": bool # Agent是否要产生chatbot
212
+ "multilangs": Dict[str, Dict[str, str]] # 多语言标签,
213
+ "is_system_agent": bool # Agent是否要注册为系统级Agent,如果注册为系统级Agent则全局租户共用一个,否则分租户使用
214
+ }]
215
+
216
+ Raises:
217
+ HTTPException: 如果配置文件加载或解析失败,包含详细错误分类
218
+ """
219
+ try:
220
+ config = await load_graph_config()
221
+
222
+ if not config.get("graphs") or not config.get("agent_description"):
223
+ raise ValueError("Missing required 'graphs' or 'agent_description' in config")
224
+
225
+ multilangs = config.get("multilangs", {})
226
+ default_multilang = {
227
+ "default": {
228
+ "en_US": "Default",
229
+ "zh_CN": "默认",
230
+ "zh_TW": "预设"
231
+ }
232
+ }
233
+
234
+ if multilangs:
235
+ multilangs = {**default_multilang, **multilangs}
236
+ else:
237
+ multilangs = default_multilang
238
+
239
+ agents = []
240
+ for agent_name in config["graphs"].keys():
241
+ agents.append({
242
+ "agent_name": agent_name,
243
+ "agent_description": config["agent_description"].get(agent_name, ""),
244
+ "agent_api_version": config.get("agent_api_version", "v0.0.1"),
245
+ "agent_icon_url": config.get("agent_icon_url", ""),
246
+ "agent_features": config.get("agent_features", {}),
247
+ "agent_labels": config.get("agent_labels", []),
248
+ "has_site": config.get("has_site", False),
249
+ "is_system_agent": config.get("is_system_agent", False),
250
+ "multilangs": multilangs
251
+ })
252
+
253
+ return agents
254
+ except HTTPException as e:
255
+ raise e
256
+ except Exception as e:
257
+ raise e
258
+
259
+
260
+ def set_model_config(
261
+ tool_name: str,
262
+ model_type: str,
263
+ model_provider: str,
264
+ model_name: str,
265
+ credentials: dict,
266
+ agent_id: str,
267
+ ts_id: str = None,
268
+ ) -> None:
269
+ category_map = {
270
+ 'llm': (ConfigCategory.CHAT_PROVIDER.value, ConfigCategory.CHAT_MODEL.value,
271
+ ConfigCategory.CHAT_CREDENTIALS.value, ConfigCategory.AGENT_ID.value),
272
+
273
+ 'text-embedding': (ConfigCategory.EMBEDDING_PROVIDER.value, ConfigCategory.EMBEDDING_MODEL.value,
274
+ ConfigCategory.EMBEDDING_CREDENTIALS.value, ConfigCategory.AGENT_ID.value),
275
+
276
+ 'rerank': (ConfigCategory.RERANK_PROVIDER.value, ConfigCategory.RERANK_MODEL.value,
277
+ ConfigCategory.RERANK_CREDENTIALS.value, ConfigCategory.AGENT_ID.value)
278
+ }
279
+
280
+ try:
281
+ provider_suffix, model_suffix, credentials_key_suffix, agent_id_suffix = category_map[model_type]
282
+ except KeyError:
283
+ if logger:
284
+ warning_msg = f"Unsupported model type {model_type}"
285
+ if tool_name != "default":
286
+ warning_msg += f" for tool {tool_name}"
287
+ logger.warning(warning_msg)
288
+ return
289
+
290
+ base_prefix = "" if tool_name == "default" else f"{tool_name.upper()}_"
291
+
292
+ provider_key = f"{base_prefix}{provider_suffix}"
293
+ model_key = f"{base_prefix}{model_suffix}"
294
+ credentials_key = f"{base_prefix}{credentials_key_suffix}"
295
+ agent_id_key = f"{base_prefix}{agent_id_suffix}"
296
+
297
+ if ts_id:
298
+ provider_key = f"{provider_key}_{ts_id}"
299
+ model_key = f"{model_key}_{ts_id}"
300
+ credentials_key = f"{credentials_key}_{ts_id}"
301
+ agent_id_key = f"{agent_id_key}_{ts_id}"
302
+
303
+ new_vars = {
304
+ provider_key: model_provider,
305
+ model_key: model_name,
306
+ credentials_key: credentials,
307
+ agent_id_key: agent_id
308
+ }
309
+ LLM_INFO_PATH = str(LLM_CFG_ENV) + f"_{ts_id}"
310
+ update_env_file(LLM_INFO_PATH, new_vars)
311
+
312
+ logger.debug(f"\nset model config to {LLM_INFO_PATH} ok\n"
313
+ f"provider_key is {provider_key}, provider_value is {model_provider}\n"
314
+ f"model_key is {model_key}, model_value is {model_name}\n"
315
+ f"agent_id_key is {agent_id_key}, agent id value is {agent_id}\n"
316
+ f"credentials_key is {credentials_key}, credentials is {credentials}\n")
317
+
318
+ def get_env(ts_tenant: str):
319
+ envs = dict(os.environ)
320
+ suffix = f"_{ts_tenant}"
321
+ LLM_INFO_PATH = str(LLM_CFG_ENV) + f"_{ts_tenant}"
322
+
323
+ try:
324
+ env_path = resolve_config_path(LLM_INFO_PATH)
325
+ if env_path.exists():
326
+ with open(env_path, 'r') as f:
327
+ encrypted_lines = f.readlines()
328
+
329
+ ll_vars = {}
330
+ for line in encrypted_lines:
331
+ line = line.strip()
332
+ if not line or line.startswith('#'):
333
+ continue
334
+ try:
335
+ key, value = line.split('=', 1)
336
+ if key.endswith(suffix):
337
+ ll_vars[key] = value
338
+ except ValueError as e:
339
+ logger.warning(f"Skipping malformed line in {env_path}: {line}")
340
+
341
+ os.environ.update(ll_vars)
342
+ envs.update(ll_vars)
343
+ except FileNotFoundError:
344
+ pass
345
+ except Exception as e:
346
+ logger.error(f"Error decrypting environment: {str(e)}")
347
+ return envs
348
+
349
+ def update_env_file(env_path, new_vars):
350
+ existing_vars = {}
351
+ if os.path.exists(env_path):
352
+ with open(env_path, 'r') as f:
353
+ for line in f:
354
+ line = line.strip()
355
+ if not line or line.startswith('#'):
356
+ continue
357
+ try:
358
+ key, value = line.split('=', 1)
359
+ existing_vars[key] = value
360
+ except ValueError:
361
+ continue
362
+
363
+ ll_vars = {}
364
+ for key, value in new_vars.items():
365
+ ll_vars[key] = value
366
+
367
+ existing_vars.update(ll_vars)
368
+ with open(env_path, 'w') as f:
369
+ for key, value in existing_vars.items():
370
+ f.write(f"{key}={value}\n")
371
+
372
+ os.chmod(env_path, 0o600)
@@ -0,0 +1 @@
1
+ from .sso_service import SSOConfig, SSOService
@@ -0,0 +1 @@
1
+ from .client import SSOClient, get_sso_client_id_and_secret
@@ -0,0 +1,224 @@
1
+ import os
2
+ import json
3
+ import logging
4
+ from urllib.parse import quote
5
+ from typing import List, Optional
6
+
7
+ import requests
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ CLIENT_CREDENTIALS = "client_credentials"
12
+ CURRENT_SERVICE_NAME = "KB-Insight"
13
+
14
+
15
+ class ClientRequest:
16
+ def __init__(
17
+ self,
18
+ datacenter: Optional[str] = None,
19
+ cluster: Optional[str] = None,
20
+ workspace: Optional[str] = None,
21
+ namespace: Optional[str] = None,
22
+ app_id: Optional[str] = None,
23
+ app_name: Optional[str] = None,
24
+ scopes: Optional[List[str]] = None,
25
+ service_name: Optional[str] = None,
26
+ redirect_url: Optional[str] = None,
27
+ **kwargs,
28
+ ):
29
+ self.datacenter = datacenter
30
+ self.cluster = cluster
31
+ self.workspace = workspace
32
+ self.namespace = namespace
33
+ self.app_id = app_id
34
+ self.app_name = app_name
35
+ self.scopes = scopes if scopes else ["Admin", "Editor", "Viewer"]
36
+ self.service_name = service_name
37
+ self.redirect_url = redirect_url
38
+
39
+ def to_dict(self):
40
+ return {
41
+ "datacenter": self.datacenter,
42
+ "cluster": self.cluster,
43
+ "workspace": self.workspace,
44
+ "namespace": self.namespace,
45
+ "appId": self.app_id,
46
+ "appName": self.app_name,
47
+ "scopes": self.scopes,
48
+ "serviceName": self.service_name,
49
+ "redirectUrl": self.redirect_url,
50
+ }
51
+
52
+
53
+ class ClientInfo(ClientRequest):
54
+ def __init__(
55
+ self,
56
+ client_id: Optional[str] = None,
57
+ creation_time: Optional[int] = None,
58
+ last_modified_time: Optional[int] = None,
59
+ client_secret: Optional[str] = None,
60
+ client_type: Optional[str] = None,
61
+ **kwargs,
62
+ ):
63
+ super().__init__(**kwargs)
64
+ self.client_id = client_id
65
+ self.creation_time = creation_time
66
+ self.last_modified_time = last_modified_time
67
+ self.client_secret = client_secret
68
+ self.client_type = client_type
69
+
70
+ @classmethod
71
+ def from_dict(cls, data: dict):
72
+ # 过滤掉 ClientRequest 不支持的字段
73
+ client_request_data = {
74
+ "datacenter": data.get("datacenter"),
75
+ "cluster": data.get("cluster"),
76
+ "workspace": data.get("workspace"),
77
+ "namespace": data.get("namespace"),
78
+ "app_id": data.get("appId"),
79
+ "app_name": data.get("appName"),
80
+ "scopes": data.get("scopes"),
81
+ "service_name": data.get("serviceName"),
82
+ "redirect_url": data.get("redirectUrl"),
83
+ }
84
+ return cls(
85
+ client_id=data.get("clientId"),
86
+ creation_time=data.get("creationTime"),
87
+ last_modified_time=data.get("lastModifiedTime"),
88
+ client_secret=data.get("clientSecret"),
89
+ client_type=data.get("clientType"),
90
+ **client_request_data, # 只传递 ClientRequest 支持的字段
91
+ )
92
+
93
+
94
+ class SsoToken:
95
+ def __init__(
96
+ self,
97
+ access_token: Optional[str] = None,
98
+ token_type: Optional[str] = None,
99
+ expires_in: Optional[int] = None,
100
+ refresh_token: Optional[str] = None,
101
+ ):
102
+ self.access_token = access_token
103
+ self.token_type = token_type
104
+ self.expires_in = expires_in
105
+ self.refresh_token = refresh_token
106
+
107
+ @classmethod
108
+ def from_dict(cls, data: dict):
109
+ return cls(
110
+ access_token=data.get("accessToken"),
111
+ token_type=data.get("tokenType"),
112
+ expires_in=data.get("expiresIn"),
113
+ refresh_token=data.get("refreshToken"),
114
+ )
115
+
116
+
117
+ class SSOClient:
118
+ def __init__(self, service_name: Optional[str] = None):
119
+ self.client_info = ClientInfo(
120
+ datacenter=os.getenv("datacenter"),
121
+ cluster=os.getenv("cluster"),
122
+ workspace=os.getenv("workspace"),
123
+ namespace=os.getenv("namespace"),
124
+ app_id=os.getenv("appID") or os.getenv("appId"),
125
+ app_name=service_name or CURRENT_SERVICE_NAME,
126
+ service_name=service_name or CURRENT_SERVICE_NAME,
127
+ )
128
+
129
+ def query_sso_client(self, address: str, req: ClientRequest) -> ClientInfo:
130
+ path = f"{address}/clients/{quote(req.app_name)}"
131
+ path += f"?namespace={req.namespace}&workspace={req.workspace}&cluster={req.cluster}&datacenter={req.datacenter}&serviceName={quote(req.service_name)}&appId={req.app_id}"
132
+
133
+ credential = self.new_srp_token_credential(req.service_name)
134
+ headers = {credential["header_name"]: credential["header_value"]}
135
+
136
+ response = requests.get(path, headers=headers)
137
+ response.raise_for_status()
138
+ return ClientInfo.from_dict(response.json())
139
+
140
+ def register_client(self, address: str, req: ClientRequest) -> ClientInfo:
141
+ path = f"{address}/clients"
142
+ body = json.dumps(req.to_dict())
143
+
144
+ logger.info({body})
145
+
146
+ credential = self.new_srp_token_credential(req.service_name)
147
+ headers = {
148
+ credential["header_name"]: credential["header_value"],
149
+ "Content-Type": "application/json"
150
+ }
151
+
152
+ response = requests.post(path, headers=headers, data=body)
153
+ response.raise_for_status()
154
+ return ClientInfo.from_dict(response.json())
155
+
156
+ def update_sso_client(self, address: str, client_id: str, req: ClientRequest) -> ClientInfo:
157
+ path = f"{address}/clients/{client_id}"
158
+ body = json.dumps(req.to_dict())
159
+
160
+ credential = self.new_srp_token_credential(req.service_name)
161
+ headers = {
162
+ credential["header_name"]: credential["header_value"],
163
+ "Content-Type": "application/json"
164
+ }
165
+
166
+ response = requests.put(path, headers=headers, data=body)
167
+ response.raise_for_status()
168
+ return ClientInfo.from_dict(response.json())
169
+
170
+ def query_sso_client_token(self, address: str) -> SsoToken:
171
+ path = f"{address}/oauth/token"
172
+ path += f"?grant_type={CLIENT_CREDENTIALS}&client_id={self.client_info.client_id}&client_secret={self.client_info.client_secret}&duration=eternal"
173
+
174
+ response = requests.post(path)
175
+ response.raise_for_status()
176
+ return SsoToken.from_dict(response.json())
177
+
178
+ def set_sso_client(self, address: str) -> None:
179
+ if not address.endswith("/v4.0"):
180
+ address += "/v4.0"
181
+
182
+ try:
183
+ rsp = self.query_sso_client(address, self.client_info)
184
+ except requests.exceptions.HTTPError as e:
185
+ if "Not Found" in str(e):
186
+ rsp = self.register_client(address, self.client_info)
187
+ else:
188
+ raise e
189
+
190
+ if rsp is None:
191
+ raise Exception("Failed to set SSO client")
192
+
193
+ if rsp.app_id != self.client_info.app_id:
194
+ try:
195
+ new_rsp = self.update_sso_client(address, rsp.client_id, self.client_info)
196
+ rsp = new_rsp
197
+ except requests.exceptions.HTTPError as e:
198
+ logger.error(f"Failed to update SSO client: {e}")
199
+
200
+ self.client_info = rsp
201
+ return None
202
+
203
+ @staticmethod
204
+ def new_srp_token_credential(service_name: str) -> dict:
205
+ from .credential import get_srp_token
206
+ return {
207
+ "header_name": "X-Auth-SRPToken",
208
+ "header_value": get_srp_token(service_name),
209
+ }
210
+
211
+
212
+ def get_sso_client_id_and_secret(address: str, service: str) -> tuple:
213
+ sso = SSOClient(service)
214
+ logger.debug(f"{sso.client_info.__dict__}")
215
+
216
+ try:
217
+ sso.set_sso_client(address)
218
+ except Exception as e:
219
+ raise ValueError(f"get sso clientId and ClientSecret err: {e}")
220
+
221
+ logger.debug(
222
+ f"ServiceName: {sso.client_info.service_name}, ClientID: {sso.client_info.client_id}, ClientSecret: {sso.client_info.client_secret}")
223
+
224
+ return sso.client_info.client_id, sso.client_info.client_secret
@@ -0,0 +1,11 @@
1
+ import time
2
+ from .encoding import AesEncrypt, Base64UrlSafeEncode
3
+
4
+ SRP_TOKEN_KEY = "ssoisno12345678987654321"
5
+
6
+
7
+ def get_srp_token(service_name: str) -> str:
8
+ cur_time = str(int(time.time()))
9
+ src = f"{cur_time}-{service_name}"
10
+ crypted = AesEncrypt(src, SRP_TOKEN_KEY)
11
+ return Base64UrlSafeEncode(crypted)
@@ -0,0 +1,22 @@
1
+ import base64
2
+ from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
3
+ from cryptography.hazmat.primitives import padding
4
+
5
+
6
+ def AesEncrypt(src: str, key: str) -> bytes:
7
+ key_bytes = key.encode("utf-8")
8
+ src_bytes = src.encode("utf-8")
9
+
10
+ # Pad the source data
11
+ padder = padding.PKCS7(algorithms.AES.block_size).padder()
12
+ padded_data = padder.update(src_bytes) + padder.finalize()
13
+
14
+ # Encrypt
15
+ cipher = Cipher(algorithms.AES(key_bytes), modes.ECB())
16
+ encryptor = cipher.encryptor()
17
+ return encryptor.update(padded_data) + encryptor.finalize()
18
+
19
+
20
+ def Base64UrlSafeEncode(data: bytes) -> str:
21
+ encoded = base64.urlsafe_b64encode(data).decode("utf-8")
22
+ return encoded.rstrip("=")