veadk-python 0.2.4__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of veadk-python might be problematic. Click here for more details.
- veadk/agent.py +28 -8
- veadk/cli/cli_deploy.py +3 -1
- veadk/cloud/cloud_app.py +21 -6
- veadk/consts.py +14 -1
- veadk/database/viking/viking_database.py +3 -3
- veadk/integrations/ve_faas/template/{{cookiecutter.local_dir_name}}/clean.py +23 -0
- veadk/integrations/ve_faas/template/{{cookiecutter.local_dir_name}}/src/app.py +4 -1
- veadk/integrations/ve_faas/template/{{cookiecutter.local_dir_name}}/src/run.sh +11 -1
- veadk/integrations/ve_tos/ve_tos.py +176 -0
- veadk/runner.py +107 -34
- veadk/tools/builtin_tools/image_edit.py +236 -0
- veadk/tools/builtin_tools/image_generate.py +236 -0
- veadk/tools/builtin_tools/video_generate.py +326 -0
- veadk/tools/sandbox/browser_sandbox.py +19 -9
- veadk/tools/sandbox/code_sandbox.py +21 -11
- veadk/tools/sandbox/computer_sandbox.py +16 -9
- veadk/tracing/base_tracer.py +0 -19
- veadk/tracing/telemetry/attributes/extractors/llm_attributes_extractors.py +65 -6
- veadk/tracing/telemetry/attributes/extractors/tool_attributes_extractors.py +20 -14
- veadk/tracing/telemetry/exporters/inmemory_exporter.py +3 -0
- veadk/tracing/telemetry/opentelemetry_tracer.py +4 -1
- veadk/tracing/telemetry/telemetry.py +113 -24
- veadk/utils/misc.py +40 -0
- veadk/version.py +1 -1
- {veadk_python-0.2.4.dist-info → veadk_python-0.2.5.dist-info}/METADATA +1 -1
- {veadk_python-0.2.4.dist-info → veadk_python-0.2.5.dist-info}/RECORD +30 -25
- {veadk_python-0.2.4.dist-info → veadk_python-0.2.5.dist-info}/WHEEL +0 -0
- {veadk_python-0.2.4.dist-info → veadk_python-0.2.5.dist-info}/entry_points.txt +0 -0
- {veadk_python-0.2.4.dist-info → veadk_python-0.2.5.dist-info}/licenses/LICENSE +0 -0
- {veadk_python-0.2.4.dist-info → veadk_python-0.2.5.dist-info}/top_level.txt +0 -0
veadk/agent.py
CHANGED
|
@@ -28,9 +28,10 @@ from typing_extensions import Any
|
|
|
28
28
|
|
|
29
29
|
from veadk.config import getenv
|
|
30
30
|
from veadk.consts import (
|
|
31
|
-
|
|
31
|
+
DEFAULT_MODEL_AGENT_PROVIDER,
|
|
32
32
|
DEFAULT_MODEL_AGENT_API_BASE,
|
|
33
33
|
DEFAULT_MODEL_AGENT_NAME,
|
|
34
|
+
DEFAULT_MODEL_EXTRA_HEADERS,
|
|
34
35
|
)
|
|
35
36
|
from veadk.evaluation import EvalSetRecorder
|
|
36
37
|
from veadk.knowledgebase import KnowledgeBase
|
|
@@ -64,7 +65,7 @@ class Agent(LlmAgent):
|
|
|
64
65
|
model_name: str = getenv("MODEL_AGENT_NAME", DEFAULT_MODEL_AGENT_NAME)
|
|
65
66
|
"""The name of the model for agent running."""
|
|
66
67
|
|
|
67
|
-
model_provider: str = getenv("MODEL_AGENT_PROVIDER",
|
|
68
|
+
model_provider: str = getenv("MODEL_AGENT_PROVIDER", DEFAULT_MODEL_AGENT_PROVIDER)
|
|
68
69
|
"""The provider of the model for agent running."""
|
|
69
70
|
|
|
70
71
|
model_api_base: str = getenv("MODEL_AGENT_API_BASE", DEFAULT_MODEL_AGENT_API_BASE)
|
|
@@ -73,6 +74,9 @@ class Agent(LlmAgent):
|
|
|
73
74
|
model_api_key: str = Field(default_factory=lambda: getenv("MODEL_AGENT_API_KEY"))
|
|
74
75
|
"""The api key of the model for agent running."""
|
|
75
76
|
|
|
77
|
+
model_extra_config: dict = Field(default_factory=dict)
|
|
78
|
+
"""The extra config to include in the model requests."""
|
|
79
|
+
|
|
76
80
|
tools: list[ToolUnion] = []
|
|
77
81
|
"""The tools provided to agent."""
|
|
78
82
|
|
|
@@ -96,11 +100,27 @@ class Agent(LlmAgent):
|
|
|
96
100
|
|
|
97
101
|
def model_post_init(self, __context: Any) -> None:
|
|
98
102
|
super().model_post_init(None) # for sub_agents init
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
103
|
+
|
|
104
|
+
# add model request source (veadk) in extra headers
|
|
105
|
+
if self.model_extra_config and "extra_headers" in self.model_extra_config:
|
|
106
|
+
self.model_extra_config["extra_headers"] |= DEFAULT_MODEL_EXTRA_HEADERS
|
|
107
|
+
else:
|
|
108
|
+
self.model_extra_config["extra_headers"] = DEFAULT_MODEL_EXTRA_HEADERS
|
|
109
|
+
|
|
110
|
+
if not self.model:
|
|
111
|
+
self.model = LiteLlm(
|
|
112
|
+
model=f"{self.model_provider}/{self.model_name}",
|
|
113
|
+
api_key=self.model_api_key,
|
|
114
|
+
api_base=self.model_api_base,
|
|
115
|
+
**self.model_extra_config,
|
|
116
|
+
)
|
|
117
|
+
logger.debug(
|
|
118
|
+
f"LiteLLM client created with config: {self.model_extra_config}"
|
|
119
|
+
)
|
|
120
|
+
else:
|
|
121
|
+
logger.warning(
|
|
122
|
+
"You are trying to use your own LiteLLM client, some default request headers may be missing."
|
|
123
|
+
)
|
|
104
124
|
|
|
105
125
|
if self.knowledgebase:
|
|
106
126
|
from veadk.tools import load_knowledgebase_tool
|
|
@@ -117,7 +137,7 @@ class Agent(LlmAgent):
|
|
|
117
137
|
|
|
118
138
|
logger.info(f"{self.__class__.__name__} `{self.name}` init done.")
|
|
119
139
|
logger.debug(
|
|
120
|
-
f"Agent: {self.model_dump(include={'name', 'model_name', 'model_api_base', 'tools'
|
|
140
|
+
f"Agent: {self.model_dump(include={'name', 'model_name', 'model_api_base', 'tools'})}"
|
|
121
141
|
)
|
|
122
142
|
|
|
123
143
|
async def _run(
|
veadk/cli/cli_deploy.py
CHANGED
|
@@ -29,7 +29,9 @@ TEMP_PATH = "/tmp"
|
|
|
29
29
|
default=None,
|
|
30
30
|
help="Volcengine secret key",
|
|
31
31
|
)
|
|
32
|
-
@click.option(
|
|
32
|
+
@click.option(
|
|
33
|
+
"--vefaas-app-name", required=True, help="Expected Volcengine FaaS application name"
|
|
34
|
+
)
|
|
33
35
|
@click.option(
|
|
34
36
|
"--veapig-instance-name", default="", help="Expected Volcengine APIG instance name"
|
|
35
37
|
)
|
veadk/cloud/cloud_app.py
CHANGED
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import json
|
|
16
|
+
import time
|
|
16
17
|
from typing import Any
|
|
17
18
|
from uuid import uuid4
|
|
18
19
|
|
|
@@ -60,9 +61,11 @@ class CloudApp:
|
|
|
60
61
|
if not vefaas_endpoint:
|
|
61
62
|
self.vefaas_endpoint = self._get_vefaas_endpoint()
|
|
62
63
|
|
|
63
|
-
if
|
|
64
|
-
|
|
65
|
-
|
|
64
|
+
if (
|
|
65
|
+
self.vefaas_endpoint
|
|
66
|
+
and not self.vefaas_endpoint.startswith("http")
|
|
67
|
+
and not self.vefaas_endpoint.startswith("https")
|
|
68
|
+
):
|
|
66
69
|
raise ValueError(
|
|
67
70
|
f"Invalid endpoint: {vefaas_endpoint}. The endpoint must start with `http` or `https`."
|
|
68
71
|
)
|
|
@@ -92,12 +95,13 @@ class CloudApp:
|
|
|
92
95
|
raise ValueError(
|
|
93
96
|
f"VeFaaS CloudAPP with application_id `{self.vefaas_application_id}` or application_name `{self.vefaas_application_name}` not found."
|
|
94
97
|
)
|
|
95
|
-
cloud_resource = json.loads(app["CloudResource"])
|
|
96
98
|
|
|
97
99
|
try:
|
|
100
|
+
cloud_resource = json.loads(app["CloudResource"])
|
|
98
101
|
vefaas_endpoint = cloud_resource["framework"]["url"]["system_url"]
|
|
99
102
|
except Exception as e:
|
|
100
|
-
|
|
103
|
+
logger.warning(f"VeFaaS cloudAPP could not get endpoint. Error: {e}")
|
|
104
|
+
vefaas_endpoint = ""
|
|
101
105
|
return vefaas_endpoint
|
|
102
106
|
|
|
103
107
|
def _get_vefaas_application_id_by_name(self) -> str:
|
|
@@ -167,7 +171,18 @@ class CloudApp:
|
|
|
167
171
|
|
|
168
172
|
vefaas_client = VeFaaS(access_key=volcengine_ak, secret_key=volcengine_sk)
|
|
169
173
|
vefaas_client.delete(self.vefaas_application_id)
|
|
170
|
-
print(
|
|
174
|
+
print(
|
|
175
|
+
f"Cloud app {self.vefaas_application_id} delete request has been sent to VeFaaS"
|
|
176
|
+
)
|
|
177
|
+
while True:
|
|
178
|
+
try:
|
|
179
|
+
id = self._get_vefaas_application_id_by_name()
|
|
180
|
+
if not id:
|
|
181
|
+
break
|
|
182
|
+
time.sleep(3)
|
|
183
|
+
except Exception as _:
|
|
184
|
+
break
|
|
185
|
+
print("Delete application done.")
|
|
171
186
|
|
|
172
187
|
async def message_send(
|
|
173
188
|
self, message: str, session_id: str, user_id: str, timeout: float = 600.0
|
veadk/consts.py
CHANGED
|
@@ -12,6 +12,19 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
+
from veadk.version import VERSION
|
|
16
|
+
|
|
15
17
|
DEFAULT_MODEL_AGENT_NAME = "doubao-seed-1-6-250615"
|
|
16
|
-
|
|
18
|
+
DEFAULT_MODEL_AGENT_PROVIDER = "openai"
|
|
17
19
|
DEFAULT_MODEL_AGENT_API_BASE = "https://ark.cn-beijing.volces.com/api/v3/"
|
|
20
|
+
DEFAULT_MODEL_EXTRA_HEADERS = {"veadk-source": "veadk", "veadk-version": VERSION}
|
|
21
|
+
|
|
22
|
+
DEFAULT_APMPLUS_OTEL_EXPORTER_ENDPOINT = "http://apmplus-cn-beijing.volces.com:4317"
|
|
23
|
+
DEFAULT_APMPLUS_OTEL_EXPORTER_SERVICE_NAME = "veadk_tracing"
|
|
24
|
+
|
|
25
|
+
DEFAULT_COZELOOP_OTEL_EXPORTER_ENDPOINT = (
|
|
26
|
+
"https://api.coze.cn/v1/loop/opentelemetry/v1/traces"
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
DEFAULT_TLS_OTEL_EXPORTER_ENDPOINT = "https://tls-cn-beijing.volces.com:4318/v1/traces"
|
|
30
|
+
DEFAULT_TLS_OTEL_EXPORTER_REGION = "cn-beijing"
|
|
@@ -387,9 +387,9 @@ class VikingDatabase(BaseModel, BaseDatabase):
|
|
|
387
387
|
logger.error(f"Error in list_collections: {result['message']}")
|
|
388
388
|
raise ValueError(f"Error in list_collections: {result['message']}")
|
|
389
389
|
|
|
390
|
-
collections = result["data"]
|
|
391
|
-
if
|
|
392
|
-
|
|
390
|
+
collections = result["data"].get("collection_list", [])
|
|
391
|
+
if len(collections) == 0:
|
|
392
|
+
return False
|
|
393
393
|
|
|
394
394
|
collection_list = set()
|
|
395
395
|
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from veadk.cloud.cloud_app import CloudApp
|
|
16
|
+
|
|
17
|
+
def main() -> None:
|
|
18
|
+
cloud_app = CloudApp(vefaas_application_name="{{cookiecutter.vefaas_application_name}}")
|
|
19
|
+
cloud_app.delete_self()
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
if __name__ == "__main__":
|
|
23
|
+
main()
|
|
@@ -26,6 +26,7 @@ from fastmcp import FastMCP
|
|
|
26
26
|
from starlette.routing import Route
|
|
27
27
|
|
|
28
28
|
from google.adk.a2a.utils.agent_card_builder import AgentCardBuilder
|
|
29
|
+
from a2a.types import AgentProvider
|
|
29
30
|
|
|
30
31
|
from veadk.a2a.ve_a2a_server import init_app
|
|
31
32
|
from veadk.runner import Runner
|
|
@@ -46,7 +47,9 @@ app_name = agent_run_config.app_name
|
|
|
46
47
|
agent = agent_run_config.agent
|
|
47
48
|
short_term_memory = agent_run_config.short_term_memory
|
|
48
49
|
|
|
49
|
-
|
|
50
|
+
VEFAAS_REGION = os.getenv("APP_REGION", "cn-beijing")
|
|
51
|
+
VEFAAS_FUNC_ID = os.getenv("_FAAS_FUNC_ID", "")
|
|
52
|
+
agent_card_builder = AgentCardBuilder(agent=agent, provider=AgentProvider(organization="Volcengine Agent Development Kit (VeADK)", url=f"https://console.volcengine.com/vefaas/region:vefaas+{VEFAAS_REGION}/function/detail/{VEFAAS_FUNC_ID}"))
|
|
50
53
|
|
|
51
54
|
|
|
52
55
|
def load_tracer() -> None:
|
|
@@ -34,7 +34,17 @@ while [[ $# -gt 0 ]]; do
|
|
|
34
34
|
done
|
|
35
35
|
|
|
36
36
|
# in case of deployment deps not installed in user's requirements.txt
|
|
37
|
-
|
|
37
|
+
if pip list | grep -q "^fastapi \|^uvicorn "; then
|
|
38
|
+
echo "fastapi and uvicorn already installed"
|
|
39
|
+
else
|
|
40
|
+
python3 -m pip install uvicorn[standard] fastapi
|
|
41
|
+
fi
|
|
42
|
+
|
|
43
|
+
# Check if MODEL_AGENT_API_KEY is set
|
|
44
|
+
if [ -z "$MODEL_AGENT_API_KEY" ]; then
|
|
45
|
+
echo "MODEL_AGENT_API_KEY is not set. Please set it in your environment variables."
|
|
46
|
+
exit 1
|
|
47
|
+
fi
|
|
38
48
|
|
|
39
49
|
USE_ADK_WEB=${USE_ADK_WEB:-False}
|
|
40
50
|
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import os
|
|
16
|
+
from veadk.config import getenv
|
|
17
|
+
from veadk.utils.logger import get_logger
|
|
18
|
+
import tos
|
|
19
|
+
import asyncio
|
|
20
|
+
from typing import Union
|
|
21
|
+
from pydantic import BaseModel, Field
|
|
22
|
+
from typing import Any
|
|
23
|
+
from urllib.parse import urlparse
|
|
24
|
+
from datetime import datetime
|
|
25
|
+
|
|
26
|
+
logger = get_logger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class TOSConfig(BaseModel):
|
|
30
|
+
region: str = Field(
|
|
31
|
+
default_factory=lambda: getenv("DATABASE_TOS_REGION"),
|
|
32
|
+
description="TOS region",
|
|
33
|
+
)
|
|
34
|
+
ak: str = Field(
|
|
35
|
+
default_factory=lambda: getenv("VOLCENGINE_ACCESS_KEY"),
|
|
36
|
+
description="Volcengine access key",
|
|
37
|
+
)
|
|
38
|
+
sk: str = Field(
|
|
39
|
+
default_factory=lambda: getenv("VOLCENGINE_SECRET_KEY"),
|
|
40
|
+
description="Volcengine secret key",
|
|
41
|
+
)
|
|
42
|
+
bucket_name: str = Field(
|
|
43
|
+
default_factory=lambda: getenv("DATABASE_TOS_BUCKET"),
|
|
44
|
+
description="TOS bucket name",
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class VeTOS(BaseModel):
|
|
49
|
+
config: TOSConfig = Field(default_factory=TOSConfig)
|
|
50
|
+
|
|
51
|
+
def model_post_init(self, __context: Any) -> None:
|
|
52
|
+
try:
|
|
53
|
+
self._client = tos.TosClientV2(
|
|
54
|
+
self.config.ak,
|
|
55
|
+
self.config.sk,
|
|
56
|
+
endpoint=f"tos-{self.config.region}.volces.com",
|
|
57
|
+
region=self.config.region,
|
|
58
|
+
)
|
|
59
|
+
logger.info("Connected to TOS successfully.")
|
|
60
|
+
except Exception as e:
|
|
61
|
+
logger.error(f"Client initialization failed:{e}")
|
|
62
|
+
return None
|
|
63
|
+
|
|
64
|
+
def create_bucket(self) -> bool:
|
|
65
|
+
"""If the bucket does not exist, create it"""
|
|
66
|
+
try:
|
|
67
|
+
self._client.head_bucket(self.config.bucket_name)
|
|
68
|
+
logger.info(f"Bucket {self.config.bucket_name} already exists")
|
|
69
|
+
return True
|
|
70
|
+
except tos.exceptions.TosServerError as e:
|
|
71
|
+
if e.status_code == 404:
|
|
72
|
+
self._client.create_bucket(
|
|
73
|
+
bucket=self.config.bucket_name,
|
|
74
|
+
storage_class=tos.StorageClassType.Storage_Class_Standard,
|
|
75
|
+
acl=tos.ACLType.ACL_Private,
|
|
76
|
+
)
|
|
77
|
+
logger.info(f"Bucket {self.config.bucket_name} created successfully")
|
|
78
|
+
return True
|
|
79
|
+
except Exception as e:
|
|
80
|
+
logger.error(f"Bucket creation failed: {str(e)}")
|
|
81
|
+
return False
|
|
82
|
+
|
|
83
|
+
def build_tos_url(
|
|
84
|
+
self, user_id: str, app_name: str, session_id: str, data_path: str
|
|
85
|
+
) -> tuple[str, str]:
|
|
86
|
+
"""generate TOS object key"""
|
|
87
|
+
parsed_url = urlparse(data_path)
|
|
88
|
+
|
|
89
|
+
if parsed_url.scheme and parsed_url.scheme in ("http", "https", "ftp", "ftps"):
|
|
90
|
+
file_name = os.path.basename(parsed_url.path)
|
|
91
|
+
else:
|
|
92
|
+
file_name = os.path.basename(data_path)
|
|
93
|
+
|
|
94
|
+
timestamp: str = datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3]
|
|
95
|
+
object_key: str = f"{app_name}-{user_id}-{session_id}/{timestamp}-{file_name}"
|
|
96
|
+
tos_url: str = f"https://{self.config.bucket_name}.tos-{self.config.region}.volces.com/{object_key}"
|
|
97
|
+
|
|
98
|
+
return object_key, tos_url
|
|
99
|
+
|
|
100
|
+
def upload(
|
|
101
|
+
self,
|
|
102
|
+
object_key: str,
|
|
103
|
+
data: Union[str, bytes],
|
|
104
|
+
):
|
|
105
|
+
if isinstance(data, str):
|
|
106
|
+
data_type = "file"
|
|
107
|
+
elif isinstance(data, bytes):
|
|
108
|
+
data_type = "bytes"
|
|
109
|
+
else:
|
|
110
|
+
error_msg = f"Upload failed: data type error. Only str (file path) and bytes are supported, got {type(data)}"
|
|
111
|
+
logger.error(error_msg)
|
|
112
|
+
raise ValueError(error_msg)
|
|
113
|
+
if data_type == "file":
|
|
114
|
+
return asyncio.to_thread(self._do_upload_file, object_key, data)
|
|
115
|
+
elif data_type == "bytes":
|
|
116
|
+
return asyncio.to_thread(self._do_upload_bytes, object_key, data)
|
|
117
|
+
|
|
118
|
+
def _do_upload_bytes(self, object_key: str, bytes: bytes) -> bool:
|
|
119
|
+
try:
|
|
120
|
+
if not self._client:
|
|
121
|
+
return False
|
|
122
|
+
if not self.create_bucket():
|
|
123
|
+
return False
|
|
124
|
+
self._client.put_object(
|
|
125
|
+
bucket=self.config.bucket_name, key=object_key, content=bytes
|
|
126
|
+
)
|
|
127
|
+
logger.debug(f"Upload success, object_key: {object_key}")
|
|
128
|
+
self._close()
|
|
129
|
+
return True
|
|
130
|
+
except Exception as e:
|
|
131
|
+
logger.error(f"Upload failed: {e}")
|
|
132
|
+
self._close()
|
|
133
|
+
return False
|
|
134
|
+
|
|
135
|
+
def _do_upload_file(self, object_key: str, file_path: str) -> bool:
|
|
136
|
+
try:
|
|
137
|
+
if not self._client:
|
|
138
|
+
return False
|
|
139
|
+
if not self.create_bucket():
|
|
140
|
+
return False
|
|
141
|
+
|
|
142
|
+
self._client.put_object_from_file(
|
|
143
|
+
bucket=self.config.bucket_name, key=object_key, file_path=file_path
|
|
144
|
+
)
|
|
145
|
+
self._close()
|
|
146
|
+
logger.debug(f"Upload success, object_key: {object_key}")
|
|
147
|
+
return True
|
|
148
|
+
except Exception as e:
|
|
149
|
+
logger.error(f"Upload failed: {e}")
|
|
150
|
+
self._close()
|
|
151
|
+
return False
|
|
152
|
+
|
|
153
|
+
def download(self, object_key: str, save_path: str) -> bool:
|
|
154
|
+
"""download image from TOS"""
|
|
155
|
+
try:
|
|
156
|
+
object_stream = self._client.get_object(self.config.bucket_name, object_key)
|
|
157
|
+
|
|
158
|
+
save_dir = os.path.dirname(save_path)
|
|
159
|
+
if save_dir and not os.path.exists(save_dir):
|
|
160
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
161
|
+
|
|
162
|
+
with open(save_path, "wb") as f:
|
|
163
|
+
for chunk in object_stream:
|
|
164
|
+
f.write(chunk)
|
|
165
|
+
|
|
166
|
+
logger.debug(f"Image download success, saved to: {save_path}")
|
|
167
|
+
return True
|
|
168
|
+
|
|
169
|
+
except Exception as e:
|
|
170
|
+
logger.error(f"Image download failed: {str(e)}")
|
|
171
|
+
|
|
172
|
+
return False
|
|
173
|
+
|
|
174
|
+
def _close(self):
|
|
175
|
+
if self._client:
|
|
176
|
+
self._client.close()
|
veadk/runner.py
CHANGED
|
@@ -11,9 +11,11 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
import asyncio
|
|
14
15
|
from typing import Union
|
|
15
16
|
|
|
16
17
|
from google.adk.agents import RunConfig
|
|
18
|
+
from google.adk.agents.invocation_context import LlmCallsLimitExceededError
|
|
17
19
|
from google.adk.agents.run_config import StreamingMode
|
|
18
20
|
from google.adk.plugins.base_plugin import BasePlugin
|
|
19
21
|
from google.adk.runners import Runner as ADKRunner
|
|
@@ -30,6 +32,7 @@ from veadk.memory.short_term_memory import ShortTermMemory
|
|
|
30
32
|
from veadk.types import MediaMessage
|
|
31
33
|
from veadk.utils.logger import get_logger
|
|
32
34
|
from veadk.utils.misc import read_png_to_bytes
|
|
35
|
+
from veadk.integrations.ve_tos.ve_tos import VeTOS
|
|
33
36
|
|
|
34
37
|
logger = get_logger(__name__)
|
|
35
38
|
|
|
@@ -49,20 +52,25 @@ class Runner:
|
|
|
49
52
|
def __init__(
|
|
50
53
|
self,
|
|
51
54
|
agent: VeAgent,
|
|
52
|
-
short_term_memory: ShortTermMemory,
|
|
55
|
+
short_term_memory: ShortTermMemory | None = None,
|
|
53
56
|
plugins: list[BasePlugin] | None = None,
|
|
54
57
|
app_name: str = "veadk_default_app",
|
|
55
58
|
user_id: str = "veadk_default_user",
|
|
56
59
|
):
|
|
57
|
-
# basic settings
|
|
58
60
|
self.app_name = app_name
|
|
59
61
|
self.user_id = user_id
|
|
60
62
|
|
|
61
|
-
# agent settings
|
|
62
63
|
self.agent = agent
|
|
63
64
|
|
|
64
|
-
|
|
65
|
-
|
|
65
|
+
if not short_term_memory:
|
|
66
|
+
logger.info(
|
|
67
|
+
"No short term memory provided, using a in-memory memory by default."
|
|
68
|
+
)
|
|
69
|
+
self.short_term_memory = ShortTermMemory()
|
|
70
|
+
else:
|
|
71
|
+
self.short_term_memory = short_term_memory
|
|
72
|
+
|
|
73
|
+
self.session_service = self.short_term_memory.session_service
|
|
66
74
|
|
|
67
75
|
# prevent VeRemoteAgent has no long-term memory attr
|
|
68
76
|
if isinstance(self.agent, Agent):
|
|
@@ -78,13 +86,25 @@ class Runner:
|
|
|
78
86
|
plugins=plugins,
|
|
79
87
|
)
|
|
80
88
|
|
|
81
|
-
def _convert_messages(self, messages) -> list:
|
|
89
|
+
def _convert_messages(self, messages, session_id) -> list:
|
|
82
90
|
if isinstance(messages, str):
|
|
83
91
|
messages = [types.Content(role="user", parts=[types.Part(text=messages)])]
|
|
84
92
|
elif isinstance(messages, MediaMessage):
|
|
85
93
|
assert messages.media.endswith(".png"), (
|
|
86
94
|
"The MediaMessage only supports PNG format file for now."
|
|
87
95
|
)
|
|
96
|
+
data = read_png_to_bytes(messages.media)
|
|
97
|
+
|
|
98
|
+
ve_tos = VeTOS()
|
|
99
|
+
object_key, tos_url = ve_tos.build_tos_url(
|
|
100
|
+
self.user_id, self.app_name, session_id, messages.media
|
|
101
|
+
)
|
|
102
|
+
try:
|
|
103
|
+
asyncio.create_task(ve_tos.upload(object_key, data))
|
|
104
|
+
except Exception as e:
|
|
105
|
+
logger.error(f"Upload to TOS failed: {e}")
|
|
106
|
+
tos_url = None
|
|
107
|
+
|
|
88
108
|
messages = [
|
|
89
109
|
types.Content(
|
|
90
110
|
role="user",
|
|
@@ -92,8 +112,8 @@ class Runner:
|
|
|
92
112
|
types.Part(text=messages.text),
|
|
93
113
|
types.Part(
|
|
94
114
|
inline_data=Blob(
|
|
95
|
-
display_name=
|
|
96
|
-
data=
|
|
115
|
+
display_name=tos_url,
|
|
116
|
+
data=data,
|
|
97
117
|
mime_type="image/png",
|
|
98
118
|
)
|
|
99
119
|
),
|
|
@@ -103,7 +123,7 @@ class Runner:
|
|
|
103
123
|
elif isinstance(messages, list):
|
|
104
124
|
converted_messages = []
|
|
105
125
|
for message in messages:
|
|
106
|
-
converted_messages.extend(self._convert_messages(message))
|
|
126
|
+
converted_messages.extend(self._convert_messages(message, session_id))
|
|
107
127
|
messages = converted_messages
|
|
108
128
|
else:
|
|
109
129
|
raise ValueError(f"Unknown message type: {type(messages)}")
|
|
@@ -114,35 +134,44 @@ class Runner:
|
|
|
114
134
|
self,
|
|
115
135
|
session_id: str,
|
|
116
136
|
message: types.Content,
|
|
137
|
+
run_config: RunConfig | None = None,
|
|
117
138
|
stream: bool = False,
|
|
118
139
|
):
|
|
119
140
|
stream_mode = StreamingMode.SSE if stream else StreamingMode.NONE
|
|
120
141
|
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
run_config=RunConfig(streaming_mode=stream_mode),
|
|
127
|
-
):
|
|
128
|
-
if event.get_function_calls():
|
|
129
|
-
for function_call in event.get_function_calls():
|
|
130
|
-
logger.debug(f"Function call: {function_call}")
|
|
131
|
-
elif (
|
|
132
|
-
event.content is not None
|
|
133
|
-
and event.content.parts
|
|
134
|
-
and event.content.parts[0].text is not None
|
|
135
|
-
and len(event.content.parts[0].text.strip()) > 0
|
|
136
|
-
):
|
|
137
|
-
yield event.content.parts[0].text
|
|
142
|
+
if run_config is not None:
|
|
143
|
+
stream_mode = run_config.streaming_mode
|
|
144
|
+
else:
|
|
145
|
+
run_config = RunConfig(streaming_mode=stream_mode)
|
|
146
|
+
try:
|
|
138
147
|
|
|
139
|
-
|
|
140
|
-
|
|
148
|
+
async def event_generator():
|
|
149
|
+
async for event in self.runner.run_async(
|
|
150
|
+
user_id=self.user_id,
|
|
151
|
+
session_id=session_id,
|
|
152
|
+
new_message=message,
|
|
153
|
+
run_config=run_config,
|
|
154
|
+
):
|
|
155
|
+
if event.get_function_calls():
|
|
156
|
+
for function_call in event.get_function_calls():
|
|
157
|
+
logger.debug(f"Function call: {function_call}")
|
|
158
|
+
elif (
|
|
159
|
+
event.content is not None
|
|
160
|
+
and event.content.parts
|
|
161
|
+
and event.content.parts[0].text is not None
|
|
162
|
+
and len(event.content.parts[0].text.strip()) > 0
|
|
163
|
+
):
|
|
164
|
+
yield event.content.parts[0].text
|
|
165
|
+
|
|
166
|
+
final_output = ""
|
|
167
|
+
async for chunk in event_generator():
|
|
168
|
+
if stream:
|
|
169
|
+
print(chunk, end="", flush=True)
|
|
170
|
+
final_output += chunk
|
|
141
171
|
if stream:
|
|
142
|
-
print(
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
print() # end with a new line
|
|
172
|
+
print() # end with a new line
|
|
173
|
+
except LlmCallsLimitExceededError as e:
|
|
174
|
+
logger.warning(f"Max number of llm calls limit exceeded: {e}")
|
|
146
175
|
|
|
147
176
|
return final_output
|
|
148
177
|
|
|
@@ -151,9 +180,10 @@ class Runner:
|
|
|
151
180
|
messages: RunnerMessage,
|
|
152
181
|
session_id: str,
|
|
153
182
|
stream: bool = False,
|
|
183
|
+
run_config: RunConfig | None = None,
|
|
154
184
|
save_tracing_data: bool = False,
|
|
155
185
|
):
|
|
156
|
-
converted_messages: list = self._convert_messages(messages)
|
|
186
|
+
converted_messages: list = self._convert_messages(messages, session_id)
|
|
157
187
|
|
|
158
188
|
await self.short_term_memory.create_session(
|
|
159
189
|
app_name=self.app_name, user_id=self.user_id, session_id=session_id
|
|
@@ -163,7 +193,9 @@ class Runner:
|
|
|
163
193
|
|
|
164
194
|
final_output = ""
|
|
165
195
|
for converted_message in converted_messages:
|
|
166
|
-
final_output = await self._run(
|
|
196
|
+
final_output = await self._run(
|
|
197
|
+
session_id, converted_message, run_config, stream
|
|
198
|
+
)
|
|
167
199
|
|
|
168
200
|
# try to save tracing file
|
|
169
201
|
if save_tracing_data:
|
|
@@ -193,6 +225,47 @@ class Runner:
|
|
|
193
225
|
logger.warning(f"Get tracer id failed as {e}")
|
|
194
226
|
return "<unknown_trace_id>"
|
|
195
227
|
|
|
228
|
+
async def run_with_raw_message(
|
|
229
|
+
self,
|
|
230
|
+
message: types.Content,
|
|
231
|
+
session_id: str,
|
|
232
|
+
run_config: RunConfig | None = None,
|
|
233
|
+
):
|
|
234
|
+
run_config = RunConfig() if not run_config else run_config
|
|
235
|
+
|
|
236
|
+
await self.short_term_memory.create_session(
|
|
237
|
+
app_name=self.app_name, user_id=self.user_id, session_id=session_id
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
try:
|
|
241
|
+
|
|
242
|
+
async def event_generator():
|
|
243
|
+
async for event in self.runner.run_async(
|
|
244
|
+
user_id=self.user_id,
|
|
245
|
+
session_id=session_id,
|
|
246
|
+
new_message=message,
|
|
247
|
+
run_config=run_config,
|
|
248
|
+
):
|
|
249
|
+
if event.get_function_calls():
|
|
250
|
+
for function_call in event.get_function_calls():
|
|
251
|
+
logger.debug(f"Function call: {function_call}")
|
|
252
|
+
elif (
|
|
253
|
+
event.content is not None
|
|
254
|
+
and event.content.parts
|
|
255
|
+
and event.content.parts[0].text is not None
|
|
256
|
+
and len(event.content.parts[0].text.strip()) > 0
|
|
257
|
+
):
|
|
258
|
+
yield event.content.parts[0].text
|
|
259
|
+
|
|
260
|
+
final_output = ""
|
|
261
|
+
|
|
262
|
+
async for chunk in event_generator():
|
|
263
|
+
final_output += chunk
|
|
264
|
+
except LlmCallsLimitExceededError as e:
|
|
265
|
+
logger.warning(f"Max number of llm calls limit exceeded: {e}")
|
|
266
|
+
|
|
267
|
+
return final_output
|
|
268
|
+
|
|
196
269
|
def _print_trace_id(self) -> None:
|
|
197
270
|
if not isinstance(self.agent, Agent):
|
|
198
271
|
logger.warning(
|