veadk-python 0.2.16__py3-none-any.whl → 0.2.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. veadk/a2a/remote_ve_agent.py +56 -1
  2. veadk/agent.py +79 -26
  3. veadk/agents/loop_agent.py +22 -9
  4. veadk/agents/parallel_agent.py +21 -9
  5. veadk/agents/sequential_agent.py +18 -9
  6. veadk/auth/veauth/apmplus_veauth.py +32 -39
  7. veadk/auth/veauth/ark_veauth.py +3 -1
  8. veadk/auth/veauth/utils.py +12 -0
  9. veadk/auth/veauth/viking_mem0_veauth.py +91 -0
  10. veadk/cli/cli.py +5 -1
  11. veadk/cli/cli_create.py +62 -1
  12. veadk/cli/cli_deploy.py +36 -1
  13. veadk/cli/cli_eval.py +55 -0
  14. veadk/cli/cli_init.py +44 -3
  15. veadk/cli/cli_kb.py +36 -1
  16. veadk/cli/cli_pipeline.py +66 -1
  17. veadk/cli/cli_prompt.py +16 -1
  18. veadk/cli/cli_uploadevalset.py +15 -1
  19. veadk/cli/cli_web.py +35 -4
  20. veadk/cloud/cloud_agent_engine.py +142 -25
  21. veadk/cloud/cloud_app.py +219 -12
  22. veadk/configs/database_configs.py +4 -0
  23. veadk/configs/model_configs.py +5 -1
  24. veadk/configs/tracing_configs.py +2 -2
  25. veadk/evaluation/adk_evaluator/adk_evaluator.py +77 -17
  26. veadk/evaluation/base_evaluator.py +219 -3
  27. veadk/evaluation/deepeval_evaluator/deepeval_evaluator.py +116 -1
  28. veadk/evaluation/eval_set_file_loader.py +20 -0
  29. veadk/evaluation/eval_set_recorder.py +54 -0
  30. veadk/evaluation/types.py +32 -0
  31. veadk/evaluation/utils/prometheus.py +61 -0
  32. veadk/knowledgebase/backends/base_backend.py +14 -1
  33. veadk/knowledgebase/backends/in_memory_backend.py +10 -1
  34. veadk/knowledgebase/backends/opensearch_backend.py +26 -0
  35. veadk/knowledgebase/backends/redis_backend.py +29 -2
  36. veadk/knowledgebase/backends/vikingdb_knowledge_backend.py +43 -5
  37. veadk/knowledgebase/knowledgebase.py +173 -12
  38. veadk/memory/long_term_memory.py +148 -4
  39. veadk/memory/long_term_memory_backends/mem0_backend.py +11 -0
  40. veadk/memory/short_term_memory.py +119 -5
  41. veadk/runner.py +412 -1
  42. veadk/tools/builtin_tools/llm_shield.py +381 -0
  43. veadk/tools/builtin_tools/mcp_router.py +9 -2
  44. veadk/tools/builtin_tools/run_code.py +25 -5
  45. veadk/tools/builtin_tools/web_search.py +38 -154
  46. veadk/tracing/base_tracer.py +28 -1
  47. veadk/tracing/telemetry/attributes/extractors/common_attributes_extractors.py +105 -1
  48. veadk/tracing/telemetry/attributes/extractors/llm_attributes_extractors.py +260 -0
  49. veadk/tracing/telemetry/attributes/extractors/tool_attributes_extractors.py +69 -0
  50. veadk/tracing/telemetry/attributes/extractors/types.py +78 -0
  51. veadk/tracing/telemetry/exporters/apmplus_exporter.py +157 -0
  52. veadk/tracing/telemetry/exporters/base_exporter.py +8 -0
  53. veadk/tracing/telemetry/exporters/cozeloop_exporter.py +60 -1
  54. veadk/tracing/telemetry/exporters/inmemory_exporter.py +118 -1
  55. veadk/tracing/telemetry/exporters/tls_exporter.py +66 -0
  56. veadk/tracing/telemetry/opentelemetry_tracer.py +111 -1
  57. veadk/tracing/telemetry/telemetry.py +118 -2
  58. veadk/version.py +1 -1
  59. {veadk_python-0.2.16.dist-info → veadk_python-0.2.17.dist-info}/METADATA +1 -1
  60. {veadk_python-0.2.16.dist-info → veadk_python-0.2.17.dist-info}/RECORD +64 -62
  61. {veadk_python-0.2.16.dist-info → veadk_python-0.2.17.dist-info}/WHEEL +0 -0
  62. {veadk_python-0.2.16.dist-info → veadk_python-0.2.17.dist-info}/entry_points.txt +0 -0
  63. {veadk_python-0.2.16.dist-info → veadk_python-0.2.17.dist-info}/licenses/LICENSE +0 -0
  64. {veadk_python-0.2.16.dist-info → veadk_python-0.2.17.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,381 @@
1
+ # Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import json
15
+ import os
16
+ import requests
17
+ from typing import Optional, List, Dict, Any, Union
18
+ from volcenginesdkllmshield.models.llm_shield_sign import request_sign
19
+
20
+ from google.adk.plugins import BasePlugin
21
+ from google.adk.agents.callback_context import CallbackContext
22
+ from google.adk.tools.tool_context import ToolContext
23
+ from google.adk.models import LlmRequest, LlmResponse
24
+ from google.genai import types
25
+ from google.adk.tools.base_tool import BaseTool
26
+
27
+ from veadk.config import getenv
28
+ from veadk.utils.logger import get_logger
29
+ from veadk.auth.veauth.utils import get_credential_from_vefaas_iam
30
+
31
+ logger = get_logger(__name__)
32
+
33
+
34
+ class LLMShieldPlugin(BasePlugin):
35
+ """
36
+ LLM Shield Plugin for content moderation and safety.
37
+
38
+ This plugin integrates with Volcano Engine's LLM Shield service to provide
39
+ real-time content moderation for user inputs, model outputs, and tool interactions.
40
+ It helps detect and block potentially harmful content including sensitive information,
41
+ prompt injection attacks, and policy violations.
42
+
43
+ Examples:
44
+ Basic usage with default settings:
45
+ ```python
46
+ from veadk.tools.builtin_tools.llm_shield import content_safety
47
+ agent = Agent(
48
+ before_model_callback=content_safety.before_model_callback,
49
+ after_model_callback=content_safety.after_model_callback,
50
+ before_tool_callback=content_safety.before_tool_callback,
51
+ after_tool_callback=content_safety.after_tool_callback,
52
+ )
53
+ ```
54
+ """
55
+
56
+ def __init__(self, region: str = "cn-beijing", timeout: int = 50) -> None:
57
+ """
58
+ Initialize the LLM Shield Plugin.
59
+
60
+ Args:
61
+ region (str, optional): The service region. Defaults to "cn-beijing".
62
+ timeout (int, optional): Request timeout in seconds. Defaults to 50.
63
+ """
64
+ self.name = "LLMShieldPlugin"
65
+ super().__init__(name=self.name)
66
+
67
+ self.appid = getenv("TOOL_LLM_SHIELD_APP_ID")
68
+ self.region = region
69
+ self.timeout = timeout
70
+
71
+ self.category_map = {
72
+ 101: "Model Misuse",
73
+ 103: "Sensitive Information",
74
+ 104: "Prompt Injection",
75
+ 106: "General Topic Control",
76
+ 107: "Computational Resource Consumption",
77
+ }
78
+
79
+ def _request_llm_shield(self, message: str, role: str) -> Optional[str]:
80
+ """
81
+ Make a request to the LLM Shield service for content moderation.
82
+
83
+ This method sends a message to the LLM Shield API for security analysis.
84
+ If the content is deemed risky, it returns a blocking message explaining
85
+ the violation. Otherwise, it returns None to allow the content through.
86
+
87
+ Args:
88
+ message (str): The content to be moderated
89
+ role (str): The role of the message sender ("user" or "assistant")
90
+
91
+ Returns:
92
+ Optional[str]: A blocking message if content violates policies,
93
+ None if content is safe or on error
94
+ """
95
+ if not self.appid:
96
+ logger.error("LLM Shield app ID not configured")
97
+ return None
98
+
99
+ ak = os.getenv("VOLCENGINE_ACCESS_KEY")
100
+ sk = os.getenv("VOLCENGINE_SECRET_KEY")
101
+ session_token = ""
102
+ if not (ak and sk):
103
+ logger.debug("Get AK/SK from environment variables failed.")
104
+ credential = get_credential_from_vefaas_iam()
105
+ ak = credential.access_key_id
106
+ sk = credential.secret_access_key
107
+ session_token = credential.session_token
108
+ else:
109
+ logger.debug("Successfully get AK/SK from environment variables.")
110
+
111
+ body = {
112
+ "Message": {
113
+ "Role": role,
114
+ "Content": message,
115
+ "ContentType": 1,
116
+ },
117
+ "Scene": self.appid,
118
+ }
119
+
120
+ body_json = json.dumps(body).encode("utf-8")
121
+
122
+ header = {"X-Security-Token": session_token}
123
+ url = f"https://{self.region}.sdk.access.llm-shield.omini-shield.com"
124
+ path = "/v2/moderate"
125
+ action = "Moderate"
126
+ version = "2025-08-31"
127
+
128
+ signed_header = request_sign(
129
+ header, ak, sk, self.region, url, path, action, body_json
130
+ )
131
+
132
+ signed_header.update(
133
+ {
134
+ "Content-Type": "application/json",
135
+ "X-Top-Service": "llmshield",
136
+ "X-Top-Region": self.region,
137
+ }
138
+ )
139
+
140
+ try:
141
+ response = requests.post(
142
+ url + path,
143
+ headers=signed_header,
144
+ data=body_json,
145
+ params={"Action": action, "Version": version},
146
+ timeout=self.timeout,
147
+ )
148
+
149
+ if response.status_code != 200:
150
+ logger.error(
151
+ f"LLM Shield HTTP error: {response.status_code} - {response.text}"
152
+ )
153
+ return None
154
+
155
+ response = response.json()
156
+ except requests.exceptions.Timeout:
157
+ logger.error("LLM Shield request timeout")
158
+ return None
159
+ except requests.exceptions.RequestException as e:
160
+ logger.error(f"LLM Shield network request failed: {e}")
161
+ return None
162
+ except json.JSONDecodeError as e:
163
+ logger.error(f"LLM Shield response JSON decode failed: {e}")
164
+ return None
165
+ except Exception as e:
166
+ logger.error(f"LLM Shield request failed: {e}")
167
+ return None
168
+
169
+ # Process risk detection results
170
+ result = response.get("Result", None)
171
+ if result:
172
+ decision = result.get("Decision", None)
173
+ decision_type = decision.get("DecisionType", None)
174
+ risk_info = result.get("RiskInfo", None)
175
+ if decision_type is not None and int(decision_type) == 2 and risk_info:
176
+ risks = risk_info.get("Risks", [])
177
+ if risks:
178
+ # Extract risk categories for user-friendly error message
179
+ risk_reasons = set()
180
+ for risk in risks:
181
+ category = risk.get("Category", None)
182
+ if category:
183
+ category_name = self.category_map.get(
184
+ int(category), f"Category {category}"
185
+ )
186
+ risk_reasons.add(category_name)
187
+
188
+ # Generate blocking response
189
+ reason_text = (
190
+ ", ".join(risk_reasons)
191
+ if risk_reasons
192
+ else "security policy violation"
193
+ )
194
+ response_text = (
195
+ f"Your request has been blocked due to: {reason_text}. "
196
+ f"Please modify your input and try again."
197
+ )
198
+
199
+ return response_text
200
+
201
+ return None
202
+
203
+ def before_agent_callback(
204
+ self, callback_context: CallbackContext, **kwargs
205
+ ) -> None:
206
+ # TODO: Implement agent-level input validation and context analysis
207
+ return None
208
+
209
+ def after_agent_callback(self, callback_context: CallbackContext, **kwargs) -> None:
210
+ # TODO: Implement post-agent analysis and context analysis
211
+ return None
212
+
213
+ def before_model_callback(
214
+ self, callback_context: CallbackContext, llm_request: LlmRequest, **kwargs
215
+ ) -> Optional[LlmResponse]:
216
+ """
217
+ Moderate user input before sending to the language model.
218
+
219
+ Extracts the last user message from the LLM request and checks it
220
+ against the LLM Shield service. If the content violates safety policies,
221
+ returns a blocking response instead of allowing the request to proceed.
222
+
223
+ Args:
224
+ callback_context (CallbackContext): The callback execution context
225
+ llm_request (LlmRequest): The incoming LLM request to moderate
226
+ **kwargs: Additional keyword arguments
227
+
228
+ Returns:
229
+ Optional[LlmResponse]: A blocking response if content is unsafe,
230
+ None if content is safe to proceed
231
+ """
232
+ # Extract the last user message for moderation
233
+ last_user_message = None
234
+ contents = getattr(llm_request, "contents", [])
235
+
236
+ if contents:
237
+ last_content = contents[-1]
238
+ last_role = getattr(last_content, "role", "")
239
+ last_parts = getattr(last_content, "parts", [])
240
+
241
+ if last_role == "user" and last_parts:
242
+ last_user_message = getattr(last_parts[0], "text", "")
243
+
244
+ # Skip moderation if message is empty
245
+ if not last_user_message:
246
+ return None
247
+
248
+ response = self._request_llm_shield(message=last_user_message, role="user")
249
+ if response:
250
+ logger.debug("LLM Shield triggered in before_model_callback.")
251
+ return LlmResponse(
252
+ content=types.Content(
253
+ role="model",
254
+ parts=[types.Part(text=response)],
255
+ )
256
+ )
257
+ return None
258
+
259
+ def after_model_callback(
260
+ self, callback_context: CallbackContext, llm_response: LlmResponse, **kwargs
261
+ ) -> Optional[LlmResponse]:
262
+ """
263
+ Moderate model output before returning to the user.
264
+
265
+ Extracts the model's response and checks it against the LLM Shield service.
266
+ If the model's output violates safety policies, returns a blocking response
267
+ instead of the original model output.
268
+
269
+ Args:
270
+ callback_context (CallbackContext): The callback execution context
271
+ llm_response (LlmResponse): The model's response to moderate
272
+ **kwargs: Additional keyword arguments
273
+
274
+ Returns:
275
+ Optional[LlmResponse]: A blocking response if content is unsafe,
276
+ None if content is safe to return
277
+ """
278
+ # Extract the model's response for moderation
279
+ last_model_message = None
280
+ content = getattr(llm_response, "content", [])
281
+
282
+ if content:
283
+ last_role = getattr(content, "role", "")
284
+ last_parts = getattr(content, "parts", [])
285
+
286
+ if last_role == "model" and last_parts:
287
+ last_model_message = getattr(last_parts[0], "text", "")
288
+
289
+ # Skip moderation if message is empty
290
+ if not last_model_message:
291
+ return None
292
+
293
+ response = self._request_llm_shield(
294
+ message=last_model_message, role="assistant"
295
+ )
296
+ if response:
297
+ logger.debug("LLM Shield triggered in after_model_callback.")
298
+ return LlmResponse(
299
+ content=types.Content(
300
+ role="model",
301
+ parts=[types.Part(text=response)],
302
+ )
303
+ )
304
+ return None
305
+
306
+ def before_tool_callback(
307
+ self, tool: BaseTool, args: Dict[str, Any], tool_context: ToolContext, **kwargs
308
+ ) -> Optional[Dict]:
309
+ """
310
+ Moderate tool arguments before tool execution.
311
+
312
+ Combines all tool arguments into a message and checks it against
313
+ the LLM Shield service. If the arguments contain unsafe content,
314
+ returns a blocking result instead of allowing tool execution.
315
+
316
+ Args:
317
+ tool (BaseTool): The tool to be executed
318
+ args (Dict[str, Any]): The arguments passed to the tool
319
+ tool_context (ToolContext): The tool execution context
320
+ **kwargs: Additional keyword arguments
321
+
322
+ Returns:
323
+ Optional[Dict]: A blocking result if arguments are unsafe,
324
+ None if arguments are safe to proceed
325
+ """
326
+ args_list = []
327
+
328
+ for key, value in args.items():
329
+ args_list.append(f"{key}: {value}")
330
+
331
+ message = "\n".join(args_list)
332
+ response = self._request_llm_shield(message=message, role="user")
333
+ if response:
334
+ logger.debug("LLM Shield triggered in before_tool_callback.")
335
+ return {"result": response}
336
+ return None
337
+
338
+ def after_tool_callback(
339
+ self,
340
+ tool: BaseTool,
341
+ args: Dict[str, Any],
342
+ tool_context: CallbackContext,
343
+ tool_response: Union[str, Dict[str, Any], List[Any]],
344
+ **kwargs,
345
+ ) -> Optional[Dict]:
346
+ """
347
+ Moderate tool output after tool execution.
348
+
349
+ Processes the tool's response (string, dict, or list) into a message
350
+ and checks it against the LLM Shield service. If the tool's output
351
+ violates safety policies, returns a blocking result.
352
+
353
+ Args:
354
+ tool (BaseTool): The tool that was executed
355
+ args (Dict[str, Any]): The arguments that were passed to the tool
356
+ tool_context (CallbackContext): The tool execution context
357
+ tool_response (Union[str, Dict[str, Any], List[Any]]): The tool's response
358
+ **kwargs: Additional keyword arguments
359
+
360
+ Returns:
361
+ Optional[Dict]: A blocking result if tool output is unsafe,
362
+ None if output is safe to return
363
+ """
364
+ message = ""
365
+ if isinstance(tool_response, str):
366
+ message = tool_response
367
+ elif isinstance(tool_response, dict):
368
+ for key, value in tool_response.items():
369
+ message += f"{value}\n"
370
+ elif isinstance(tool_response, list):
371
+ for item in tool_response:
372
+ message += f"{item}\n"
373
+
374
+ response = self._request_llm_shield(message=message, role="assistant")
375
+ if response:
376
+ logger.debug("LLM Shield triggered in after_tool_callback.")
377
+ return {"result": response}
378
+ return None
379
+
380
+
381
+ content_safety = LLMShieldPlugin()
@@ -15,8 +15,15 @@
15
15
  from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset
16
16
 
17
17
  from veadk.config import getenv
18
- from veadk.utils.mcp_utils import get_mcp_params
18
+ from google.adk.tools.mcp_tool.mcp_session_manager import (
19
+ StreamableHTTPConnectionParams,
20
+ )
19
21
 
20
22
  url = getenv("TOOL_MCP_ROUTER_URL")
23
+ api_key = getenv("TOOL_MCP_ROUTER_API_KEY")
21
24
 
22
- mcp_router = MCPToolset(connection_params=get_mcp_params(url=url))
25
+ mcp_router = MCPToolset(
26
+ connection_params=StreamableHTTPConnectionParams(
27
+ url=url, headers={"Authorization": f"Bearer {api_key}"}
28
+ ),
29
+ )
@@ -13,12 +13,14 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import json
16
+ import os
16
17
 
17
18
  from google.adk.tools import ToolContext
18
19
 
19
20
  from veadk.config import getenv
20
21
  from veadk.utils.logger import get_logger
21
22
  from veadk.utils.volcengine_sign import ve_request
23
+ from veadk.auth.veauth.utils import get_credential_from_vefaas_iam
22
24
 
23
25
  logger = get_logger(__name__)
24
26
 
@@ -51,8 +53,26 @@ def run_code(code: str, language: str, tool_context: ToolContext) -> str:
51
53
  f"Running code in language: {language}, session_id={session_id}, code={code}, tool_id={tool_id}, host={host}, service={service}, region={region}"
52
54
  )
53
55
 
54
- access_key = getenv("VOLCENGINE_ACCESS_KEY")
55
- secret_key = getenv("VOLCENGINE_SECRET_KEY")
56
+ ak = tool_context.state.get("VOLCENGINE_ACCESS_KEY")
57
+ sk = tool_context.state.get("VOLCENGINE_SECRET_KEY")
58
+ header = {}
59
+
60
+ if not (ak and sk):
61
+ logger.debug("Get AK/SK from tool context failed.")
62
+ ak = os.getenv("VOLCENGINE_ACCESS_KEY")
63
+ sk = os.getenv("VOLCENGINE_SECRET_KEY")
64
+ if not (ak and sk):
65
+ logger.debug(
66
+ "Get AK/SK from environment variables failed. Try to use credential from Iam."
67
+ )
68
+ credential = get_credential_from_vefaas_iam()
69
+ ak = credential.access_key_id
70
+ sk = credential.secret_access_key
71
+ header = {"X-Security-Token": credential.session_token}
72
+ else:
73
+ logger.debug("Successfully get AK/SK from environment variables.")
74
+ else:
75
+ logger.debug("Successfully get AK/SK from tool context.")
56
76
 
57
77
  res = ve_request(
58
78
  request_body={
@@ -68,14 +88,14 @@ def run_code(code: str, language: str, tool_context: ToolContext) -> str:
68
88
  ),
69
89
  },
70
90
  action="InvokeTool",
71
- ak=access_key,
72
- sk=secret_key,
91
+ ak=ak,
92
+ sk=sk,
73
93
  service=service,
74
94
  version="2025-10-30",
75
95
  region=region,
76
96
  host=host,
97
+ header=header,
77
98
  )
78
-
79
99
  logger.debug(f"Invoke run code response: {res}")
80
100
 
81
101
  try:
@@ -16,142 +16,17 @@
16
16
  The document of this tool see: https://www.volcengine.com/docs/85508/1650263
17
17
  """
18
18
 
19
- import datetime
20
- import hashlib
21
- import hmac
22
- import json
23
- from urllib.parse import quote
19
+ import os
24
20
 
25
- import requests
26
21
  from google.adk.tools import ToolContext
27
22
 
28
- from veadk.config import getenv
23
+ from veadk.auth.veauth.utils import get_credential_from_vefaas_iam
29
24
  from veadk.utils.logger import get_logger
25
+ from veadk.utils.volcengine_sign import ve_request
30
26
 
31
27
  logger = get_logger(__name__)
32
28
 
33
29
 
34
- Service = "volc_torchlight_api"
35
- Version = "2025-01-01"
36
- Region = "cn-beijing"
37
- Host = "mercury.volcengineapi.com"
38
- ContentType = "application/json"
39
-
40
-
41
- def norm_query(params):
42
- query = ""
43
- for key in sorted(params.keys()):
44
- if isinstance(params[key], list):
45
- for k in params[key]:
46
- query = (
47
- query + quote(key, safe="-_.~") + "=" + quote(k, safe="-_.~") + "&"
48
- )
49
- else:
50
- query = (
51
- query
52
- + quote(key, safe="-_.~")
53
- + "="
54
- + quote(params[key], safe="-_.~")
55
- + "&"
56
- )
57
- query = query[:-1]
58
- return query.replace("+", "%20")
59
-
60
-
61
- def hmac_sha256(key: bytes, content: str):
62
- return hmac.new(key, content.encode("utf-8"), hashlib.sha256).digest()
63
-
64
-
65
- def hash_sha256(content: str):
66
- return hashlib.sha256(content.encode("utf-8")).hexdigest()
67
-
68
-
69
- def request(method, date, query, header, ak, sk, action, body):
70
- credential = {
71
- "access_key_id": ak,
72
- "secret_access_key": sk,
73
- "service": Service,
74
- "region": Region,
75
- }
76
- request_param = {
77
- "body": body,
78
- "host": Host,
79
- "path": "/",
80
- "method": method,
81
- "content_type": ContentType,
82
- "date": date,
83
- "query": {"Action": action, "Version": Version, **query},
84
- }
85
- if body is None:
86
- request_param["body"] = ""
87
- # 第四步:接下来开始计算签名。在计算签名前,先准备好用于接收签算结果的 signResult 变量,并设置一些参数。
88
- # 初始化签名结果的结构体
89
- x_date = request_param["date"].strftime("%Y%m%dT%H%M%SZ")
90
- short_x_date = x_date[:8]
91
- x_content_sha256 = hash_sha256(request_param["body"])
92
- sign_result = {
93
- "Host": request_param["host"],
94
- "X-Content-Sha256": x_content_sha256,
95
- "X-Date": x_date,
96
- "Content-Type": request_param["content_type"],
97
- }
98
- signed_headers_str = ";".join(
99
- ["content-type", "host", "x-content-sha256", "x-date"]
100
- )
101
- # signed_headers_str = signed_headers_str + ";x-security-token"
102
- canonical_request_str = "\n".join(
103
- [
104
- request_param["method"].upper(),
105
- request_param["path"],
106
- norm_query(request_param["query"]),
107
- "\n".join(
108
- [
109
- "content-type:" + request_param["content_type"],
110
- "host:" + request_param["host"],
111
- "x-content-sha256:" + x_content_sha256,
112
- "x-date:" + x_date,
113
- ]
114
- ),
115
- "",
116
- signed_headers_str,
117
- x_content_sha256,
118
- ]
119
- )
120
-
121
- hashed_canonical_request = hash_sha256(canonical_request_str)
122
-
123
- credential_scope = "/".join(
124
- [short_x_date, credential["region"], credential["service"], "request"]
125
- )
126
- string_to_sign = "\n".join(
127
- ["HMAC-SHA256", x_date, credential_scope, hashed_canonical_request]
128
- )
129
-
130
- k_date = hmac_sha256(credential["secret_access_key"].encode("utf-8"), short_x_date)
131
- k_region = hmac_sha256(k_date, credential["region"])
132
- k_service = hmac_sha256(k_region, credential["service"])
133
- k_signing = hmac_sha256(k_service, "request")
134
- signature = hmac_sha256(k_signing, string_to_sign).hex()
135
-
136
- sign_result["Authorization"] = (
137
- "HMAC-SHA256 Credential={}, SignedHeaders={}, Signature={}".format(
138
- credential["access_key_id"] + "/" + credential_scope,
139
- signed_headers_str,
140
- signature,
141
- )
142
- )
143
- header = {**header, **sign_result}
144
- # header = {**header, **{"X-Security-Token": SessionToken}}
145
- r = requests.request(
146
- method=method,
147
- url="https://{}{}".format(request_param["host"], request_param["path"]),
148
- headers=header,
149
- params=request_param["query"],
150
- data=request_param["body"],
151
- )
152
- return r.json()
153
-
154
-
155
30
  def web_search(query: str, tool_context: ToolContext) -> list[str]:
156
31
  """Search a query in websites.
157
32
 
@@ -161,39 +36,48 @@ def web_search(query: str, tool_context: ToolContext) -> list[str]:
161
36
  Returns:
162
37
  A list of result documents.
163
38
  """
164
- req = {
165
- "Query": query,
166
- "SearchType": "web",
167
- "Count": 5,
168
- "NeedSummary": True,
169
- }
170
-
171
39
  ak = tool_context.state.get("VOLCENGINE_ACCESS_KEY")
172
- if not ak:
173
- ak = getenv("VOLCENGINE_ACCESS_KEY")
174
-
175
40
  sk = tool_context.state.get("VOLCENGINE_SECRET_KEY")
176
- if not sk:
177
- sk = getenv("VOLCENGINE_SECRET_KEY")
178
-
179
- now = datetime.datetime.utcnow()
180
- response_body = request(
181
- "POST",
182
- now,
183
- {},
184
- {},
185
- ak,
186
- sk,
187
- "WebSearch",
188
- json.dumps(req),
41
+ session_token = ""
42
+
43
+ if not (ak and sk):
44
+ logger.debug("Get AK/SK from tool context failed.")
45
+ ak = os.getenv("VOLCENGINE_ACCESS_KEY")
46
+ sk = os.getenv("VOLCENGINE_SECRET_KEY")
47
+ if not (ak and sk):
48
+ logger.debug("Get AK/SK from environment variables failed.")
49
+ credential = get_credential_from_vefaas_iam()
50
+ ak = credential.access_key_id
51
+ sk = credential.secret_access_key
52
+ session_token = credential.session_token
53
+ else:
54
+ logger.debug("Successfully get AK/SK from environment variables.")
55
+ else:
56
+ logger.debug("Successfully get AK/SK from tool context.")
57
+
58
+ response = ve_request(
59
+ request_body={
60
+ "Query": query,
61
+ "SearchType": "web",
62
+ "Count": 5,
63
+ "NeedSummary": True,
64
+ },
65
+ action="WebSearch",
66
+ ak=ak,
67
+ sk=sk,
68
+ service="volc_torchlight_api",
69
+ version="2025-01-01",
70
+ region="cn-beijing",
71
+ host="mercury.volcengineapi.com",
72
+ header={"X-Security-Token": session_token},
189
73
  )
190
74
 
191
75
  try:
192
- results: list = response_body["Result"]["WebResults"]
76
+ results: list = response["Result"]["WebResults"]
193
77
  final_results = []
194
78
  for result in results:
195
79
  final_results.append(result["Summary"].strip())
196
80
  return final_results
197
81
  except Exception as e:
198
- logger.error(f"Web search failed {e}, response body: {response_body}")
199
- return [response_body]
82
+ logger.error(f"Web search failed {e}, response body: {response}")
83
+ return [response]