veadk-python 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of veadk-python might be problematic. Click here for more details.
- veadk/agent.py +7 -3
- veadk/auth/veauth/ark_veauth.py +43 -51
- veadk/auth/veauth/utils.py +57 -0
- veadk/cli/cli.py +2 -0
- veadk/cli/cli_uploadevalset.py +125 -0
- veadk/cli/cli_web.py +15 -2
- veadk/configs/model_configs.py +3 -3
- veadk/consts.py +9 -0
- veadk/knowledgebase/knowledgebase.py +19 -32
- veadk/memory/long_term_memory.py +39 -92
- veadk/memory/long_term_memory_backends/base_backend.py +4 -2
- veadk/memory/long_term_memory_backends/in_memory_backend.py +8 -6
- veadk/memory/long_term_memory_backends/mem0_backend.py +8 -8
- veadk/memory/long_term_memory_backends/opensearch_backend.py +40 -36
- veadk/memory/long_term_memory_backends/redis_backend.py +59 -46
- veadk/memory/long_term_memory_backends/vikingdb_memory_backend.py +54 -29
- veadk/memory/short_term_memory.py +9 -11
- veadk/runner.py +19 -11
- veadk/tools/builtin_tools/generate_image.py +230 -189
- veadk/tools/builtin_tools/image_edit.py +24 -5
- veadk/tools/builtin_tools/image_generate.py +24 -5
- veadk/tools/builtin_tools/load_knowledgebase.py +97 -0
- veadk/tools/builtin_tools/video_generate.py +38 -11
- veadk/utils/misc.py +6 -10
- veadk/utils/volcengine_sign.py +2 -0
- veadk/version.py +1 -1
- {veadk_python-0.2.10.dist-info → veadk_python-0.2.12.dist-info}/METADATA +2 -1
- {veadk_python-0.2.10.dist-info → veadk_python-0.2.12.dist-info}/RECORD +32 -29
- {veadk_python-0.2.10.dist-info → veadk_python-0.2.12.dist-info}/WHEEL +0 -0
- {veadk_python-0.2.10.dist-info → veadk_python-0.2.12.dist-info}/entry_points.txt +0 -0
- {veadk_python-0.2.10.dist-info → veadk_python-0.2.12.dist-info}/licenses/LICENSE +0 -0
- {veadk_python-0.2.10.dist-info → veadk_python-0.2.12.dist-info}/top_level.txt +0 -0
veadk/memory/long_term_memory.py
CHANGED
|
@@ -72,10 +72,6 @@ def _get_backend_cls(backend: str) -> type[BaseLongTermMemoryBackend]:
|
|
|
72
72
|
raise ValueError(f"Unsupported long term memory backend: {backend}")
|
|
73
73
|
|
|
74
74
|
|
|
75
|
-
def build_long_term_memory_index(app_name: str, user_id: str):
|
|
76
|
-
return f"{app_name}_{user_id}"
|
|
77
|
-
|
|
78
|
-
|
|
79
75
|
class LongTermMemory(BaseMemoryService, BaseModel):
|
|
80
76
|
backend: Union[
|
|
81
77
|
Literal["local", "opensearch", "redis", "viking", "viking_mem", "mem0"],
|
|
@@ -89,54 +85,48 @@ class LongTermMemory(BaseMemoryService, BaseModel):
|
|
|
89
85
|
top_k: int = 5
|
|
90
86
|
"""Number of top similar documents to retrieve during search."""
|
|
91
87
|
|
|
88
|
+
index: str = ""
|
|
89
|
+
|
|
92
90
|
app_name: str = ""
|
|
93
91
|
|
|
94
92
|
user_id: str = ""
|
|
93
|
+
"""Deprecated attribute"""
|
|
95
94
|
|
|
96
95
|
def model_post_init(self, __context: Any) -> None:
|
|
97
|
-
if self.backend == "viking_mem":
|
|
98
|
-
logger.warning(
|
|
99
|
-
"The `viking_mem` backend is deprecated, please use `viking` instead."
|
|
100
|
-
)
|
|
101
|
-
self.backend = "viking"
|
|
102
|
-
|
|
103
|
-
self._backend = None
|
|
104
|
-
|
|
105
96
|
# Once user define a backend instance, use it directly
|
|
106
97
|
if isinstance(self.backend, BaseLongTermMemoryBackend):
|
|
107
98
|
self._backend = self.backend
|
|
99
|
+
self.index = self._backend.index
|
|
108
100
|
logger.info(
|
|
109
|
-
f"Initialized long term memory with provided backend instance {self._backend.__class__.__name__}"
|
|
101
|
+
f"Initialized long term memory with provided backend instance {self._backend.__class__.__name__}, index={self.index}"
|
|
110
102
|
)
|
|
111
103
|
return
|
|
112
104
|
|
|
105
|
+
# Once user define backend config, use it directly
|
|
113
106
|
if self.backend_config:
|
|
114
|
-
logger.warning(
|
|
115
|
-
f"Initialized long term memory backend {self.backend} with config. We will ignore `app_name` and `user_id` if provided."
|
|
116
|
-
)
|
|
117
107
|
self._backend = _get_backend_cls(self.backend)(**self.backend_config)
|
|
118
|
-
_index = self.backend_config.get("index", None)
|
|
119
|
-
if _index:
|
|
120
|
-
self._index = _index
|
|
121
|
-
logger.info(f"Long term memory index set to {self._index}.")
|
|
122
|
-
else:
|
|
123
|
-
logger.warning(
|
|
124
|
-
"Cannot find index via backend_config, please set `index` parameter."
|
|
125
|
-
)
|
|
126
108
|
return
|
|
127
109
|
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
self._backend = _get_backend_cls(self.backend)(
|
|
134
|
-
index=self._index, **self.backend_config if self.backend_config else {}
|
|
110
|
+
# Check index
|
|
111
|
+
self.index = self.index or self.app_name
|
|
112
|
+
if not self.index:
|
|
113
|
+
logger.warning(
|
|
114
|
+
"Attribute `index` or `app_name` not provided, use `default_app` instead."
|
|
135
115
|
)
|
|
136
|
-
|
|
116
|
+
self.index = "default_app"
|
|
117
|
+
|
|
118
|
+
# Forward compliance
|
|
119
|
+
if self.backend == "viking_mem":
|
|
137
120
|
logger.warning(
|
|
138
|
-
"
|
|
121
|
+
"The `viking_mem` backend is deprecated, change to `viking` instead."
|
|
139
122
|
)
|
|
123
|
+
self.backend = "viking"
|
|
124
|
+
|
|
125
|
+
self._backend = _get_backend_cls(self.backend)(index=self.index)
|
|
126
|
+
|
|
127
|
+
logger.info(
|
|
128
|
+
f"Initialized long term memory with provided backend instance {self._backend.__class__.__name__}, index={self.index}"
|
|
129
|
+
)
|
|
140
130
|
|
|
141
131
|
def _filter_and_convert_events(self, events: list[Event]) -> list[str]:
|
|
142
132
|
final_events = []
|
|
@@ -164,75 +154,32 @@ class LongTermMemory(BaseMemoryService, BaseModel):
|
|
|
164
154
|
self,
|
|
165
155
|
session: Session,
|
|
166
156
|
):
|
|
167
|
-
app_name = session.app_name
|
|
168
157
|
user_id = session.user_id
|
|
169
|
-
|
|
170
|
-
if not self._backend and isinstance(self.backend, str):
|
|
171
|
-
self._index = build_long_term_memory_index(app_name, user_id)
|
|
172
|
-
self._backend = _get_backend_cls(self.backend)(
|
|
173
|
-
index=self._index, **self.backend_config if self.backend_config else {}
|
|
174
|
-
)
|
|
175
|
-
logger.info(
|
|
176
|
-
f"Initialize long term memory backend now, index is {self._index}"
|
|
177
|
-
)
|
|
178
|
-
|
|
179
|
-
if not self._index and self._index != build_long_term_memory_index(
|
|
180
|
-
app_name, user_id
|
|
181
|
-
):
|
|
182
|
-
logger.warning(
|
|
183
|
-
f"The `app_name` or `user_id` is different from the initialized one, skip add session to memory. Initialized index: {self._index}, current built index: {build_long_term_memory_index(app_name, user_id)}"
|
|
184
|
-
)
|
|
185
|
-
return
|
|
186
158
|
event_strings = self._filter_and_convert_events(session.events)
|
|
187
159
|
|
|
188
160
|
logger.info(
|
|
189
|
-
f"Adding {len(event_strings)} events to long term memory: index={self.
|
|
161
|
+
f"Adding {len(event_strings)} events to long term memory: index={self.index}"
|
|
162
|
+
)
|
|
163
|
+
self._backend.save_memory(user_id=user_id, event_strings=event_strings)
|
|
164
|
+
logger.info(
|
|
165
|
+
f"Added {len(event_strings)} events to long term memory: index={self.index}, user_id={user_id}"
|
|
190
166
|
)
|
|
191
|
-
|
|
192
|
-
if self._backend:
|
|
193
|
-
self._backend.save_memory(event_strings=event_strings, user_id=user_id)
|
|
194
|
-
|
|
195
|
-
logger.info(
|
|
196
|
-
f"Added {len(event_strings)} events to long term memory: index={self._index}"
|
|
197
|
-
)
|
|
198
|
-
else:
|
|
199
|
-
logger.error(
|
|
200
|
-
"Long term memory backend initialize failed, cannot add session to memory."
|
|
201
|
-
)
|
|
202
167
|
|
|
203
168
|
@override
|
|
204
|
-
async def search_memory(
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
self._backend = _get_backend_cls(self.backend)(
|
|
209
|
-
index=self._index, **self.backend_config if self.backend_config else {}
|
|
210
|
-
)
|
|
211
|
-
logger.info(
|
|
212
|
-
f"Initialize long term memory backend now, index is {self._index}"
|
|
213
|
-
)
|
|
169
|
+
async def search_memory(
|
|
170
|
+
self, *, app_name: str, user_id: str, query: str
|
|
171
|
+
) -> SearchMemoryResponse:
|
|
172
|
+
logger.info(f"Search memory with query={query}")
|
|
214
173
|
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
f"The `app_name` or `user_id` is different from the initialized one. Initialized index: {self._index}, current built index: {build_long_term_memory_index(app_name, user_id)}. Search memory return empty list."
|
|
174
|
+
memory_chunks = []
|
|
175
|
+
try:
|
|
176
|
+
memory_chunks = self._backend.search_memory(
|
|
177
|
+
query=query, top_k=self.top_k, user_id=user_id
|
|
220
178
|
)
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
if not self._backend:
|
|
179
|
+
except Exception as e:
|
|
224
180
|
logger.error(
|
|
225
|
-
"
|
|
181
|
+
f"Exception orrcus during memory search: {e}. Return empty memory chunks"
|
|
226
182
|
)
|
|
227
|
-
return SearchMemoryResponse(memories=[])
|
|
228
|
-
|
|
229
|
-
logger.info(
|
|
230
|
-
f"Searching long term memory: query={query} index={self._index} top_k={self.top_k}"
|
|
231
|
-
)
|
|
232
|
-
|
|
233
|
-
memory_chunks = self._backend.search_memory(
|
|
234
|
-
query=query, top_k=self.top_k, user_id=user_id
|
|
235
|
-
)
|
|
236
183
|
|
|
237
184
|
memory_events = []
|
|
238
185
|
for memory in memory_chunks:
|
|
@@ -260,6 +207,6 @@ class LongTermMemory(BaseMemoryService, BaseModel):
|
|
|
260
207
|
)
|
|
261
208
|
|
|
262
209
|
logger.info(
|
|
263
|
-
f"Return {len(memory_events)} memory events for query: {query} index={self.
|
|
210
|
+
f"Return {len(memory_events)} memory events for query: {query} index={self.index} user_id={user_id}"
|
|
264
211
|
)
|
|
265
212
|
return SearchMemoryResponse(memories=memory_events)
|
|
@@ -25,9 +25,11 @@ class BaseLongTermMemoryBackend(ABC, BaseModel):
|
|
|
25
25
|
"""Check the index name is valid or not"""
|
|
26
26
|
|
|
27
27
|
@abstractmethod
|
|
28
|
-
def save_memory(self, event_strings: list[str], **kwargs) -> bool:
|
|
28
|
+
def save_memory(self, user_id: str, event_strings: list[str], **kwargs) -> bool:
|
|
29
29
|
"""Save memory to long term memory backend"""
|
|
30
30
|
|
|
31
31
|
@abstractmethod
|
|
32
|
-
def search_memory(
|
|
32
|
+
def search_memory(
|
|
33
|
+
self, user_id: str, query: str, top_k: int, **kwargs
|
|
34
|
+
) -> list[str]:
|
|
33
35
|
"""Retrieve memory from long term memory backend"""
|
|
@@ -29,10 +29,6 @@ class InMemoryLTMBackend(BaseLongTermMemoryBackend):
|
|
|
29
29
|
embedding_config: EmbeddingModelConfig = Field(default_factory=EmbeddingModelConfig)
|
|
30
30
|
"""Embedding model configs"""
|
|
31
31
|
|
|
32
|
-
def precheck_index_naming(self):
|
|
33
|
-
# no checking
|
|
34
|
-
pass
|
|
35
|
-
|
|
36
32
|
def model_post_init(self, __context: Any) -> None:
|
|
37
33
|
self._embed_model = OpenAILikeEmbedding(
|
|
38
34
|
model_name=self.embedding_config.name,
|
|
@@ -41,8 +37,12 @@ class InMemoryLTMBackend(BaseLongTermMemoryBackend):
|
|
|
41
37
|
)
|
|
42
38
|
self._vector_index = VectorStoreIndex([], embed_model=self._embed_model)
|
|
43
39
|
|
|
40
|
+
def precheck_index_naming(self):
|
|
41
|
+
# no checking
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
44
|
@override
|
|
45
|
-
def save_memory(self, event_strings: list[str], **kwargs) -> bool:
|
|
45
|
+
def save_memory(self, user_id: str, event_strings: list[str], **kwargs) -> bool:
|
|
46
46
|
for event_string in event_strings:
|
|
47
47
|
document = Document(text=event_string)
|
|
48
48
|
nodes = self._split_documents([document])
|
|
@@ -50,7 +50,9 @@ class InMemoryLTMBackend(BaseLongTermMemoryBackend):
|
|
|
50
50
|
return True
|
|
51
51
|
|
|
52
52
|
@override
|
|
53
|
-
def search_memory(
|
|
53
|
+
def search_memory(
|
|
54
|
+
self, user_id: str, query: str, top_k: int, **kwargs
|
|
55
|
+
) -> list[str]:
|
|
54
56
|
_retriever = self._vector_index.as_retriever(similarity_top_k=top_k)
|
|
55
57
|
retrieved_nodes = _retriever.retrieve(query)
|
|
56
58
|
return [node.text for node in retrieved_nodes]
|
|
@@ -13,12 +13,11 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
from typing import Any
|
|
16
|
-
|
|
16
|
+
|
|
17
17
|
from pydantic import Field
|
|
18
|
+
from typing_extensions import override
|
|
18
19
|
|
|
19
20
|
from veadk.configs.database_configs import Mem0Config
|
|
20
|
-
|
|
21
|
-
|
|
22
21
|
from veadk.memory.long_term_memory_backends.base_backend import (
|
|
23
22
|
BaseLongTermMemoryBackend,
|
|
24
23
|
)
|
|
@@ -66,7 +65,9 @@ class Mem0LTMBackend(BaseLongTermMemoryBackend):
|
|
|
66
65
|
pass
|
|
67
66
|
|
|
68
67
|
@override
|
|
69
|
-
def save_memory(
|
|
68
|
+
def save_memory(
|
|
69
|
+
self, event_strings: list[str], user_id: str = "default_user", **kwargs
|
|
70
|
+
) -> bool:
|
|
70
71
|
"""Save memory to Mem0
|
|
71
72
|
|
|
72
73
|
Args:
|
|
@@ -76,8 +77,6 @@ class Mem0LTMBackend(BaseLongTermMemoryBackend):
|
|
|
76
77
|
Returns:
|
|
77
78
|
bool: True if saved successfully, False otherwise
|
|
78
79
|
"""
|
|
79
|
-
user_id = kwargs.get("user_id", "default_user")
|
|
80
|
-
|
|
81
80
|
try:
|
|
82
81
|
logger.info(
|
|
83
82
|
f"Saving {len(event_strings)} events to Mem0 for user: {user_id}"
|
|
@@ -100,7 +99,9 @@ class Mem0LTMBackend(BaseLongTermMemoryBackend):
|
|
|
100
99
|
return False
|
|
101
100
|
|
|
102
101
|
@override
|
|
103
|
-
def search_memory(
|
|
102
|
+
def search_memory(
|
|
103
|
+
self, query: str, top_k: int, user_id: str = "default_user", **kwargs
|
|
104
|
+
) -> list[str]:
|
|
104
105
|
"""Search memory from Mem0
|
|
105
106
|
|
|
106
107
|
Args:
|
|
@@ -111,7 +112,6 @@ class Mem0LTMBackend(BaseLongTermMemoryBackend):
|
|
|
111
112
|
Returns:
|
|
112
113
|
list[str]: List of memory strings
|
|
113
114
|
"""
|
|
114
|
-
user_id = kwargs.get("user_id", "default_user")
|
|
115
115
|
|
|
116
116
|
try:
|
|
117
117
|
logger.info(
|
|
@@ -14,11 +14,7 @@
|
|
|
14
14
|
|
|
15
15
|
import re
|
|
16
16
|
|
|
17
|
-
from llama_index.core import
|
|
18
|
-
Document,
|
|
19
|
-
StorageContext,
|
|
20
|
-
VectorStoreIndex,
|
|
21
|
-
)
|
|
17
|
+
from llama_index.core import Document, VectorStoreIndex
|
|
22
18
|
from llama_index.core.schema import BaseNode
|
|
23
19
|
from llama_index.embeddings.openai_like import OpenAILikeEmbedding
|
|
24
20
|
from pydantic import Field
|
|
@@ -31,6 +27,7 @@ from veadk.knowledgebase.backends.utils import get_llama_index_splitter
|
|
|
31
27
|
from veadk.memory.long_term_memory_backends.base_backend import (
|
|
32
28
|
BaseLongTermMemoryBackend,
|
|
33
29
|
)
|
|
30
|
+
from veadk.utils.logger import get_logger
|
|
34
31
|
|
|
35
32
|
try:
|
|
36
33
|
from llama_index.vector_stores.opensearch import (
|
|
@@ -42,6 +39,8 @@ except ImportError:
|
|
|
42
39
|
"Please install VeADK extensions\npip install veadk-python[extensions]"
|
|
43
40
|
)
|
|
44
41
|
|
|
42
|
+
logger = get_logger(__name__)
|
|
43
|
+
|
|
45
44
|
|
|
46
45
|
class OpensearchLTMBackend(BaseLongTermMemoryBackend):
|
|
47
46
|
opensearch_config: OpensearchConfig = Field(default_factory=OpensearchConfig)
|
|
@@ -52,19 +51,30 @@ class OpensearchLTMBackend(BaseLongTermMemoryBackend):
|
|
|
52
51
|
)
|
|
53
52
|
"""Embedding model configs"""
|
|
54
53
|
|
|
55
|
-
def
|
|
54
|
+
def model_post_init(self, __context: Any) -> None:
|
|
55
|
+
self._embed_model = OpenAILikeEmbedding(
|
|
56
|
+
model_name=self.embedding_config.name,
|
|
57
|
+
api_key=self.embedding_config.api_key,
|
|
58
|
+
api_base=self.embedding_config.api_base,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
def precheck_index_naming(self, index: str):
|
|
56
62
|
if not (
|
|
57
|
-
isinstance(
|
|
58
|
-
and not
|
|
59
|
-
and
|
|
60
|
-
and re.match(r"^[a-z0-9_\-.]+$",
|
|
63
|
+
isinstance(index, str)
|
|
64
|
+
and not index.startswith(("_", "-"))
|
|
65
|
+
and index.islower()
|
|
66
|
+
and re.match(r"^[a-z0-9_\-.]+$", index)
|
|
61
67
|
):
|
|
62
68
|
raise ValueError(
|
|
63
|
-
"The index name does not conform to the naming rules of OpenSearch"
|
|
69
|
+
f"The index name {index} does not conform to the naming rules of OpenSearch"
|
|
64
70
|
)
|
|
65
71
|
|
|
66
|
-
def
|
|
67
|
-
|
|
72
|
+
def _create_vector_index(self, index: str) -> VectorStoreIndex:
|
|
73
|
+
logger.info(f"Create OpenSearch vector index with index={index}")
|
|
74
|
+
|
|
75
|
+
self.precheck_index_naming(index)
|
|
76
|
+
|
|
77
|
+
opensearch_client = OpensearchVectorClient(
|
|
68
78
|
endpoint=self.opensearch_config.host,
|
|
69
79
|
port=self.opensearch_config.port,
|
|
70
80
|
http_auth=(
|
|
@@ -74,39 +84,33 @@ class OpensearchLTMBackend(BaseLongTermMemoryBackend):
|
|
|
74
84
|
use_ssl=True,
|
|
75
85
|
verify_certs=False,
|
|
76
86
|
dim=self.embedding_config.dim,
|
|
77
|
-
index=
|
|
87
|
+
index=index,
|
|
78
88
|
)
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
self._storage_context = StorageContext.from_defaults(
|
|
83
|
-
vector_store=self._vector_store
|
|
84
|
-
)
|
|
85
|
-
|
|
86
|
-
self._embed_model = OpenAILikeEmbedding(
|
|
87
|
-
model_name=self.embedding_config.name,
|
|
88
|
-
api_key=self.embedding_config.api_key,
|
|
89
|
-
api_base=self.embedding_config.api_base,
|
|
90
|
-
)
|
|
91
|
-
|
|
92
|
-
self._vector_index = VectorStoreIndex.from_documents(
|
|
93
|
-
documents=[],
|
|
94
|
-
storage_context=self._storage_context,
|
|
95
|
-
embed_model=self._embed_model,
|
|
89
|
+
vector_store = OpensearchVectorStore(client=opensearch_client)
|
|
90
|
+
return VectorStoreIndex.from_vector_store(
|
|
91
|
+
vector_store=vector_store, embed_model=self._embed_model
|
|
96
92
|
)
|
|
97
|
-
self._retriever = self._vector_index.as_retriever()
|
|
98
93
|
|
|
99
94
|
@override
|
|
100
|
-
def save_memory(self, event_strings: list[str], **kwargs) -> bool:
|
|
95
|
+
def save_memory(self, user_id: str, event_strings: list[str], **kwargs) -> bool:
|
|
96
|
+
index = f"{self.index}_{user_id}"
|
|
97
|
+
vector_index = self._create_vector_index(index)
|
|
98
|
+
|
|
101
99
|
for event_string in event_strings:
|
|
102
100
|
document = Document(text=event_string)
|
|
103
101
|
nodes = self._split_documents([document])
|
|
104
|
-
|
|
102
|
+
vector_index.insert_nodes(nodes)
|
|
105
103
|
return True
|
|
106
104
|
|
|
107
105
|
@override
|
|
108
|
-
def search_memory(
|
|
109
|
-
|
|
106
|
+
def search_memory(
|
|
107
|
+
self, user_id: str, query: str, top_k: int, **kwargs
|
|
108
|
+
) -> list[str]:
|
|
109
|
+
index = f"{self.index}_{user_id}"
|
|
110
|
+
|
|
111
|
+
vector_index = self._create_vector_index(index)
|
|
112
|
+
|
|
113
|
+
_retriever = vector_index.as_retriever(similarity_top_k=top_k)
|
|
110
114
|
retrieved_nodes = _retriever.retrieve(query)
|
|
111
115
|
return [node.text for node in retrieved_nodes]
|
|
112
116
|
|
|
@@ -12,11 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from llama_index.core import
|
|
16
|
-
Document,
|
|
17
|
-
StorageContext,
|
|
18
|
-
VectorStoreIndex,
|
|
19
|
-
)
|
|
15
|
+
from llama_index.core import Document, VectorStoreIndex
|
|
20
16
|
from llama_index.core.schema import BaseNode
|
|
21
17
|
from llama_index.embeddings.openai_like import OpenAILikeEmbedding
|
|
22
18
|
from pydantic import Field
|
|
@@ -29,21 +25,22 @@ from veadk.knowledgebase.backends.utils import get_llama_index_splitter
|
|
|
29
25
|
from veadk.memory.long_term_memory_backends.base_backend import (
|
|
30
26
|
BaseLongTermMemoryBackend,
|
|
31
27
|
)
|
|
28
|
+
from veadk.utils.logger import get_logger
|
|
32
29
|
|
|
33
30
|
try:
|
|
34
31
|
from llama_index.vector_stores.redis import RedisVectorStore
|
|
35
|
-
from llama_index.vector_stores.redis.schema import (
|
|
36
|
-
RedisIndexInfo,
|
|
37
|
-
RedisVectorStoreSchema,
|
|
38
|
-
)
|
|
39
32
|
from redis import Redis
|
|
40
|
-
from redisvl.schema
|
|
33
|
+
from redisvl.schema import IndexSchema
|
|
34
|
+
|
|
41
35
|
except ImportError:
|
|
42
36
|
raise ImportError(
|
|
43
37
|
"Please install VeADK extensions\npip install veadk-python[extensions]"
|
|
44
38
|
)
|
|
45
39
|
|
|
46
40
|
|
|
41
|
+
logger = get_logger(__name__)
|
|
42
|
+
|
|
43
|
+
|
|
47
44
|
class RedisLTMBackend(BaseLongTermMemoryBackend):
|
|
48
45
|
redis_config: RedisConfig = Field(default_factory=RedisConfig)
|
|
49
46
|
"""Redis client configs"""
|
|
@@ -53,67 +50,83 @@ class RedisLTMBackend(BaseLongTermMemoryBackend):
|
|
|
53
50
|
)
|
|
54
51
|
"""Embedding model configs"""
|
|
55
52
|
|
|
56
|
-
def
|
|
53
|
+
def model_post_init(self, __context: Any) -> None:
|
|
54
|
+
self._embed_model = OpenAILikeEmbedding(
|
|
55
|
+
model_name=self.embedding_config.name,
|
|
56
|
+
api_key=self.embedding_config.api_key,
|
|
57
|
+
api_base=self.embedding_config.api_base,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
def precheck_index_naming(self, index: str):
|
|
57
61
|
# no checking
|
|
58
62
|
pass
|
|
59
63
|
|
|
60
|
-
def
|
|
64
|
+
def _create_vector_index(self, index: str) -> VectorStoreIndex:
|
|
65
|
+
logger.info(f"Create Redis vector index with index={index}")
|
|
66
|
+
|
|
67
|
+
self.precheck_index_naming(index)
|
|
68
|
+
|
|
61
69
|
# We will use `from_url` to init Redis client once the
|
|
62
70
|
# AK/SK -> STS token is ready.
|
|
63
71
|
# self._redis_client = Redis.from_url(url=...)
|
|
64
|
-
|
|
65
|
-
self._redis_client = Redis(
|
|
72
|
+
redis_client = Redis(
|
|
66
73
|
host=self.redis_config.host,
|
|
67
74
|
port=self.redis_config.port,
|
|
68
75
|
db=self.redis_config.db,
|
|
69
76
|
password=self.redis_config.password,
|
|
70
77
|
)
|
|
71
78
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
79
|
+
# Create an index for each user
|
|
80
|
+
# Should be Optimized in the future
|
|
81
|
+
schema = IndexSchema.from_dict(
|
|
82
|
+
{
|
|
83
|
+
"index": {"name": index, "prefix": index, "key_separator": "_"},
|
|
84
|
+
"fields": [
|
|
85
|
+
{"name": "id", "type": "tag", "attrs": {"sortable": False}},
|
|
86
|
+
{"name": "doc_id", "type": "tag", "attrs": {"sortable": False}},
|
|
87
|
+
{"name": "text", "type": "text", "attrs": {"weight": 1.0}},
|
|
88
|
+
{
|
|
89
|
+
"name": "vector",
|
|
90
|
+
"type": "vector",
|
|
91
|
+
"attrs": {
|
|
92
|
+
"dims": self.embedding_config.dim,
|
|
93
|
+
"algorithm": "flat",
|
|
94
|
+
"distance_metric": "cosine",
|
|
95
|
+
},
|
|
96
|
+
},
|
|
97
|
+
],
|
|
98
|
+
}
|
|
76
99
|
)
|
|
100
|
+
vector_store = RedisVectorStore(schema=schema, redis_client=redis_client)
|
|
77
101
|
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
)
|
|
81
|
-
if "vector" in self._schema.fields:
|
|
82
|
-
vector_field = self._schema.fields["vector"]
|
|
83
|
-
if (
|
|
84
|
-
vector_field
|
|
85
|
-
and vector_field.attrs
|
|
86
|
-
and isinstance(vector_field.attrs, BaseVectorFieldAttributes)
|
|
87
|
-
):
|
|
88
|
-
vector_field.attrs.dims = self.embedding_config.dim
|
|
89
|
-
self._vector_store = RedisVectorStore(
|
|
90
|
-
schema=self._schema,
|
|
91
|
-
redis_client=self._redis_client,
|
|
92
|
-
overwrite=True,
|
|
93
|
-
collection_name=self.index,
|
|
94
|
-
)
|
|
95
|
-
|
|
96
|
-
self._storage_context = StorageContext.from_defaults(
|
|
97
|
-
vector_store=self._vector_store
|
|
102
|
+
logger.info(
|
|
103
|
+
f"Create vector store done, index_name={vector_store.index_name} prefix={vector_store.schema.index.prefix}"
|
|
98
104
|
)
|
|
99
105
|
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
storage_context=self._storage_context,
|
|
103
|
-
embed_model=self._embed_model,
|
|
106
|
+
return VectorStoreIndex.from_vector_store(
|
|
107
|
+
vector_store=vector_store, embed_model=self._embed_model
|
|
104
108
|
)
|
|
105
109
|
|
|
106
110
|
@override
|
|
107
|
-
def save_memory(self, event_strings: list[str], **kwargs) -> bool:
|
|
111
|
+
def save_memory(self, user_id: str, event_strings: list[str], **kwargs) -> bool:
|
|
112
|
+
index = f"veadk-ltm/{self.index}/{user_id}"
|
|
113
|
+
vector_index = self._create_vector_index(index)
|
|
114
|
+
|
|
108
115
|
for event_string in event_strings:
|
|
109
116
|
document = Document(text=event_string)
|
|
110
117
|
nodes = self._split_documents([document])
|
|
111
|
-
|
|
118
|
+
vector_index.insert_nodes(nodes)
|
|
119
|
+
|
|
112
120
|
return True
|
|
113
121
|
|
|
114
122
|
@override
|
|
115
|
-
def search_memory(
|
|
116
|
-
|
|
123
|
+
def search_memory(
|
|
124
|
+
self, user_id: str, query: str, top_k: int, **kwargs
|
|
125
|
+
) -> list[str]:
|
|
126
|
+
index = f"veadk-ltm/{self.index}/{user_id}"
|
|
127
|
+
vector_index = self._create_vector_index(index)
|
|
128
|
+
|
|
129
|
+
_retriever = vector_index.as_retriever(similarity_top_k=top_k)
|
|
117
130
|
retrieved_nodes = _retriever.retrieve(query)
|
|
118
131
|
return [node.text for node in retrieved_nodes]
|
|
119
132
|
|