veadk-python 0.2.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- veadk/__init__.py +37 -0
- veadk/a2a/__init__.py +13 -0
- veadk/a2a/agent_card.py +45 -0
- veadk/a2a/remote_ve_agent.py +390 -0
- veadk/a2a/utils/__init__.py +13 -0
- veadk/a2a/utils/agent_to_a2a.py +170 -0
- veadk/a2a/ve_a2a_server.py +93 -0
- veadk/a2a/ve_agent_executor.py +78 -0
- veadk/a2a/ve_middlewares.py +313 -0
- veadk/a2a/ve_task_store.py +37 -0
- veadk/agent.py +402 -0
- veadk/agent_builder.py +93 -0
- veadk/agents/loop_agent.py +68 -0
- veadk/agents/parallel_agent.py +72 -0
- veadk/agents/sequential_agent.py +64 -0
- veadk/auth/__init__.py +13 -0
- veadk/auth/base_auth.py +22 -0
- veadk/auth/ve_credential_service.py +203 -0
- veadk/auth/veauth/__init__.py +13 -0
- veadk/auth/veauth/apmplus_veauth.py +58 -0
- veadk/auth/veauth/ark_veauth.py +75 -0
- veadk/auth/veauth/base_veauth.py +50 -0
- veadk/auth/veauth/cozeloop_veauth.py +13 -0
- veadk/auth/veauth/opensearch_veauth.py +75 -0
- veadk/auth/veauth/postgresql_veauth.py +75 -0
- veadk/auth/veauth/prompt_pilot_veauth.py +60 -0
- veadk/auth/veauth/speech_veauth.py +54 -0
- veadk/auth/veauth/utils.py +69 -0
- veadk/auth/veauth/vesearch_veauth.py +62 -0
- veadk/auth/veauth/viking_mem0_veauth.py +91 -0
- veadk/cli/__init__.py +13 -0
- veadk/cli/cli.py +58 -0
- veadk/cli/cli_clean.py +87 -0
- veadk/cli/cli_create.py +163 -0
- veadk/cli/cli_deploy.py +233 -0
- veadk/cli/cli_eval.py +215 -0
- veadk/cli/cli_init.py +214 -0
- veadk/cli/cli_kb.py +110 -0
- veadk/cli/cli_pipeline.py +285 -0
- veadk/cli/cli_prompt.py +86 -0
- veadk/cli/cli_update.py +106 -0
- veadk/cli/cli_uploadevalset.py +139 -0
- veadk/cli/cli_web.py +143 -0
- veadk/cloud/__init__.py +13 -0
- veadk/cloud/cloud_agent_engine.py +485 -0
- veadk/cloud/cloud_app.py +475 -0
- veadk/config.py +115 -0
- veadk/configs/__init__.py +13 -0
- veadk/configs/auth_configs.py +133 -0
- veadk/configs/database_configs.py +132 -0
- veadk/configs/model_configs.py +78 -0
- veadk/configs/tool_configs.py +54 -0
- veadk/configs/tracing_configs.py +110 -0
- veadk/consts.py +74 -0
- veadk/evaluation/__init__.py +17 -0
- veadk/evaluation/adk_evaluator/__init__.py +17 -0
- veadk/evaluation/adk_evaluator/adk_evaluator.py +302 -0
- veadk/evaluation/base_evaluator.py +642 -0
- veadk/evaluation/deepeval_evaluator/__init__.py +17 -0
- veadk/evaluation/deepeval_evaluator/deepeval_evaluator.py +339 -0
- veadk/evaluation/eval_set_file_loader.py +48 -0
- veadk/evaluation/eval_set_recorder.py +146 -0
- veadk/evaluation/types.py +65 -0
- veadk/evaluation/utils/prometheus.py +196 -0
- veadk/integrations/__init__.py +13 -0
- veadk/integrations/ve_apig/__init__.py +13 -0
- veadk/integrations/ve_apig/ve_apig.py +349 -0
- veadk/integrations/ve_apig/ve_apig_utils.py +332 -0
- veadk/integrations/ve_code_pipeline/__init__.py +13 -0
- veadk/integrations/ve_code_pipeline/ve_code_pipeline.py +431 -0
- veadk/integrations/ve_cozeloop/__init__.py +13 -0
- veadk/integrations/ve_cozeloop/ve_cozeloop.py +96 -0
- veadk/integrations/ve_cr/__init__.py +13 -0
- veadk/integrations/ve_cr/ve_cr.py +220 -0
- veadk/integrations/ve_faas/__init__.py +13 -0
- veadk/integrations/ve_faas/template/cookiecutter.json +15 -0
- veadk/integrations/ve_faas/template/{{cookiecutter.local_dir_name}}/__init__.py +13 -0
- veadk/integrations/ve_faas/template/{{cookiecutter.local_dir_name}}/clean.py +23 -0
- veadk/integrations/ve_faas/template/{{cookiecutter.local_dir_name}}/config.yaml.example +6 -0
- veadk/integrations/ve_faas/template/{{cookiecutter.local_dir_name}}/deploy.py +106 -0
- veadk/integrations/ve_faas/template/{{cookiecutter.local_dir_name}}/src/__init__.py +13 -0
- veadk/integrations/ve_faas/template/{{cookiecutter.local_dir_name}}/src/agent.py +25 -0
- veadk/integrations/ve_faas/template/{{cookiecutter.local_dir_name}}/src/app.py +202 -0
- veadk/integrations/ve_faas/template/{{cookiecutter.local_dir_name}}/src/requirements.txt +3 -0
- veadk/integrations/ve_faas/template/{{cookiecutter.local_dir_name}}/src/run.sh +49 -0
- veadk/integrations/ve_faas/template/{{cookiecutter.local_dir_name}}/src/{{ cookiecutter.app_name }}/__init__.py +14 -0
- veadk/integrations/ve_faas/template/{{cookiecutter.local_dir_name}}/src/{{ cookiecutter.app_name }}/agent.py +27 -0
- veadk/integrations/ve_faas/ve_faas.py +754 -0
- veadk/integrations/ve_faas/ve_faas_utils.py +408 -0
- veadk/integrations/ve_faas/web_template/cookiecutter.json +20 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/__init__.py +13 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/clean.py +23 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/config.yaml.example +2 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/deploy.py +44 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/src/Dockerfile +23 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/src/app.py +123 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/src/init_db.py +46 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/src/models.py +36 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/src/requirements.txt +4 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/src/run.sh +21 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/src/static/css/style.css +368 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/src/static/js/admin.js +0 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/src/templates/admin/dashboard.html +21 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/src/templates/admin/edit_post.html +24 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/src/templates/admin/login.html +21 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/src/templates/admin/posts.html +53 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/src/templates/base.html +45 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/src/templates/index.html +29 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/src/templates/post.html +14 -0
- veadk/integrations/ve_identity/__init__.py +110 -0
- veadk/integrations/ve_identity/auth_config.py +261 -0
- veadk/integrations/ve_identity/auth_mixins.py +650 -0
- veadk/integrations/ve_identity/auth_processor.py +385 -0
- veadk/integrations/ve_identity/function_tool.py +158 -0
- veadk/integrations/ve_identity/identity_client.py +864 -0
- veadk/integrations/ve_identity/mcp_tool.py +181 -0
- veadk/integrations/ve_identity/mcp_toolset.py +431 -0
- veadk/integrations/ve_identity/models.py +228 -0
- veadk/integrations/ve_identity/token_manager.py +188 -0
- veadk/integrations/ve_identity/utils.py +151 -0
- veadk/integrations/ve_prompt_pilot/__init__.py +13 -0
- veadk/integrations/ve_prompt_pilot/ve_prompt_pilot.py +85 -0
- veadk/integrations/ve_tls/__init__.py +13 -0
- veadk/integrations/ve_tls/utils.py +116 -0
- veadk/integrations/ve_tls/ve_tls.py +212 -0
- veadk/integrations/ve_tos/ve_tos.py +710 -0
- veadk/integrations/ve_viking_db_memory/__init__.py +13 -0
- veadk/integrations/ve_viking_db_memory/ve_viking_db_memory.py +308 -0
- veadk/knowledgebase/__init__.py +17 -0
- veadk/knowledgebase/backends/__init__.py +13 -0
- veadk/knowledgebase/backends/base_backend.py +72 -0
- veadk/knowledgebase/backends/in_memory_backend.py +91 -0
- veadk/knowledgebase/backends/opensearch_backend.py +162 -0
- veadk/knowledgebase/backends/redis_backend.py +172 -0
- veadk/knowledgebase/backends/utils.py +92 -0
- veadk/knowledgebase/backends/vikingdb_knowledge_backend.py +608 -0
- veadk/knowledgebase/entry.py +25 -0
- veadk/knowledgebase/knowledgebase.py +307 -0
- veadk/memory/__init__.py +35 -0
- veadk/memory/long_term_memory.py +365 -0
- veadk/memory/long_term_memory_backends/__init__.py +13 -0
- veadk/memory/long_term_memory_backends/base_backend.py +35 -0
- veadk/memory/long_term_memory_backends/in_memory_backend.py +67 -0
- veadk/memory/long_term_memory_backends/mem0_backend.py +155 -0
- veadk/memory/long_term_memory_backends/opensearch_backend.py +124 -0
- veadk/memory/long_term_memory_backends/redis_backend.py +140 -0
- veadk/memory/long_term_memory_backends/vikingdb_memory_backend.py +189 -0
- veadk/memory/short_term_memory.py +252 -0
- veadk/memory/short_term_memory_backends/__init__.py +13 -0
- veadk/memory/short_term_memory_backends/base_backend.py +31 -0
- veadk/memory/short_term_memory_backends/mysql_backend.py +49 -0
- veadk/memory/short_term_memory_backends/postgresql_backend.py +49 -0
- veadk/memory/short_term_memory_backends/sqlite_backend.py +55 -0
- veadk/memory/short_term_memory_processor.py +100 -0
- veadk/processors/__init__.py +26 -0
- veadk/processors/base_run_processor.py +120 -0
- veadk/prompts/__init__.py +13 -0
- veadk/prompts/agent_default_prompt.py +30 -0
- veadk/prompts/prompt_evaluator.py +20 -0
- veadk/prompts/prompt_memory_processor.py +55 -0
- veadk/prompts/prompt_optimization.py +150 -0
- veadk/runner.py +732 -0
- veadk/tools/__init__.py +13 -0
- veadk/tools/builtin_tools/__init__.py +13 -0
- veadk/tools/builtin_tools/agent_authorization.py +94 -0
- veadk/tools/builtin_tools/generate_image.py +23 -0
- veadk/tools/builtin_tools/image_edit.py +300 -0
- veadk/tools/builtin_tools/image_generate.py +446 -0
- veadk/tools/builtin_tools/lark.py +67 -0
- veadk/tools/builtin_tools/las.py +24 -0
- veadk/tools/builtin_tools/link_reader.py +66 -0
- veadk/tools/builtin_tools/llm_shield.py +381 -0
- veadk/tools/builtin_tools/load_knowledgebase.py +97 -0
- veadk/tools/builtin_tools/mcp_router.py +29 -0
- veadk/tools/builtin_tools/run_code.py +113 -0
- veadk/tools/builtin_tools/tts.py +253 -0
- veadk/tools/builtin_tools/vesearch.py +49 -0
- veadk/tools/builtin_tools/video_generate.py +363 -0
- veadk/tools/builtin_tools/web_scraper.py +76 -0
- veadk/tools/builtin_tools/web_search.py +83 -0
- veadk/tools/demo_tools.py +58 -0
- veadk/tools/load_knowledgebase_tool.py +149 -0
- veadk/tools/sandbox/__init__.py +13 -0
- veadk/tools/sandbox/browser_sandbox.py +37 -0
- veadk/tools/sandbox/code_sandbox.py +40 -0
- veadk/tools/sandbox/computer_sandbox.py +34 -0
- veadk/tracing/__init__.py +13 -0
- veadk/tracing/base_tracer.py +58 -0
- veadk/tracing/telemetry/__init__.py +13 -0
- veadk/tracing/telemetry/attributes/attributes.py +29 -0
- veadk/tracing/telemetry/attributes/extractors/common_attributes_extractors.py +180 -0
- veadk/tracing/telemetry/attributes/extractors/llm_attributes_extractors.py +858 -0
- veadk/tracing/telemetry/attributes/extractors/tool_attributes_extractors.py +152 -0
- veadk/tracing/telemetry/attributes/extractors/types.py +164 -0
- veadk/tracing/telemetry/exporters/__init__.py +13 -0
- veadk/tracing/telemetry/exporters/apmplus_exporter.py +558 -0
- veadk/tracing/telemetry/exporters/base_exporter.py +39 -0
- veadk/tracing/telemetry/exporters/cozeloop_exporter.py +129 -0
- veadk/tracing/telemetry/exporters/inmemory_exporter.py +248 -0
- veadk/tracing/telemetry/exporters/tls_exporter.py +139 -0
- veadk/tracing/telemetry/opentelemetry_tracer.py +320 -0
- veadk/tracing/telemetry/telemetry.py +411 -0
- veadk/types.py +47 -0
- veadk/utils/__init__.py +13 -0
- veadk/utils/audio_manager.py +95 -0
- veadk/utils/auth.py +294 -0
- veadk/utils/logger.py +59 -0
- veadk/utils/mcp_utils.py +44 -0
- veadk/utils/misc.py +184 -0
- veadk/utils/patches.py +101 -0
- veadk/utils/volcengine_sign.py +205 -0
- veadk/version.py +15 -0
- veadk_python-0.2.27.dist-info/METADATA +373 -0
- veadk_python-0.2.27.dist-info/RECORD +218 -0
- veadk_python-0.2.27.dist-info/WHEEL +5 -0
- veadk_python-0.2.27.dist-info/entry_points.txt +2 -0
- veadk_python-0.2.27.dist-info/licenses/LICENSE +201 -0
- veadk_python-0.2.27.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Data models for veIdentity integration."""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
from typing import Any, Callable, Optional, TYPE_CHECKING, List
|
|
20
|
+
|
|
21
|
+
from pydantic import BaseModel, model_validator, field_validator
|
|
22
|
+
from google.adk.auth.auth_credential import OAuth2Auth
|
|
23
|
+
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
from veadk.integrations.ve_identity.identity_client import IdentityClient
|
|
26
|
+
else:
|
|
27
|
+
# For runtime, use Any to avoid circular import issues
|
|
28
|
+
IdentityClient = Any
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# Forward declaration for type hints
|
|
32
|
+
class OAuth2AuthPoller:
|
|
33
|
+
"""Abstract base class for OAuth2 authentication polling implementations.
|
|
34
|
+
|
|
35
|
+
OAuth2 auth pollers are used to retrieve complete OAuth2 authentication data
|
|
36
|
+
after user authorization. Implementations should poll the identity service
|
|
37
|
+
until the authentication becomes available or a timeout occurs.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
async def poll_for_auth(self) -> OAuth2Auth:
|
|
41
|
+
"""Poll for OAuth2 authentication data and return it when available.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
The complete OAuth2Auth object containing tokens and metadata.
|
|
45
|
+
|
|
46
|
+
Raises:
|
|
47
|
+
asyncio.TimeoutError: If polling times out before auth data is available.
|
|
48
|
+
"""
|
|
49
|
+
raise NotImplementedError("Subclasses must implement poll_for_auth")
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class AuthRequestConfig(BaseModel):
|
|
53
|
+
"""Configuration for authentication request processing.
|
|
54
|
+
|
|
55
|
+
Attributes:
|
|
56
|
+
on_auth_url: Optional callback function invoked when an authorization URL is generated.
|
|
57
|
+
Can be sync or async. Receives the auth URL as a parameter.
|
|
58
|
+
oauth2_auth_poller: Optional custom token poller implementation for retrieving tokens
|
|
59
|
+
after user authorization. If None, a default poller will be used.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
model_config = {"arbitrary_types_allowed": True}
|
|
63
|
+
|
|
64
|
+
on_auth_url: Optional[Callable[[str], Any]] = None
|
|
65
|
+
# Currently we only use auth_uri to initialize poller, may extend to support other fields like exchanged_auth_credential.
|
|
66
|
+
oauth2_auth_poller: Optional[Callable[[Any], OAuth2AuthPoller]] = None
|
|
67
|
+
max_auth_cycles: Optional[int] = None
|
|
68
|
+
identity_client: Optional[IdentityClient] = None
|
|
69
|
+
region: Optional[str] = None
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class OAuth2TokenResponse(BaseModel):
|
|
73
|
+
"""Response from OAuth2 token request.
|
|
74
|
+
|
|
75
|
+
Attributes:
|
|
76
|
+
response_type: Type of response - either "token" or "auth_url".
|
|
77
|
+
access_token: The OAuth2 access token (present when response_type is "token").
|
|
78
|
+
authorization_url: The authorization URL for user consent (present when response_type is "auth_url").
|
|
79
|
+
resource_ref: When response_type is "auth_url", this field contains the serialized request parameters
|
|
80
|
+
needed to poll for the final OAuth2 tokens after user authorization.
|
|
81
|
+
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
response_type: str
|
|
85
|
+
access_token: Optional[str] = None
|
|
86
|
+
authorization_url: Optional[str] = None
|
|
87
|
+
resource_ref: Optional[str] = None
|
|
88
|
+
|
|
89
|
+
@field_validator("response_type")
|
|
90
|
+
@classmethod
|
|
91
|
+
def validate_response_type(cls, v: str) -> str:
|
|
92
|
+
"""Validate that response_type is either 'token' or 'auth_url'."""
|
|
93
|
+
if v not in ("token", "auth_url"):
|
|
94
|
+
raise ValueError("response_type must be either 'token' or 'auth_url'")
|
|
95
|
+
return v
|
|
96
|
+
|
|
97
|
+
@model_validator(mode="after")
|
|
98
|
+
def validate_response_fields(self):
|
|
99
|
+
"""Validate that required fields are present based on response_type."""
|
|
100
|
+
if self.response_type == "token":
|
|
101
|
+
if not self.access_token:
|
|
102
|
+
raise ValueError(
|
|
103
|
+
"access_token is required when response_type is 'token'"
|
|
104
|
+
)
|
|
105
|
+
elif self.response_type == "auth_url":
|
|
106
|
+
if not self.authorization_url:
|
|
107
|
+
raise ValueError(
|
|
108
|
+
"authorization_url is required when response_type is 'auth_url'"
|
|
109
|
+
)
|
|
110
|
+
return self
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class DCRRegistrationRequest(BaseModel):
|
|
114
|
+
"""Dynamic Client Registration (DCR) request model.
|
|
115
|
+
|
|
116
|
+
Based on RFC 7591 - OAuth 2.0 Dynamic Client Registration Protocol.
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
client_name: str = "VeADK Framework"
|
|
120
|
+
redirect_uris: Optional[List[str]] = None
|
|
121
|
+
scope: Optional[str] = None
|
|
122
|
+
grant_types: Optional[List[str]] = None
|
|
123
|
+
response_types: Optional[List[str]] = None
|
|
124
|
+
token_endpoint_auth_method: Optional[str] = None
|
|
125
|
+
|
|
126
|
+
@field_validator("client_name")
|
|
127
|
+
@classmethod
|
|
128
|
+
def validate_client_name_not_empty(cls, v: str) -> str:
|
|
129
|
+
"""Validate that client_name is not empty."""
|
|
130
|
+
if not v or not v.strip():
|
|
131
|
+
raise ValueError("client_name cannot be empty")
|
|
132
|
+
return v.strip()
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class DCRRegistrationResponse(BaseModel):
|
|
136
|
+
"""Dynamic Client Registration (DCR) response model.
|
|
137
|
+
|
|
138
|
+
Based on RFC 7591 - OAuth 2.0 Dynamic Client Registration Protocol.
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
client_id: str
|
|
142
|
+
client_secret: Optional[str] = None
|
|
143
|
+
client_id_issued_at: Optional[int] = None
|
|
144
|
+
client_secret_expires_at: Optional[int] = None
|
|
145
|
+
redirect_uris: Optional[List[str]] = None
|
|
146
|
+
grant_types: Optional[List[str]] = None
|
|
147
|
+
response_types: Optional[List[str]] = None
|
|
148
|
+
scope: Optional[str] = None
|
|
149
|
+
token_endpoint_auth_method: Optional[str] = None
|
|
150
|
+
|
|
151
|
+
@field_validator("client_id")
|
|
152
|
+
@classmethod
|
|
153
|
+
def validate_client_id_not_empty(cls, v: str) -> str:
|
|
154
|
+
"""Validate that client_id is not empty."""
|
|
155
|
+
if not v or not v.strip():
|
|
156
|
+
raise ValueError("client_id cannot be empty")
|
|
157
|
+
return v.strip()
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class AuthorizationServerMetadata(BaseModel):
|
|
161
|
+
"""Extended Authorization Server Metadata with DCR support.
|
|
162
|
+
|
|
163
|
+
Based on RFC 8414 - OAuth 2.0 Authorization Server Metadata
|
|
164
|
+
and RFC 7591 - OAuth 2.0 Dynamic Client Registration Protocol.
|
|
165
|
+
"""
|
|
166
|
+
|
|
167
|
+
authorization_endpoint: str
|
|
168
|
+
token_endpoint: str
|
|
169
|
+
issuer: str
|
|
170
|
+
register_endpoint: Optional[str] = None # DCR endpoint
|
|
171
|
+
response_types: Optional[List[str]] = None
|
|
172
|
+
|
|
173
|
+
@field_validator("authorization_endpoint", "token_endpoint", "issuer")
|
|
174
|
+
@classmethod
|
|
175
|
+
def validate_url_not_empty(cls, v: str) -> str:
|
|
176
|
+
"""Validate that URL fields are not empty."""
|
|
177
|
+
if not v or not v.strip():
|
|
178
|
+
raise ValueError("URL fields cannot be empty")
|
|
179
|
+
return v.strip()
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class OAuth2Discovery(BaseModel):
|
|
183
|
+
"""OAuth2 Discovery configuration with DCR support."""
|
|
184
|
+
|
|
185
|
+
authorization_server_metadata: AuthorizationServerMetadata
|
|
186
|
+
discovery_url: Optional[str] = None
|
|
187
|
+
|
|
188
|
+
@field_validator("discovery_url")
|
|
189
|
+
@classmethod
|
|
190
|
+
def validate_discovery_url(cls, v: Optional[str]) -> Optional[str]:
|
|
191
|
+
"""Validate discovery URL if provided."""
|
|
192
|
+
if v is not None and (not v or not v.strip()):
|
|
193
|
+
raise ValueError("discovery_url cannot be empty string")
|
|
194
|
+
return v.strip() if v else None
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
class WorkloadToken(BaseModel):
|
|
198
|
+
"""Workload access token and expiration time.
|
|
199
|
+
|
|
200
|
+
Attributes:
|
|
201
|
+
workload_access_token: The workload access token.
|
|
202
|
+
expires_at: Unix timestamp (in seconds) when the token expires.
|
|
203
|
+
"""
|
|
204
|
+
|
|
205
|
+
workload_access_token: str
|
|
206
|
+
expires_at: int
|
|
207
|
+
|
|
208
|
+
@field_validator("workload_access_token")
|
|
209
|
+
@classmethod
|
|
210
|
+
def validate_token_not_empty(cls, v: str) -> str:
|
|
211
|
+
"""Validate that the workload access token is not empty."""
|
|
212
|
+
if not v or not v.strip():
|
|
213
|
+
raise ValueError("workload_access_token cannot be empty")
|
|
214
|
+
return v.strip()
|
|
215
|
+
|
|
216
|
+
@field_validator("expires_at")
|
|
217
|
+
@classmethod
|
|
218
|
+
def validate_expires_at_positive(cls, v: int) -> int:
|
|
219
|
+
"""Validate that expires_at is a positive timestamp."""
|
|
220
|
+
if v <= 0:
|
|
221
|
+
raise ValueError("expires_at must be a positive Unix timestamp")
|
|
222
|
+
return v
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
class AssumeRoleCredential(BaseModel):
|
|
226
|
+
access_key_id: str
|
|
227
|
+
secret_access_key: str
|
|
228
|
+
session_token: str
|
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Token manager for agent identity tokens with caching and expiration handling."""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import time
|
|
20
|
+
from typing import Optional, Union
|
|
21
|
+
|
|
22
|
+
from google.adk.tools.tool_context import ToolContext
|
|
23
|
+
from google.adk.agents.callback_context import CallbackContext
|
|
24
|
+
from google.adk.agents.readonly_context import ReadonlyContext
|
|
25
|
+
|
|
26
|
+
from veadk.integrations.ve_identity.auth_config import (
|
|
27
|
+
get_default_identity_client,
|
|
28
|
+
)
|
|
29
|
+
from veadk.utils.logger import get_logger
|
|
30
|
+
|
|
31
|
+
from veadk.integrations.ve_identity.identity_client import IdentityClient
|
|
32
|
+
from veadk.integrations.ve_identity.models import WorkloadToken
|
|
33
|
+
|
|
34
|
+
logger = get_logger(__name__)
|
|
35
|
+
|
|
36
|
+
# Token expiration buffer in seconds - tokens will be refreshed this many seconds before actual expiration
|
|
37
|
+
TOKEN_EXPIRATION_BUFFER_SECONDS = 60
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class WorkloadTokenManager:
|
|
41
|
+
"""Manager for workload access tokens with automatic caching and expiration handling.
|
|
42
|
+
|
|
43
|
+
This class manages the lifecycle of workload access tokens, including:
|
|
44
|
+
- Caching tokens in session state
|
|
45
|
+
- Automatic token refresh when expired
|
|
46
|
+
- Support for different authentication modes (JWT, user ID, workload-only)
|
|
47
|
+
|
|
48
|
+
Attributes:
|
|
49
|
+
identity_client: The IdentityClient instance for making API requests.
|
|
50
|
+
region: VolcEngine region for the identity client. Defaults to "cn-beijing".
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
identity_client: Optional[IdentityClient] = None,
|
|
56
|
+
region: Optional[str] = None,
|
|
57
|
+
):
|
|
58
|
+
"""Initialize the token manager.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
identity_client: Optional IdentityClient instance to use for token requests.
|
|
62
|
+
If not provided and use_global_client is True, uses the global client
|
|
63
|
+
from VeIdentityConfig.
|
|
64
|
+
region: Optional region for creating a new IdentityClient.
|
|
65
|
+
Only used if identity_client is not provided and use_global_client is False.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
self._identity_client = identity_client or get_default_identity_client(
|
|
69
|
+
region=region
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
def _build_cache_key(
|
|
73
|
+
self, tool_context: Union[ToolContext | CallbackContext | ReadonlyContext]
|
|
74
|
+
) -> str:
|
|
75
|
+
"""Build a unique cache key for storing the workload token.
|
|
76
|
+
|
|
77
|
+
The cache key is composed of the agent name and user ID to ensure
|
|
78
|
+
tokens are properly scoped per agent and user.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
tool_context: The tool context containing agent and user information.
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
A unique cache key string in the format "workload_token:{agent}:{user}".
|
|
85
|
+
"""
|
|
86
|
+
return f"workload_token:{tool_context.agent_name}:{tool_context._invocation_context.user_id}"
|
|
87
|
+
|
|
88
|
+
def _is_token_expired(self, expires_at: int) -> bool:
|
|
89
|
+
"""Check if a token has expired or will expire soon.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
expires_at: The expiration timestamp in seconds since Unix epoch.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
True if the token has expired or will expire within the buffer period,
|
|
96
|
+
False otherwise.
|
|
97
|
+
"""
|
|
98
|
+
current_time = int(time.time())
|
|
99
|
+
return current_time >= (expires_at - TOKEN_EXPIRATION_BUFFER_SECONDS)
|
|
100
|
+
|
|
101
|
+
async def get_workload_token(
|
|
102
|
+
self,
|
|
103
|
+
tool_context: Union[ToolContext | CallbackContext | ReadonlyContext],
|
|
104
|
+
workload_name: Optional[str] = None,
|
|
105
|
+
user_token: Optional[str] = None,
|
|
106
|
+
) -> str:
|
|
107
|
+
"""Get or refresh the workload access token.
|
|
108
|
+
|
|
109
|
+
This method implements intelligent token caching:
|
|
110
|
+
1. Checks if a valid cached token exists in session state
|
|
111
|
+
2. Returns cached token if not expired
|
|
112
|
+
3. Fetches and caches a new token if expired or not found
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
tool_context: The tool context containing session state and user information.
|
|
116
|
+
workload_name: Optional workload name. If not provided, uses tool_context.agent_name.
|
|
117
|
+
user_token: Optional JWT token for user-scoped authentication.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
The workload access token string.
|
|
121
|
+
|
|
122
|
+
Raises:
|
|
123
|
+
ValueError: If the identity service response is missing required fields.
|
|
124
|
+
"""
|
|
125
|
+
cache_key = self._build_cache_key(tool_context)
|
|
126
|
+
|
|
127
|
+
# Attempt to retrieve cached token from session state
|
|
128
|
+
cached_data: Optional[WorkloadToken | None] = (
|
|
129
|
+
tool_context._invocation_context.session.state.get(cache_key)
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# Validate and return cached token if still valid, and type check
|
|
133
|
+
if cached_data and isinstance(cached_data, WorkloadToken):
|
|
134
|
+
if cached_data.workload_access_token and cached_data.expires_at:
|
|
135
|
+
if not self._is_token_expired(cached_data.expires_at):
|
|
136
|
+
return cached_data.workload_access_token
|
|
137
|
+
else:
|
|
138
|
+
logger.info(
|
|
139
|
+
f"Cached workload token expired for agent '{tool_context.agent_name}', refreshing..."
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
# Determine user_id based on authentication mode
|
|
143
|
+
user_id = None if user_token else tool_context._invocation_context.user_id
|
|
144
|
+
|
|
145
|
+
# Request new token from identity service
|
|
146
|
+
workload_token: WorkloadToken = self._identity_client.get_workload_access_token(
|
|
147
|
+
workload_name=workload_name,
|
|
148
|
+
user_token=user_token,
|
|
149
|
+
user_id=user_id,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
tool_context._invocation_context.session.state[cache_key] = workload_token
|
|
153
|
+
|
|
154
|
+
return workload_token.workload_access_token
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
async def get_workload_token(
|
|
158
|
+
tool_context: Union[ToolContext | CallbackContext | ReadonlyContext],
|
|
159
|
+
identity_client: Optional[IdentityClient] = None,
|
|
160
|
+
workload_name: Optional[str] = None,
|
|
161
|
+
user_token: Optional[str] = None,
|
|
162
|
+
region: str = "cn-beijing",
|
|
163
|
+
) -> str:
|
|
164
|
+
"""Convenience function to get a workload access token.
|
|
165
|
+
|
|
166
|
+
This function creates a token manager and retrieves the token with automatic
|
|
167
|
+
caching and expiration handling. It's a simplified interface for common use cases.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
tool_context: The tool context containing session state and user information.
|
|
171
|
+
identity_client: Optional IdentityClient instance. If not provided, creates a new one.
|
|
172
|
+
workload_name: Optional workload name. If not provided, uses tool_context.agent_name.
|
|
173
|
+
user_token: Optional JWT token for user-scoped authentication.
|
|
174
|
+
region: The VolcEngine region for the identity client. Defaults to "cn-beijing".
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
The workload access token string.
|
|
178
|
+
|
|
179
|
+
Raises:
|
|
180
|
+
ValueError: If the identity service response is missing required fields.
|
|
181
|
+
"""
|
|
182
|
+
return await WorkloadTokenManager(
|
|
183
|
+
identity_client=identity_client, region=region
|
|
184
|
+
).get_workload_token(
|
|
185
|
+
tool_context=tool_context,
|
|
186
|
+
workload_name=workload_name,
|
|
187
|
+
user_token=user_token,
|
|
188
|
+
)
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Utility functions for handling authentication events in the veADK framework."""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
import base64
|
|
19
|
+
from typing import Optional
|
|
20
|
+
|
|
21
|
+
from google.adk.events import Event
|
|
22
|
+
from google.adk.auth import AuthConfig
|
|
23
|
+
from google.adk.auth.auth_credential import AuthCredential
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def is_pending_auth_event(event: Event) -> bool:
|
|
27
|
+
"""Check if an ADK event represents a pending authentication request.
|
|
28
|
+
|
|
29
|
+
The ADK framework emits a special function call ('adk_request_credential')
|
|
30
|
+
when a tool requires user authentication that hasn't been satisfied yet.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
event: The ADK Event object to inspect.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
True if the event is an 'adk_request_credential' function call, False otherwise.
|
|
37
|
+
"""
|
|
38
|
+
return (
|
|
39
|
+
event.content
|
|
40
|
+
and event.content.parts
|
|
41
|
+
and event.content.parts[0]
|
|
42
|
+
and event.content.parts[0].function_call
|
|
43
|
+
and event.content.parts[0].function_call.name == "adk_request_credential"
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def get_function_call_id(event: Event) -> str:
|
|
48
|
+
"""Extract the unique function call ID from an ADK event.
|
|
49
|
+
|
|
50
|
+
This ID is used to correlate function responses back to the specific
|
|
51
|
+
function call that initiated the authentication request.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
event: The ADK Event object containing the function call.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
The unique identifier string of the function call.
|
|
58
|
+
|
|
59
|
+
Raises:
|
|
60
|
+
ValueError: If the function call ID cannot be found in the event structure.
|
|
61
|
+
"""
|
|
62
|
+
if (
|
|
63
|
+
event
|
|
64
|
+
and event.content
|
|
65
|
+
and event.content.parts
|
|
66
|
+
and event.content.parts[0]
|
|
67
|
+
and event.content.parts[0].function_call
|
|
68
|
+
and event.content.parts[0].function_call.id
|
|
69
|
+
):
|
|
70
|
+
return event.content.parts[0].function_call.id
|
|
71
|
+
|
|
72
|
+
raise ValueError(f"Cannot extract function call ID from event: {event}")
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def get_function_call_auth_config(event: Event) -> AuthConfig:
|
|
76
|
+
"""Extract authentication configuration from an 'adk_request_credential' event.
|
|
77
|
+
|
|
78
|
+
The client should use this AuthConfig to provide necessary authentication details
|
|
79
|
+
(like OAuth codes and state) and send it back to the ADK to continue the OAuth
|
|
80
|
+
token exchange process.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
event: The ADK Event object containing the 'adk_request_credential' call.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
An AuthConfig object populated with details from the function call arguments.
|
|
87
|
+
|
|
88
|
+
Raises:
|
|
89
|
+
ValueError: If the 'authConfig' argument cannot be found in the event.
|
|
90
|
+
"""
|
|
91
|
+
if (
|
|
92
|
+
event
|
|
93
|
+
and event.content
|
|
94
|
+
and event.content.parts
|
|
95
|
+
and event.content.parts[0]
|
|
96
|
+
and event.content.parts[0].function_call
|
|
97
|
+
and event.content.parts[0].function_call.args
|
|
98
|
+
and event.content.parts[0].function_call.args.get("authConfig")
|
|
99
|
+
):
|
|
100
|
+
auth_config_dict = event.content.parts[0].function_call.args.get("authConfig")
|
|
101
|
+
return AuthConfig(**auth_config_dict)
|
|
102
|
+
|
|
103
|
+
raise ValueError(f"Cannot extract auth config from event: {event}")
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def generate_headers(credential: AuthCredential) -> Optional[dict[str, str]]:
|
|
107
|
+
"""Extracts authentication headers from credentials.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
credential: The authentication credential to process.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
Dictionary of headers to add to the request, or None if no auth.
|
|
114
|
+
"""
|
|
115
|
+
headers: Optional[dict[str, str]] = None
|
|
116
|
+
if credential:
|
|
117
|
+
if credential.oauth2:
|
|
118
|
+
headers = {"Authorization": f"Bearer {credential.oauth2.access_token}"}
|
|
119
|
+
elif credential.http:
|
|
120
|
+
# Handle HTTP authentication schemes
|
|
121
|
+
if (
|
|
122
|
+
credential.http.scheme.lower() == "bearer"
|
|
123
|
+
and credential.http.credentials.token
|
|
124
|
+
):
|
|
125
|
+
headers = {
|
|
126
|
+
"Authorization": f"Bearer {credential.http.credentials.token}"
|
|
127
|
+
}
|
|
128
|
+
elif credential.http.scheme.lower() == "basic":
|
|
129
|
+
# Handle basic auth
|
|
130
|
+
if (
|
|
131
|
+
credential.http.credentials.username
|
|
132
|
+
and credential.http.credentials.password
|
|
133
|
+
):
|
|
134
|
+
credentials = f"{credential.http.credentials.username}:{credential.http.credentials.password}"
|
|
135
|
+
encoded_credentials = base64.b64encode(
|
|
136
|
+
credentials.encode()
|
|
137
|
+
).decode()
|
|
138
|
+
headers = {"Authorization": f"Basic {encoded_credentials}"}
|
|
139
|
+
elif credential.http.credentials.token:
|
|
140
|
+
# Handle other HTTP schemes with token
|
|
141
|
+
headers = {
|
|
142
|
+
"Authorization": (
|
|
143
|
+
f"{credential.http.scheme} {credential.http.credentials.token}"
|
|
144
|
+
)
|
|
145
|
+
}
|
|
146
|
+
elif credential.api_key:
|
|
147
|
+
headers = {"Authorization": credential.api_key}
|
|
148
|
+
elif credential.service_account:
|
|
149
|
+
pass
|
|
150
|
+
|
|
151
|
+
return headers
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import agent_pilot as ap
|
|
18
|
+
from agent_pilot.models import TaskType
|
|
19
|
+
from veadk import Agent
|
|
20
|
+
from veadk.prompts import prompt_optimization
|
|
21
|
+
from veadk.utils.logger import get_logger
|
|
22
|
+
|
|
23
|
+
logger = get_logger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class VePromptPilot:
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
api_key: str,
|
|
30
|
+
workspace_id: str,
|
|
31
|
+
path: str = "",
|
|
32
|
+
task_id: str | None = None,
|
|
33
|
+
) -> None:
|
|
34
|
+
self.api_key = api_key
|
|
35
|
+
self.workspace_id = workspace_id
|
|
36
|
+
|
|
37
|
+
self.path = path
|
|
38
|
+
|
|
39
|
+
def optimize(
|
|
40
|
+
self,
|
|
41
|
+
agents: list[Agent],
|
|
42
|
+
feedback: str = "",
|
|
43
|
+
model_name: str = "doubao-1.5-pro-32k-250115",
|
|
44
|
+
) -> str:
|
|
45
|
+
for idx, agent in enumerate(agents):
|
|
46
|
+
optimized_prompt = ""
|
|
47
|
+
if not feedback:
|
|
48
|
+
logger.info("Optimizing prompt without feedback.")
|
|
49
|
+
task_description = prompt_optimization.render_prompt_with_jinja2(agent)
|
|
50
|
+
else:
|
|
51
|
+
logger.info(f"Optimizing prompt with feedback: {feedback}")
|
|
52
|
+
task_description = (
|
|
53
|
+
prompt_optimization.render_prompt_feedback_with_jinja2(
|
|
54
|
+
agent, feedback
|
|
55
|
+
)
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
logger.info(
|
|
59
|
+
f"Optimizing prompt for agent {agent.name} by {model_name} [{idx + 1}/{len(agents)}]"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
usage = None
|
|
63
|
+
for chunk in ap.generate_prompt_stream(
|
|
64
|
+
task_description=task_description,
|
|
65
|
+
current_prompt=str(agent.instruction),
|
|
66
|
+
model_name=model_name,
|
|
67
|
+
task_type=TaskType.DIALOG,
|
|
68
|
+
temperature=1.0,
|
|
69
|
+
top_p=0.7,
|
|
70
|
+
api_key=self.api_key,
|
|
71
|
+
workspace_id=self.workspace_id,
|
|
72
|
+
): # stream chunks of optimized prompt
|
|
73
|
+
# Process each chunk as it arrives
|
|
74
|
+
optimized_prompt += chunk.data.content if chunk.data else ""
|
|
75
|
+
# print(chunk.data.content, end="", flush=True)
|
|
76
|
+
if chunk.event == "usage":
|
|
77
|
+
usage = chunk.data.usage if chunk.data else 0
|
|
78
|
+
optimized_prompt = optimized_prompt.replace("\\n", "\n")
|
|
79
|
+
print(f"Optimized prompt for agent {agent.name}:\n{optimized_prompt}")
|
|
80
|
+
if usage:
|
|
81
|
+
logger.info(f"Token usage: {usage['total_tokens']}")
|
|
82
|
+
else:
|
|
83
|
+
logger.warning("No usage data.")
|
|
84
|
+
|
|
85
|
+
return optimized_prompt
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|