veadk-python 0.2.9__py3-none-any.whl → 0.2.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of veadk-python might be problematic. Click here for more details.
- veadk/a2a/remote_ve_agent.py +63 -6
- veadk/agent.py +10 -3
- veadk/agent_builder.py +2 -3
- veadk/auth/veauth/ark_veauth.py +43 -51
- veadk/auth/veauth/utils.py +57 -0
- veadk/cli/cli.py +2 -0
- veadk/cli/cli_kb.py +75 -0
- veadk/cli/cli_web.py +4 -0
- veadk/configs/model_configs.py +3 -3
- veadk/consts.py +9 -0
- veadk/integrations/__init__.py +13 -0
- veadk/integrations/ve_viking_db_memory/__init__.py +13 -0
- veadk/integrations/ve_viking_db_memory/ve_viking_db_memory.py +293 -0
- veadk/knowledgebase/knowledgebase.py +19 -32
- veadk/memory/__init__.py +1 -1
- veadk/memory/long_term_memory.py +40 -68
- veadk/memory/long_term_memory_backends/base_backend.py +4 -2
- veadk/memory/long_term_memory_backends/in_memory_backend.py +8 -6
- veadk/memory/long_term_memory_backends/mem0_backend.py +25 -10
- veadk/memory/long_term_memory_backends/opensearch_backend.py +40 -36
- veadk/memory/long_term_memory_backends/redis_backend.py +59 -46
- veadk/memory/long_term_memory_backends/vikingdb_memory_backend.py +56 -35
- veadk/memory/short_term_memory.py +12 -8
- veadk/memory/short_term_memory_backends/postgresql_backend.py +3 -1
- veadk/runner.py +42 -19
- veadk/tools/builtin_tools/generate_image.py +56 -17
- veadk/tools/builtin_tools/image_edit.py +17 -7
- veadk/tools/builtin_tools/image_generate.py +17 -7
- veadk/tools/builtin_tools/load_knowledgebase.py +97 -0
- veadk/tools/builtin_tools/video_generate.py +11 -9
- veadk/tools/builtin_tools/web_search.py +10 -3
- veadk/tools/load_knowledgebase_tool.py +12 -0
- veadk/tracing/telemetry/attributes/extractors/llm_attributes_extractors.py +5 -0
- veadk/tracing/telemetry/attributes/extractors/tool_attributes_extractors.py +7 -0
- veadk/tracing/telemetry/exporters/apmplus_exporter.py +82 -2
- veadk/tracing/telemetry/exporters/inmemory_exporter.py +8 -2
- veadk/tracing/telemetry/telemetry.py +41 -5
- veadk/utils/misc.py +6 -10
- veadk/utils/volcengine_sign.py +2 -0
- veadk/version.py +1 -1
- {veadk_python-0.2.9.dist-info → veadk_python-0.2.11.dist-info}/METADATA +4 -3
- {veadk_python-0.2.9.dist-info → veadk_python-0.2.11.dist-info}/RECORD +46 -40
- {veadk_python-0.2.9.dist-info → veadk_python-0.2.11.dist-info}/WHEEL +0 -0
- {veadk_python-0.2.9.dist-info → veadk_python-0.2.11.dist-info}/entry_points.txt +0 -0
- {veadk_python-0.2.9.dist-info → veadk_python-0.2.11.dist-info}/licenses/LICENSE +0 -0
- {veadk_python-0.2.9.dist-info → veadk_python-0.2.11.dist-info}/top_level.txt +0 -0
|
@@ -13,12 +13,11 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
from typing import Any
|
|
16
|
-
|
|
16
|
+
|
|
17
17
|
from pydantic import Field
|
|
18
|
+
from typing_extensions import override
|
|
18
19
|
|
|
19
20
|
from veadk.configs.database_configs import Mem0Config
|
|
20
|
-
|
|
21
|
-
|
|
22
21
|
from veadk.memory.long_term_memory_backends.base_backend import (
|
|
23
22
|
BaseLongTermMemoryBackend,
|
|
24
23
|
)
|
|
@@ -46,12 +45,17 @@ class Mem0LTMBackend(BaseLongTermMemoryBackend):
|
|
|
46
45
|
|
|
47
46
|
try:
|
|
48
47
|
self._mem0_client = MemoryClient(
|
|
49
|
-
|
|
48
|
+
host=self.mem0_config.base_url, # mem0 endpoint
|
|
50
49
|
api_key=self.mem0_config.api_key, # mem0 API key
|
|
51
50
|
)
|
|
51
|
+
logger.info(
|
|
52
|
+
f"Initialized Mem0 client for host: {self.mem0_config.base_url}"
|
|
53
|
+
)
|
|
52
54
|
logger.info(f"Initialized Mem0 client for index: {self.index}")
|
|
53
55
|
except Exception as e:
|
|
54
|
-
logger.error(
|
|
56
|
+
logger.error(
|
|
57
|
+
f"Failed to initialize Mem0 client for host {self.mem0_config.base_url} : {str(e)}"
|
|
58
|
+
)
|
|
55
59
|
raise
|
|
56
60
|
|
|
57
61
|
def precheck_index_naming(self):
|
|
@@ -61,7 +65,9 @@ class Mem0LTMBackend(BaseLongTermMemoryBackend):
|
|
|
61
65
|
pass
|
|
62
66
|
|
|
63
67
|
@override
|
|
64
|
-
def save_memory(
|
|
68
|
+
def save_memory(
|
|
69
|
+
self, event_strings: list[str], user_id: str = "default_user", **kwargs
|
|
70
|
+
) -> bool:
|
|
65
71
|
"""Save memory to Mem0
|
|
66
72
|
|
|
67
73
|
Args:
|
|
@@ -71,8 +77,6 @@ class Mem0LTMBackend(BaseLongTermMemoryBackend):
|
|
|
71
77
|
Returns:
|
|
72
78
|
bool: True if saved successfully, False otherwise
|
|
73
79
|
"""
|
|
74
|
-
user_id = kwargs.get("user_id", "default_user")
|
|
75
|
-
|
|
76
80
|
try:
|
|
77
81
|
logger.info(
|
|
78
82
|
f"Saving {len(event_strings)} events to Mem0 for user: {user_id}"
|
|
@@ -84,6 +88,7 @@ class Mem0LTMBackend(BaseLongTermMemoryBackend):
|
|
|
84
88
|
[{"role": "user", "content": event_string}],
|
|
85
89
|
user_id=user_id,
|
|
86
90
|
output_format="v1.1",
|
|
91
|
+
async_mode=True,
|
|
87
92
|
)
|
|
88
93
|
logger.debug(f"Saved memory result: {result}")
|
|
89
94
|
|
|
@@ -94,7 +99,9 @@ class Mem0LTMBackend(BaseLongTermMemoryBackend):
|
|
|
94
99
|
return False
|
|
95
100
|
|
|
96
101
|
@override
|
|
97
|
-
def search_memory(
|
|
102
|
+
def search_memory(
|
|
103
|
+
self, query: str, top_k: int, user_id: str = "default_user", **kwargs
|
|
104
|
+
) -> list[str]:
|
|
98
105
|
"""Search memory from Mem0
|
|
99
106
|
|
|
100
107
|
Args:
|
|
@@ -105,7 +112,6 @@ class Mem0LTMBackend(BaseLongTermMemoryBackend):
|
|
|
105
112
|
Returns:
|
|
106
113
|
list[str]: List of memory strings
|
|
107
114
|
"""
|
|
108
|
-
user_id = kwargs.get("user_id", "default_user")
|
|
109
115
|
|
|
110
116
|
try:
|
|
111
117
|
logger.info(
|
|
@@ -116,7 +122,16 @@ class Mem0LTMBackend(BaseLongTermMemoryBackend):
|
|
|
116
122
|
query, user_id=user_id, output_format="v1.1", top_k=top_k
|
|
117
123
|
)
|
|
118
124
|
|
|
125
|
+
logger.debug(f"return relevant memories: {memories}")
|
|
126
|
+
|
|
119
127
|
memory_list = []
|
|
128
|
+
# 如果 memories 是列表,直接返回
|
|
129
|
+
if isinstance(memories, list):
|
|
130
|
+
for mem in memories:
|
|
131
|
+
if "memory" in mem:
|
|
132
|
+
memory_list.append(mem["memory"])
|
|
133
|
+
return memory_list
|
|
134
|
+
|
|
120
135
|
if memories.get("results", []):
|
|
121
136
|
for mem in memories["results"]:
|
|
122
137
|
if "memory" in mem:
|
|
@@ -14,11 +14,7 @@
|
|
|
14
14
|
|
|
15
15
|
import re
|
|
16
16
|
|
|
17
|
-
from llama_index.core import
|
|
18
|
-
Document,
|
|
19
|
-
StorageContext,
|
|
20
|
-
VectorStoreIndex,
|
|
21
|
-
)
|
|
17
|
+
from llama_index.core import Document, VectorStoreIndex
|
|
22
18
|
from llama_index.core.schema import BaseNode
|
|
23
19
|
from llama_index.embeddings.openai_like import OpenAILikeEmbedding
|
|
24
20
|
from pydantic import Field
|
|
@@ -31,6 +27,7 @@ from veadk.knowledgebase.backends.utils import get_llama_index_splitter
|
|
|
31
27
|
from veadk.memory.long_term_memory_backends.base_backend import (
|
|
32
28
|
BaseLongTermMemoryBackend,
|
|
33
29
|
)
|
|
30
|
+
from veadk.utils.logger import get_logger
|
|
34
31
|
|
|
35
32
|
try:
|
|
36
33
|
from llama_index.vector_stores.opensearch import (
|
|
@@ -42,6 +39,8 @@ except ImportError:
|
|
|
42
39
|
"Please install VeADK extensions\npip install veadk-python[extensions]"
|
|
43
40
|
)
|
|
44
41
|
|
|
42
|
+
logger = get_logger(__name__)
|
|
43
|
+
|
|
45
44
|
|
|
46
45
|
class OpensearchLTMBackend(BaseLongTermMemoryBackend):
|
|
47
46
|
opensearch_config: OpensearchConfig = Field(default_factory=OpensearchConfig)
|
|
@@ -52,19 +51,30 @@ class OpensearchLTMBackend(BaseLongTermMemoryBackend):
|
|
|
52
51
|
)
|
|
53
52
|
"""Embedding model configs"""
|
|
54
53
|
|
|
55
|
-
def
|
|
54
|
+
def model_post_init(self, __context: Any) -> None:
|
|
55
|
+
self._embed_model = OpenAILikeEmbedding(
|
|
56
|
+
model_name=self.embedding_config.name,
|
|
57
|
+
api_key=self.embedding_config.api_key,
|
|
58
|
+
api_base=self.embedding_config.api_base,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
def precheck_index_naming(self, index: str):
|
|
56
62
|
if not (
|
|
57
|
-
isinstance(
|
|
58
|
-
and not
|
|
59
|
-
and
|
|
60
|
-
and re.match(r"^[a-z0-9_\-.]+$",
|
|
63
|
+
isinstance(index, str)
|
|
64
|
+
and not index.startswith(("_", "-"))
|
|
65
|
+
and index.islower()
|
|
66
|
+
and re.match(r"^[a-z0-9_\-.]+$", index)
|
|
61
67
|
):
|
|
62
68
|
raise ValueError(
|
|
63
|
-
"The index name does not conform to the naming rules of OpenSearch"
|
|
69
|
+
f"The index name {index} does not conform to the naming rules of OpenSearch"
|
|
64
70
|
)
|
|
65
71
|
|
|
66
|
-
def
|
|
67
|
-
|
|
72
|
+
def _create_vector_index(self, index: str) -> VectorStoreIndex:
|
|
73
|
+
logger.info(f"Create OpenSearch vector index with index={index}")
|
|
74
|
+
|
|
75
|
+
self.precheck_index_naming(index)
|
|
76
|
+
|
|
77
|
+
opensearch_client = OpensearchVectorClient(
|
|
68
78
|
endpoint=self.opensearch_config.host,
|
|
69
79
|
port=self.opensearch_config.port,
|
|
70
80
|
http_auth=(
|
|
@@ -74,39 +84,33 @@ class OpensearchLTMBackend(BaseLongTermMemoryBackend):
|
|
|
74
84
|
use_ssl=True,
|
|
75
85
|
verify_certs=False,
|
|
76
86
|
dim=self.embedding_config.dim,
|
|
77
|
-
index=
|
|
87
|
+
index=index,
|
|
78
88
|
)
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
self._storage_context = StorageContext.from_defaults(
|
|
83
|
-
vector_store=self._vector_store
|
|
84
|
-
)
|
|
85
|
-
|
|
86
|
-
self._embed_model = OpenAILikeEmbedding(
|
|
87
|
-
model_name=self.embedding_config.name,
|
|
88
|
-
api_key=self.embedding_config.api_key,
|
|
89
|
-
api_base=self.embedding_config.api_base,
|
|
90
|
-
)
|
|
91
|
-
|
|
92
|
-
self._vector_index = VectorStoreIndex.from_documents(
|
|
93
|
-
documents=[],
|
|
94
|
-
storage_context=self._storage_context,
|
|
95
|
-
embed_model=self._embed_model,
|
|
89
|
+
vector_store = OpensearchVectorStore(client=opensearch_client)
|
|
90
|
+
return VectorStoreIndex.from_vector_store(
|
|
91
|
+
vector_store=vector_store, embed_model=self._embed_model
|
|
96
92
|
)
|
|
97
|
-
self._retriever = self._vector_index.as_retriever()
|
|
98
93
|
|
|
99
94
|
@override
|
|
100
|
-
def save_memory(self, event_strings: list[str], **kwargs) -> bool:
|
|
95
|
+
def save_memory(self, user_id: str, event_strings: list[str], **kwargs) -> bool:
|
|
96
|
+
index = f"{self.index}_{user_id}"
|
|
97
|
+
vector_index = self._create_vector_index(index)
|
|
98
|
+
|
|
101
99
|
for event_string in event_strings:
|
|
102
100
|
document = Document(text=event_string)
|
|
103
101
|
nodes = self._split_documents([document])
|
|
104
|
-
|
|
102
|
+
vector_index.insert_nodes(nodes)
|
|
105
103
|
return True
|
|
106
104
|
|
|
107
105
|
@override
|
|
108
|
-
def search_memory(
|
|
109
|
-
|
|
106
|
+
def search_memory(
|
|
107
|
+
self, user_id: str, query: str, top_k: int, **kwargs
|
|
108
|
+
) -> list[str]:
|
|
109
|
+
index = f"{self.index}_{user_id}"
|
|
110
|
+
|
|
111
|
+
vector_index = self._create_vector_index(index)
|
|
112
|
+
|
|
113
|
+
_retriever = vector_index.as_retriever(similarity_top_k=top_k)
|
|
110
114
|
retrieved_nodes = _retriever.retrieve(query)
|
|
111
115
|
return [node.text for node in retrieved_nodes]
|
|
112
116
|
|
|
@@ -12,11 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from llama_index.core import
|
|
16
|
-
Document,
|
|
17
|
-
StorageContext,
|
|
18
|
-
VectorStoreIndex,
|
|
19
|
-
)
|
|
15
|
+
from llama_index.core import Document, VectorStoreIndex
|
|
20
16
|
from llama_index.core.schema import BaseNode
|
|
21
17
|
from llama_index.embeddings.openai_like import OpenAILikeEmbedding
|
|
22
18
|
from pydantic import Field
|
|
@@ -29,21 +25,22 @@ from veadk.knowledgebase.backends.utils import get_llama_index_splitter
|
|
|
29
25
|
from veadk.memory.long_term_memory_backends.base_backend import (
|
|
30
26
|
BaseLongTermMemoryBackend,
|
|
31
27
|
)
|
|
28
|
+
from veadk.utils.logger import get_logger
|
|
32
29
|
|
|
33
30
|
try:
|
|
34
31
|
from llama_index.vector_stores.redis import RedisVectorStore
|
|
35
|
-
from llama_index.vector_stores.redis.schema import (
|
|
36
|
-
RedisIndexInfo,
|
|
37
|
-
RedisVectorStoreSchema,
|
|
38
|
-
)
|
|
39
32
|
from redis import Redis
|
|
40
|
-
from redisvl.schema
|
|
33
|
+
from redisvl.schema import IndexSchema
|
|
34
|
+
|
|
41
35
|
except ImportError:
|
|
42
36
|
raise ImportError(
|
|
43
37
|
"Please install VeADK extensions\npip install veadk-python[extensions]"
|
|
44
38
|
)
|
|
45
39
|
|
|
46
40
|
|
|
41
|
+
logger = get_logger(__name__)
|
|
42
|
+
|
|
43
|
+
|
|
47
44
|
class RedisLTMBackend(BaseLongTermMemoryBackend):
|
|
48
45
|
redis_config: RedisConfig = Field(default_factory=RedisConfig)
|
|
49
46
|
"""Redis client configs"""
|
|
@@ -53,67 +50,83 @@ class RedisLTMBackend(BaseLongTermMemoryBackend):
|
|
|
53
50
|
)
|
|
54
51
|
"""Embedding model configs"""
|
|
55
52
|
|
|
56
|
-
def
|
|
53
|
+
def model_post_init(self, __context: Any) -> None:
|
|
54
|
+
self._embed_model = OpenAILikeEmbedding(
|
|
55
|
+
model_name=self.embedding_config.name,
|
|
56
|
+
api_key=self.embedding_config.api_key,
|
|
57
|
+
api_base=self.embedding_config.api_base,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
def precheck_index_naming(self, index: str):
|
|
57
61
|
# no checking
|
|
58
62
|
pass
|
|
59
63
|
|
|
60
|
-
def
|
|
64
|
+
def _create_vector_index(self, index: str) -> VectorStoreIndex:
|
|
65
|
+
logger.info(f"Create Redis vector index with index={index}")
|
|
66
|
+
|
|
67
|
+
self.precheck_index_naming(index)
|
|
68
|
+
|
|
61
69
|
# We will use `from_url` to init Redis client once the
|
|
62
70
|
# AK/SK -> STS token is ready.
|
|
63
71
|
# self._redis_client = Redis.from_url(url=...)
|
|
64
|
-
|
|
65
|
-
self._redis_client = Redis(
|
|
72
|
+
redis_client = Redis(
|
|
66
73
|
host=self.redis_config.host,
|
|
67
74
|
port=self.redis_config.port,
|
|
68
75
|
db=self.redis_config.db,
|
|
69
76
|
password=self.redis_config.password,
|
|
70
77
|
)
|
|
71
78
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
79
|
+
# Create an index for each user
|
|
80
|
+
# Should be Optimized in the future
|
|
81
|
+
schema = IndexSchema.from_dict(
|
|
82
|
+
{
|
|
83
|
+
"index": {"name": index, "prefix": index, "key_separator": "_"},
|
|
84
|
+
"fields": [
|
|
85
|
+
{"name": "id", "type": "tag", "attrs": {"sortable": False}},
|
|
86
|
+
{"name": "doc_id", "type": "tag", "attrs": {"sortable": False}},
|
|
87
|
+
{"name": "text", "type": "text", "attrs": {"weight": 1.0}},
|
|
88
|
+
{
|
|
89
|
+
"name": "vector",
|
|
90
|
+
"type": "vector",
|
|
91
|
+
"attrs": {
|
|
92
|
+
"dims": self.embedding_config.dim,
|
|
93
|
+
"algorithm": "flat",
|
|
94
|
+
"distance_metric": "cosine",
|
|
95
|
+
},
|
|
96
|
+
},
|
|
97
|
+
],
|
|
98
|
+
}
|
|
76
99
|
)
|
|
100
|
+
vector_store = RedisVectorStore(schema=schema, redis_client=redis_client)
|
|
77
101
|
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
)
|
|
81
|
-
if "vector" in self._schema.fields:
|
|
82
|
-
vector_field = self._schema.fields["vector"]
|
|
83
|
-
if (
|
|
84
|
-
vector_field
|
|
85
|
-
and vector_field.attrs
|
|
86
|
-
and isinstance(vector_field.attrs, BaseVectorFieldAttributes)
|
|
87
|
-
):
|
|
88
|
-
vector_field.attrs.dims = self.embedding_config.dim
|
|
89
|
-
self._vector_store = RedisVectorStore(
|
|
90
|
-
schema=self._schema,
|
|
91
|
-
redis_client=self._redis_client,
|
|
92
|
-
overwrite=True,
|
|
93
|
-
collection_name=self.index,
|
|
94
|
-
)
|
|
95
|
-
|
|
96
|
-
self._storage_context = StorageContext.from_defaults(
|
|
97
|
-
vector_store=self._vector_store
|
|
102
|
+
logger.info(
|
|
103
|
+
f"Create vector store done, index_name={vector_store.index_name} prefix={vector_store.schema.index.prefix}"
|
|
98
104
|
)
|
|
99
105
|
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
storage_context=self._storage_context,
|
|
103
|
-
embed_model=self._embed_model,
|
|
106
|
+
return VectorStoreIndex.from_vector_store(
|
|
107
|
+
vector_store=vector_store, embed_model=self._embed_model
|
|
104
108
|
)
|
|
105
109
|
|
|
106
110
|
@override
|
|
107
|
-
def save_memory(self, event_strings: list[str], **kwargs) -> bool:
|
|
111
|
+
def save_memory(self, user_id: str, event_strings: list[str], **kwargs) -> bool:
|
|
112
|
+
index = f"veadk-ltm/{self.index}/{user_id}"
|
|
113
|
+
vector_index = self._create_vector_index(index)
|
|
114
|
+
|
|
108
115
|
for event_string in event_strings:
|
|
109
116
|
document = Document(text=event_string)
|
|
110
117
|
nodes = self._split_documents([document])
|
|
111
|
-
|
|
118
|
+
vector_index.insert_nodes(nodes)
|
|
119
|
+
|
|
112
120
|
return True
|
|
113
121
|
|
|
114
122
|
@override
|
|
115
|
-
def search_memory(
|
|
116
|
-
|
|
123
|
+
def search_memory(
|
|
124
|
+
self, user_id: str, query: str, top_k: int, **kwargs
|
|
125
|
+
) -> list[str]:
|
|
126
|
+
index = f"veadk-ltm/{self.index}/{user_id}"
|
|
127
|
+
vector_index = self._create_vector_index(index)
|
|
128
|
+
|
|
129
|
+
_retriever = vector_index.as_retriever(similarity_top_k=top_k)
|
|
117
130
|
retrieved_nodes = _retriever.retrieve(query)
|
|
118
131
|
return [node.text for node in retrieved_nodes]
|
|
119
132
|
|
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import json
|
|
16
|
+
import os
|
|
16
17
|
import re
|
|
17
18
|
import time
|
|
18
19
|
import uuid
|
|
@@ -22,34 +23,37 @@ from pydantic import Field
|
|
|
22
23
|
from typing_extensions import override
|
|
23
24
|
|
|
24
25
|
import veadk.config # noqa E401
|
|
25
|
-
from veadk.
|
|
26
|
+
from veadk.auth.veauth.utils import get_credential_from_vefaas_iam
|
|
27
|
+
from veadk.integrations.ve_viking_db_memory.ve_viking_db_memory import (
|
|
28
|
+
VikingDBMemoryClient,
|
|
29
|
+
)
|
|
26
30
|
from veadk.memory.long_term_memory_backends.base_backend import (
|
|
27
31
|
BaseLongTermMemoryBackend,
|
|
28
32
|
)
|
|
29
33
|
from veadk.utils.logger import get_logger
|
|
30
34
|
|
|
31
|
-
try:
|
|
32
|
-
from mcp_server_vikingdb_memory.common.memory_client import VikingDBMemoryService
|
|
33
|
-
except ImportError:
|
|
34
|
-
raise ImportError(
|
|
35
|
-
"Please install VeADK extensions\npip install veadk-python[extensions]"
|
|
36
|
-
)
|
|
37
|
-
|
|
38
35
|
logger = get_logger(__name__)
|
|
39
36
|
|
|
40
37
|
|
|
41
38
|
class VikingDBLTMBackend(BaseLongTermMemoryBackend):
|
|
42
|
-
volcengine_access_key: str = Field(
|
|
43
|
-
default_factory=lambda: getenv("VOLCENGINE_ACCESS_KEY")
|
|
39
|
+
volcengine_access_key: str | None = Field(
|
|
40
|
+
default_factory=lambda: os.getenv("VOLCENGINE_ACCESS_KEY")
|
|
44
41
|
)
|
|
45
42
|
|
|
46
|
-
volcengine_secret_key: str = Field(
|
|
47
|
-
default_factory=lambda: getenv("VOLCENGINE_SECRET_KEY")
|
|
43
|
+
volcengine_secret_key: str | None = Field(
|
|
44
|
+
default_factory=lambda: os.getenv("VOLCENGINE_SECRET_KEY")
|
|
48
45
|
)
|
|
49
46
|
|
|
47
|
+
session_token: str = ""
|
|
48
|
+
|
|
50
49
|
region: str = "cn-beijing"
|
|
51
50
|
"""VikingDB memory region"""
|
|
52
51
|
|
|
52
|
+
def model_post_init(self, __context: Any) -> None:
|
|
53
|
+
# check whether collection exist, if not, create it
|
|
54
|
+
if not self._collection_exist():
|
|
55
|
+
self._create_collection()
|
|
56
|
+
|
|
53
57
|
def precheck_index_naming(self):
|
|
54
58
|
if not (
|
|
55
59
|
isinstance(self.index, str)
|
|
@@ -60,37 +64,39 @@ class VikingDBLTMBackend(BaseLongTermMemoryBackend):
|
|
|
60
64
|
"The index name does not conform to the rules: it must start with an English letter, contain only letters, numbers, and underscores, and have a length of 1-128."
|
|
61
65
|
)
|
|
62
66
|
|
|
63
|
-
def model_post_init(self, __context: Any) -> None:
|
|
64
|
-
self._client = VikingDBMemoryService(
|
|
65
|
-
ak=self.volcengine_access_key,
|
|
66
|
-
sk=self.volcengine_secret_key,
|
|
67
|
-
region=self.region,
|
|
68
|
-
)
|
|
69
|
-
|
|
70
|
-
# check whether collection exist, if not, create it
|
|
71
|
-
if not self._collection_exist():
|
|
72
|
-
self._create_collection()
|
|
73
|
-
|
|
74
67
|
def _collection_exist(self) -> bool:
|
|
75
68
|
try:
|
|
76
|
-
self.
|
|
69
|
+
client = self._get_client()
|
|
70
|
+
client.get_collection(collection_name=self.index)
|
|
77
71
|
return True
|
|
78
72
|
except Exception:
|
|
79
73
|
return False
|
|
80
74
|
|
|
81
75
|
def _create_collection(self) -> None:
|
|
82
|
-
|
|
76
|
+
client = self._get_client()
|
|
77
|
+
response = client.create_collection(
|
|
83
78
|
collection_name=self.index,
|
|
84
79
|
description="Created by Volcengine Agent Development Kit VeADK",
|
|
85
80
|
builtin_event_types=["sys_event_v1"],
|
|
86
81
|
)
|
|
87
82
|
return response
|
|
88
83
|
|
|
84
|
+
def _get_client(self) -> VikingDBMemoryClient:
|
|
85
|
+
if not (self.volcengine_access_key and self.volcengine_secret_key):
|
|
86
|
+
cred = get_credential_from_vefaas_iam()
|
|
87
|
+
self.volcengine_access_key = cred.access_key_id
|
|
88
|
+
self.volcengine_secret_key = cred.secret_access_key
|
|
89
|
+
self.session_token = cred.session_token
|
|
90
|
+
|
|
91
|
+
return VikingDBMemoryClient(
|
|
92
|
+
ak=self.volcengine_access_key,
|
|
93
|
+
sk=self.volcengine_secret_key,
|
|
94
|
+
sts_token=self.session_token,
|
|
95
|
+
region=self.region,
|
|
96
|
+
)
|
|
97
|
+
|
|
89
98
|
@override
|
|
90
|
-
def save_memory(self, event_strings: list[str], **kwargs) -> bool:
|
|
91
|
-
user_id = kwargs.get("user_id")
|
|
92
|
-
if user_id is None:
|
|
93
|
-
raise ValueError("user_id is required")
|
|
99
|
+
def save_memory(self, user_id: str, event_strings: list[str], **kwargs) -> bool:
|
|
94
100
|
session_id = str(uuid.uuid1())
|
|
95
101
|
messages = []
|
|
96
102
|
for raw_events in event_strings:
|
|
@@ -105,31 +111,46 @@ class VikingDBLTMBackend(BaseLongTermMemoryBackend):
|
|
|
105
111
|
"default_assistant_id": "assistant",
|
|
106
112
|
"time": int(time.time() * 1000),
|
|
107
113
|
}
|
|
108
|
-
|
|
114
|
+
|
|
115
|
+
logger.debug(
|
|
116
|
+
f"Request for add {len(messages)} memory to VikingDB: collection_name={self.index}, metadata={metadata}, session_id={session_id}"
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
client = self._get_client()
|
|
120
|
+
response = client.add_messages(
|
|
109
121
|
collection_name=self.index,
|
|
110
122
|
messages=messages,
|
|
111
123
|
metadata=metadata,
|
|
112
124
|
session_id=session_id,
|
|
113
125
|
)
|
|
114
126
|
|
|
127
|
+
logger.debug(f"Response from add memory to VikingDB: {response}")
|
|
128
|
+
|
|
115
129
|
if not response.get("code") == 0:
|
|
116
130
|
raise ValueError(f"Save VikingDB memory error: {response}")
|
|
117
131
|
|
|
118
132
|
return True
|
|
119
133
|
|
|
120
134
|
@override
|
|
121
|
-
def search_memory(
|
|
122
|
-
user_id
|
|
123
|
-
|
|
124
|
-
raise ValueError("user_id is required")
|
|
135
|
+
def search_memory(
|
|
136
|
+
self, user_id: str, query: str, top_k: int, **kwargs
|
|
137
|
+
) -> list[str]:
|
|
125
138
|
filter = {
|
|
126
139
|
"user_id": user_id,
|
|
127
140
|
"memory_type": ["sys_event_v1"],
|
|
128
141
|
}
|
|
129
|
-
|
|
142
|
+
|
|
143
|
+
logger.debug(
|
|
144
|
+
f"Request for search memory in VikingDB: filter={filter}, collection_name={self.index}, query={query}, limit={top_k}"
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
client = self._get_client()
|
|
148
|
+
response = client.search_memory(
|
|
130
149
|
collection_name=self.index, query=query, filter=filter, limit=top_k
|
|
131
150
|
)
|
|
132
151
|
|
|
152
|
+
logger.debug(f"Response from search memory in VikingDB: {response}")
|
|
153
|
+
|
|
133
154
|
if not response.get("code") == 0:
|
|
134
155
|
raise ValueError(f"Search VikingDB memory error: {response}")
|
|
135
156
|
|
|
@@ -19,6 +19,7 @@ from google.adk.sessions import (
|
|
|
19
19
|
BaseSessionService,
|
|
20
20
|
DatabaseSessionService,
|
|
21
21
|
InMemorySessionService,
|
|
22
|
+
Session,
|
|
22
23
|
)
|
|
23
24
|
from pydantic import BaseModel, Field, PrivateAttr
|
|
24
25
|
|
|
@@ -106,7 +107,7 @@ class ShortTermMemory(BaseModel):
|
|
|
106
107
|
app_name: str,
|
|
107
108
|
user_id: str,
|
|
108
109
|
session_id: str,
|
|
109
|
-
) -> None:
|
|
110
|
+
) -> Session | None:
|
|
110
111
|
if isinstance(self._session_service, DatabaseSessionService):
|
|
111
112
|
list_sessions_response = await self._session_service.list_sessions(
|
|
112
113
|
app_name=app_name, user_id=user_id
|
|
@@ -116,13 +117,16 @@ class ShortTermMemory(BaseModel):
|
|
|
116
117
|
f"Loaded {len(list_sessions_response.sessions)} sessions from db {self.db_url}."
|
|
117
118
|
)
|
|
118
119
|
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
120
|
+
session = await self._session_service.get_session(
|
|
121
|
+
app_name=app_name, user_id=user_id, session_id=session_id
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
if session:
|
|
125
|
+
logger.info(
|
|
126
|
+
f"Session {session_id} already exists with app_name={app_name} user_id={user_id}."
|
|
122
127
|
)
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
await self._session_service.create_session(
|
|
128
|
+
return session
|
|
129
|
+
else:
|
|
130
|
+
return await self._session_service.create_session(
|
|
127
131
|
app_name=app_name, user_id=user_id, session_id=session_id
|
|
128
132
|
)
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
|
|
15
15
|
from functools import cached_property
|
|
16
16
|
from typing import Any
|
|
17
|
+
from venv import logger
|
|
17
18
|
|
|
18
19
|
from google.adk.sessions import (
|
|
19
20
|
BaseSessionService,
|
|
@@ -33,7 +34,8 @@ class PostgreSqlSTMBackend(BaseShortTermMemoryBackend):
|
|
|
33
34
|
postgresql_config: PostgreSqlConfig = Field(default_factory=PostgreSqlConfig)
|
|
34
35
|
|
|
35
36
|
def model_post_init(self, context: Any) -> None:
|
|
36
|
-
self._db_url = f"postgresql
|
|
37
|
+
self._db_url = f"postgresql://{self.postgresql_config.user}:{self.postgresql_config.password}@{self.postgresql_config.host}:{self.postgresql_config.port}/{self.postgresql_config.database}"
|
|
38
|
+
logger.debug(self._db_url)
|
|
37
39
|
|
|
38
40
|
@cached_property
|
|
39
41
|
@override
|