veadk-python 0.2.9__py3-none-any.whl → 0.2.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of veadk-python might be problematic. Click here for more details.
- veadk/a2a/remote_ve_agent.py +63 -6
- veadk/agent.py +10 -3
- veadk/agent_builder.py +2 -3
- veadk/auth/veauth/ark_veauth.py +43 -51
- veadk/auth/veauth/utils.py +57 -0
- veadk/cli/cli.py +2 -0
- veadk/cli/cli_kb.py +75 -0
- veadk/cli/cli_web.py +4 -0
- veadk/configs/model_configs.py +3 -3
- veadk/consts.py +9 -0
- veadk/integrations/__init__.py +13 -0
- veadk/integrations/ve_viking_db_memory/__init__.py +13 -0
- veadk/integrations/ve_viking_db_memory/ve_viking_db_memory.py +293 -0
- veadk/knowledgebase/knowledgebase.py +19 -32
- veadk/memory/__init__.py +1 -1
- veadk/memory/long_term_memory.py +40 -68
- veadk/memory/long_term_memory_backends/base_backend.py +4 -2
- veadk/memory/long_term_memory_backends/in_memory_backend.py +8 -6
- veadk/memory/long_term_memory_backends/mem0_backend.py +25 -10
- veadk/memory/long_term_memory_backends/opensearch_backend.py +40 -36
- veadk/memory/long_term_memory_backends/redis_backend.py +59 -46
- veadk/memory/long_term_memory_backends/vikingdb_memory_backend.py +56 -35
- veadk/memory/short_term_memory.py +12 -8
- veadk/memory/short_term_memory_backends/postgresql_backend.py +3 -1
- veadk/runner.py +42 -19
- veadk/tools/builtin_tools/generate_image.py +56 -17
- veadk/tools/builtin_tools/image_edit.py +17 -7
- veadk/tools/builtin_tools/image_generate.py +17 -7
- veadk/tools/builtin_tools/load_knowledgebase.py +97 -0
- veadk/tools/builtin_tools/video_generate.py +11 -9
- veadk/tools/builtin_tools/web_search.py +10 -3
- veadk/tools/load_knowledgebase_tool.py +12 -0
- veadk/tracing/telemetry/attributes/extractors/llm_attributes_extractors.py +5 -0
- veadk/tracing/telemetry/attributes/extractors/tool_attributes_extractors.py +7 -0
- veadk/tracing/telemetry/exporters/apmplus_exporter.py +82 -2
- veadk/tracing/telemetry/exporters/inmemory_exporter.py +8 -2
- veadk/tracing/telemetry/telemetry.py +41 -5
- veadk/utils/misc.py +6 -10
- veadk/utils/volcengine_sign.py +2 -0
- veadk/version.py +1 -1
- {veadk_python-0.2.9.dist-info → veadk_python-0.2.11.dist-info}/METADATA +4 -3
- {veadk_python-0.2.9.dist-info → veadk_python-0.2.11.dist-info}/RECORD +46 -40
- {veadk_python-0.2.9.dist-info → veadk_python-0.2.11.dist-info}/WHEEL +0 -0
- {veadk_python-0.2.9.dist-info → veadk_python-0.2.11.dist-info}/entry_points.txt +0 -0
- {veadk_python-0.2.9.dist-info → veadk_python-0.2.11.dist-info}/licenses/LICENSE +0 -0
- {veadk_python-0.2.9.dist-info → veadk_python-0.2.11.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,293 @@
|
|
|
1
|
+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import json
|
|
16
|
+
import threading
|
|
17
|
+
|
|
18
|
+
from volcengine.ApiInfo import ApiInfo
|
|
19
|
+
from volcengine.auth.SignerV4 import SignerV4
|
|
20
|
+
from volcengine.base.Service import Service
|
|
21
|
+
from volcengine.Credentials import Credentials
|
|
22
|
+
from volcengine.ServiceInfo import ServiceInfo
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class VikingDBMemoryException(Exception):
|
|
26
|
+
def __init__(self, code, request_id, message=None):
|
|
27
|
+
self.code = code
|
|
28
|
+
self.request_id = request_id
|
|
29
|
+
self.message = "{}, code:{},request_id:{}".format(
|
|
30
|
+
message, self.code, self.request_id
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
def __str__(self):
|
|
34
|
+
return self.message
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class VikingDBMemoryClient(Service):
|
|
38
|
+
_instance_lock = threading.Lock()
|
|
39
|
+
|
|
40
|
+
def __new__(cls, *args, **kwargs):
|
|
41
|
+
if not hasattr(VikingDBMemoryClient, "_instance"):
|
|
42
|
+
with VikingDBMemoryClient._instance_lock:
|
|
43
|
+
if not hasattr(VikingDBMemoryClient, "_instance"):
|
|
44
|
+
VikingDBMemoryClient._instance = object.__new__(cls)
|
|
45
|
+
return VikingDBMemoryClient._instance
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
host="api-knowledgebase.mlp.cn-beijing.volces.com",
|
|
50
|
+
region="cn-beijing",
|
|
51
|
+
ak="",
|
|
52
|
+
sk="",
|
|
53
|
+
sts_token="",
|
|
54
|
+
scheme="http",
|
|
55
|
+
connection_timeout=30,
|
|
56
|
+
socket_timeout=30,
|
|
57
|
+
):
|
|
58
|
+
self.service_info = VikingDBMemoryClient.get_service_info(
|
|
59
|
+
host, region, scheme, connection_timeout, socket_timeout
|
|
60
|
+
)
|
|
61
|
+
self.api_info = VikingDBMemoryClient.get_api_info()
|
|
62
|
+
super(VikingDBMemoryClient, self).__init__(self.service_info, self.api_info)
|
|
63
|
+
if ak:
|
|
64
|
+
self.set_ak(ak)
|
|
65
|
+
if sk:
|
|
66
|
+
self.set_sk(sk)
|
|
67
|
+
if sts_token:
|
|
68
|
+
self.set_session_token(session_token=sts_token)
|
|
69
|
+
try:
|
|
70
|
+
self.get_body("Ping", {}, json.dumps({}))
|
|
71
|
+
except Exception as e:
|
|
72
|
+
raise VikingDBMemoryException(
|
|
73
|
+
1000028, "missed", "host or region is incorrect: {}".format(str(e))
|
|
74
|
+
) from None
|
|
75
|
+
|
|
76
|
+
def setHeader(self, header):
|
|
77
|
+
api_info = VikingDBMemoryClient.get_api_info()
|
|
78
|
+
for key in api_info:
|
|
79
|
+
for item in header:
|
|
80
|
+
api_info[key].header[item] = header[item]
|
|
81
|
+
self.api_info = api_info
|
|
82
|
+
|
|
83
|
+
@staticmethod
|
|
84
|
+
def get_service_info(host, region, scheme, connection_timeout, socket_timeout):
|
|
85
|
+
service_info = ServiceInfo(
|
|
86
|
+
host,
|
|
87
|
+
{"Host": host},
|
|
88
|
+
Credentials("", "", "air", region),
|
|
89
|
+
connection_timeout,
|
|
90
|
+
socket_timeout,
|
|
91
|
+
scheme=scheme,
|
|
92
|
+
)
|
|
93
|
+
return service_info
|
|
94
|
+
|
|
95
|
+
@staticmethod
|
|
96
|
+
def get_api_info():
|
|
97
|
+
api_info = {
|
|
98
|
+
"CreateCollection": ApiInfo(
|
|
99
|
+
"POST",
|
|
100
|
+
"/api/memory/collection/create",
|
|
101
|
+
{},
|
|
102
|
+
{},
|
|
103
|
+
{"Accept": "application/json", "Content-Type": "application/json"},
|
|
104
|
+
),
|
|
105
|
+
"GetCollection": ApiInfo(
|
|
106
|
+
"POST",
|
|
107
|
+
"/api/memory/collection/info",
|
|
108
|
+
{},
|
|
109
|
+
{},
|
|
110
|
+
{"Accept": "application/json", "Content-Type": "application/json"},
|
|
111
|
+
),
|
|
112
|
+
"DropCollection": ApiInfo(
|
|
113
|
+
"POST",
|
|
114
|
+
"/api/memory/collection/delete",
|
|
115
|
+
{},
|
|
116
|
+
{},
|
|
117
|
+
{"Accept": "application/json", "Content-Type": "application/json"},
|
|
118
|
+
),
|
|
119
|
+
"UpdateCollection": ApiInfo(
|
|
120
|
+
"POST",
|
|
121
|
+
"/api/memory/collection/update",
|
|
122
|
+
{},
|
|
123
|
+
{},
|
|
124
|
+
{"Accept": "application/json", "Content-Type": "application/json"},
|
|
125
|
+
),
|
|
126
|
+
"SearchMemory": ApiInfo(
|
|
127
|
+
"POST",
|
|
128
|
+
"/api/memory/search",
|
|
129
|
+
{},
|
|
130
|
+
{},
|
|
131
|
+
{"Accept": "application/json", "Content-Type": "application/json"},
|
|
132
|
+
),
|
|
133
|
+
"AddMessages": ApiInfo(
|
|
134
|
+
"POST",
|
|
135
|
+
"/api/memory/messages/add",
|
|
136
|
+
{},
|
|
137
|
+
{},
|
|
138
|
+
{"Accept": "application/json", "Content-Type": "application/json"},
|
|
139
|
+
),
|
|
140
|
+
"Ping": ApiInfo(
|
|
141
|
+
"GET",
|
|
142
|
+
"/api/memory/ping",
|
|
143
|
+
{},
|
|
144
|
+
{},
|
|
145
|
+
{"Accept": "application/json", "Content-Type": "application/json"},
|
|
146
|
+
),
|
|
147
|
+
}
|
|
148
|
+
return api_info
|
|
149
|
+
|
|
150
|
+
def get_body(self, api, params, body):
|
|
151
|
+
if api not in self.api_info:
|
|
152
|
+
raise Exception("no such api")
|
|
153
|
+
api_info = self.api_info[api]
|
|
154
|
+
r = self.prepare_request(api_info, params)
|
|
155
|
+
r.headers["Content-Type"] = "application/json"
|
|
156
|
+
r.headers["Traffic-Source"] = "SDK"
|
|
157
|
+
r.body = body
|
|
158
|
+
|
|
159
|
+
SignerV4.sign(r, self.service_info.credentials)
|
|
160
|
+
|
|
161
|
+
url = r.build()
|
|
162
|
+
resp = self.session.get(
|
|
163
|
+
url,
|
|
164
|
+
headers=r.headers,
|
|
165
|
+
data=r.body,
|
|
166
|
+
timeout=(
|
|
167
|
+
self.service_info.connection_timeout,
|
|
168
|
+
self.service_info.socket_timeout,
|
|
169
|
+
),
|
|
170
|
+
)
|
|
171
|
+
if resp.status_code == 200:
|
|
172
|
+
return json.dumps(resp.json())
|
|
173
|
+
else:
|
|
174
|
+
raise Exception(resp.text.encode("utf-8"))
|
|
175
|
+
|
|
176
|
+
def get_body_exception(self, api, params, body):
|
|
177
|
+
try:
|
|
178
|
+
res = self.get_body(api, params, body)
|
|
179
|
+
except Exception as e:
|
|
180
|
+
try:
|
|
181
|
+
res_json = json.loads(e.args[0].decode("utf-8"))
|
|
182
|
+
except Exception as e:
|
|
183
|
+
raise VikingDBMemoryException(
|
|
184
|
+
1000028, "missed", "json load res error, res:{}".format(str(e))
|
|
185
|
+
) from None
|
|
186
|
+
code = res_json.get("code", 1000028)
|
|
187
|
+
request_id = res_json.get("request_id", 1000028)
|
|
188
|
+
message = res_json.get("message", None)
|
|
189
|
+
|
|
190
|
+
raise VikingDBMemoryException(code, request_id, message)
|
|
191
|
+
|
|
192
|
+
if res == "":
|
|
193
|
+
raise VikingDBMemoryException(
|
|
194
|
+
1000028,
|
|
195
|
+
"missed",
|
|
196
|
+
"empty response due to unknown error, please contact customer service",
|
|
197
|
+
) from None
|
|
198
|
+
return res
|
|
199
|
+
|
|
200
|
+
def get_exception(self, api, params):
|
|
201
|
+
try:
|
|
202
|
+
res = self.get(api, params)
|
|
203
|
+
except Exception as e:
|
|
204
|
+
try:
|
|
205
|
+
res_json = json.loads(e.args[0].decode("utf-8"))
|
|
206
|
+
except Exception as e:
|
|
207
|
+
raise VikingDBMemoryException(
|
|
208
|
+
1000028, "missed", "json load res error, res:{}".format(str(e))
|
|
209
|
+
) from None
|
|
210
|
+
code = res_json.get("code", 1000028)
|
|
211
|
+
request_id = res_json.get("request_id", 1000028)
|
|
212
|
+
message = res_json.get("message", None)
|
|
213
|
+
raise VikingDBMemoryException(code, request_id, message)
|
|
214
|
+
if res == "":
|
|
215
|
+
raise VikingDBMemoryException(
|
|
216
|
+
1000028,
|
|
217
|
+
"missed",
|
|
218
|
+
"empty response due to unknown error, please contact customer service",
|
|
219
|
+
) from None
|
|
220
|
+
return res
|
|
221
|
+
|
|
222
|
+
def create_collection(
|
|
223
|
+
self,
|
|
224
|
+
collection_name,
|
|
225
|
+
description="",
|
|
226
|
+
custom_event_type_schemas=[],
|
|
227
|
+
custom_entity_type_schemas=[],
|
|
228
|
+
builtin_event_types=[],
|
|
229
|
+
builtin_entity_types=[],
|
|
230
|
+
):
|
|
231
|
+
params = {
|
|
232
|
+
"CollectionName": collection_name,
|
|
233
|
+
"Description": description,
|
|
234
|
+
"CustomEventTypeSchemas": custom_event_type_schemas,
|
|
235
|
+
"CustomEntityTypeSchemas": custom_entity_type_schemas,
|
|
236
|
+
"BuiltinEventTypes": builtin_event_types,
|
|
237
|
+
"BuiltinEntityTypes": builtin_entity_types,
|
|
238
|
+
}
|
|
239
|
+
res = self.json("CreateCollection", {}, json.dumps(params))
|
|
240
|
+
return json.loads(res)
|
|
241
|
+
|
|
242
|
+
def get_collection(self, collection_name):
|
|
243
|
+
params = {"CollectionName": collection_name}
|
|
244
|
+
res = self.json("GetCollection", {}, json.dumps(params))
|
|
245
|
+
return json.loads(res)
|
|
246
|
+
|
|
247
|
+
def drop_collection(self, collection_name):
|
|
248
|
+
params = {"CollectionName": collection_name}
|
|
249
|
+
res = self.json("DropCollection", {}, json.dumps(params))
|
|
250
|
+
return json.loads(res)
|
|
251
|
+
|
|
252
|
+
def update_collection(
|
|
253
|
+
self,
|
|
254
|
+
collection_name,
|
|
255
|
+
custom_event_type_schemas=[],
|
|
256
|
+
custom_entity_type_schemas=[],
|
|
257
|
+
builtin_event_types=[],
|
|
258
|
+
builtin_entity_types=[],
|
|
259
|
+
):
|
|
260
|
+
params = {
|
|
261
|
+
"CollectionName": collection_name,
|
|
262
|
+
"CustomEventTypeSchemas": custom_event_type_schemas,
|
|
263
|
+
"CustomEntityTypeSchemas": custom_entity_type_schemas,
|
|
264
|
+
"BuiltinEventTypes": builtin_event_types,
|
|
265
|
+
"BuiltinEntityTypes": builtin_entity_types,
|
|
266
|
+
}
|
|
267
|
+
res = self.json("UpdateCollection", {}, json.dumps(params))
|
|
268
|
+
return json.loads(res)
|
|
269
|
+
|
|
270
|
+
def search_memory(self, collection_name, query, filter, limit=10):
|
|
271
|
+
params = {
|
|
272
|
+
"collection_name": collection_name,
|
|
273
|
+
"limit": limit,
|
|
274
|
+
"filter": filter,
|
|
275
|
+
}
|
|
276
|
+
if query:
|
|
277
|
+
params["query"] = query
|
|
278
|
+
res = self.json("SearchMemory", {}, json.dumps(params))
|
|
279
|
+
return json.loads(res)
|
|
280
|
+
|
|
281
|
+
def add_messages(
|
|
282
|
+
self, collection_name, session_id, messages, metadata, entities=None
|
|
283
|
+
):
|
|
284
|
+
params = {
|
|
285
|
+
"collection_name": collection_name,
|
|
286
|
+
"session_id": session_id,
|
|
287
|
+
"messages": messages,
|
|
288
|
+
"metadata": metadata,
|
|
289
|
+
}
|
|
290
|
+
if entities is not None:
|
|
291
|
+
params["entities"] = entities
|
|
292
|
+
res = self.json("AddMessages", {}, json.dumps(params))
|
|
293
|
+
return json.loads(res)
|
|
@@ -12,10 +12,11 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from typing import Any, Callable, Literal, Union
|
|
16
18
|
|
|
17
19
|
from pydantic import BaseModel, Field
|
|
18
|
-
from typing_extensions import Union
|
|
19
20
|
|
|
20
21
|
from veadk.knowledgebase.backends.base_backend import BaseKnowledgebaseBackend
|
|
21
22
|
from veadk.knowledgebase.entry import KnowledgebaseEntry
|
|
@@ -54,11 +55,11 @@ def _get_backend_cls(backend: str) -> type[BaseKnowledgebaseBackend]:
|
|
|
54
55
|
raise ValueError(f"Unsupported knowledgebase backend: {backend}")
|
|
55
56
|
|
|
56
57
|
|
|
57
|
-
|
|
58
|
-
|
|
58
|
+
class KnowledgeBase(BaseModel):
|
|
59
|
+
name: str = "user_knowledgebase"
|
|
59
60
|
|
|
61
|
+
description: str = "This knowledgebase stores some user-related information."
|
|
60
62
|
|
|
61
|
-
class KnowledgeBase(BaseModel):
|
|
62
63
|
backend: Union[
|
|
63
64
|
Literal["local", "opensearch", "viking", "redis"], BaseKnowledgebaseBackend
|
|
64
65
|
] = "local"
|
|
@@ -73,9 +74,7 @@ class KnowledgeBase(BaseModel):
|
|
|
73
74
|
"""Configuration for the backend"""
|
|
74
75
|
|
|
75
76
|
top_k: int = 10
|
|
76
|
-
"""Number of top similar documents to retrieve during search
|
|
77
|
-
|
|
78
|
-
Default is 10."""
|
|
77
|
+
"""Number of top similar documents to retrieve during search"""
|
|
79
78
|
|
|
80
79
|
app_name: str = ""
|
|
81
80
|
|
|
@@ -85,38 +84,27 @@ class KnowledgeBase(BaseModel):
|
|
|
85
84
|
def model_post_init(self, __context: Any) -> None:
|
|
86
85
|
if isinstance(self.backend, BaseKnowledgebaseBackend):
|
|
87
86
|
self._backend = self.backend
|
|
87
|
+
self.index = self._backend.index
|
|
88
88
|
logger.info(
|
|
89
89
|
f"Initialized knowledgebase with provided backend instance {self._backend.__class__.__name__}"
|
|
90
90
|
)
|
|
91
91
|
return
|
|
92
92
|
|
|
93
|
-
#
|
|
94
|
-
if
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
)
|
|
98
|
-
|
|
99
|
-
# priority use index
|
|
100
|
-
if self.app_name and self.index:
|
|
101
|
-
logger.warning(
|
|
102
|
-
"`app_name` and `index` are both provided, using `index` as the knowledgebase index name."
|
|
103
|
-
)
|
|
93
|
+
# Once user define backend config, use it directly
|
|
94
|
+
if self.backend_config:
|
|
95
|
+
self._backend = _get_backend_cls(self.backend)(**self.backend_config)
|
|
96
|
+
return
|
|
104
97
|
|
|
105
|
-
|
|
106
|
-
if
|
|
107
|
-
|
|
108
|
-
logger.info(
|
|
109
|
-
f"Knowledgebase index is set to {self.index} (generated by the app_name: {self.app_name})."
|
|
110
|
-
)
|
|
98
|
+
self.index = self.index or self.app_name
|
|
99
|
+
if not self.index:
|
|
100
|
+
raise ValueError("Either `index` or `app_name` must be provided.")
|
|
111
101
|
|
|
112
102
|
logger.info(
|
|
113
|
-
f"Initializing knowledgebase: backend={self.backend} top_k={self.top_k}"
|
|
114
|
-
)
|
|
115
|
-
self._backend = _get_backend_cls(self.backend)(
|
|
116
|
-
index=self.index, **self.backend_config if self.backend_config else {}
|
|
103
|
+
f"Initializing knowledgebase: backend={self.backend} index={self.index} top_k={self.top_k}"
|
|
117
104
|
)
|
|
105
|
+
self._backend = _get_backend_cls(self.backend)(index=self.index)
|
|
118
106
|
logger.info(
|
|
119
|
-
f"Initialized knowledgebase with backend {self.
|
|
107
|
+
f"Initialized knowledgebase with backend {self.backend.__class__.__name__}"
|
|
120
108
|
)
|
|
121
109
|
|
|
122
110
|
def add_from_directory(self, directory: str, **kwargs) -> bool:
|
|
@@ -133,8 +121,7 @@ class KnowledgeBase(BaseModel):
|
|
|
133
121
|
|
|
134
122
|
def search(self, query: str, top_k: int = 0, **kwargs) -> list[KnowledgebaseEntry]:
|
|
135
123
|
"""Search knowledge from knowledgebase"""
|
|
136
|
-
if top_k
|
|
137
|
-
top_k = self.top_k
|
|
124
|
+
top_k = top_k if top_k != 0 else self.top_k
|
|
138
125
|
|
|
139
126
|
_entries = self._backend.search(query=query, top_k=top_k, **kwargs)
|
|
140
127
|
|
veadk/memory/__init__.py
CHANGED
|
@@ -25,7 +25,7 @@ def __getattr__(name):
|
|
|
25
25
|
from veadk.memory.short_term_memory import ShortTermMemory
|
|
26
26
|
|
|
27
27
|
return ShortTermMemory
|
|
28
|
-
if name == "
|
|
28
|
+
if name == "LongTermMemory":
|
|
29
29
|
from veadk.memory.long_term_memory import LongTermMemory
|
|
30
30
|
|
|
31
31
|
return LongTermMemory
|
veadk/memory/long_term_memory.py
CHANGED
|
@@ -72,10 +72,6 @@ def _get_backend_cls(backend: str) -> type[BaseLongTermMemoryBackend]:
|
|
|
72
72
|
raise ValueError(f"Unsupported long term memory backend: {backend}")
|
|
73
73
|
|
|
74
74
|
|
|
75
|
-
def build_long_term_memory_index(app_name: str, user_id: str):
|
|
76
|
-
return f"{app_name}_{user_id}"
|
|
77
|
-
|
|
78
|
-
|
|
79
75
|
class LongTermMemory(BaseMemoryService, BaseModel):
|
|
80
76
|
backend: Union[
|
|
81
77
|
Literal["local", "opensearch", "redis", "viking", "viking_mem", "mem0"],
|
|
@@ -89,46 +85,48 @@ class LongTermMemory(BaseMemoryService, BaseModel):
|
|
|
89
85
|
top_k: int = 5
|
|
90
86
|
"""Number of top similar documents to retrieve during search."""
|
|
91
87
|
|
|
88
|
+
index: str = ""
|
|
89
|
+
|
|
92
90
|
app_name: str = ""
|
|
93
91
|
|
|
94
92
|
user_id: str = ""
|
|
93
|
+
"""Deprecated attribute"""
|
|
95
94
|
|
|
96
95
|
def model_post_init(self, __context: Any) -> None:
|
|
97
|
-
if self.backend == "viking_mem":
|
|
98
|
-
logger.warning(
|
|
99
|
-
"The `viking_mem` backend is deprecated, please use `viking` instead."
|
|
100
|
-
)
|
|
101
|
-
self.backend = "viking"
|
|
102
|
-
|
|
103
|
-
self._backend = None
|
|
104
|
-
|
|
105
96
|
# Once user define a backend instance, use it directly
|
|
106
97
|
if isinstance(self.backend, BaseLongTermMemoryBackend):
|
|
107
98
|
self._backend = self.backend
|
|
99
|
+
self.index = self._backend.index
|
|
108
100
|
logger.info(
|
|
109
|
-
f"Initialized long term memory with provided backend instance {self._backend.__class__.__name__}"
|
|
101
|
+
f"Initialized long term memory with provided backend instance {self._backend.__class__.__name__}, index={self.index}"
|
|
110
102
|
)
|
|
111
103
|
return
|
|
112
104
|
|
|
105
|
+
# Once user define backend config, use it directly
|
|
113
106
|
if self.backend_config:
|
|
114
|
-
logger.warning(
|
|
115
|
-
f"Initialized long term memory backend {self.backend} with config. We will ignore `app_name` and `user_id` if provided."
|
|
116
|
-
)
|
|
117
107
|
self._backend = _get_backend_cls(self.backend)(**self.backend_config)
|
|
118
108
|
return
|
|
119
109
|
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
self._backend = _get_backend_cls(self.backend)(
|
|
126
|
-
index=self._index, **self.backend_config if self.backend_config else {}
|
|
110
|
+
# Check index
|
|
111
|
+
self.index = self.index or self.app_name
|
|
112
|
+
if not self.index:
|
|
113
|
+
logger.warning(
|
|
114
|
+
"Attribute `index` or `app_name` not provided, use `default_app` instead."
|
|
127
115
|
)
|
|
128
|
-
|
|
116
|
+
self.index = "default_app"
|
|
117
|
+
|
|
118
|
+
# Forward compliance
|
|
119
|
+
if self.backend == "viking_mem":
|
|
129
120
|
logger.warning(
|
|
130
|
-
"
|
|
121
|
+
"The `viking_mem` backend is deprecated, change to `viking` instead."
|
|
131
122
|
)
|
|
123
|
+
self.backend = "viking"
|
|
124
|
+
|
|
125
|
+
self._backend = _get_backend_cls(self.backend)(index=self.index)
|
|
126
|
+
|
|
127
|
+
logger.info(
|
|
128
|
+
f"Initialized long term memory with provided backend instance {self._backend.__class__.__name__}, index={self.index}"
|
|
129
|
+
)
|
|
132
130
|
|
|
133
131
|
def _filter_and_convert_events(self, events: list[Event]) -> list[str]:
|
|
134
132
|
final_events = []
|
|
@@ -156,58 +154,32 @@ class LongTermMemory(BaseMemoryService, BaseModel):
|
|
|
156
154
|
self,
|
|
157
155
|
session: Session,
|
|
158
156
|
):
|
|
159
|
-
app_name = session.app_name
|
|
160
157
|
user_id = session.user_id
|
|
161
|
-
|
|
162
|
-
if not self._backend and isinstance(self.backend, str):
|
|
163
|
-
self._index = build_long_term_memory_index(app_name, user_id)
|
|
164
|
-
self._backend = _get_backend_cls(self.backend)(
|
|
165
|
-
index=self._index, **self.backend_config if self.backend_config else {}
|
|
166
|
-
)
|
|
167
|
-
logger.info(
|
|
168
|
-
f"Initialize long term memory backend now, index is {self._index}"
|
|
169
|
-
)
|
|
170
|
-
|
|
171
|
-
if not self._index and self._index != build_long_term_memory_index(
|
|
172
|
-
app_name, user_id
|
|
173
|
-
):
|
|
174
|
-
logger.warning(
|
|
175
|
-
f"The `app_name` or `user_id` is different from the initialized one, skip add session to memory. Initialized index: {self._index}, current built index: {build_long_term_memory_index(app_name, user_id)}"
|
|
176
|
-
)
|
|
177
|
-
return
|
|
178
158
|
event_strings = self._filter_and_convert_events(session.events)
|
|
179
159
|
|
|
180
160
|
logger.info(
|
|
181
|
-
f"Adding {len(event_strings)} events to long term memory: index={self.
|
|
161
|
+
f"Adding {len(event_strings)} events to long term memory: index={self.index}"
|
|
162
|
+
)
|
|
163
|
+
self._backend.save_memory(user_id=user_id, event_strings=event_strings)
|
|
164
|
+
logger.info(
|
|
165
|
+
f"Added {len(event_strings)} events to long term memory: index={self.index}, user_id={user_id}"
|
|
182
166
|
)
|
|
183
167
|
|
|
184
|
-
|
|
185
|
-
|
|
168
|
+
@override
|
|
169
|
+
async def search_memory(
|
|
170
|
+
self, *, app_name: str, user_id: str, query: str
|
|
171
|
+
) -> SearchMemoryResponse:
|
|
172
|
+
logger.info(f"Search memory with query={query}")
|
|
186
173
|
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
logger.error(
|
|
192
|
-
"Long term memory backend initialize failed, cannot add session to memory."
|
|
174
|
+
memory_chunks = []
|
|
175
|
+
try:
|
|
176
|
+
memory_chunks = self._backend.search_memory(
|
|
177
|
+
query=query, top_k=self.top_k, user_id=user_id
|
|
193
178
|
)
|
|
194
|
-
|
|
195
|
-
@override
|
|
196
|
-
async def search_memory(self, *, app_name: str, user_id: str, query: str):
|
|
197
|
-
# prevent model invoke `load_memory` before add session to this memory
|
|
198
|
-
if not self._backend:
|
|
179
|
+
except Exception as e:
|
|
199
180
|
logger.error(
|
|
200
|
-
"
|
|
181
|
+
f"Exception orrcus during memory search: {e}. Return empty memory chunks"
|
|
201
182
|
)
|
|
202
|
-
return SearchMemoryResponse(memories=[])
|
|
203
|
-
|
|
204
|
-
logger.info(
|
|
205
|
-
f"Searching long term memory: query={query} index={self._index} top_k={self.top_k}"
|
|
206
|
-
)
|
|
207
|
-
|
|
208
|
-
memory_chunks = self._backend.search_memory(
|
|
209
|
-
query=query, top_k=self.top_k, user_id=user_id
|
|
210
|
-
)
|
|
211
183
|
|
|
212
184
|
memory_events = []
|
|
213
185
|
for memory in memory_chunks:
|
|
@@ -235,6 +207,6 @@ class LongTermMemory(BaseMemoryService, BaseModel):
|
|
|
235
207
|
)
|
|
236
208
|
|
|
237
209
|
logger.info(
|
|
238
|
-
f"Return {len(memory_events)} memory events for query: {query} index={self.
|
|
210
|
+
f"Return {len(memory_events)} memory events for query: {query} index={self.index} user_id={user_id}"
|
|
239
211
|
)
|
|
240
212
|
return SearchMemoryResponse(memories=memory_events)
|
|
@@ -25,9 +25,11 @@ class BaseLongTermMemoryBackend(ABC, BaseModel):
|
|
|
25
25
|
"""Check the index name is valid or not"""
|
|
26
26
|
|
|
27
27
|
@abstractmethod
|
|
28
|
-
def save_memory(self, event_strings: list[str], **kwargs) -> bool:
|
|
28
|
+
def save_memory(self, user_id: str, event_strings: list[str], **kwargs) -> bool:
|
|
29
29
|
"""Save memory to long term memory backend"""
|
|
30
30
|
|
|
31
31
|
@abstractmethod
|
|
32
|
-
def search_memory(
|
|
32
|
+
def search_memory(
|
|
33
|
+
self, user_id: str, query: str, top_k: int, **kwargs
|
|
34
|
+
) -> list[str]:
|
|
33
35
|
"""Retrieve memory from long term memory backend"""
|
|
@@ -29,10 +29,6 @@ class InMemoryLTMBackend(BaseLongTermMemoryBackend):
|
|
|
29
29
|
embedding_config: EmbeddingModelConfig = Field(default_factory=EmbeddingModelConfig)
|
|
30
30
|
"""Embedding model configs"""
|
|
31
31
|
|
|
32
|
-
def precheck_index_naming(self):
|
|
33
|
-
# no checking
|
|
34
|
-
pass
|
|
35
|
-
|
|
36
32
|
def model_post_init(self, __context: Any) -> None:
|
|
37
33
|
self._embed_model = OpenAILikeEmbedding(
|
|
38
34
|
model_name=self.embedding_config.name,
|
|
@@ -41,8 +37,12 @@ class InMemoryLTMBackend(BaseLongTermMemoryBackend):
|
|
|
41
37
|
)
|
|
42
38
|
self._vector_index = VectorStoreIndex([], embed_model=self._embed_model)
|
|
43
39
|
|
|
40
|
+
def precheck_index_naming(self):
|
|
41
|
+
# no checking
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
44
|
@override
|
|
45
|
-
def save_memory(self, event_strings: list[str], **kwargs) -> bool:
|
|
45
|
+
def save_memory(self, user_id: str, event_strings: list[str], **kwargs) -> bool:
|
|
46
46
|
for event_string in event_strings:
|
|
47
47
|
document = Document(text=event_string)
|
|
48
48
|
nodes = self._split_documents([document])
|
|
@@ -50,7 +50,9 @@ class InMemoryLTMBackend(BaseLongTermMemoryBackend):
|
|
|
50
50
|
return True
|
|
51
51
|
|
|
52
52
|
@override
|
|
53
|
-
def search_memory(
|
|
53
|
+
def search_memory(
|
|
54
|
+
self, user_id: str, query: str, top_k: int, **kwargs
|
|
55
|
+
) -> list[str]:
|
|
54
56
|
_retriever = self._vector_index.as_retriever(similarity_top_k=top_k)
|
|
55
57
|
retrieved_nodes = _retriever.retrieve(query)
|
|
56
58
|
return [node.text for node in retrieved_nodes]
|