MemoryOS 0.2.2__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MemoryOS might be problematic. Click here for more details.
- {memoryos-0.2.2.dist-info → memoryos-1.0.1.dist-info}/METADATA +7 -1
- {memoryos-0.2.2.dist-info → memoryos-1.0.1.dist-info}/RECORD +81 -66
- memos/__init__.py +1 -1
- memos/api/config.py +31 -8
- memos/api/context/context.py +1 -1
- memos/api/context/context_thread.py +96 -0
- memos/api/middleware/request_context.py +94 -0
- memos/api/product_api.py +5 -1
- memos/api/product_models.py +16 -0
- memos/api/routers/product_router.py +39 -3
- memos/api/start_api.py +3 -0
- memos/configs/internet_retriever.py +13 -0
- memos/configs/mem_scheduler.py +38 -16
- memos/configs/memory.py +13 -0
- memos/configs/reranker.py +18 -0
- memos/graph_dbs/base.py +33 -4
- memos/graph_dbs/nebular.py +631 -236
- memos/graph_dbs/neo4j.py +18 -7
- memos/graph_dbs/neo4j_community.py +6 -3
- memos/llms/vllm.py +2 -0
- memos/log.py +125 -8
- memos/mem_os/core.py +49 -11
- memos/mem_os/main.py +1 -1
- memos/mem_os/product.py +392 -215
- memos/mem_os/utils/default_config.py +1 -1
- memos/mem_os/utils/format_utils.py +11 -47
- memos/mem_os/utils/reference_utils.py +153 -0
- memos/mem_reader/simple_struct.py +112 -43
- memos/mem_scheduler/base_scheduler.py +58 -55
- memos/mem_scheduler/{modules → general_modules}/base.py +1 -2
- memos/mem_scheduler/{modules → general_modules}/dispatcher.py +54 -15
- memos/mem_scheduler/{modules → general_modules}/rabbitmq_service.py +4 -4
- memos/mem_scheduler/{modules → general_modules}/redis_service.py +1 -1
- memos/mem_scheduler/{modules → general_modules}/retriever.py +19 -5
- memos/mem_scheduler/{modules → general_modules}/scheduler_logger.py +10 -4
- memos/mem_scheduler/general_scheduler.py +110 -67
- memos/mem_scheduler/monitors/__init__.py +0 -0
- memos/mem_scheduler/monitors/dispatcher_monitor.py +305 -0
- memos/mem_scheduler/{modules/monitor.py → monitors/general_monitor.py} +57 -19
- memos/mem_scheduler/mos_for_test_scheduler.py +7 -1
- memos/mem_scheduler/schemas/general_schemas.py +3 -2
- memos/mem_scheduler/schemas/message_schemas.py +2 -1
- memos/mem_scheduler/schemas/monitor_schemas.py +10 -2
- memos/mem_scheduler/utils/misc_utils.py +43 -2
- memos/mem_user/mysql_user_manager.py +4 -2
- memos/memories/activation/item.py +1 -1
- memos/memories/activation/kv.py +20 -8
- memos/memories/textual/base.py +1 -1
- memos/memories/textual/general.py +1 -1
- memos/memories/textual/item.py +1 -1
- memos/memories/textual/tree.py +31 -1
- memos/memories/textual/tree_text_memory/organize/{conflict.py → handler.py} +30 -48
- memos/memories/textual/tree_text_memory/organize/manager.py +8 -96
- memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +2 -0
- memos/memories/textual/tree_text_memory/organize/reorganizer.py +102 -140
- memos/memories/textual/tree_text_memory/retrieve/bochasearch.py +231 -0
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py +9 -0
- memos/memories/textual/tree_text_memory/retrieve/recall.py +67 -10
- memos/memories/textual/tree_text_memory/retrieve/reranker.py +1 -1
- memos/memories/textual/tree_text_memory/retrieve/searcher.py +246 -134
- memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +7 -2
- memos/memories/textual/tree_text_memory/retrieve/utils.py +7 -5
- memos/memos_tools/lockfree_dict.py +120 -0
- memos/memos_tools/notification_utils.py +46 -0
- memos/memos_tools/thread_safe_dict.py +288 -0
- memos/reranker/__init__.py +4 -0
- memos/reranker/base.py +24 -0
- memos/reranker/cosine_local.py +95 -0
- memos/reranker/factory.py +43 -0
- memos/reranker/http_bge.py +99 -0
- memos/reranker/noop.py +16 -0
- memos/templates/mem_reader_prompts.py +290 -39
- memos/templates/mem_scheduler_prompts.py +23 -10
- memos/templates/mos_prompts.py +133 -31
- memos/templates/tree_reorganize_prompts.py +24 -17
- memos/utils.py +19 -0
- memos/memories/textual/tree_text_memory/organize/redundancy.py +0 -193
- {memoryos-0.2.2.dist-info → memoryos-1.0.1.dist-info}/LICENSE +0 -0
- {memoryos-0.2.2.dist-info → memoryos-1.0.1.dist-info}/WHEEL +0 -0
- {memoryos-0.2.2.dist-info → memoryos-1.0.1.dist-info}/entry_points.txt +0 -0
- /memos/mem_scheduler/{modules → general_modules}/__init__.py +0 -0
- /memos/mem_scheduler/{modules → general_modules}/misc.py +0 -0
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
"""BochaAI Search API retriever for tree text memory."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
|
|
5
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
|
|
8
|
+
import requests
|
|
9
|
+
|
|
10
|
+
from memos.embedders.factory import OllamaEmbedder
|
|
11
|
+
from memos.log import get_logger
|
|
12
|
+
from memos.mem_reader.base import BaseMemReader
|
|
13
|
+
from memos.memories.textual.item import TextualMemoryItem
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
logger = get_logger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class BochaAISearchAPI:
|
|
20
|
+
"""BochaAI Search API Client"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, api_key: str, max_results: int = 20):
|
|
23
|
+
"""
|
|
24
|
+
Initialize BochaAI Search API client.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
api_key: BochaAI API key
|
|
28
|
+
max_results: Maximum number of search results to retrieve
|
|
29
|
+
"""
|
|
30
|
+
self.api_key = api_key
|
|
31
|
+
self.max_results = max_results
|
|
32
|
+
|
|
33
|
+
self.web_url = "https://api.bochaai.com/v1/web-search"
|
|
34
|
+
self.ai_url = "https://api.bochaai.com/v1/ai-search"
|
|
35
|
+
|
|
36
|
+
self.headers = {
|
|
37
|
+
"Authorization": f"Bearer {api_key}",
|
|
38
|
+
"Content-Type": "application/json",
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
def search_web(self, query: str, summary: bool = True, freshness="noLimit") -> list[dict]:
|
|
42
|
+
"""
|
|
43
|
+
Perform a Web Search (equivalent to the first curl).
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
query: Search query string
|
|
47
|
+
summary: Whether to include summary in the results
|
|
48
|
+
freshness: Freshness filter (e.g. 'noLimit', 'day', 'week')
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
A list of search result dicts
|
|
52
|
+
"""
|
|
53
|
+
body = {
|
|
54
|
+
"query": query,
|
|
55
|
+
"summary": summary,
|
|
56
|
+
"freshness": freshness,
|
|
57
|
+
"count": self.max_results,
|
|
58
|
+
}
|
|
59
|
+
return self._post(self.web_url, body)
|
|
60
|
+
|
|
61
|
+
def search_ai(
|
|
62
|
+
self, query: str, answer: bool = False, stream: bool = False, freshness="noLimit"
|
|
63
|
+
) -> list[dict]:
|
|
64
|
+
"""
|
|
65
|
+
Perform an AI Search (equivalent to the second curl).
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
query: Search query string
|
|
69
|
+
answer: Whether BochaAI should generate an answer
|
|
70
|
+
stream: Whether to use streaming response
|
|
71
|
+
freshness: Freshness filter (e.g. 'noLimit', 'day', 'week')
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
A list of search result dicts
|
|
75
|
+
"""
|
|
76
|
+
body = {
|
|
77
|
+
"query": query,
|
|
78
|
+
"freshness": freshness,
|
|
79
|
+
"count": self.max_results,
|
|
80
|
+
"answer": answer,
|
|
81
|
+
"stream": stream,
|
|
82
|
+
}
|
|
83
|
+
return self._post(self.ai_url, body)
|
|
84
|
+
|
|
85
|
+
def _post(self, url: str, body: dict) -> list[dict]:
|
|
86
|
+
"""Send POST request and parse BochaAI search results."""
|
|
87
|
+
try:
|
|
88
|
+
resp = requests.post(url, headers=self.headers, json=body)
|
|
89
|
+
resp.raise_for_status()
|
|
90
|
+
raw_data = resp.json()
|
|
91
|
+
|
|
92
|
+
# parse the nested structure correctly
|
|
93
|
+
# ✅ AI Search
|
|
94
|
+
if "messages" in raw_data:
|
|
95
|
+
results = []
|
|
96
|
+
for msg in raw_data["messages"]:
|
|
97
|
+
if msg.get("type") == "source" and msg.get("content_type") == "webpage":
|
|
98
|
+
try:
|
|
99
|
+
content_json = json.loads(msg["content"])
|
|
100
|
+
results.extend(content_json.get("value", []))
|
|
101
|
+
except Exception as e:
|
|
102
|
+
logger.error(f"Failed to parse message content: {e}")
|
|
103
|
+
return results
|
|
104
|
+
|
|
105
|
+
# ✅ Web Search
|
|
106
|
+
return raw_data.get("data", {}).get("webPages", {}).get("value", [])
|
|
107
|
+
|
|
108
|
+
except Exception:
|
|
109
|
+
import traceback
|
|
110
|
+
|
|
111
|
+
logger.error(f"BochaAI search error: {traceback.format_exc()}")
|
|
112
|
+
return []
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class BochaAISearchRetriever:
|
|
116
|
+
"""BochaAI retriever that converts search results into TextualMemoryItem objects"""
|
|
117
|
+
|
|
118
|
+
def __init__(
|
|
119
|
+
self,
|
|
120
|
+
access_key: str,
|
|
121
|
+
embedder: OllamaEmbedder,
|
|
122
|
+
reader: BaseMemReader,
|
|
123
|
+
max_results: int = 20,
|
|
124
|
+
):
|
|
125
|
+
"""
|
|
126
|
+
Initialize BochaAI Search retriever.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
access_key: BochaAI API key
|
|
130
|
+
embedder: Embedder instance for generating embeddings
|
|
131
|
+
reader: MemReader instance for processing internet content
|
|
132
|
+
max_results: Maximum number of search results to retrieve
|
|
133
|
+
"""
|
|
134
|
+
self.bocha_api = BochaAISearchAPI(access_key, max_results=max_results)
|
|
135
|
+
self.embedder = embedder
|
|
136
|
+
self.reader = reader
|
|
137
|
+
|
|
138
|
+
def retrieve_from_internet(
|
|
139
|
+
self, query: str, top_k: int = 10, parsed_goal=None, info=None
|
|
140
|
+
) -> list[TextualMemoryItem]:
|
|
141
|
+
"""
|
|
142
|
+
Default internet retrieval (Web Search).
|
|
143
|
+
This keeps consistent API with Xinyu and Google retrievers.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
query: Search query
|
|
147
|
+
top_k: Number of results to retrieve
|
|
148
|
+
parsed_goal: Parsed task goal (optional)
|
|
149
|
+
info (dict): Metadata for memory consumption tracking
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
List of TextualMemoryItem
|
|
153
|
+
"""
|
|
154
|
+
search_results = self.bocha_api.search_ai(query) # ✅ default to
|
|
155
|
+
# web-search
|
|
156
|
+
return self._convert_to_mem_items(search_results, query, parsed_goal, info)
|
|
157
|
+
|
|
158
|
+
def retrieve_from_web(
|
|
159
|
+
self, query: str, top_k: int = 10, parsed_goal=None, info=None
|
|
160
|
+
) -> list[TextualMemoryItem]:
|
|
161
|
+
"""Explicitly retrieve using Bocha Web Search."""
|
|
162
|
+
search_results = self.bocha_api.search_web(query)
|
|
163
|
+
return self._convert_to_mem_items(search_results, query, parsed_goal, info)
|
|
164
|
+
|
|
165
|
+
def retrieve_from_ai(
|
|
166
|
+
self, query: str, top_k: int = 10, parsed_goal=None, info=None
|
|
167
|
+
) -> list[TextualMemoryItem]:
|
|
168
|
+
"""Explicitly retrieve using Bocha AI Search."""
|
|
169
|
+
search_results = self.bocha_api.search_ai(query)
|
|
170
|
+
return self._convert_to_mem_items(search_results, query, parsed_goal, info)
|
|
171
|
+
|
|
172
|
+
def _convert_to_mem_items(
|
|
173
|
+
self, search_results: list[dict], query: str, parsed_goal=None, info=None
|
|
174
|
+
):
|
|
175
|
+
"""Convert API search results into TextualMemoryItem objects."""
|
|
176
|
+
memory_items = []
|
|
177
|
+
if not info:
|
|
178
|
+
info = {"user_id": "", "session_id": ""}
|
|
179
|
+
|
|
180
|
+
with ThreadPoolExecutor(max_workers=8) as executor:
|
|
181
|
+
futures = [
|
|
182
|
+
executor.submit(self._process_result, r, query, parsed_goal, info)
|
|
183
|
+
for r in search_results
|
|
184
|
+
]
|
|
185
|
+
for future in as_completed(futures):
|
|
186
|
+
try:
|
|
187
|
+
memory_items.extend(future.result())
|
|
188
|
+
except Exception as e:
|
|
189
|
+
logger.error(f"Error processing BochaAI search result: {e}")
|
|
190
|
+
|
|
191
|
+
# Deduplicate items by memory text
|
|
192
|
+
unique_memory_items = {item.memory: item for item in memory_items}
|
|
193
|
+
return list(unique_memory_items.values())
|
|
194
|
+
|
|
195
|
+
def _process_result(
|
|
196
|
+
self, result: dict, query: str, parsed_goal: str, info: None
|
|
197
|
+
) -> list[TextualMemoryItem]:
|
|
198
|
+
"""Process one Bocha search result into TextualMemoryItem."""
|
|
199
|
+
title = result.get("name", "")
|
|
200
|
+
content = result.get("summary", "") or result.get("snippet", "")
|
|
201
|
+
summary = result.get("snippet", "")
|
|
202
|
+
url = result.get("url", "")
|
|
203
|
+
publish_time = result.get("datePublished", "")
|
|
204
|
+
|
|
205
|
+
if publish_time:
|
|
206
|
+
try:
|
|
207
|
+
publish_time = datetime.fromisoformat(publish_time.replace("Z", "+00:00")).strftime(
|
|
208
|
+
"%Y-%m-%d"
|
|
209
|
+
)
|
|
210
|
+
except Exception:
|
|
211
|
+
publish_time = datetime.now().strftime("%Y-%m-%d")
|
|
212
|
+
else:
|
|
213
|
+
publish_time = datetime.now().strftime("%Y-%m-%d")
|
|
214
|
+
|
|
215
|
+
# Use reader to split and process the content into chunks
|
|
216
|
+
read_items = self.reader.get_memory([content], type="doc", info=info)
|
|
217
|
+
|
|
218
|
+
memory_items = []
|
|
219
|
+
for read_item_i in read_items[0]:
|
|
220
|
+
read_item_i.memory = (
|
|
221
|
+
f"[Outer internet view] Title: {title}\nNewsTime:"
|
|
222
|
+
f" {publish_time}\nSummary:"
|
|
223
|
+
f" {summary}\n"
|
|
224
|
+
f"Content: {read_item_i.memory}"
|
|
225
|
+
)
|
|
226
|
+
read_item_i.metadata.source = "web"
|
|
227
|
+
read_item_i.metadata.memory_type = "OuterMemory"
|
|
228
|
+
read_item_i.metadata.sources = [url] if url else []
|
|
229
|
+
read_item_i.metadata.visibility = "public"
|
|
230
|
+
memory_items.append(read_item_i)
|
|
231
|
+
return memory_items
|
|
@@ -5,6 +5,7 @@ from typing import Any, ClassVar
|
|
|
5
5
|
from memos.configs.internet_retriever import InternetRetrieverConfigFactory
|
|
6
6
|
from memos.embedders.base import BaseEmbedder
|
|
7
7
|
from memos.mem_reader.factory import MemReaderFactory
|
|
8
|
+
from memos.memories.textual.tree_text_memory.retrieve.bochasearch import BochaAISearchRetriever
|
|
8
9
|
from memos.memories.textual.tree_text_memory.retrieve.internet_retriever import (
|
|
9
10
|
InternetGoogleRetriever,
|
|
10
11
|
)
|
|
@@ -18,6 +19,7 @@ class InternetRetrieverFactory:
|
|
|
18
19
|
"google": InternetGoogleRetriever,
|
|
19
20
|
"bing": InternetGoogleRetriever, # TODO: Implement BingRetriever
|
|
20
21
|
"xinyu": XinyuSearchRetriever,
|
|
22
|
+
"bocha": BochaAISearchRetriever,
|
|
21
23
|
}
|
|
22
24
|
|
|
23
25
|
@classmethod
|
|
@@ -70,6 +72,13 @@ class InternetRetrieverFactory:
|
|
|
70
72
|
reader=MemReaderFactory.from_config(config.reader),
|
|
71
73
|
max_results=config.max_results,
|
|
72
74
|
)
|
|
75
|
+
elif backend == "bocha":
|
|
76
|
+
return retriever_class(
|
|
77
|
+
access_key=config.api_key, # Use api_key as access_key for xinyu
|
|
78
|
+
embedder=embedder,
|
|
79
|
+
reader=MemReaderFactory.from_config(config.reader),
|
|
80
|
+
max_results=config.max_results,
|
|
81
|
+
)
|
|
73
82
|
else:
|
|
74
83
|
raise ValueError(f"Unsupported backend: {backend}")
|
|
75
84
|
|
|
@@ -44,16 +44,23 @@ class GraphMemoryRetriever:
|
|
|
44
44
|
|
|
45
45
|
if memory_scope == "WorkingMemory":
|
|
46
46
|
# For working memory, retrieve all entries (no filtering)
|
|
47
|
-
working_memories = self.graph_store.get_all_memory_items(
|
|
47
|
+
working_memories = self.graph_store.get_all_memory_items(
|
|
48
|
+
scope="WorkingMemory", include_embedding=True
|
|
49
|
+
)
|
|
48
50
|
return [TextualMemoryItem.from_dict(record) for record in working_memories]
|
|
49
51
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
|
53
|
+
# Structured graph-based retrieval
|
|
54
|
+
future_graph = executor.submit(self._graph_recall, parsed_goal, memory_scope)
|
|
55
|
+
# Vector similarity search
|
|
56
|
+
future_vector = executor.submit(
|
|
57
|
+
self._vector_recall, query_embedding, memory_scope, top_k
|
|
58
|
+
)
|
|
52
59
|
|
|
53
|
-
|
|
54
|
-
|
|
60
|
+
graph_results = future_graph.result()
|
|
61
|
+
vector_results = future_vector.result()
|
|
55
62
|
|
|
56
|
-
#
|
|
63
|
+
# Merge and deduplicate by ID
|
|
57
64
|
combined = {item.id: item for item in graph_results + vector_results}
|
|
58
65
|
|
|
59
66
|
graph_ids = {item.id for item in graph_results}
|
|
@@ -67,6 +74,51 @@ class GraphMemoryRetriever:
|
|
|
67
74
|
|
|
68
75
|
return list(combined.values())
|
|
69
76
|
|
|
77
|
+
def retrieve_from_cube(
|
|
78
|
+
self,
|
|
79
|
+
top_k: int,
|
|
80
|
+
memory_scope: str,
|
|
81
|
+
query_embedding: list[list[float]] | None = None,
|
|
82
|
+
cube_name: str = "memos_cube01",
|
|
83
|
+
) -> list[TextualMemoryItem]:
|
|
84
|
+
"""
|
|
85
|
+
Perform hybrid memory retrieval:
|
|
86
|
+
- Run graph-based lookup from dispatch plan.
|
|
87
|
+
- Run vector similarity search from embedded query.
|
|
88
|
+
- Merge and return combined result set.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
top_k (int): Number of candidates to return.
|
|
92
|
+
memory_scope (str): One of ['working', 'long_term', 'user'].
|
|
93
|
+
query_embedding(list of embedding): list of embedding of query
|
|
94
|
+
cube_name: specify cube_name
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
list: Combined memory items.
|
|
98
|
+
"""
|
|
99
|
+
if memory_scope not in ["WorkingMemory", "LongTermMemory", "UserMemory"]:
|
|
100
|
+
raise ValueError(f"Unsupported memory scope: {memory_scope}")
|
|
101
|
+
|
|
102
|
+
graph_results = self._vector_recall(
|
|
103
|
+
query_embedding, memory_scope, top_k, cube_name=cube_name
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
for result_i in graph_results:
|
|
107
|
+
result_i.metadata.memory_type = "OuterMemory"
|
|
108
|
+
# Merge and deduplicate by ID
|
|
109
|
+
combined = {item.id: item for item in graph_results}
|
|
110
|
+
|
|
111
|
+
graph_ids = {item.id for item in graph_results}
|
|
112
|
+
combined_ids = set(combined.keys())
|
|
113
|
+
lost_ids = graph_ids - combined_ids
|
|
114
|
+
|
|
115
|
+
if lost_ids:
|
|
116
|
+
print(
|
|
117
|
+
f"[DEBUG] The following nodes were in graph_results but missing in combined: {lost_ids}"
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
return list(combined.values())
|
|
121
|
+
|
|
70
122
|
def _graph_recall(
|
|
71
123
|
self, parsed_goal: ParsedTaskGoal, memory_scope: str
|
|
72
124
|
) -> list[TextualMemoryItem]:
|
|
@@ -101,7 +153,7 @@ class GraphMemoryRetriever:
|
|
|
101
153
|
return []
|
|
102
154
|
|
|
103
155
|
# Load nodes and post-filter
|
|
104
|
-
node_dicts = self.graph_store.get_nodes(list(candidate_ids))
|
|
156
|
+
node_dicts = self.graph_store.get_nodes(list(candidate_ids), include_embedding=True)
|
|
105
157
|
|
|
106
158
|
final_nodes = []
|
|
107
159
|
for node in node_dicts:
|
|
@@ -127,7 +179,8 @@ class GraphMemoryRetriever:
|
|
|
127
179
|
query_embedding: list[list[float]],
|
|
128
180
|
memory_scope: str,
|
|
129
181
|
top_k: int = 20,
|
|
130
|
-
max_num: int =
|
|
182
|
+
max_num: int = 3,
|
|
183
|
+
cube_name: str | None = None,
|
|
131
184
|
) -> list[TextualMemoryItem]:
|
|
132
185
|
"""
|
|
133
186
|
# TODO: tackle with post-filter and pre-filter(5.18+) better.
|
|
@@ -137,7 +190,9 @@ class GraphMemoryRetriever:
|
|
|
137
190
|
|
|
138
191
|
def search_single(vec):
|
|
139
192
|
return (
|
|
140
|
-
self.graph_store.search_by_embedding(
|
|
193
|
+
self.graph_store.search_by_embedding(
|
|
194
|
+
vector=vec, top_k=top_k, scope=memory_scope, cube_name=cube_name
|
|
195
|
+
)
|
|
141
196
|
or []
|
|
142
197
|
)
|
|
143
198
|
|
|
@@ -152,6 +207,8 @@ class GraphMemoryRetriever:
|
|
|
152
207
|
|
|
153
208
|
# Step 3: Extract matched IDs and retrieve full nodes
|
|
154
209
|
unique_ids = set({r["id"] for r in all_matches})
|
|
155
|
-
node_dicts = self.graph_store.get_nodes(
|
|
210
|
+
node_dicts = self.graph_store.get_nodes(
|
|
211
|
+
list(unique_ids), include_embedding=True, cube_name=cube_name
|
|
212
|
+
)
|
|
156
213
|
|
|
157
214
|
return [TextualMemoryItem.from_dict(record) for record in node_dicts]
|
|
@@ -78,7 +78,7 @@ class MemoryReranker:
|
|
|
78
78
|
embeddings = [item.metadata.embedding for item in items_with_embeddings]
|
|
79
79
|
|
|
80
80
|
if not embeddings:
|
|
81
|
-
return graph_results[:top_k]
|
|
81
|
+
return [(item, 0.5) for item in graph_results[:top_k]]
|
|
82
82
|
|
|
83
83
|
# Step 2: Compute cosine similarities
|
|
84
84
|
similarity_scores = batch_cosine_similarity(query_embedding, embeddings)
|