agent-api-server 2.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent_api_server/__init__.py +0 -0
- agent_api_server/api/__init__.py +0 -0
- agent_api_server/api/v1/__init__.py +0 -0
- agent_api_server/api/v1/api.py +25 -0
- agent_api_server/api/v1/config.py +57 -0
- agent_api_server/api/v1/graph.py +59 -0
- agent_api_server/api/v1/schema.py +57 -0
- agent_api_server/api/v1/thread.py +563 -0
- agent_api_server/cache/__init__.py +0 -0
- agent_api_server/cache/redis_cache.py +385 -0
- agent_api_server/callback_handler.py +18 -0
- agent_api_server/client/css/styles.css +1202 -0
- agent_api_server/client/favicon.ico +0 -0
- agent_api_server/client/index.html +102 -0
- agent_api_server/client/js/app.js +1499 -0
- agent_api_server/client/js/index.umd.js +824 -0
- agent_api_server/config_center/config_center.py +239 -0
- agent_api_server/configs/__init__.py +3 -0
- agent_api_server/configs/config.py +163 -0
- agent_api_server/dynamic_llm/__init__.py +0 -0
- agent_api_server/dynamic_llm/dynamic_llm.py +331 -0
- agent_api_server/listener.py +530 -0
- agent_api_server/log/__init__.py +0 -0
- agent_api_server/log/formatters.py +122 -0
- agent_api_server/log/logging.json +50 -0
- agent_api_server/mcp_convert/__init__.py +0 -0
- agent_api_server/mcp_convert/mcp_convert.py +375 -0
- agent_api_server/memeory/__init__.py +0 -0
- agent_api_server/memeory/postgres.py +233 -0
- agent_api_server/register/__init__.py +0 -0
- agent_api_server/register/register.py +65 -0
- agent_api_server/service.py +354 -0
- agent_api_server/service_hub/service_hub.py +233 -0
- agent_api_server/service_hub/service_hub_test.py +700 -0
- agent_api_server/shared/__init__.py +0 -0
- agent_api_server/shared/ase.py +54 -0
- agent_api_server/shared/base_model.py +103 -0
- agent_api_server/shared/common.py +110 -0
- agent_api_server/shared/decode_token.py +107 -0
- agent_api_server/shared/detect_message.py +410 -0
- agent_api_server/shared/get_model_info.py +491 -0
- agent_api_server/shared/message.py +419 -0
- agent_api_server/shared/util_func.py +372 -0
- agent_api_server/sso_service/__init__.py +1 -0
- agent_api_server/sso_service/sdk/__init__.py +1 -0
- agent_api_server/sso_service/sdk/client.py +224 -0
- agent_api_server/sso_service/sdk/credential.py +11 -0
- agent_api_server/sso_service/sdk/encoding.py +22 -0
- agent_api_server/sso_service/sso_service.py +177 -0
- agent_api_server-2.1.7.dist-info/METADATA +130 -0
- agent_api_server-2.1.7.dist-info/RECORD +52 -0
- agent_api_server-2.1.7.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,491 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import json
|
|
3
|
+
from typing import Dict, Any, Optional, Union
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from llm_sdk.model_providers.base import ConfigType
|
|
6
|
+
from langchain_core.runnables.config import RunnableConfig
|
|
7
|
+
from langchain_deepseek.chat_models import DEFAULT_API_BASE
|
|
8
|
+
from agent_api_server.configs import global_config
|
|
9
|
+
from agent_api_server.dynamic_llm.dynamic_llm import DynamicLLM
|
|
10
|
+
from agent_api_server.shared.ase import AESCipher
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
# Mapping of providers to their credential keys
|
|
15
|
+
PROVIDER_CREDENTIAL_MAP = {
|
|
16
|
+
'azure_openai': 'openai_api_key',
|
|
17
|
+
'xinference': 'api_key',
|
|
18
|
+
'tongyi': 'dashscope_api_key',
|
|
19
|
+
'openai_api_compatible': 'api_key',
|
|
20
|
+
'deepseek': 'api_key',
|
|
21
|
+
# ollama does not require API key
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
# Default values for special providers
|
|
25
|
+
SPECIAL_PROVIDER_DEFAULTS = {
|
|
26
|
+
'xinference': 'sk-xxxxxx',
|
|
27
|
+
'ollama': '',
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
# Provider-specific base URLs
|
|
31
|
+
PROVIDER_BASE_URLS = {
|
|
32
|
+
'deepseek': DEFAULT_API_BASE,
|
|
33
|
+
'tongyi': 'https://dashscope.aliyuncs.com/compatible-mode/v1',
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class LLMConfig:
|
|
39
|
+
"""Configuration for LLM provider settings."""
|
|
40
|
+
provider: Optional[str] = None
|
|
41
|
+
model: Optional[str] = None
|
|
42
|
+
credentials: Optional[Union[str, Dict[str, Any]]] = None
|
|
43
|
+
agent_id: Optional[str] = None
|
|
44
|
+
ts_tenant: Optional[str] = None
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass
|
|
48
|
+
class CredentialsInfo:
|
|
49
|
+
"""Structured credential information."""
|
|
50
|
+
api_key: str = ''
|
|
51
|
+
base_url: str = ''
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class ConfigKeyBuilder:
|
|
55
|
+
"""Builds configuration keys based on tool and tenant."""
|
|
56
|
+
|
|
57
|
+
@staticmethod
|
|
58
|
+
def get_key(tool_name: str, field: str, ts_tenant: Optional[str] = None) -> str:
|
|
59
|
+
"""
|
|
60
|
+
Generate configuration key in format: TOOL_CHAT_FIELD[_TENANT]
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
tool_name: Name of the tool
|
|
64
|
+
field: Configuration field (e.g., PROVIDER, MODEL)
|
|
65
|
+
ts_tenant: Optional tenant identifier
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
Formatted configuration key
|
|
69
|
+
"""
|
|
70
|
+
parts = []
|
|
71
|
+
|
|
72
|
+
if tool_name.lower() != "default":
|
|
73
|
+
parts.append(tool_name.upper())
|
|
74
|
+
|
|
75
|
+
parts.extend([ConfigType.CHAT.value, field])
|
|
76
|
+
|
|
77
|
+
if ts_tenant:
|
|
78
|
+
parts.append(ts_tenant)
|
|
79
|
+
|
|
80
|
+
return "_".join(parts)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class CredentialsParser:
|
|
84
|
+
"""Parses credential data from various formats."""
|
|
85
|
+
|
|
86
|
+
@staticmethod
|
|
87
|
+
def parse_credentials(credentials: Union[str, Dict[str, Any]],
|
|
88
|
+
tool_name: str) -> Optional[Dict[str, Any]]:
|
|
89
|
+
"""
|
|
90
|
+
Parse credentials from string or dict format.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
credentials: Credentials data (string JSON or dict)
|
|
94
|
+
tool_name: Name of the tool for logging
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
Parsed credentials as dict, or None on error
|
|
98
|
+
"""
|
|
99
|
+
if isinstance(credentials, dict):
|
|
100
|
+
return credentials
|
|
101
|
+
|
|
102
|
+
if not isinstance(credentials, str):
|
|
103
|
+
logger.error(f"Unexpected credentials type for tool {tool_name}: {type(credentials)}")
|
|
104
|
+
return None
|
|
105
|
+
|
|
106
|
+
credentials_str = credentials.strip()
|
|
107
|
+
if not credentials_str:
|
|
108
|
+
return {}
|
|
109
|
+
|
|
110
|
+
if credentials_str.startswith(('{', '[')):
|
|
111
|
+
try:
|
|
112
|
+
return json.loads(credentials_str)
|
|
113
|
+
except json.JSONDecodeError as e:
|
|
114
|
+
logger.error(f"Failed to parse JSON credentials for tool {tool_name}: {e}")
|
|
115
|
+
logger.debug(f"Raw credentials (first 100 chars): {credentials_str[:100]}")
|
|
116
|
+
return None
|
|
117
|
+
|
|
118
|
+
return {"raw_key": credentials_str}
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class SecurityManager:
|
|
122
|
+
"""Manages cryptographic operations for API keys."""
|
|
123
|
+
|
|
124
|
+
def __init__(self):
|
|
125
|
+
self._aes_cache: Dict[str, AESCipher] = {}
|
|
126
|
+
|
|
127
|
+
def _get_aes_client(self, app_id: str) -> AESCipher:
|
|
128
|
+
"""Get or create AES cipher client for the given app_id."""
|
|
129
|
+
if app_id not in self._aes_cache:
|
|
130
|
+
self._aes_cache[app_id] = AESCipher(app_id)
|
|
131
|
+
return self._aes_cache[app_id]
|
|
132
|
+
|
|
133
|
+
def decrypt_api_key(self, encrypted_key: str, app_id: str) -> Optional[str]:
|
|
134
|
+
"""
|
|
135
|
+
Decrypt API key using AES cipher.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
encrypted_key: Encrypted API key
|
|
139
|
+
app_id: Application ID for decryption
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
Decrypted key, or None if decryption fails
|
|
143
|
+
"""
|
|
144
|
+
if not encrypted_key or not app_id:
|
|
145
|
+
return None
|
|
146
|
+
|
|
147
|
+
try:
|
|
148
|
+
aes_client = self._get_aes_client(app_id)
|
|
149
|
+
decrypted_key = aes_client.decrypt(encrypted_key)
|
|
150
|
+
logger.debug(f"Successfully decrypted API key (length: {len(decrypted_key)})")
|
|
151
|
+
return decrypted_key
|
|
152
|
+
except Exception as e:
|
|
153
|
+
logger.error(f"Failed to decrypt API key: {str(e)}")
|
|
154
|
+
return None
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class LLMConfigManager:
|
|
158
|
+
"""Main manager for LLM configuration and credential retrieval."""
|
|
159
|
+
|
|
160
|
+
def __init__(self):
|
|
161
|
+
self.key_builder = ConfigKeyBuilder()
|
|
162
|
+
self.credentials_parser = CredentialsParser()
|
|
163
|
+
self.security_manager = SecurityManager()
|
|
164
|
+
|
|
165
|
+
def get_llm_config(self, tool_name: str, context: Dict[str, Any]) -> LLMConfig:
|
|
166
|
+
"""
|
|
167
|
+
Extract LLM configuration from context.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
tool_name: Name of the tool
|
|
171
|
+
context: Configuration context dictionary
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
LLMConfig object with extracted values
|
|
175
|
+
"""
|
|
176
|
+
configurable = context.get("configurable", {})
|
|
177
|
+
ts_tenant = configurable.get("ts_tenant") or configurable.get("TSTenant")
|
|
178
|
+
graph_name = configurable.get("graph_name")
|
|
179
|
+
|
|
180
|
+
return LLMConfig(
|
|
181
|
+
provider=configurable.get(self.key_builder.get_key(tool_name, "PROVIDER", ts_tenant)),
|
|
182
|
+
model=configurable.get(self.key_builder.get_key(tool_name, "MODEL", ts_tenant)),
|
|
183
|
+
credentials=configurable.get(self.key_builder.get_key(tool_name, "CREDENTIALS", ts_tenant)),
|
|
184
|
+
agent_id=graph_name,
|
|
185
|
+
ts_tenant=ts_tenant
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
def get_api_version(self, tool_name: str, config: Optional[RunnableConfig]) -> str:
|
|
189
|
+
if not config:
|
|
190
|
+
logger.debug(f"No config provided for tool: {tool_name}")
|
|
191
|
+
return ''
|
|
192
|
+
|
|
193
|
+
llm_config = self.get_llm_config(tool_name, context=config or {})
|
|
194
|
+
credentials_dict = self.credentials_parser.parse_credentials(
|
|
195
|
+
llm_config.credentials, tool_name
|
|
196
|
+
)
|
|
197
|
+
if not credentials_dict:
|
|
198
|
+
return ''
|
|
199
|
+
|
|
200
|
+
return credentials_dict.get('openai_api_version')
|
|
201
|
+
|
|
202
|
+
def get_base_model_name(self, tool_name: str, config: Optional[RunnableConfig]) -> str:
|
|
203
|
+
if not config:
|
|
204
|
+
logger.debug(f"No config provided for tool: {tool_name}")
|
|
205
|
+
return ''
|
|
206
|
+
|
|
207
|
+
llm_config = self.get_llm_config(tool_name, context=config or {})
|
|
208
|
+
credentials_dict = self.credentials_parser.parse_credentials(
|
|
209
|
+
llm_config.credentials, tool_name
|
|
210
|
+
)
|
|
211
|
+
if not credentials_dict:
|
|
212
|
+
return ''
|
|
213
|
+
|
|
214
|
+
return credentials_dict.get('base_model_name')
|
|
215
|
+
|
|
216
|
+
def get_provider(self, tool_name: str, config: Optional[RunnableConfig]) -> str:
|
|
217
|
+
"""Get provider name from configuration."""
|
|
218
|
+
llm_config = self.get_llm_config(tool_name, context=config or {})
|
|
219
|
+
return llm_config.provider or ''
|
|
220
|
+
|
|
221
|
+
def get_model_name(self, tool_name: str, config: Optional[RunnableConfig]) -> str:
|
|
222
|
+
"""Get model name from configuration."""
|
|
223
|
+
llm_config = self.get_llm_config(tool_name, context=config or {})
|
|
224
|
+
return llm_config.model or ''
|
|
225
|
+
|
|
226
|
+
def _process_xinference_base_url(self, credentials_dict: Dict[str, Any], app_id: str) -> str:
|
|
227
|
+
"""
|
|
228
|
+
Process Xinference base URL, handling decryption and formatting.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
credentials_dict: Parsed credentials dictionary
|
|
232
|
+
app_id: Application ID for decryption
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
Formatted base URL, or empty string on error
|
|
236
|
+
"""
|
|
237
|
+
server_url = credentials_dict.get("server_url")
|
|
238
|
+
if not server_url:
|
|
239
|
+
logger.error("server_url is required for xinference provider")
|
|
240
|
+
return ''
|
|
241
|
+
|
|
242
|
+
decrypted_url = self.security_manager.decrypt_api_key(server_url, app_id)
|
|
243
|
+
if not decrypted_url:
|
|
244
|
+
logger.warning("Failed to decrypt xinference server_url")
|
|
245
|
+
return ''
|
|
246
|
+
|
|
247
|
+
base_url = decrypted_url.rstrip("/")
|
|
248
|
+
if not base_url.endswith("/v1"):
|
|
249
|
+
base_url = base_url + "/v1"
|
|
250
|
+
|
|
251
|
+
return base_url
|
|
252
|
+
|
|
253
|
+
@staticmethod
|
|
254
|
+
def _get_base_url_from_credentials(credentials_dict: Dict[str, Any], provider: str) -> str:
|
|
255
|
+
"""
|
|
256
|
+
Extract base URL from credentials for specific providers.
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
credentials_dict: Parsed credentials dictionary
|
|
260
|
+
provider: Provider name
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
Base URL if found, empty string otherwise
|
|
264
|
+
"""
|
|
265
|
+
if provider == 'azure_openai':
|
|
266
|
+
return credentials_dict.get("openai_api_base", "")
|
|
267
|
+
elif provider == 'openai_api_compatible':
|
|
268
|
+
return credentials_dict.get("endpoint_url", "")
|
|
269
|
+
elif provider == 'ollama':
|
|
270
|
+
return credentials_dict.get("base_url", "")
|
|
271
|
+
return ''
|
|
272
|
+
|
|
273
|
+
def get_base_url(self, tool_name: str, config: Optional[RunnableConfig]) -> str:
|
|
274
|
+
"""
|
|
275
|
+
Get base URL for the LLM provider.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
tool_name: Name of the tool
|
|
279
|
+
config: Runnable configuration
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
Base URL string, or empty string if not found
|
|
283
|
+
"""
|
|
284
|
+
if not config:
|
|
285
|
+
logger.debug(f"No config provided for tool: {tool_name}")
|
|
286
|
+
return ''
|
|
287
|
+
|
|
288
|
+
llm_config = self.get_llm_config(tool_name, context=config)
|
|
289
|
+
if not llm_config.provider:
|
|
290
|
+
logger.debug(f"No provider found for tool: {tool_name}")
|
|
291
|
+
return ''
|
|
292
|
+
|
|
293
|
+
# Check provider-specific defaults
|
|
294
|
+
if llm_config.provider in PROVIDER_BASE_URLS:
|
|
295
|
+
return PROVIDER_BASE_URLS[llm_config.provider]
|
|
296
|
+
|
|
297
|
+
if not llm_config.credentials:
|
|
298
|
+
logger.debug(f"No credentials found for tool: {tool_name}")
|
|
299
|
+
return ''
|
|
300
|
+
|
|
301
|
+
credentials_dict = self.credentials_parser.parse_credentials(
|
|
302
|
+
llm_config.credentials, tool_name
|
|
303
|
+
)
|
|
304
|
+
if not credentials_dict:
|
|
305
|
+
return ''
|
|
306
|
+
|
|
307
|
+
# Special handling for xinference
|
|
308
|
+
if llm_config.provider == 'xinference':
|
|
309
|
+
if not llm_config.agent_id:
|
|
310
|
+
logger.warning("No agent_id found in config for xinference")
|
|
311
|
+
return ''
|
|
312
|
+
return self._process_xinference_base_url(credentials_dict, llm_config.agent_id)
|
|
313
|
+
|
|
314
|
+
return self._get_base_url_from_credentials(credentials_dict, llm_config.provider)
|
|
315
|
+
|
|
316
|
+
@staticmethod
|
|
317
|
+
def _get_raw_api_key_from_credentials(credentials: Union[str, Dict[str, Any]],
|
|
318
|
+
credential_key: str) -> Optional[str]:
|
|
319
|
+
"""
|
|
320
|
+
Extract raw API key from credentials.
|
|
321
|
+
|
|
322
|
+
Args:
|
|
323
|
+
credentials: Credentials data
|
|
324
|
+
credential_key: Key to look for in credentials
|
|
325
|
+
|
|
326
|
+
Returns:
|
|
327
|
+
Raw API key, or None if not found
|
|
328
|
+
"""
|
|
329
|
+
if isinstance(credentials, dict):
|
|
330
|
+
return credentials.get(credential_key)
|
|
331
|
+
|
|
332
|
+
if not isinstance(credentials, str):
|
|
333
|
+
return None
|
|
334
|
+
|
|
335
|
+
credentials_str = credentials.strip()
|
|
336
|
+
if not credentials_str:
|
|
337
|
+
return None
|
|
338
|
+
|
|
339
|
+
if credentials_str.startswith(('{', '[')):
|
|
340
|
+
try:
|
|
341
|
+
cred_dict = json.loads(credentials_str)
|
|
342
|
+
if isinstance(cred_dict, dict):
|
|
343
|
+
return cred_dict.get(credential_key)
|
|
344
|
+
logger.warning(f"Parsed credentials is not a dict: {type(cred_dict)}")
|
|
345
|
+
return None
|
|
346
|
+
except json.JSONDecodeError:
|
|
347
|
+
return None
|
|
348
|
+
else:
|
|
349
|
+
return credentials_str
|
|
350
|
+
|
|
351
|
+
def get_api_key(self, tool_name: str, config: Optional[RunnableConfig]) -> str:
|
|
352
|
+
"""
|
|
353
|
+
Get API key for the LLM provider, handling decryption.
|
|
354
|
+
|
|
355
|
+
Args:
|
|
356
|
+
tool_name: Name of the tool
|
|
357
|
+
config: Runnable configuration
|
|
358
|
+
|
|
359
|
+
Returns:
|
|
360
|
+
API key string (decrypted if encrypted, raw otherwise)
|
|
361
|
+
"""
|
|
362
|
+
if not config:
|
|
363
|
+
logger.debug(f"No config provided for tool: {tool_name}")
|
|
364
|
+
return ''
|
|
365
|
+
|
|
366
|
+
llm_config = self.get_llm_config(tool_name, context=config)
|
|
367
|
+
|
|
368
|
+
# Check for required fields
|
|
369
|
+
if not llm_config.provider:
|
|
370
|
+
logger.debug(f"No provider found for tool: {tool_name}")
|
|
371
|
+
return ''
|
|
372
|
+
|
|
373
|
+
provider = llm_config.provider
|
|
374
|
+
|
|
375
|
+
# Return defaults for special providers if no credentials
|
|
376
|
+
if provider in SPECIAL_PROVIDER_DEFAULTS and not llm_config.credentials:
|
|
377
|
+
return SPECIAL_PROVIDER_DEFAULTS[provider]
|
|
378
|
+
|
|
379
|
+
# Get credential key for the provider
|
|
380
|
+
credential_key = PROVIDER_CREDENTIAL_MAP.get(provider)
|
|
381
|
+
if not credential_key:
|
|
382
|
+
logger.debug(f"No credential mapping for provider: {provider}")
|
|
383
|
+
return ''
|
|
384
|
+
|
|
385
|
+
# Extract raw key from credentials
|
|
386
|
+
raw_key = self._get_raw_api_key_from_credentials(
|
|
387
|
+
llm_config.credentials, credential_key
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
# Handle missing keys
|
|
391
|
+
if not raw_key:
|
|
392
|
+
if provider in SPECIAL_PROVIDER_DEFAULTS:
|
|
393
|
+
return SPECIAL_PROVIDER_DEFAULTS.get(provider, '')
|
|
394
|
+
return ''
|
|
395
|
+
|
|
396
|
+
# Return default keys directly
|
|
397
|
+
if (provider in SPECIAL_PROVIDER_DEFAULTS and
|
|
398
|
+
raw_key == SPECIAL_PROVIDER_DEFAULTS[provider]):
|
|
399
|
+
return raw_key
|
|
400
|
+
|
|
401
|
+
# Attempt decryption for encrypted keys
|
|
402
|
+
if llm_config.agent_id:
|
|
403
|
+
decrypted_key = self.security_manager.decrypt_api_key(raw_key, llm_config.agent_id)
|
|
404
|
+
if decrypted_key is not None:
|
|
405
|
+
return decrypted_key
|
|
406
|
+
else:
|
|
407
|
+
# Return raw key if decryption fails
|
|
408
|
+
logger.warning(f"Decryption failed for {provider}, using raw key")
|
|
409
|
+
return raw_key or ''
|
|
410
|
+
else:
|
|
411
|
+
# No agent_id for decryption, return raw key
|
|
412
|
+
logger.debug(f"No agent_id for decryption, using raw key for {provider}")
|
|
413
|
+
return raw_key
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
# Global instance
|
|
417
|
+
config_manager = LLMConfigManager()
|
|
418
|
+
|
|
419
|
+
def get_chat_model_information(tool_name: str, config: Optional[RunnableConfig]) -> Dict:
|
|
420
|
+
configurable = config.get("configurable")
|
|
421
|
+
graph_name = configurable.get('graph_name')
|
|
422
|
+
use_sys_llm = configurable.get('use_sys_llm')
|
|
423
|
+
|
|
424
|
+
if use_sys_llm == 'true':
|
|
425
|
+
from llm_sdk.llm import ModelInstanceFactory
|
|
426
|
+
_model = ModelInstanceFactory.get_model_instance(
|
|
427
|
+
model_managment_url=global_config.MODEL_MANAGER_SERVICE_URL,
|
|
428
|
+
config_type=ConfigType.CHAT,
|
|
429
|
+
token=global_config.CLIENT_TOKEN,
|
|
430
|
+
app_id=graph_name,
|
|
431
|
+
app_name=graph_name,
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
return _model.unified_credential()
|
|
435
|
+
else:
|
|
436
|
+
chat_llm = DynamicLLM(tool_name=tool_name, config_type=ConfigType.CHAT)
|
|
437
|
+
llm_config = chat_llm._get_llm_config_from_context(config or {})
|
|
438
|
+
|
|
439
|
+
if chat_llm._is_credentials_empty(config=llm_config):
|
|
440
|
+
_ = chat_llm._get_model_instance(llm_config, config)
|
|
441
|
+
|
|
442
|
+
model_provider = get_chat_llm_model_provider(tool_name=tool_name, config=config)
|
|
443
|
+
api_key = get_chat_llm_model_api_key(tool_name=tool_name, config=config)
|
|
444
|
+
base_url = get_chat_llm_model_base_url(tool_name=tool_name, config=config)
|
|
445
|
+
api_version = get_chat_llm_model_api_version(tool_name=tool_name, config=config)
|
|
446
|
+
|
|
447
|
+
if model_provider == 'azure_openai':
|
|
448
|
+
model_name = get_chat_llm_model_base_model_name(tool_name=tool_name, config=config)
|
|
449
|
+
else:
|
|
450
|
+
model_name = get_chat_llm_model_name(tool_name=tool_name, config=config)
|
|
451
|
+
|
|
452
|
+
return {
|
|
453
|
+
"openai_api_key": api_key,
|
|
454
|
+
"openai_api_type": model_provider,
|
|
455
|
+
"openai_api_deployment": model_name,
|
|
456
|
+
"openai_api_version": api_version,
|
|
457
|
+
"openai_api_base": base_url
|
|
458
|
+
}
|
|
459
|
+
|
|
460
|
+
def get_chat_llm_model_provider(tool_name: str, config: Optional[RunnableConfig]) -> str:
|
|
461
|
+
"""Public interface to get LLM provider."""
|
|
462
|
+
return config_manager.get_provider(tool_name, config)
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
def get_chat_llm_model_name(tool_name: str, config: Optional[RunnableConfig]) -> str:
|
|
466
|
+
"""Public interface to get LLM model name."""
|
|
467
|
+
return config_manager.get_model_name(tool_name, config)
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
def get_chat_llm_model_base_url(tool_name: str, config: Optional[RunnableConfig]) -> str:
|
|
471
|
+
"""Public interface to get LLM base URL."""
|
|
472
|
+
return config_manager.get_base_url(tool_name, config)
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
def get_chat_llm_model_api_key(tool_name: str, config: Optional[RunnableConfig]) -> str:
|
|
476
|
+
"""Public interface to get LLM API key."""
|
|
477
|
+
return config_manager.get_api_key(tool_name, config)
|
|
478
|
+
|
|
479
|
+
def get_chat_llm_model_api_version(tool_name: str, config: Optional[RunnableConfig]) -> str:
|
|
480
|
+
provider = config_manager.get_provider(tool_name, config)
|
|
481
|
+
if provider == 'azure_openai':
|
|
482
|
+
return config_manager.get_api_version(tool_name, config)
|
|
483
|
+
else:
|
|
484
|
+
return ''
|
|
485
|
+
|
|
486
|
+
def get_chat_llm_model_base_model_name(tool_name: str, config: Optional[RunnableConfig]) -> str:
|
|
487
|
+
provider = config_manager.get_provider(tool_name, config)
|
|
488
|
+
if provider == 'azure_openai':
|
|
489
|
+
return config_manager.get_base_model_name(tool_name, config)
|
|
490
|
+
else:
|
|
491
|
+
return ''
|