veadk-python 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of veadk-python might be problematic. Click here for more details.
- veadk/agent.py +40 -8
- veadk/cli/cli_deploy.py +5 -1
- veadk/cli/cli_init.py +25 -6
- veadk/cloud/cloud_app.py +21 -6
- veadk/consts.py +33 -1
- veadk/database/database_adapter.py +88 -0
- veadk/database/kv/redis_database.py +47 -0
- veadk/database/local_database.py +22 -4
- veadk/database/relational/mysql_database.py +58 -0
- veadk/database/vector/opensearch_vector_database.py +6 -3
- veadk/database/viking/viking_database.py +72 -3
- veadk/integrations/ve_cr/__init__.py +13 -0
- veadk/integrations/ve_cr/ve_cr.py +205 -0
- veadk/integrations/ve_faas/template/cookiecutter.json +2 -1
- veadk/integrations/ve_faas/template/{{cookiecutter.local_dir_name}}/clean.py +23 -0
- veadk/integrations/ve_faas/template/{{cookiecutter.local_dir_name}}/src/app.py +28 -2
- veadk/integrations/ve_faas/template/{{cookiecutter.local_dir_name}}/src/requirements.txt +3 -1
- veadk/integrations/ve_faas/template/{{cookiecutter.local_dir_name}}/src/run.sh +5 -2
- veadk/integrations/ve_faas/ve_faas.py +2 -0
- veadk/integrations/ve_faas/web_template/cookiecutter.json +17 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/__init__.py +13 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/clean.py +23 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/config.yaml.example +2 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/deploy.py +41 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/src/Dockerfile +23 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/src/app.py +123 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/src/init_db.py +46 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/src/models.py +36 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/src/requirements.txt +4 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/src/run.sh +21 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/src/static/css/style.css +368 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/src/static/js/admin.js +0 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/src/templates/admin/dashboard.html +21 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/src/templates/admin/edit_post.html +24 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/src/templates/admin/login.html +21 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/src/templates/admin/posts.html +53 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/src/templates/base.html +45 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/src/templates/index.html +29 -0
- veadk/integrations/ve_faas/web_template/{{cookiecutter.local_dir_name}}/src/templates/post.html +14 -0
- veadk/integrations/ve_tos/ve_tos.py +238 -0
- veadk/knowledgebase/knowledgebase.py +8 -0
- veadk/runner.py +140 -34
- veadk/tools/builtin_tools/image_edit.py +236 -0
- veadk/tools/builtin_tools/image_generate.py +236 -0
- veadk/tools/builtin_tools/video_generate.py +326 -0
- veadk/tools/sandbox/browser_sandbox.py +19 -9
- veadk/tools/sandbox/code_sandbox.py +21 -11
- veadk/tools/sandbox/computer_sandbox.py +16 -9
- veadk/tracing/base_tracer.py +0 -19
- veadk/tracing/telemetry/attributes/extractors/common_attributes_extractors.py +5 -0
- veadk/tracing/telemetry/attributes/extractors/llm_attributes_extractors.py +311 -128
- veadk/tracing/telemetry/attributes/extractors/tool_attributes_extractors.py +20 -14
- veadk/tracing/telemetry/attributes/extractors/types.py +15 -4
- veadk/tracing/telemetry/exporters/inmemory_exporter.py +3 -0
- veadk/tracing/telemetry/opentelemetry_tracer.py +15 -6
- veadk/tracing/telemetry/telemetry.py +128 -24
- veadk/utils/misc.py +40 -0
- veadk/version.py +1 -1
- {veadk_python-0.2.4.dist-info → veadk_python-0.2.6.dist-info}/METADATA +1 -1
- {veadk_python-0.2.4.dist-info → veadk_python-0.2.6.dist-info}/RECORD +64 -37
- {veadk_python-0.2.4.dist-info → veadk_python-0.2.6.dist-info}/WHEEL +0 -0
- {veadk_python-0.2.4.dist-info → veadk_python-0.2.6.dist-info}/entry_points.txt +0 -0
- {veadk_python-0.2.4.dist-info → veadk_python-0.2.6.dist-info}/licenses/LICENSE +0 -0
- {veadk_python-0.2.4.dist-info → veadk_python-0.2.6.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import os
|
|
16
|
+
from veadk.config import getenv
|
|
17
|
+
from veadk.utils.logger import get_logger
|
|
18
|
+
import asyncio
|
|
19
|
+
from typing import Union
|
|
20
|
+
from pydantic import BaseModel, Field
|
|
21
|
+
from typing import Any
|
|
22
|
+
from urllib.parse import urlparse
|
|
23
|
+
from datetime import datetime
|
|
24
|
+
|
|
25
|
+
# Initialize logger before using it
|
|
26
|
+
logger = get_logger(__name__)
|
|
27
|
+
|
|
28
|
+
# Try to import tos module, and provide helpful error message if it fails
|
|
29
|
+
try:
|
|
30
|
+
import tos
|
|
31
|
+
except ImportError as e:
|
|
32
|
+
logger.error(
|
|
33
|
+
"Failed to import 'tos' module. Please install it using: pip install tos\n"
|
|
34
|
+
)
|
|
35
|
+
raise ImportError(
|
|
36
|
+
"Missing 'tos' module. Please install it using: pip install tos\n"
|
|
37
|
+
) from e
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class TOSConfig(BaseModel):
|
|
41
|
+
region: str = Field(
|
|
42
|
+
default_factory=lambda: getenv("DATABASE_TOS_REGION"),
|
|
43
|
+
description="TOS region",
|
|
44
|
+
)
|
|
45
|
+
ak: str = Field(
|
|
46
|
+
default_factory=lambda: getenv("VOLCENGINE_ACCESS_KEY"),
|
|
47
|
+
description="Volcengine access key",
|
|
48
|
+
)
|
|
49
|
+
sk: str = Field(
|
|
50
|
+
default_factory=lambda: getenv("VOLCENGINE_SECRET_KEY"),
|
|
51
|
+
description="Volcengine secret key",
|
|
52
|
+
)
|
|
53
|
+
bucket_name: str = Field(
|
|
54
|
+
default_factory=lambda: getenv("DATABASE_TOS_BUCKET"),
|
|
55
|
+
description="TOS bucket name",
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class VeTOS(BaseModel):
|
|
60
|
+
config: TOSConfig = Field(default_factory=TOSConfig)
|
|
61
|
+
|
|
62
|
+
def model_post_init(self, __context: Any) -> None:
|
|
63
|
+
try:
|
|
64
|
+
self._client = tos.TosClientV2(
|
|
65
|
+
self.config.ak,
|
|
66
|
+
self.config.sk,
|
|
67
|
+
endpoint=f"tos-{self.config.region}.volces.com",
|
|
68
|
+
region=self.config.region,
|
|
69
|
+
)
|
|
70
|
+
logger.info("Connected to TOS successfully.")
|
|
71
|
+
except Exception as e:
|
|
72
|
+
logger.error(f"Client initialization failed:{e}")
|
|
73
|
+
self._client = None
|
|
74
|
+
|
|
75
|
+
def _refresh_client(self):
|
|
76
|
+
try:
|
|
77
|
+
if self._client:
|
|
78
|
+
self._client.close()
|
|
79
|
+
self._client = tos.TosClientV2(
|
|
80
|
+
self.config.ak,
|
|
81
|
+
self.config.sk,
|
|
82
|
+
endpoint=f"tos-{self.config.region}.volces.com",
|
|
83
|
+
region=self.config.region,
|
|
84
|
+
)
|
|
85
|
+
logger.info("refreshed client successfully.")
|
|
86
|
+
except Exception as e:
|
|
87
|
+
logger.error(f"Failed to refresh client: {str(e)}")
|
|
88
|
+
self._client = None
|
|
89
|
+
|
|
90
|
+
def create_bucket(self) -> bool:
|
|
91
|
+
"""If the bucket does not exist, create it and set CORS rules"""
|
|
92
|
+
if not self._client:
|
|
93
|
+
logger.error("TOS client is not initialized")
|
|
94
|
+
return False
|
|
95
|
+
try:
|
|
96
|
+
self._client.head_bucket(self.config.bucket_name)
|
|
97
|
+
logger.info(f"Bucket {self.config.bucket_name} already exists")
|
|
98
|
+
except tos.exceptions.TosServerError as e:
|
|
99
|
+
if e.status_code == 404:
|
|
100
|
+
try:
|
|
101
|
+
self._client.create_bucket(
|
|
102
|
+
bucket=self.config.bucket_name,
|
|
103
|
+
storage_class=tos.StorageClassType.Storage_Class_Standard,
|
|
104
|
+
acl=tos.ACLType.ACL_Public_Read, # 公开读
|
|
105
|
+
)
|
|
106
|
+
logger.info(
|
|
107
|
+
f"Bucket {self.config.bucket_name} created successfully"
|
|
108
|
+
)
|
|
109
|
+
self._refresh_client()
|
|
110
|
+
except Exception as create_error:
|
|
111
|
+
logger.error(f"Bucket creation failed: {str(create_error)}")
|
|
112
|
+
return False
|
|
113
|
+
else:
|
|
114
|
+
logger.error(f"Bucket check failed: {str(e)}")
|
|
115
|
+
return False
|
|
116
|
+
except Exception as e:
|
|
117
|
+
logger.error(f"Bucket check failed: {str(e)}")
|
|
118
|
+
return False
|
|
119
|
+
|
|
120
|
+
# 确保在所有路径上返回布尔值
|
|
121
|
+
return self._set_cors_rules()
|
|
122
|
+
|
|
123
|
+
def _set_cors_rules(self) -> bool:
|
|
124
|
+
if not self._client:
|
|
125
|
+
logger.error("TOS client is not initialized")
|
|
126
|
+
return False
|
|
127
|
+
try:
|
|
128
|
+
rule = tos.models2.CORSRule(
|
|
129
|
+
allowed_origins=["*"],
|
|
130
|
+
allowed_methods=["GET", "HEAD"],
|
|
131
|
+
allowed_headers=["*"],
|
|
132
|
+
max_age_seconds=1000,
|
|
133
|
+
)
|
|
134
|
+
self._client.put_bucket_cors(self.config.bucket_name, [rule])
|
|
135
|
+
logger.info(
|
|
136
|
+
f"CORS rules for bucket {self.config.bucket_name} set successfully"
|
|
137
|
+
)
|
|
138
|
+
return True
|
|
139
|
+
except Exception as e:
|
|
140
|
+
logger.error(
|
|
141
|
+
f"Failed to set CORS rules for bucket {self.config.bucket_name}: {str(e)}"
|
|
142
|
+
)
|
|
143
|
+
return False
|
|
144
|
+
|
|
145
|
+
def build_tos_url(
|
|
146
|
+
self, user_id: str, app_name: str, session_id: str, data_path: str
|
|
147
|
+
) -> tuple[str, str]:
|
|
148
|
+
"""generate TOS object key"""
|
|
149
|
+
parsed_url = urlparse(data_path)
|
|
150
|
+
|
|
151
|
+
if parsed_url.scheme and parsed_url.scheme in ("http", "https", "ftp", "ftps"):
|
|
152
|
+
file_name = os.path.basename(parsed_url.path)
|
|
153
|
+
else:
|
|
154
|
+
file_name = os.path.basename(data_path)
|
|
155
|
+
|
|
156
|
+
timestamp: str = datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3]
|
|
157
|
+
object_key: str = f"{app_name}-{user_id}-{session_id}/{timestamp}-{file_name}"
|
|
158
|
+
tos_url: str = f"https://{self.config.bucket_name}.tos-{self.config.region}.volces.com/{object_key}"
|
|
159
|
+
|
|
160
|
+
return object_key, tos_url
|
|
161
|
+
|
|
162
|
+
def upload(
|
|
163
|
+
self,
|
|
164
|
+
object_key: str,
|
|
165
|
+
data: Union[str, bytes],
|
|
166
|
+
):
|
|
167
|
+
if isinstance(data, str):
|
|
168
|
+
# data is a file path
|
|
169
|
+
return asyncio.to_thread(self._do_upload_file, object_key, data)
|
|
170
|
+
elif isinstance(data, bytes):
|
|
171
|
+
# data is bytes content
|
|
172
|
+
return asyncio.to_thread(self._do_upload_bytes, object_key, data)
|
|
173
|
+
else:
|
|
174
|
+
error_msg = f"Upload failed: data type error. Only str (file path) and bytes are supported, got {type(data)}"
|
|
175
|
+
logger.error(error_msg)
|
|
176
|
+
raise ValueError(error_msg)
|
|
177
|
+
|
|
178
|
+
def _do_upload_bytes(self, object_key: str, data: bytes) -> None:
|
|
179
|
+
try:
|
|
180
|
+
if not self._client:
|
|
181
|
+
return
|
|
182
|
+
if not self.create_bucket():
|
|
183
|
+
return
|
|
184
|
+
self._client.put_object(
|
|
185
|
+
bucket=self.config.bucket_name, key=object_key, content=data
|
|
186
|
+
)
|
|
187
|
+
logger.debug(f"Upload success, object_key: {object_key}")
|
|
188
|
+
self._close()
|
|
189
|
+
return
|
|
190
|
+
except Exception as e:
|
|
191
|
+
logger.error(f"Upload failed: {e}")
|
|
192
|
+
self._close()
|
|
193
|
+
return
|
|
194
|
+
|
|
195
|
+
def _do_upload_file(self, object_key: str, file_path: str) -> None:
|
|
196
|
+
try:
|
|
197
|
+
if not self._client:
|
|
198
|
+
return
|
|
199
|
+
if not self.create_bucket():
|
|
200
|
+
return
|
|
201
|
+
self._client.put_object_from_file(
|
|
202
|
+
bucket=self.config.bucket_name, key=object_key, file_path=file_path
|
|
203
|
+
)
|
|
204
|
+
self._close()
|
|
205
|
+
logger.debug(f"Upload success, object_key: {object_key}")
|
|
206
|
+
return
|
|
207
|
+
except Exception as e:
|
|
208
|
+
logger.error(f"Upload failed: {e}")
|
|
209
|
+
self._close()
|
|
210
|
+
return
|
|
211
|
+
|
|
212
|
+
def download(self, object_key: str, save_path: str) -> bool:
|
|
213
|
+
"""download image from TOS"""
|
|
214
|
+
if not self._client:
|
|
215
|
+
logger.error("TOS client is not initialized")
|
|
216
|
+
return False
|
|
217
|
+
try:
|
|
218
|
+
object_stream = self._client.get_object(self.config.bucket_name, object_key)
|
|
219
|
+
|
|
220
|
+
save_dir = os.path.dirname(save_path)
|
|
221
|
+
if save_dir and not os.path.exists(save_dir):
|
|
222
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
223
|
+
|
|
224
|
+
with open(save_path, "wb") as f:
|
|
225
|
+
for chunk in object_stream:
|
|
226
|
+
f.write(chunk)
|
|
227
|
+
|
|
228
|
+
logger.debug(f"Image download success, saved to: {save_path}")
|
|
229
|
+
return True
|
|
230
|
+
|
|
231
|
+
except Exception as e:
|
|
232
|
+
logger.error(f"Image download failed: {str(e)}")
|
|
233
|
+
|
|
234
|
+
return False
|
|
235
|
+
|
|
236
|
+
def _close(self):
|
|
237
|
+
if self._client:
|
|
238
|
+
self._client.close()
|
|
@@ -80,3 +80,11 @@ class KnowledgeBase:
|
|
|
80
80
|
if len(result) == 0:
|
|
81
81
|
logger.warning(f"No documents found in knowledgebase. Query: {query}")
|
|
82
82
|
return result
|
|
83
|
+
|
|
84
|
+
def delete_doc(self, app_name: str, id: str) -> bool:
|
|
85
|
+
index = build_knowledgebase_index(app_name)
|
|
86
|
+
return self.adapter.delete_doc(index=index, id=id)
|
|
87
|
+
|
|
88
|
+
def list_docs(self, app_name: str, offset: int = 0, limit: int = 100) -> list[dict]:
|
|
89
|
+
index = build_knowledgebase_index(app_name)
|
|
90
|
+
return self.adapter.list_docs(index=index, offset=offset, limit=limit)
|
veadk/runner.py
CHANGED
|
@@ -11,9 +11,11 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
import asyncio
|
|
14
15
|
from typing import Union
|
|
15
16
|
|
|
16
17
|
from google.adk.agents import RunConfig
|
|
18
|
+
from google.adk.agents.invocation_context import LlmCallsLimitExceededError
|
|
17
19
|
from google.adk.agents.run_config import StreamingMode
|
|
18
20
|
from google.adk.plugins.base_plugin import BasePlugin
|
|
19
21
|
from google.adk.runners import Runner as ADKRunner
|
|
@@ -25,6 +27,7 @@ from veadk.agent import Agent
|
|
|
25
27
|
from veadk.agents.loop_agent import LoopAgent
|
|
26
28
|
from veadk.agents.parallel_agent import ParallelAgent
|
|
27
29
|
from veadk.agents.sequential_agent import SequentialAgent
|
|
30
|
+
from veadk.config import getenv
|
|
28
31
|
from veadk.evaluation import EvalSetRecorder
|
|
29
32
|
from veadk.memory.short_term_memory import ShortTermMemory
|
|
30
33
|
from veadk.types import MediaMessage
|
|
@@ -49,20 +52,25 @@ class Runner:
|
|
|
49
52
|
def __init__(
|
|
50
53
|
self,
|
|
51
54
|
agent: VeAgent,
|
|
52
|
-
short_term_memory: ShortTermMemory,
|
|
55
|
+
short_term_memory: ShortTermMemory | None = None,
|
|
53
56
|
plugins: list[BasePlugin] | None = None,
|
|
54
57
|
app_name: str = "veadk_default_app",
|
|
55
58
|
user_id: str = "veadk_default_user",
|
|
56
59
|
):
|
|
57
|
-
# basic settings
|
|
58
60
|
self.app_name = app_name
|
|
59
61
|
self.user_id = user_id
|
|
60
62
|
|
|
61
|
-
# agent settings
|
|
62
63
|
self.agent = agent
|
|
63
64
|
|
|
64
|
-
|
|
65
|
-
|
|
65
|
+
if not short_term_memory:
|
|
66
|
+
logger.info(
|
|
67
|
+
"No short term memory provided, using a in-memory memory by default."
|
|
68
|
+
)
|
|
69
|
+
self.short_term_memory = ShortTermMemory()
|
|
70
|
+
else:
|
|
71
|
+
self.short_term_memory = short_term_memory
|
|
72
|
+
|
|
73
|
+
self.session_service = self.short_term_memory.session_service
|
|
66
74
|
|
|
67
75
|
# prevent VeRemoteAgent has no long-term memory attr
|
|
68
76
|
if isinstance(self.agent, Agent):
|
|
@@ -78,13 +86,37 @@ class Runner:
|
|
|
78
86
|
plugins=plugins,
|
|
79
87
|
)
|
|
80
88
|
|
|
81
|
-
def _convert_messages(
|
|
89
|
+
def _convert_messages(
|
|
90
|
+
self, messages, session_id, upload_inline_data_to_tos
|
|
91
|
+
) -> list:
|
|
82
92
|
if isinstance(messages, str):
|
|
83
93
|
messages = [types.Content(role="user", parts=[types.Part(text=messages)])]
|
|
84
94
|
elif isinstance(messages, MediaMessage):
|
|
85
95
|
assert messages.media.endswith(".png"), (
|
|
86
96
|
"The MediaMessage only supports PNG format file for now."
|
|
87
97
|
)
|
|
98
|
+
data = read_png_to_bytes(messages.media)
|
|
99
|
+
tos_url = "<tos_url>"
|
|
100
|
+
if upload_inline_data_to_tos:
|
|
101
|
+
try:
|
|
102
|
+
from veadk.integrations.ve_tos.ve_tos import VeTOS
|
|
103
|
+
|
|
104
|
+
ve_tos = VeTOS()
|
|
105
|
+
object_key, tos_url = ve_tos.build_tos_url(
|
|
106
|
+
self.user_id, self.app_name, session_id, messages.media
|
|
107
|
+
)
|
|
108
|
+
upload_task = ve_tos.upload(object_key, data)
|
|
109
|
+
if upload_task is not None:
|
|
110
|
+
asyncio.create_task(upload_task)
|
|
111
|
+
except Exception as e:
|
|
112
|
+
logger.error(f"Upload to TOS failed: {e}")
|
|
113
|
+
tos_url = None
|
|
114
|
+
|
|
115
|
+
else:
|
|
116
|
+
logger.warning(
|
|
117
|
+
"Loss of multimodal data may occur in the tracing process."
|
|
118
|
+
)
|
|
119
|
+
|
|
88
120
|
messages = [
|
|
89
121
|
types.Content(
|
|
90
122
|
role="user",
|
|
@@ -92,8 +124,8 @@ class Runner:
|
|
|
92
124
|
types.Part(text=messages.text),
|
|
93
125
|
types.Part(
|
|
94
126
|
inline_data=Blob(
|
|
95
|
-
display_name=
|
|
96
|
-
data=
|
|
127
|
+
display_name=tos_url,
|
|
128
|
+
data=data,
|
|
97
129
|
mime_type="image/png",
|
|
98
130
|
)
|
|
99
131
|
),
|
|
@@ -103,7 +135,11 @@ class Runner:
|
|
|
103
135
|
elif isinstance(messages, list):
|
|
104
136
|
converted_messages = []
|
|
105
137
|
for message in messages:
|
|
106
|
-
converted_messages.extend(
|
|
138
|
+
converted_messages.extend(
|
|
139
|
+
self._convert_messages(
|
|
140
|
+
message, session_id, upload_inline_data_to_tos
|
|
141
|
+
)
|
|
142
|
+
)
|
|
107
143
|
messages = converted_messages
|
|
108
144
|
else:
|
|
109
145
|
raise ValueError(f"Unknown message type: {type(messages)}")
|
|
@@ -114,35 +150,51 @@ class Runner:
|
|
|
114
150
|
self,
|
|
115
151
|
session_id: str,
|
|
116
152
|
message: types.Content,
|
|
153
|
+
run_config: RunConfig | None = None,
|
|
117
154
|
stream: bool = False,
|
|
118
155
|
):
|
|
119
156
|
stream_mode = StreamingMode.SSE if stream else StreamingMode.NONE
|
|
120
157
|
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
)
|
|
128
|
-
if event.get_function_calls():
|
|
129
|
-
for function_call in event.get_function_calls():
|
|
130
|
-
logger.debug(f"Function call: {function_call}")
|
|
131
|
-
elif (
|
|
132
|
-
event.content is not None
|
|
133
|
-
and event.content.parts
|
|
134
|
-
and event.content.parts[0].text is not None
|
|
135
|
-
and len(event.content.parts[0].text.strip()) > 0
|
|
136
|
-
):
|
|
137
|
-
yield event.content.parts[0].text
|
|
158
|
+
if run_config is not None:
|
|
159
|
+
stream_mode = run_config.streaming_mode
|
|
160
|
+
else:
|
|
161
|
+
run_config = RunConfig(
|
|
162
|
+
streaming_mode=stream_mode,
|
|
163
|
+
max_llm_calls=int(getenv("MODEL_AGENT_MAX_LLM_CALLS", 100)),
|
|
164
|
+
)
|
|
138
165
|
|
|
139
|
-
|
|
140
|
-
|
|
166
|
+
logger.info(f"Run config: {run_config}")
|
|
167
|
+
|
|
168
|
+
try:
|
|
169
|
+
|
|
170
|
+
async def event_generator():
|
|
171
|
+
async for event in self.runner.run_async(
|
|
172
|
+
user_id=self.user_id,
|
|
173
|
+
session_id=session_id,
|
|
174
|
+
new_message=message,
|
|
175
|
+
run_config=run_config,
|
|
176
|
+
):
|
|
177
|
+
if event.get_function_calls():
|
|
178
|
+
for function_call in event.get_function_calls():
|
|
179
|
+
logger.debug(f"Function call: {function_call}")
|
|
180
|
+
elif (
|
|
181
|
+
event.content is not None
|
|
182
|
+
and event.content.parts
|
|
183
|
+
and event.content.parts[0].text is not None
|
|
184
|
+
and len(event.content.parts[0].text.strip()) > 0
|
|
185
|
+
):
|
|
186
|
+
yield event.content.parts[0].text
|
|
187
|
+
|
|
188
|
+
final_output = ""
|
|
189
|
+
async for chunk in event_generator():
|
|
190
|
+
if stream:
|
|
191
|
+
print(chunk, end="", flush=True)
|
|
192
|
+
final_output += chunk
|
|
141
193
|
if stream:
|
|
142
|
-
print(
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
194
|
+
print() # end with a new line
|
|
195
|
+
except LlmCallsLimitExceededError as e:
|
|
196
|
+
logger.warning(f"Max number of llm calls limit exceeded: {e}")
|
|
197
|
+
final_output = ""
|
|
146
198
|
|
|
147
199
|
return final_output
|
|
148
200
|
|
|
@@ -151,9 +203,13 @@ class Runner:
|
|
|
151
203
|
messages: RunnerMessage,
|
|
152
204
|
session_id: str,
|
|
153
205
|
stream: bool = False,
|
|
206
|
+
run_config: RunConfig | None = None,
|
|
154
207
|
save_tracing_data: bool = False,
|
|
208
|
+
upload_inline_data_to_tos: bool = False,
|
|
155
209
|
):
|
|
156
|
-
converted_messages: list = self._convert_messages(
|
|
210
|
+
converted_messages: list = self._convert_messages(
|
|
211
|
+
messages, session_id, upload_inline_data_to_tos
|
|
212
|
+
)
|
|
157
213
|
|
|
158
214
|
await self.short_term_memory.create_session(
|
|
159
215
|
app_name=self.app_name, user_id=self.user_id, session_id=session_id
|
|
@@ -163,7 +219,9 @@ class Runner:
|
|
|
163
219
|
|
|
164
220
|
final_output = ""
|
|
165
221
|
for converted_message in converted_messages:
|
|
166
|
-
final_output = await self._run(
|
|
222
|
+
final_output = await self._run(
|
|
223
|
+
session_id, converted_message, run_config, stream
|
|
224
|
+
)
|
|
167
225
|
|
|
168
226
|
# try to save tracing file
|
|
169
227
|
if save_tracing_data:
|
|
@@ -193,6 +251,54 @@ class Runner:
|
|
|
193
251
|
logger.warning(f"Get tracer id failed as {e}")
|
|
194
252
|
return "<unknown_trace_id>"
|
|
195
253
|
|
|
254
|
+
async def run_with_raw_message(
|
|
255
|
+
self,
|
|
256
|
+
message: types.Content,
|
|
257
|
+
session_id: str,
|
|
258
|
+
run_config: RunConfig | None = None,
|
|
259
|
+
):
|
|
260
|
+
run_config = (
|
|
261
|
+
RunConfig(max_llm_calls=int(getenv("MODEL_AGENT_MAX_LLM_CALLS", 100)))
|
|
262
|
+
if not run_config
|
|
263
|
+
else run_config
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
logger.info(f"Run config: {run_config}")
|
|
267
|
+
|
|
268
|
+
await self.short_term_memory.create_session(
|
|
269
|
+
app_name=self.app_name, user_id=self.user_id, session_id=session_id
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
try:
|
|
273
|
+
|
|
274
|
+
async def event_generator():
|
|
275
|
+
async for event in self.runner.run_async(
|
|
276
|
+
user_id=self.user_id,
|
|
277
|
+
session_id=session_id,
|
|
278
|
+
new_message=message,
|
|
279
|
+
run_config=run_config,
|
|
280
|
+
):
|
|
281
|
+
if event.get_function_calls():
|
|
282
|
+
for function_call in event.get_function_calls():
|
|
283
|
+
logger.debug(f"Function call: {function_call}")
|
|
284
|
+
elif (
|
|
285
|
+
event.content is not None
|
|
286
|
+
and event.content.parts
|
|
287
|
+
and event.content.parts[0].text is not None
|
|
288
|
+
and len(event.content.parts[0].text.strip()) > 0
|
|
289
|
+
):
|
|
290
|
+
yield event.content.parts[0].text
|
|
291
|
+
|
|
292
|
+
final_output = ""
|
|
293
|
+
|
|
294
|
+
async for chunk in event_generator():
|
|
295
|
+
final_output += chunk
|
|
296
|
+
except LlmCallsLimitExceededError as e:
|
|
297
|
+
logger.warning(f"Max number of llm calls limit exceeded: {e}")
|
|
298
|
+
final_output = ""
|
|
299
|
+
|
|
300
|
+
return final_output
|
|
301
|
+
|
|
196
302
|
def _print_trace_id(self) -> None:
|
|
197
303
|
if not isinstance(self.agent, Agent):
|
|
198
304
|
logger.warning(
|