veadk-python 0.2.7__py3-none-any.whl → 0.2.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of veadk-python might be problematic. Click here for more details.
- veadk/agent.py +3 -2
- veadk/auth/veauth/opensearch_veauth.py +75 -0
- veadk/auth/veauth/postgresql_veauth.py +75 -0
- veadk/cli/cli.py +3 -1
- veadk/cli/cli_eval.py +160 -0
- veadk/cli/cli_prompt.py +9 -2
- veadk/cli/cli_web.py +6 -1
- veadk/configs/database_configs.py +43 -0
- veadk/configs/model_configs.py +32 -0
- veadk/consts.py +11 -4
- veadk/evaluation/adk_evaluator/adk_evaluator.py +5 -2
- veadk/evaluation/base_evaluator.py +95 -68
- veadk/evaluation/deepeval_evaluator/deepeval_evaluator.py +23 -15
- veadk/evaluation/eval_set_recorder.py +2 -2
- veadk/integrations/ve_prompt_pilot/ve_prompt_pilot.py +9 -3
- veadk/integrations/ve_tls/utils.py +1 -2
- veadk/integrations/ve_tls/ve_tls.py +9 -5
- veadk/integrations/ve_tos/ve_tos.py +542 -68
- veadk/knowledgebase/backends/base_backend.py +59 -0
- veadk/knowledgebase/backends/in_memory_backend.py +82 -0
- veadk/knowledgebase/backends/opensearch_backend.py +136 -0
- veadk/knowledgebase/backends/redis_backend.py +144 -0
- veadk/knowledgebase/backends/utils.py +91 -0
- veadk/knowledgebase/backends/vikingdb_knowledge_backend.py +524 -0
- veadk/{database/__init__.py → knowledgebase/entry.py} +10 -2
- veadk/knowledgebase/knowledgebase.py +120 -139
- veadk/memory/__init__.py +22 -0
- veadk/memory/long_term_memory.py +124 -41
- veadk/{database/base_database.py → memory/long_term_memory_backends/base_backend.py} +10 -22
- veadk/memory/long_term_memory_backends/in_memory_backend.py +65 -0
- veadk/memory/long_term_memory_backends/mem0_backend.py +129 -0
- veadk/memory/long_term_memory_backends/opensearch_backend.py +120 -0
- veadk/memory/long_term_memory_backends/redis_backend.py +127 -0
- veadk/memory/long_term_memory_backends/vikingdb_memory_backend.py +148 -0
- veadk/memory/short_term_memory.py +80 -72
- veadk/memory/short_term_memory_backends/base_backend.py +31 -0
- veadk/memory/short_term_memory_backends/mysql_backend.py +41 -0
- veadk/memory/short_term_memory_backends/postgresql_backend.py +41 -0
- veadk/memory/short_term_memory_backends/sqlite_backend.py +48 -0
- veadk/runner.py +12 -19
- veadk/tools/builtin_tools/generate_image.py +355 -0
- veadk/tools/builtin_tools/image_edit.py +56 -16
- veadk/tools/builtin_tools/image_generate.py +51 -15
- veadk/tools/builtin_tools/video_generate.py +41 -41
- veadk/tools/builtin_tools/web_scraper.py +1 -1
- veadk/tools/builtin_tools/web_search.py +7 -7
- veadk/tools/load_knowledgebase_tool.py +2 -8
- veadk/tracing/telemetry/attributes/extractors/llm_attributes_extractors.py +21 -3
- veadk/tracing/telemetry/exporters/apmplus_exporter.py +24 -6
- veadk/tracing/telemetry/exporters/cozeloop_exporter.py +2 -0
- veadk/tracing/telemetry/exporters/inmemory_exporter.py +22 -8
- veadk/tracing/telemetry/exporters/tls_exporter.py +2 -0
- veadk/tracing/telemetry/opentelemetry_tracer.py +13 -10
- veadk/tracing/telemetry/telemetry.py +66 -63
- veadk/utils/misc.py +15 -0
- veadk/version.py +1 -1
- {veadk_python-0.2.7.dist-info → veadk_python-0.2.9.dist-info}/METADATA +28 -5
- {veadk_python-0.2.7.dist-info → veadk_python-0.2.9.dist-info}/RECORD +65 -56
- veadk/database/database_adapter.py +0 -533
- veadk/database/database_factory.py +0 -80
- veadk/database/kv/redis_database.py +0 -159
- veadk/database/local_database.py +0 -62
- veadk/database/relational/mysql_database.py +0 -173
- veadk/database/vector/opensearch_vector_database.py +0 -263
- veadk/database/vector/type.py +0 -50
- veadk/database/viking/__init__.py +0 -13
- veadk/database/viking/viking_database.py +0 -638
- veadk/database/viking/viking_memory_db.py +0 -525
- /veadk/{database/kv → knowledgebase/backends}/__init__.py +0 -0
- /veadk/{database/relational → memory/long_term_memory_backends}/__init__.py +0 -0
- /veadk/{database/vector → memory/short_term_memory_backends}/__init__.py +0 -0
- {veadk_python-0.2.7.dist-info → veadk_python-0.2.9.dist-info}/WHEEL +0 -0
- {veadk_python-0.2.7.dist-info → veadk_python-0.2.9.dist-info}/entry_points.txt +0 -0
- {veadk_python-0.2.7.dist-info → veadk_python-0.2.9.dist-info}/licenses/LICENSE +0 -0
- {veadk_python-0.2.7.dist-info → veadk_python-0.2.9.dist-info}/top_level.txt +0 -0
|
@@ -12,95 +12,103 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
import
|
|
16
|
-
from typing import Literal
|
|
15
|
+
from functools import wraps
|
|
16
|
+
from typing import Any, Callable, Literal
|
|
17
|
+
|
|
18
|
+
from google.adk.sessions import (
|
|
19
|
+
BaseSessionService,
|
|
20
|
+
DatabaseSessionService,
|
|
21
|
+
InMemorySessionService,
|
|
22
|
+
)
|
|
23
|
+
from pydantic import BaseModel, Field, PrivateAttr
|
|
24
|
+
|
|
25
|
+
from veadk.memory.short_term_memory_backends.mysql_backend import (
|
|
26
|
+
MysqlSTMBackend,
|
|
27
|
+
)
|
|
28
|
+
from veadk.memory.short_term_memory_backends.postgresql_backend import (
|
|
29
|
+
PostgreSqlSTMBackend,
|
|
30
|
+
)
|
|
31
|
+
from veadk.memory.short_term_memory_backends.sqlite_backend import (
|
|
32
|
+
SQLiteSTMBackend,
|
|
33
|
+
)
|
|
34
|
+
from veadk.utils.logger import get_logger
|
|
17
35
|
|
|
18
|
-
|
|
36
|
+
logger = get_logger(__name__)
|
|
19
37
|
|
|
20
|
-
from veadk.config import getenv
|
|
21
|
-
from veadk.utils.logger import get_logger
|
|
22
38
|
|
|
23
|
-
|
|
39
|
+
def wrap_get_session_with_callbacks(obj, callback_fn: Callable):
|
|
40
|
+
get_session_fn = getattr(obj, "get_session")
|
|
24
41
|
|
|
25
|
-
|
|
42
|
+
@wraps(get_session_fn)
|
|
43
|
+
def wrapper(*args, **kwargs):
|
|
44
|
+
result = get_session_fn(*args, **kwargs)
|
|
45
|
+
callback_fn(result, *args, **kwargs)
|
|
46
|
+
return result
|
|
26
47
|
|
|
27
|
-
|
|
48
|
+
setattr(obj, "get_session", wrapper)
|
|
28
49
|
|
|
29
50
|
|
|
30
|
-
class ShortTermMemory:
|
|
31
|
-
"""
|
|
32
|
-
Short term memory
|
|
51
|
+
class ShortTermMemory(BaseModel):
|
|
52
|
+
backend: Literal["local", "mysql", "sqlite", "postgresql", "database"] = "local"
|
|
53
|
+
"""Short term memory backend. `Local` for in-memory storage, `mysql` for mysql / PostgreSQL storage. `sqlite` for sqlite storage."""
|
|
33
54
|
|
|
34
|
-
|
|
35
|
-
"""
|
|
55
|
+
backend_configs: dict = Field(default_factory=dict)
|
|
56
|
+
"""Backend specific configurations."""
|
|
36
57
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
backend: Literal["local", "database", "mysql"] = "local",
|
|
40
|
-
db_url: str = "",
|
|
41
|
-
enable_memory_optimization: bool = False,
|
|
42
|
-
):
|
|
43
|
-
self.backend = backend
|
|
44
|
-
self.db_url = db_url
|
|
45
|
-
|
|
46
|
-
if self.backend == "mysql":
|
|
47
|
-
host = getenv("DATABASE_MYSQL_HOST")
|
|
48
|
-
user = getenv("DATABASE_MYSQL_USER")
|
|
49
|
-
password = getenv("DATABASE_MYSQL_PASSWORD")
|
|
50
|
-
database = getenv("DATABASE_MYSQL_DATABASE")
|
|
51
|
-
db_url = f"mysql+pymysql://{user}:{password}@{host}/{database}"
|
|
52
|
-
|
|
53
|
-
self.db_url = db_url
|
|
54
|
-
self.backend = "database"
|
|
55
|
-
|
|
56
|
-
if self.backend == "local":
|
|
57
|
-
logger.warning(
|
|
58
|
-
f"Short term memory backend: {self.backend}, the history will be lost after application shutdown."
|
|
59
|
-
)
|
|
60
|
-
self.session_service = InMemorySessionService()
|
|
61
|
-
elif self.backend == "database":
|
|
62
|
-
if self.db_url == "" or self.db_url is None:
|
|
63
|
-
logger.warning("The `db_url` is an empty or None string.")
|
|
64
|
-
self._use_default_database()
|
|
65
|
-
else:
|
|
66
|
-
try:
|
|
67
|
-
self.session_service = DatabaseSessionService(db_url=self.db_url)
|
|
68
|
-
logger.info("Connected to database with db_url.")
|
|
69
|
-
except Exception as e:
|
|
70
|
-
logger.error(f"Failed to connect to database, error: {e}.")
|
|
71
|
-
self._use_default_database()
|
|
72
|
-
else:
|
|
73
|
-
raise ValueError(f"Unknown short term memory backend: {self.backend}")
|
|
58
|
+
db_url: str = ""
|
|
59
|
+
"""Database connection URL, e.g. `sqlite:///./test.db`. Once set, it will override the `backend` parameter."""
|
|
74
60
|
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
)
|
|
61
|
+
local_database_path: str = "/tmp/veadk_local_database.db"
|
|
62
|
+
"""Local database path, only used when `backend` is `sqlite`. Default to `/tmp/veadk_local_database.db`."""
|
|
63
|
+
|
|
64
|
+
after_load_memory_callback: Callable | None = None
|
|
65
|
+
"""A callback to be called after loading memory from the backend. The callback function should accept `Session` as an input."""
|
|
81
66
|
|
|
82
|
-
|
|
83
|
-
self.db_url = DEFAULT_LOCAL_DATABASE_PATH
|
|
84
|
-
logger.info(f"Using default local database {self.db_url}")
|
|
85
|
-
if not os.path.exists(self.db_url):
|
|
86
|
-
self.create_local_sqlite3_db(self.db_url)
|
|
87
|
-
self.session_service = DatabaseSessionService(db_url="sqlite:///" + self.db_url)
|
|
67
|
+
_session_service: BaseSessionService = PrivateAttr()
|
|
88
68
|
|
|
89
|
-
def
|
|
90
|
-
|
|
69
|
+
def model_post_init(self, __context: Any) -> None:
|
|
70
|
+
if self.db_url:
|
|
71
|
+
logger.info("The `db_url` is set, ignore `backend` option.")
|
|
72
|
+
self._session_service = DatabaseSessionService(db_url=self.db_url)
|
|
73
|
+
else:
|
|
74
|
+
if self.backend == "database":
|
|
75
|
+
logger.warning(
|
|
76
|
+
"Backend `database` is deprecated, use `sqlite` to create short term memory."
|
|
77
|
+
)
|
|
78
|
+
self.backend = "sqlite"
|
|
79
|
+
match self.backend:
|
|
80
|
+
case "local":
|
|
81
|
+
self._session_service = InMemorySessionService()
|
|
82
|
+
case "mysql":
|
|
83
|
+
self._session_service = MysqlSTMBackend(
|
|
84
|
+
**self.backend_configs
|
|
85
|
+
).session_service
|
|
86
|
+
case "sqlite":
|
|
87
|
+
self._session_service = SQLiteSTMBackend(
|
|
88
|
+
local_path=self.local_database_path
|
|
89
|
+
).session_service
|
|
90
|
+
case "postgresql":
|
|
91
|
+
self._session_service = PostgreSqlSTMBackend(
|
|
92
|
+
**self.backend_configs
|
|
93
|
+
).session_service
|
|
94
|
+
|
|
95
|
+
if self.after_load_memory_callback:
|
|
96
|
+
wrap_get_session_with_callbacks(
|
|
97
|
+
self._session_service, self.after_load_memory_callback
|
|
98
|
+
)
|
|
91
99
|
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
100
|
+
@property
|
|
101
|
+
def session_service(self) -> BaseSessionService:
|
|
102
|
+
return self._session_service
|
|
95
103
|
|
|
96
104
|
async def create_session(
|
|
97
105
|
self,
|
|
98
106
|
app_name: str,
|
|
99
107
|
user_id: str,
|
|
100
108
|
session_id: str,
|
|
101
|
-
):
|
|
102
|
-
if isinstance(self.
|
|
103
|
-
list_sessions_response = await self.
|
|
109
|
+
) -> None:
|
|
110
|
+
if isinstance(self._session_service, DatabaseSessionService):
|
|
111
|
+
list_sessions_response = await self._session_service.list_sessions(
|
|
104
112
|
app_name=app_name, user_id=user_id
|
|
105
113
|
)
|
|
106
114
|
|
|
@@ -109,12 +117,12 @@ class ShortTermMemory:
|
|
|
109
117
|
)
|
|
110
118
|
|
|
111
119
|
if (
|
|
112
|
-
await self.
|
|
120
|
+
await self._session_service.get_session(
|
|
113
121
|
app_name=app_name, user_id=user_id, session_id=session_id
|
|
114
122
|
)
|
|
115
123
|
is None
|
|
116
124
|
):
|
|
117
125
|
# create a new session for this running
|
|
118
|
-
await self.
|
|
126
|
+
await self._session_service.create_session(
|
|
119
127
|
app_name=app_name, user_id=user_id, session_id=session_id
|
|
120
128
|
)
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
from abc import ABC, abstractmethod
|
|
17
|
+
from functools import cached_property
|
|
18
|
+
|
|
19
|
+
from google.adk.sessions import BaseSessionService
|
|
20
|
+
from pydantic import BaseModel
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class BaseShortTermMemoryBackend(ABC, BaseModel):
|
|
24
|
+
"""
|
|
25
|
+
Base class for short term memory backend.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
@cached_property
|
|
29
|
+
@abstractmethod
|
|
30
|
+
def session_service(self) -> BaseSessionService:
|
|
31
|
+
"""Return the session service instance."""
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from functools import cached_property
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
18
|
+
from google.adk.sessions import (
|
|
19
|
+
BaseSessionService,
|
|
20
|
+
DatabaseSessionService,
|
|
21
|
+
)
|
|
22
|
+
from pydantic import Field
|
|
23
|
+
from typing_extensions import override
|
|
24
|
+
|
|
25
|
+
import veadk.config # noqa E401
|
|
26
|
+
from veadk.configs.database_configs import MysqlConfig
|
|
27
|
+
from veadk.memory.short_term_memory_backends.base_backend import (
|
|
28
|
+
BaseShortTermMemoryBackend,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class MysqlSTMBackend(BaseShortTermMemoryBackend):
|
|
33
|
+
mysql_config: MysqlConfig = Field(default_factory=MysqlConfig)
|
|
34
|
+
|
|
35
|
+
def model_post_init(self, context: Any) -> None:
|
|
36
|
+
self._db_url = f"mysql+pymysql://{self.mysql_config.user}:{self.mysql_config.password}@{self.mysql_config.host}/{self.mysql_config.database}"
|
|
37
|
+
|
|
38
|
+
@cached_property
|
|
39
|
+
@override
|
|
40
|
+
def session_service(self) -> BaseSessionService:
|
|
41
|
+
return DatabaseSessionService(db_url=self._db_url)
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from functools import cached_property
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
18
|
+
from google.adk.sessions import (
|
|
19
|
+
BaseSessionService,
|
|
20
|
+
DatabaseSessionService,
|
|
21
|
+
)
|
|
22
|
+
from pydantic import Field
|
|
23
|
+
from typing_extensions import override
|
|
24
|
+
|
|
25
|
+
import veadk.config # noqa E401
|
|
26
|
+
from veadk.configs.database_configs import PostgreSqlConfig
|
|
27
|
+
from veadk.memory.short_term_memory_backends.base_backend import (
|
|
28
|
+
BaseShortTermMemoryBackend,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class PostgreSqlSTMBackend(BaseShortTermMemoryBackend):
|
|
33
|
+
postgresql_config: PostgreSqlConfig = Field(default_factory=PostgreSqlConfig)
|
|
34
|
+
|
|
35
|
+
def model_post_init(self, context: Any) -> None:
|
|
36
|
+
self._db_url = f"postgresql+psycopg2://{self.postgresql_config.user}:{self.postgresql_config.password}@{self.postgresql_config.host}:{self.postgresql_config.port}/{self.postgresql_config.database}"
|
|
37
|
+
|
|
38
|
+
@cached_property
|
|
39
|
+
@override
|
|
40
|
+
def session_service(self) -> BaseSessionService:
|
|
41
|
+
return DatabaseSessionService(db_url=self._db_url)
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import os
|
|
16
|
+
import sqlite3
|
|
17
|
+
from functools import cached_property
|
|
18
|
+
from typing import Any
|
|
19
|
+
|
|
20
|
+
from google.adk.sessions import (
|
|
21
|
+
BaseSessionService,
|
|
22
|
+
DatabaseSessionService,
|
|
23
|
+
)
|
|
24
|
+
from typing_extensions import override
|
|
25
|
+
|
|
26
|
+
from veadk.memory.short_term_memory_backends.base_backend import (
|
|
27
|
+
BaseShortTermMemoryBackend,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class SQLiteSTMBackend(BaseShortTermMemoryBackend):
|
|
32
|
+
local_path: str
|
|
33
|
+
|
|
34
|
+
def model_post_init(self, context: Any) -> None:
|
|
35
|
+
# if the DB file not exists, create it
|
|
36
|
+
if not self._db_exists():
|
|
37
|
+
conn = sqlite3.connect(self.local_path)
|
|
38
|
+
conn.close()
|
|
39
|
+
|
|
40
|
+
self._db_url = f"sqlite:///{self.local_path}"
|
|
41
|
+
|
|
42
|
+
@cached_property
|
|
43
|
+
@override
|
|
44
|
+
def session_service(self) -> BaseSessionService:
|
|
45
|
+
return DatabaseSessionService(db_url=self._db_url)
|
|
46
|
+
|
|
47
|
+
def _db_exists(self) -> bool:
|
|
48
|
+
return os.path.exists(self.local_path)
|
veadk/runner.py
CHANGED
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
import
|
|
15
|
+
import os
|
|
16
16
|
import functools
|
|
17
17
|
from types import MethodType
|
|
18
18
|
from typing import Union
|
|
@@ -47,7 +47,7 @@ RunnerMessage = Union[
|
|
|
47
47
|
]
|
|
48
48
|
|
|
49
49
|
|
|
50
|
-
def pre_run_process(self, process_func, new_message, user_id, session_id):
|
|
50
|
+
async def pre_run_process(self, process_func, new_message, user_id, session_id):
|
|
51
51
|
if new_message.parts:
|
|
52
52
|
for part in new_message.parts:
|
|
53
53
|
if (
|
|
@@ -55,13 +55,12 @@ def pre_run_process(self, process_func, new_message, user_id, session_id):
|
|
|
55
55
|
and part.inline_data.mime_type == "image/png"
|
|
56
56
|
and self.upload_inline_data_to_tos
|
|
57
57
|
):
|
|
58
|
-
process_func(
|
|
58
|
+
await process_func(
|
|
59
59
|
part,
|
|
60
60
|
self.app_name,
|
|
61
61
|
user_id,
|
|
62
62
|
session_id,
|
|
63
63
|
)
|
|
64
|
-
return
|
|
65
64
|
|
|
66
65
|
|
|
67
66
|
def post_run_process(self):
|
|
@@ -79,7 +78,7 @@ def intercept_new_message(process_func):
|
|
|
79
78
|
new_message: types.Content,
|
|
80
79
|
**kwargs,
|
|
81
80
|
):
|
|
82
|
-
pre_run_process(self, process_func, new_message, user_id, session_id)
|
|
81
|
+
await pre_run_process(self, process_func, new_message, user_id, session_id)
|
|
83
82
|
|
|
84
83
|
async for event in func(
|
|
85
84
|
user_id=user_id,
|
|
@@ -137,27 +136,21 @@ def _convert_messages(
|
|
|
137
136
|
return _messages
|
|
138
137
|
|
|
139
138
|
|
|
140
|
-
def _upload_image_to_tos(
|
|
139
|
+
async def _upload_image_to_tos(
|
|
141
140
|
part: genai.types.Part, app_name: str, user_id: str, session_id: str
|
|
142
141
|
) -> None:
|
|
143
142
|
try:
|
|
144
143
|
if part.inline_data and part.inline_data.display_name and part.inline_data.data:
|
|
145
144
|
from veadk.integrations.ve_tos.ve_tos import VeTOS
|
|
146
145
|
|
|
146
|
+
filename = os.path.basename(part.inline_data.display_name)
|
|
147
|
+
object_key = f"{app_name}/{user_id}-{session_id}-{filename}"
|
|
147
148
|
ve_tos = VeTOS()
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
session_id=session_id,
|
|
153
|
-
data_path=part.inline_data.display_name,
|
|
149
|
+
tos_url = ve_tos.build_tos_signed_url(object_key=object_key)
|
|
150
|
+
await ve_tos.async_upload_bytes(
|
|
151
|
+
object_key=object_key,
|
|
152
|
+
data=part.inline_data.data,
|
|
154
153
|
)
|
|
155
|
-
|
|
156
|
-
upload_task = ve_tos.upload(object_key, part.inline_data.data)
|
|
157
|
-
|
|
158
|
-
if upload_task is not None:
|
|
159
|
-
asyncio.create_task(upload_task)
|
|
160
|
-
|
|
161
154
|
part.inline_data.display_name = tos_url
|
|
162
155
|
except Exception as e:
|
|
163
156
|
logger.error(f"Upload to TOS failed: {e}")
|
|
@@ -226,7 +219,7 @@ class Runner(ADKRunner):
|
|
|
226
219
|
)
|
|
227
220
|
|
|
228
221
|
self.run_async = MethodType(
|
|
229
|
-
intercept_new_message(_upload_image_to_tos)(
|
|
222
|
+
intercept_new_message(_upload_image_to_tos)(super().run_async), self
|
|
230
223
|
)
|
|
231
224
|
|
|
232
225
|
async def run(
|