veadk-python 0.2.10__py3-none-any.whl → 0.2.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of veadk-python might be problematic. Click here for more details.
- veadk/agent.py +7 -3
- veadk/auth/veauth/ark_veauth.py +43 -51
- veadk/auth/veauth/utils.py +57 -0
- veadk/configs/model_configs.py +3 -3
- veadk/consts.py +9 -0
- veadk/knowledgebase/knowledgebase.py +19 -32
- veadk/memory/long_term_memory.py +39 -92
- veadk/memory/long_term_memory_backends/base_backend.py +4 -2
- veadk/memory/long_term_memory_backends/in_memory_backend.py +8 -6
- veadk/memory/long_term_memory_backends/mem0_backend.py +8 -8
- veadk/memory/long_term_memory_backends/opensearch_backend.py +40 -36
- veadk/memory/long_term_memory_backends/redis_backend.py +59 -46
- veadk/memory/long_term_memory_backends/vikingdb_memory_backend.py +53 -28
- veadk/memory/short_term_memory.py +9 -11
- veadk/runner.py +19 -11
- veadk/tools/builtin_tools/generate_image.py +11 -6
- veadk/tools/builtin_tools/image_edit.py +9 -4
- veadk/tools/builtin_tools/image_generate.py +9 -4
- veadk/tools/builtin_tools/load_knowledgebase.py +97 -0
- veadk/tools/builtin_tools/video_generate.py +6 -4
- veadk/utils/misc.py +6 -10
- veadk/utils/volcengine_sign.py +2 -0
- veadk/version.py +1 -1
- {veadk_python-0.2.10.dist-info → veadk_python-0.2.11.dist-info}/METADATA +2 -1
- {veadk_python-0.2.10.dist-info → veadk_python-0.2.11.dist-info}/RECORD +29 -27
- {veadk_python-0.2.10.dist-info → veadk_python-0.2.11.dist-info}/WHEEL +0 -0
- {veadk_python-0.2.10.dist-info → veadk_python-0.2.11.dist-info}/entry_points.txt +0 -0
- {veadk_python-0.2.10.dist-info → veadk_python-0.2.11.dist-info}/licenses/LICENSE +0 -0
- {veadk_python-0.2.10.dist-info → veadk_python-0.2.11.dist-info}/top_level.txt +0 -0
|
@@ -14,11 +14,7 @@
|
|
|
14
14
|
|
|
15
15
|
import re
|
|
16
16
|
|
|
17
|
-
from llama_index.core import
|
|
18
|
-
Document,
|
|
19
|
-
StorageContext,
|
|
20
|
-
VectorStoreIndex,
|
|
21
|
-
)
|
|
17
|
+
from llama_index.core import Document, VectorStoreIndex
|
|
22
18
|
from llama_index.core.schema import BaseNode
|
|
23
19
|
from llama_index.embeddings.openai_like import OpenAILikeEmbedding
|
|
24
20
|
from pydantic import Field
|
|
@@ -31,6 +27,7 @@ from veadk.knowledgebase.backends.utils import get_llama_index_splitter
|
|
|
31
27
|
from veadk.memory.long_term_memory_backends.base_backend import (
|
|
32
28
|
BaseLongTermMemoryBackend,
|
|
33
29
|
)
|
|
30
|
+
from veadk.utils.logger import get_logger
|
|
34
31
|
|
|
35
32
|
try:
|
|
36
33
|
from llama_index.vector_stores.opensearch import (
|
|
@@ -42,6 +39,8 @@ except ImportError:
|
|
|
42
39
|
"Please install VeADK extensions\npip install veadk-python[extensions]"
|
|
43
40
|
)
|
|
44
41
|
|
|
42
|
+
logger = get_logger(__name__)
|
|
43
|
+
|
|
45
44
|
|
|
46
45
|
class OpensearchLTMBackend(BaseLongTermMemoryBackend):
|
|
47
46
|
opensearch_config: OpensearchConfig = Field(default_factory=OpensearchConfig)
|
|
@@ -52,19 +51,30 @@ class OpensearchLTMBackend(BaseLongTermMemoryBackend):
|
|
|
52
51
|
)
|
|
53
52
|
"""Embedding model configs"""
|
|
54
53
|
|
|
55
|
-
def
|
|
54
|
+
def model_post_init(self, __context: Any) -> None:
|
|
55
|
+
self._embed_model = OpenAILikeEmbedding(
|
|
56
|
+
model_name=self.embedding_config.name,
|
|
57
|
+
api_key=self.embedding_config.api_key,
|
|
58
|
+
api_base=self.embedding_config.api_base,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
def precheck_index_naming(self, index: str):
|
|
56
62
|
if not (
|
|
57
|
-
isinstance(
|
|
58
|
-
and not
|
|
59
|
-
and
|
|
60
|
-
and re.match(r"^[a-z0-9_\-.]+$",
|
|
63
|
+
isinstance(index, str)
|
|
64
|
+
and not index.startswith(("_", "-"))
|
|
65
|
+
and index.islower()
|
|
66
|
+
and re.match(r"^[a-z0-9_\-.]+$", index)
|
|
61
67
|
):
|
|
62
68
|
raise ValueError(
|
|
63
|
-
"The index name does not conform to the naming rules of OpenSearch"
|
|
69
|
+
f"The index name {index} does not conform to the naming rules of OpenSearch"
|
|
64
70
|
)
|
|
65
71
|
|
|
66
|
-
def
|
|
67
|
-
|
|
72
|
+
def _create_vector_index(self, index: str) -> VectorStoreIndex:
|
|
73
|
+
logger.info(f"Create OpenSearch vector index with index={index}")
|
|
74
|
+
|
|
75
|
+
self.precheck_index_naming(index)
|
|
76
|
+
|
|
77
|
+
opensearch_client = OpensearchVectorClient(
|
|
68
78
|
endpoint=self.opensearch_config.host,
|
|
69
79
|
port=self.opensearch_config.port,
|
|
70
80
|
http_auth=(
|
|
@@ -74,39 +84,33 @@ class OpensearchLTMBackend(BaseLongTermMemoryBackend):
|
|
|
74
84
|
use_ssl=True,
|
|
75
85
|
verify_certs=False,
|
|
76
86
|
dim=self.embedding_config.dim,
|
|
77
|
-
index=
|
|
87
|
+
index=index,
|
|
78
88
|
)
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
self._storage_context = StorageContext.from_defaults(
|
|
83
|
-
vector_store=self._vector_store
|
|
84
|
-
)
|
|
85
|
-
|
|
86
|
-
self._embed_model = OpenAILikeEmbedding(
|
|
87
|
-
model_name=self.embedding_config.name,
|
|
88
|
-
api_key=self.embedding_config.api_key,
|
|
89
|
-
api_base=self.embedding_config.api_base,
|
|
90
|
-
)
|
|
91
|
-
|
|
92
|
-
self._vector_index = VectorStoreIndex.from_documents(
|
|
93
|
-
documents=[],
|
|
94
|
-
storage_context=self._storage_context,
|
|
95
|
-
embed_model=self._embed_model,
|
|
89
|
+
vector_store = OpensearchVectorStore(client=opensearch_client)
|
|
90
|
+
return VectorStoreIndex.from_vector_store(
|
|
91
|
+
vector_store=vector_store, embed_model=self._embed_model
|
|
96
92
|
)
|
|
97
|
-
self._retriever = self._vector_index.as_retriever()
|
|
98
93
|
|
|
99
94
|
@override
|
|
100
|
-
def save_memory(self, event_strings: list[str], **kwargs) -> bool:
|
|
95
|
+
def save_memory(self, user_id: str, event_strings: list[str], **kwargs) -> bool:
|
|
96
|
+
index = f"{self.index}_{user_id}"
|
|
97
|
+
vector_index = self._create_vector_index(index)
|
|
98
|
+
|
|
101
99
|
for event_string in event_strings:
|
|
102
100
|
document = Document(text=event_string)
|
|
103
101
|
nodes = self._split_documents([document])
|
|
104
|
-
|
|
102
|
+
vector_index.insert_nodes(nodes)
|
|
105
103
|
return True
|
|
106
104
|
|
|
107
105
|
@override
|
|
108
|
-
def search_memory(
|
|
109
|
-
|
|
106
|
+
def search_memory(
|
|
107
|
+
self, user_id: str, query: str, top_k: int, **kwargs
|
|
108
|
+
) -> list[str]:
|
|
109
|
+
index = f"{self.index}_{user_id}"
|
|
110
|
+
|
|
111
|
+
vector_index = self._create_vector_index(index)
|
|
112
|
+
|
|
113
|
+
_retriever = vector_index.as_retriever(similarity_top_k=top_k)
|
|
110
114
|
retrieved_nodes = _retriever.retrieve(query)
|
|
111
115
|
return [node.text for node in retrieved_nodes]
|
|
112
116
|
|
|
@@ -12,11 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from llama_index.core import
|
|
16
|
-
Document,
|
|
17
|
-
StorageContext,
|
|
18
|
-
VectorStoreIndex,
|
|
19
|
-
)
|
|
15
|
+
from llama_index.core import Document, VectorStoreIndex
|
|
20
16
|
from llama_index.core.schema import BaseNode
|
|
21
17
|
from llama_index.embeddings.openai_like import OpenAILikeEmbedding
|
|
22
18
|
from pydantic import Field
|
|
@@ -29,21 +25,22 @@ from veadk.knowledgebase.backends.utils import get_llama_index_splitter
|
|
|
29
25
|
from veadk.memory.long_term_memory_backends.base_backend import (
|
|
30
26
|
BaseLongTermMemoryBackend,
|
|
31
27
|
)
|
|
28
|
+
from veadk.utils.logger import get_logger
|
|
32
29
|
|
|
33
30
|
try:
|
|
34
31
|
from llama_index.vector_stores.redis import RedisVectorStore
|
|
35
|
-
from llama_index.vector_stores.redis.schema import (
|
|
36
|
-
RedisIndexInfo,
|
|
37
|
-
RedisVectorStoreSchema,
|
|
38
|
-
)
|
|
39
32
|
from redis import Redis
|
|
40
|
-
from redisvl.schema
|
|
33
|
+
from redisvl.schema import IndexSchema
|
|
34
|
+
|
|
41
35
|
except ImportError:
|
|
42
36
|
raise ImportError(
|
|
43
37
|
"Please install VeADK extensions\npip install veadk-python[extensions]"
|
|
44
38
|
)
|
|
45
39
|
|
|
46
40
|
|
|
41
|
+
logger = get_logger(__name__)
|
|
42
|
+
|
|
43
|
+
|
|
47
44
|
class RedisLTMBackend(BaseLongTermMemoryBackend):
|
|
48
45
|
redis_config: RedisConfig = Field(default_factory=RedisConfig)
|
|
49
46
|
"""Redis client configs"""
|
|
@@ -53,67 +50,83 @@ class RedisLTMBackend(BaseLongTermMemoryBackend):
|
|
|
53
50
|
)
|
|
54
51
|
"""Embedding model configs"""
|
|
55
52
|
|
|
56
|
-
def
|
|
53
|
+
def model_post_init(self, __context: Any) -> None:
|
|
54
|
+
self._embed_model = OpenAILikeEmbedding(
|
|
55
|
+
model_name=self.embedding_config.name,
|
|
56
|
+
api_key=self.embedding_config.api_key,
|
|
57
|
+
api_base=self.embedding_config.api_base,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
def precheck_index_naming(self, index: str):
|
|
57
61
|
# no checking
|
|
58
62
|
pass
|
|
59
63
|
|
|
60
|
-
def
|
|
64
|
+
def _create_vector_index(self, index: str) -> VectorStoreIndex:
|
|
65
|
+
logger.info(f"Create Redis vector index with index={index}")
|
|
66
|
+
|
|
67
|
+
self.precheck_index_naming(index)
|
|
68
|
+
|
|
61
69
|
# We will use `from_url` to init Redis client once the
|
|
62
70
|
# AK/SK -> STS token is ready.
|
|
63
71
|
# self._redis_client = Redis.from_url(url=...)
|
|
64
|
-
|
|
65
|
-
self._redis_client = Redis(
|
|
72
|
+
redis_client = Redis(
|
|
66
73
|
host=self.redis_config.host,
|
|
67
74
|
port=self.redis_config.port,
|
|
68
75
|
db=self.redis_config.db,
|
|
69
76
|
password=self.redis_config.password,
|
|
70
77
|
)
|
|
71
78
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
79
|
+
# Create an index for each user
|
|
80
|
+
# Should be Optimized in the future
|
|
81
|
+
schema = IndexSchema.from_dict(
|
|
82
|
+
{
|
|
83
|
+
"index": {"name": index, "prefix": index, "key_separator": "_"},
|
|
84
|
+
"fields": [
|
|
85
|
+
{"name": "id", "type": "tag", "attrs": {"sortable": False}},
|
|
86
|
+
{"name": "doc_id", "type": "tag", "attrs": {"sortable": False}},
|
|
87
|
+
{"name": "text", "type": "text", "attrs": {"weight": 1.0}},
|
|
88
|
+
{
|
|
89
|
+
"name": "vector",
|
|
90
|
+
"type": "vector",
|
|
91
|
+
"attrs": {
|
|
92
|
+
"dims": self.embedding_config.dim,
|
|
93
|
+
"algorithm": "flat",
|
|
94
|
+
"distance_metric": "cosine",
|
|
95
|
+
},
|
|
96
|
+
},
|
|
97
|
+
],
|
|
98
|
+
}
|
|
76
99
|
)
|
|
100
|
+
vector_store = RedisVectorStore(schema=schema, redis_client=redis_client)
|
|
77
101
|
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
)
|
|
81
|
-
if "vector" in self._schema.fields:
|
|
82
|
-
vector_field = self._schema.fields["vector"]
|
|
83
|
-
if (
|
|
84
|
-
vector_field
|
|
85
|
-
and vector_field.attrs
|
|
86
|
-
and isinstance(vector_field.attrs, BaseVectorFieldAttributes)
|
|
87
|
-
):
|
|
88
|
-
vector_field.attrs.dims = self.embedding_config.dim
|
|
89
|
-
self._vector_store = RedisVectorStore(
|
|
90
|
-
schema=self._schema,
|
|
91
|
-
redis_client=self._redis_client,
|
|
92
|
-
overwrite=True,
|
|
93
|
-
collection_name=self.index,
|
|
94
|
-
)
|
|
95
|
-
|
|
96
|
-
self._storage_context = StorageContext.from_defaults(
|
|
97
|
-
vector_store=self._vector_store
|
|
102
|
+
logger.info(
|
|
103
|
+
f"Create vector store done, index_name={vector_store.index_name} prefix={vector_store.schema.index.prefix}"
|
|
98
104
|
)
|
|
99
105
|
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
storage_context=self._storage_context,
|
|
103
|
-
embed_model=self._embed_model,
|
|
106
|
+
return VectorStoreIndex.from_vector_store(
|
|
107
|
+
vector_store=vector_store, embed_model=self._embed_model
|
|
104
108
|
)
|
|
105
109
|
|
|
106
110
|
@override
|
|
107
|
-
def save_memory(self, event_strings: list[str], **kwargs) -> bool:
|
|
111
|
+
def save_memory(self, user_id: str, event_strings: list[str], **kwargs) -> bool:
|
|
112
|
+
index = f"veadk-ltm/{self.index}/{user_id}"
|
|
113
|
+
vector_index = self._create_vector_index(index)
|
|
114
|
+
|
|
108
115
|
for event_string in event_strings:
|
|
109
116
|
document = Document(text=event_string)
|
|
110
117
|
nodes = self._split_documents([document])
|
|
111
|
-
|
|
118
|
+
vector_index.insert_nodes(nodes)
|
|
119
|
+
|
|
112
120
|
return True
|
|
113
121
|
|
|
114
122
|
@override
|
|
115
|
-
def search_memory(
|
|
116
|
-
|
|
123
|
+
def search_memory(
|
|
124
|
+
self, user_id: str, query: str, top_k: int, **kwargs
|
|
125
|
+
) -> list[str]:
|
|
126
|
+
index = f"veadk-ltm/{self.index}/{user_id}"
|
|
127
|
+
vector_index = self._create_vector_index(index)
|
|
128
|
+
|
|
129
|
+
_retriever = vector_index.as_retriever(similarity_top_k=top_k)
|
|
117
130
|
retrieved_nodes = _retriever.retrieve(query)
|
|
118
131
|
return [node.text for node in retrieved_nodes]
|
|
119
132
|
|
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import json
|
|
16
|
+
import os
|
|
16
17
|
import re
|
|
17
18
|
import time
|
|
18
19
|
import uuid
|
|
@@ -22,7 +23,7 @@ from pydantic import Field
|
|
|
22
23
|
from typing_extensions import override
|
|
23
24
|
|
|
24
25
|
import veadk.config # noqa E401
|
|
25
|
-
from veadk.
|
|
26
|
+
from veadk.auth.veauth.utils import get_credential_from_vefaas_iam
|
|
26
27
|
from veadk.integrations.ve_viking_db_memory.ve_viking_db_memory import (
|
|
27
28
|
VikingDBMemoryClient,
|
|
28
29
|
)
|
|
@@ -35,17 +36,24 @@ logger = get_logger(__name__)
|
|
|
35
36
|
|
|
36
37
|
|
|
37
38
|
class VikingDBLTMBackend(BaseLongTermMemoryBackend):
|
|
38
|
-
volcengine_access_key: str = Field(
|
|
39
|
-
default_factory=lambda: getenv("VOLCENGINE_ACCESS_KEY")
|
|
39
|
+
volcengine_access_key: str | None = Field(
|
|
40
|
+
default_factory=lambda: os.getenv("VOLCENGINE_ACCESS_KEY")
|
|
40
41
|
)
|
|
41
42
|
|
|
42
|
-
volcengine_secret_key: str = Field(
|
|
43
|
-
default_factory=lambda: getenv("VOLCENGINE_SECRET_KEY")
|
|
43
|
+
volcengine_secret_key: str | None = Field(
|
|
44
|
+
default_factory=lambda: os.getenv("VOLCENGINE_SECRET_KEY")
|
|
44
45
|
)
|
|
45
46
|
|
|
47
|
+
session_token: str = ""
|
|
48
|
+
|
|
46
49
|
region: str = "cn-beijing"
|
|
47
50
|
"""VikingDB memory region"""
|
|
48
51
|
|
|
52
|
+
def model_post_init(self, __context: Any) -> None:
|
|
53
|
+
# check whether collection exist, if not, create it
|
|
54
|
+
if not self._collection_exist():
|
|
55
|
+
self._create_collection()
|
|
56
|
+
|
|
49
57
|
def precheck_index_naming(self):
|
|
50
58
|
if not (
|
|
51
59
|
isinstance(self.index, str)
|
|
@@ -56,37 +64,39 @@ class VikingDBLTMBackend(BaseLongTermMemoryBackend):
|
|
|
56
64
|
"The index name does not conform to the rules: it must start with an English letter, contain only letters, numbers, and underscores, and have a length of 1-128."
|
|
57
65
|
)
|
|
58
66
|
|
|
59
|
-
def model_post_init(self, __context: Any) -> None:
|
|
60
|
-
self._client = VikingDBMemoryClient(
|
|
61
|
-
ak=self.volcengine_access_key,
|
|
62
|
-
sk=self.volcengine_secret_key,
|
|
63
|
-
region=self.region,
|
|
64
|
-
)
|
|
65
|
-
|
|
66
|
-
# check whether collection exist, if not, create it
|
|
67
|
-
if not self._collection_exist():
|
|
68
|
-
self._create_collection()
|
|
69
|
-
|
|
70
67
|
def _collection_exist(self) -> bool:
|
|
71
68
|
try:
|
|
72
|
-
self.
|
|
69
|
+
client = self._get_client()
|
|
70
|
+
client.get_collection(collection_name=self.index)
|
|
73
71
|
return True
|
|
74
72
|
except Exception:
|
|
75
73
|
return False
|
|
76
74
|
|
|
77
75
|
def _create_collection(self) -> None:
|
|
78
|
-
|
|
76
|
+
client = self._get_client()
|
|
77
|
+
response = client.create_collection(
|
|
79
78
|
collection_name=self.index,
|
|
80
79
|
description="Created by Volcengine Agent Development Kit VeADK",
|
|
81
80
|
builtin_event_types=["sys_event_v1"],
|
|
82
81
|
)
|
|
83
82
|
return response
|
|
84
83
|
|
|
84
|
+
def _get_client(self) -> VikingDBMemoryClient:
|
|
85
|
+
if not (self.volcengine_access_key and self.volcengine_secret_key):
|
|
86
|
+
cred = get_credential_from_vefaas_iam()
|
|
87
|
+
self.volcengine_access_key = cred.access_key_id
|
|
88
|
+
self.volcengine_secret_key = cred.secret_access_key
|
|
89
|
+
self.session_token = cred.session_token
|
|
90
|
+
|
|
91
|
+
return VikingDBMemoryClient(
|
|
92
|
+
ak=self.volcengine_access_key,
|
|
93
|
+
sk=self.volcengine_secret_key,
|
|
94
|
+
sts_token=self.session_token,
|
|
95
|
+
region=self.region,
|
|
96
|
+
)
|
|
97
|
+
|
|
85
98
|
@override
|
|
86
|
-
def save_memory(self, event_strings: list[str], **kwargs) -> bool:
|
|
87
|
-
user_id = kwargs.get("user_id")
|
|
88
|
-
if user_id is None:
|
|
89
|
-
raise ValueError("user_id is required")
|
|
99
|
+
def save_memory(self, user_id: str, event_strings: list[str], **kwargs) -> bool:
|
|
90
100
|
session_id = str(uuid.uuid1())
|
|
91
101
|
messages = []
|
|
92
102
|
for raw_events in event_strings:
|
|
@@ -101,31 +111,46 @@ class VikingDBLTMBackend(BaseLongTermMemoryBackend):
|
|
|
101
111
|
"default_assistant_id": "assistant",
|
|
102
112
|
"time": int(time.time() * 1000),
|
|
103
113
|
}
|
|
104
|
-
|
|
114
|
+
|
|
115
|
+
logger.debug(
|
|
116
|
+
f"Request for add {len(messages)} memory to VikingDB: collection_name={self.index}, metadata={metadata}, session_id={session_id}"
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
client = self._get_client()
|
|
120
|
+
response = client.add_messages(
|
|
105
121
|
collection_name=self.index,
|
|
106
122
|
messages=messages,
|
|
107
123
|
metadata=metadata,
|
|
108
124
|
session_id=session_id,
|
|
109
125
|
)
|
|
110
126
|
|
|
127
|
+
logger.debug(f"Response from add memory to VikingDB: {response}")
|
|
128
|
+
|
|
111
129
|
if not response.get("code") == 0:
|
|
112
130
|
raise ValueError(f"Save VikingDB memory error: {response}")
|
|
113
131
|
|
|
114
132
|
return True
|
|
115
133
|
|
|
116
134
|
@override
|
|
117
|
-
def search_memory(
|
|
118
|
-
user_id
|
|
119
|
-
|
|
120
|
-
raise ValueError("user_id is required")
|
|
135
|
+
def search_memory(
|
|
136
|
+
self, user_id: str, query: str, top_k: int, **kwargs
|
|
137
|
+
) -> list[str]:
|
|
121
138
|
filter = {
|
|
122
139
|
"user_id": user_id,
|
|
123
140
|
"memory_type": ["sys_event_v1"],
|
|
124
141
|
}
|
|
125
|
-
|
|
142
|
+
|
|
143
|
+
logger.debug(
|
|
144
|
+
f"Request for search memory in VikingDB: filter={filter}, collection_name={self.index}, query={query}, limit={top_k}"
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
client = self._get_client()
|
|
148
|
+
response = client.search_memory(
|
|
126
149
|
collection_name=self.index, query=query, filter=filter, limit=top_k
|
|
127
150
|
)
|
|
128
151
|
|
|
152
|
+
logger.debug(f"Response from search memory in VikingDB: {response}")
|
|
153
|
+
|
|
129
154
|
if not response.get("code") == 0:
|
|
130
155
|
raise ValueError(f"Search VikingDB memory error: {response}")
|
|
131
156
|
|
|
@@ -117,18 +117,16 @@ class ShortTermMemory(BaseModel):
|
|
|
117
117
|
f"Loaded {len(list_sessions_response.sessions)} sessions from db {self.db_url}."
|
|
118
118
|
)
|
|
119
119
|
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
120
|
+
session = await self._session_service.get_session(
|
|
121
|
+
app_name=app_name, user_id=user_id, session_id=session_id
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
if session:
|
|
125
|
+
logger.info(
|
|
126
|
+
f"Session {session_id} already exists with app_name={app_name} user_id={user_id}."
|
|
123
127
|
)
|
|
124
|
-
|
|
125
|
-
|
|
128
|
+
return session
|
|
129
|
+
else:
|
|
126
130
|
return await self._session_service.create_session(
|
|
127
131
|
app_name=app_name, user_id=user_id, session_id=session_id
|
|
128
132
|
)
|
|
129
|
-
else:
|
|
130
|
-
logger.info(
|
|
131
|
-
f"Session {session_id} already exists with app_name={app_name} user_id={user_id}."
|
|
132
|
-
)
|
|
133
|
-
|
|
134
|
-
return None
|
veadk/runner.py
CHANGED
|
@@ -34,7 +34,7 @@ from veadk.evaluation import EvalSetRecorder
|
|
|
34
34
|
from veadk.memory.short_term_memory import ShortTermMemory
|
|
35
35
|
from veadk.types import MediaMessage
|
|
36
36
|
from veadk.utils.logger import get_logger
|
|
37
|
-
from veadk.utils.misc import formatted_timestamp,
|
|
37
|
+
from veadk.utils.misc import formatted_timestamp, read_file_to_bytes
|
|
38
38
|
|
|
39
39
|
logger = get_logger(__name__)
|
|
40
40
|
|
|
@@ -50,11 +50,7 @@ RunnerMessage = Union[
|
|
|
50
50
|
async def pre_run_process(self, process_func, new_message, user_id, session_id):
|
|
51
51
|
if new_message.parts:
|
|
52
52
|
for part in new_message.parts:
|
|
53
|
-
if
|
|
54
|
-
part.inline_data
|
|
55
|
-
and part.inline_data.mime_type == "image/png"
|
|
56
|
-
and self.upload_inline_data_to_tos
|
|
57
|
-
):
|
|
53
|
+
if part.inline_data and self.upload_inline_data_to_tos:
|
|
58
54
|
await process_func(
|
|
59
55
|
part,
|
|
60
56
|
self.app_name,
|
|
@@ -105,9 +101,20 @@ def _convert_messages(
|
|
|
105
101
|
if isinstance(messages, str):
|
|
106
102
|
_messages = [types.Content(role="user", parts=[types.Part(text=messages)])]
|
|
107
103
|
elif isinstance(messages, MediaMessage):
|
|
108
|
-
|
|
109
|
-
|
|
104
|
+
import filetype
|
|
105
|
+
|
|
106
|
+
file_data = read_file_to_bytes(messages.media)
|
|
107
|
+
|
|
108
|
+
kind = filetype.guess(file_data)
|
|
109
|
+
if kind is None:
|
|
110
|
+
raise ValueError("Unsupported or unknown file type.")
|
|
111
|
+
|
|
112
|
+
mime_type = kind.mime
|
|
113
|
+
|
|
114
|
+
assert mime_type.startswith(("image/", "video/")), (
|
|
115
|
+
f"Unsupported media type: {mime_type}"
|
|
110
116
|
)
|
|
117
|
+
|
|
111
118
|
_messages = [
|
|
112
119
|
types.Content(
|
|
113
120
|
role="user",
|
|
@@ -116,8 +123,8 @@ def _convert_messages(
|
|
|
116
123
|
types.Part(
|
|
117
124
|
inline_data=Blob(
|
|
118
125
|
display_name=messages.media,
|
|
119
|
-
data=
|
|
120
|
-
mime_type=
|
|
126
|
+
data=file_data,
|
|
127
|
+
mime_type=mime_type,
|
|
121
128
|
)
|
|
122
129
|
),
|
|
123
130
|
],
|
|
@@ -277,7 +284,8 @@ class Runner(ADKRunner):
|
|
|
277
284
|
and event.content.parts[0].text is not None
|
|
278
285
|
and len(event.content.parts[0].text.strip()) > 0
|
|
279
286
|
):
|
|
280
|
-
final_output
|
|
287
|
+
final_output = event.content.parts[0].text
|
|
288
|
+
logger.debug(f"Event output: {final_output}")
|
|
281
289
|
except LlmCallsLimitExceededError as e:
|
|
282
290
|
logger.warning(f"Max number of llm calls limit exceeded: {e}")
|
|
283
291
|
final_output = ""
|
|
@@ -25,17 +25,22 @@ from opentelemetry.trace import Span
|
|
|
25
25
|
from volcenginesdkarkruntime import Ark
|
|
26
26
|
from volcenginesdkarkruntime.types.images.images import SequentialImageGenerationOptions
|
|
27
27
|
|
|
28
|
-
from veadk.config import getenv
|
|
29
|
-
from veadk.consts import
|
|
28
|
+
from veadk.config import getenv, settings
|
|
29
|
+
from veadk.consts import (
|
|
30
|
+
DEFAULT_IMAGE_GENERATE_MODEL_NAME,
|
|
31
|
+
DEFAULT_IMAGE_GENERATE_MODEL_API_BASE,
|
|
32
|
+
)
|
|
30
33
|
from veadk.utils.logger import get_logger
|
|
31
|
-
from veadk.utils.misc import formatted_timestamp,
|
|
34
|
+
from veadk.utils.misc import formatted_timestamp, read_file_to_bytes
|
|
32
35
|
from veadk.version import VERSION
|
|
33
36
|
|
|
34
37
|
logger = get_logger(__name__)
|
|
35
38
|
|
|
36
39
|
client = Ark(
|
|
37
|
-
api_key=getenv(
|
|
38
|
-
|
|
40
|
+
api_key=getenv(
|
|
41
|
+
"MODEL_IMAGE_API_KEY", getenv("MODEL_AGENT_API_KEY", settings.model.api_key)
|
|
42
|
+
),
|
|
43
|
+
base_url=getenv("MODEL_IMAGE_API_BASE", DEFAULT_IMAGE_GENERATE_MODEL_API_BASE),
|
|
39
44
|
)
|
|
40
45
|
|
|
41
46
|
|
|
@@ -299,7 +304,7 @@ async def image_generate(
|
|
|
299
304
|
artifact=Part(
|
|
300
305
|
inline_data=Blob(
|
|
301
306
|
display_name=filename,
|
|
302
|
-
data=
|
|
307
|
+
data=read_file_to_bytes(image_tos_url),
|
|
303
308
|
mime_type=mimetypes.guess_type(image_tos_url)[0],
|
|
304
309
|
)
|
|
305
310
|
),
|
|
@@ -15,8 +15,11 @@
|
|
|
15
15
|
from typing import Dict
|
|
16
16
|
from google.adk.tools import ToolContext
|
|
17
17
|
from volcenginesdkarkruntime import Ark
|
|
18
|
-
from veadk.config import getenv
|
|
19
|
-
from veadk.consts import
|
|
18
|
+
from veadk.config import getenv, settings
|
|
19
|
+
from veadk.consts import (
|
|
20
|
+
DEFAULT_IMAGE_EDIT_MODEL_API_BASE,
|
|
21
|
+
DEFAULT_IMAGE_EDIT_MODEL_NAME,
|
|
22
|
+
)
|
|
20
23
|
import base64
|
|
21
24
|
from opentelemetry import trace
|
|
22
25
|
import traceback
|
|
@@ -28,8 +31,10 @@ from veadk.utils.logger import get_logger
|
|
|
28
31
|
logger = get_logger(__name__)
|
|
29
32
|
|
|
30
33
|
client = Ark(
|
|
31
|
-
api_key=getenv(
|
|
32
|
-
|
|
34
|
+
api_key=getenv(
|
|
35
|
+
"MODEL_EDIT_API_KEY", getenv("MODEL_AGENT_API_KEY", settings.model.api_key)
|
|
36
|
+
),
|
|
37
|
+
base_url=getenv("MODEL_EDIT_API_BASE", DEFAULT_IMAGE_EDIT_MODEL_API_BASE),
|
|
33
38
|
)
|
|
34
39
|
|
|
35
40
|
|
|
@@ -15,8 +15,11 @@
|
|
|
15
15
|
from typing import Dict
|
|
16
16
|
|
|
17
17
|
from google.adk.tools import ToolContext
|
|
18
|
-
from veadk.config import getenv
|
|
19
|
-
from veadk.consts import
|
|
18
|
+
from veadk.config import getenv, settings
|
|
19
|
+
from veadk.consts import (
|
|
20
|
+
DEFAULT_TEXT_TO_IMAGE_MODEL_NAME,
|
|
21
|
+
DEFAULT_TEXT_TO_IMAGE_MODEL_API_BASE,
|
|
22
|
+
)
|
|
20
23
|
import base64
|
|
21
24
|
from volcenginesdkarkruntime import Ark
|
|
22
25
|
from opentelemetry import trace
|
|
@@ -29,8 +32,10 @@ from veadk.utils.logger import get_logger
|
|
|
29
32
|
logger = get_logger(__name__)
|
|
30
33
|
|
|
31
34
|
client = Ark(
|
|
32
|
-
api_key=getenv(
|
|
33
|
-
|
|
35
|
+
api_key=getenv(
|
|
36
|
+
"MODEL_IMAGE_API_KEY", getenv("MODEL_AGENT_API_KEY", settings.model.api_key)
|
|
37
|
+
),
|
|
38
|
+
base_url=getenv("MODEL_IMAGE_API_BASE", DEFAULT_TEXT_TO_IMAGE_MODEL_API_BASE),
|
|
34
39
|
)
|
|
35
40
|
|
|
36
41
|
|