MemoryOS 1.0.0__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MemoryOS might be problematic. Click here for more details.
- {memoryos-1.0.0.dist-info → memoryos-1.1.1.dist-info}/METADATA +8 -2
- {memoryos-1.0.0.dist-info → memoryos-1.1.1.dist-info}/RECORD +92 -69
- {memoryos-1.0.0.dist-info → memoryos-1.1.1.dist-info}/WHEEL +1 -1
- memos/__init__.py +1 -1
- memos/api/client.py +109 -0
- memos/api/config.py +35 -8
- memos/api/context/dependencies.py +15 -66
- memos/api/middleware/request_context.py +63 -0
- memos/api/product_api.py +5 -2
- memos/api/product_models.py +107 -16
- memos/api/routers/product_router.py +62 -19
- memos/api/start_api.py +13 -0
- memos/configs/graph_db.py +4 -0
- memos/configs/mem_scheduler.py +38 -3
- memos/configs/memory.py +13 -0
- memos/configs/reranker.py +18 -0
- memos/context/context.py +255 -0
- memos/embedders/factory.py +2 -0
- memos/graph_dbs/base.py +4 -2
- memos/graph_dbs/nebular.py +368 -223
- memos/graph_dbs/neo4j.py +49 -13
- memos/graph_dbs/neo4j_community.py +13 -3
- memos/llms/factory.py +2 -0
- memos/llms/openai.py +74 -2
- memos/llms/vllm.py +2 -0
- memos/log.py +128 -4
- memos/mem_cube/general.py +3 -1
- memos/mem_os/core.py +89 -23
- memos/mem_os/main.py +3 -6
- memos/mem_os/product.py +418 -154
- memos/mem_os/utils/reference_utils.py +20 -0
- memos/mem_reader/factory.py +2 -0
- memos/mem_reader/simple_struct.py +204 -82
- memos/mem_scheduler/analyzer/__init__.py +0 -0
- memos/mem_scheduler/analyzer/mos_for_test_scheduler.py +569 -0
- memos/mem_scheduler/analyzer/scheduler_for_eval.py +280 -0
- memos/mem_scheduler/base_scheduler.py +126 -56
- memos/mem_scheduler/general_modules/dispatcher.py +2 -2
- memos/mem_scheduler/general_modules/misc.py +99 -1
- memos/mem_scheduler/general_modules/scheduler_logger.py +17 -11
- memos/mem_scheduler/general_scheduler.py +40 -88
- memos/mem_scheduler/memory_manage_modules/__init__.py +5 -0
- memos/mem_scheduler/memory_manage_modules/memory_filter.py +308 -0
- memos/mem_scheduler/{general_modules → memory_manage_modules}/retriever.py +34 -7
- memos/mem_scheduler/monitors/dispatcher_monitor.py +9 -8
- memos/mem_scheduler/monitors/general_monitor.py +119 -39
- memos/mem_scheduler/optimized_scheduler.py +124 -0
- memos/mem_scheduler/orm_modules/__init__.py +0 -0
- memos/mem_scheduler/orm_modules/base_model.py +635 -0
- memos/mem_scheduler/orm_modules/monitor_models.py +261 -0
- memos/mem_scheduler/scheduler_factory.py +2 -0
- memos/mem_scheduler/schemas/monitor_schemas.py +96 -29
- memos/mem_scheduler/utils/config_utils.py +100 -0
- memos/mem_scheduler/utils/db_utils.py +33 -0
- memos/mem_scheduler/utils/filter_utils.py +1 -1
- memos/mem_scheduler/webservice_modules/__init__.py +0 -0
- memos/mem_user/mysql_user_manager.py +4 -2
- memos/memories/activation/kv.py +2 -1
- memos/memories/textual/item.py +96 -17
- memos/memories/textual/naive.py +1 -1
- memos/memories/textual/tree.py +57 -3
- memos/memories/textual/tree_text_memory/organize/handler.py +4 -2
- memos/memories/textual/tree_text_memory/organize/manager.py +28 -14
- memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +1 -2
- memos/memories/textual/tree_text_memory/organize/reorganizer.py +75 -23
- memos/memories/textual/tree_text_memory/retrieve/bochasearch.py +10 -6
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever.py +6 -2
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py +2 -0
- memos/memories/textual/tree_text_memory/retrieve/recall.py +119 -21
- memos/memories/textual/tree_text_memory/retrieve/searcher.py +172 -44
- memos/memories/textual/tree_text_memory/retrieve/utils.py +6 -4
- memos/memories/textual/tree_text_memory/retrieve/xinyusearch.py +5 -4
- memos/memos_tools/notification_utils.py +46 -0
- memos/memos_tools/singleton.py +174 -0
- memos/memos_tools/thread_safe_dict.py +22 -0
- memos/memos_tools/thread_safe_dict_segment.py +382 -0
- memos/parsers/factory.py +2 -0
- memos/reranker/__init__.py +4 -0
- memos/reranker/base.py +24 -0
- memos/reranker/concat.py +59 -0
- memos/reranker/cosine_local.py +96 -0
- memos/reranker/factory.py +48 -0
- memos/reranker/http_bge.py +312 -0
- memos/reranker/noop.py +16 -0
- memos/templates/mem_reader_prompts.py +289 -40
- memos/templates/mem_scheduler_prompts.py +242 -0
- memos/templates/mos_prompts.py +133 -60
- memos/types.py +4 -1
- memos/api/context/context.py +0 -147
- memos/mem_scheduler/mos_for_test_scheduler.py +0 -146
- {memoryos-1.0.0.dist-info → memoryos-1.1.1.dist-info}/entry_points.txt +0 -0
- {memoryos-1.0.0.dist-info → memoryos-1.1.1.dist-info/licenses}/LICENSE +0 -0
- /memos/mem_scheduler/{general_modules → webservice_modules}/rabbitmq_service.py +0 -0
- /memos/mem_scheduler/{general_modules → webservice_modules}/redis_service.py +0 -0
|
@@ -4,19 +4,20 @@ import time
|
|
|
4
4
|
import traceback
|
|
5
5
|
|
|
6
6
|
from collections import defaultdict
|
|
7
|
-
from concurrent.futures import
|
|
7
|
+
from concurrent.futures import as_completed
|
|
8
8
|
from queue import PriorityQueue
|
|
9
9
|
from typing import Literal
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
|
|
13
|
+
from memos.context.context import ContextThreadPoolExecutor
|
|
13
14
|
from memos.dependency import require_python_package
|
|
14
15
|
from memos.embedders.factory import OllamaEmbedder
|
|
15
16
|
from memos.graph_dbs.item import GraphDBEdge, GraphDBNode
|
|
16
17
|
from memos.graph_dbs.neo4j import Neo4jGraphDB
|
|
17
18
|
from memos.llms.base import BaseLLM
|
|
18
19
|
from memos.log import get_logger
|
|
19
|
-
from memos.memories.textual.item import TreeNodeTextualMemoryMetadata
|
|
20
|
+
from memos.memories.textual.item import SourceMessage, TreeNodeTextualMemoryMetadata
|
|
20
21
|
from memos.memories.textual.tree_text_memory.organize.handler import NodeHandler
|
|
21
22
|
from memos.memories.textual.tree_text_memory.organize.relation_reason_detector import (
|
|
22
23
|
RelationAndReasoningDetector,
|
|
@@ -27,6 +28,22 @@ from memos.templates.tree_reorganize_prompts import LOCAL_SUBCLUSTER_PROMPT, REO
|
|
|
27
28
|
logger = get_logger(__name__)
|
|
28
29
|
|
|
29
30
|
|
|
31
|
+
def build_summary_parent_node(cluster_nodes):
|
|
32
|
+
normalized_sources = []
|
|
33
|
+
for n in cluster_nodes:
|
|
34
|
+
sm = SourceMessage(
|
|
35
|
+
type="chat",
|
|
36
|
+
role=None,
|
|
37
|
+
chat_time=None,
|
|
38
|
+
message_id=None,
|
|
39
|
+
content=n.memory,
|
|
40
|
+
# extra
|
|
41
|
+
node_id=n.id,
|
|
42
|
+
)
|
|
43
|
+
normalized_sources.append(sm)
|
|
44
|
+
return normalized_sources
|
|
45
|
+
|
|
46
|
+
|
|
30
47
|
class QueueMessage:
|
|
31
48
|
def __init__(
|
|
32
49
|
self,
|
|
@@ -51,6 +68,15 @@ class QueueMessage:
|
|
|
51
68
|
return op_priority[self.op] < op_priority[other.op]
|
|
52
69
|
|
|
53
70
|
|
|
71
|
+
def extract_first_to_last_brace(text: str):
|
|
72
|
+
start = text.find("{")
|
|
73
|
+
end = text.rfind("}")
|
|
74
|
+
if start == -1 or end == -1 or end < start:
|
|
75
|
+
return "", None
|
|
76
|
+
json_str = text[start : end + 1]
|
|
77
|
+
return json_str, json.loads(json_str)
|
|
78
|
+
|
|
79
|
+
|
|
54
80
|
class GraphStructureReorganizer:
|
|
55
81
|
def __init__(
|
|
56
82
|
self, graph_store: Neo4jGraphDB, llm: BaseLLM, embedder: OllamaEmbedder, is_reorganize: bool
|
|
@@ -87,6 +113,7 @@ class GraphStructureReorganizer:
|
|
|
87
113
|
1) queue is empty
|
|
88
114
|
2) any running structure optimization is done
|
|
89
115
|
"""
|
|
116
|
+
deadline = time.time() + 600
|
|
90
117
|
if not self.is_reorganize:
|
|
91
118
|
return
|
|
92
119
|
|
|
@@ -96,6 +123,9 @@ class GraphStructureReorganizer:
|
|
|
96
123
|
|
|
97
124
|
while any(self._is_optimizing.values()):
|
|
98
125
|
logger.debug(f"Waiting for structure optimizer to finish... {self._is_optimizing}")
|
|
126
|
+
if time.time() > deadline:
|
|
127
|
+
logger.error(f"Wait timed out; flags={self._is_optimizing}")
|
|
128
|
+
break
|
|
99
129
|
time.sleep(1)
|
|
100
130
|
logger.debug("Structure optimizer is now idle.")
|
|
101
131
|
|
|
@@ -129,6 +159,9 @@ class GraphStructureReorganizer:
|
|
|
129
159
|
|
|
130
160
|
logger.info("Structure optimizer schedule started.")
|
|
131
161
|
while not getattr(self, "_stop_scheduler", False):
|
|
162
|
+
if any(self._is_optimizing.values()):
|
|
163
|
+
time.sleep(1)
|
|
164
|
+
continue
|
|
132
165
|
if self._reorganize_needed:
|
|
133
166
|
logger.info("[Reorganizer] Triggering optimize_structure due to new nodes.")
|
|
134
167
|
self.optimize_structure(scope="LongTermMemory")
|
|
@@ -176,6 +209,7 @@ class GraphStructureReorganizer:
|
|
|
176
209
|
local_tree_threshold: int = 10,
|
|
177
210
|
min_cluster_size: int = 4,
|
|
178
211
|
min_group_size: int = 20,
|
|
212
|
+
max_duration_sec: int = 600,
|
|
179
213
|
):
|
|
180
214
|
"""
|
|
181
215
|
Periodically reorganize the graph:
|
|
@@ -183,8 +217,20 @@ class GraphStructureReorganizer:
|
|
|
183
217
|
2. Summarize each cluster.
|
|
184
218
|
3. Create parent nodes and build local PARENT trees.
|
|
185
219
|
"""
|
|
220
|
+
# --- Total time watch dog: check functions ---
|
|
221
|
+
start_ts = time.time()
|
|
222
|
+
|
|
223
|
+
def _check_deadline(where: str):
|
|
224
|
+
if time.time() - start_ts > max_duration_sec:
|
|
225
|
+
logger.error(
|
|
226
|
+
f"[GraphStructureReorganize] {scope} surpass {max_duration_sec}s,time "
|
|
227
|
+
f"over at {where}"
|
|
228
|
+
)
|
|
229
|
+
return True
|
|
230
|
+
return False
|
|
231
|
+
|
|
186
232
|
if self._is_optimizing[scope]:
|
|
187
|
-
logger.info(f"Already optimizing for {scope}. Skipping.")
|
|
233
|
+
logger.info(f"[GraphStructureReorganize] Already optimizing for {scope}. Skipping.")
|
|
188
234
|
return
|
|
189
235
|
|
|
190
236
|
if self.graph_store.node_not_exist(scope):
|
|
@@ -198,32 +244,35 @@ class GraphStructureReorganizer:
|
|
|
198
244
|
)
|
|
199
245
|
|
|
200
246
|
logger.debug(
|
|
201
|
-
f"Num of scope in self.graph_store is
|
|
247
|
+
f"[GraphStructureReorganize] Num of scope in self.graph_store is"
|
|
248
|
+
f" {self.graph_store.get_memory_count(scope)}"
|
|
202
249
|
)
|
|
203
250
|
# Load candidate nodes
|
|
251
|
+
if _check_deadline("[GraphStructureReorganize] Before loading candidates"):
|
|
252
|
+
return
|
|
204
253
|
raw_nodes = self.graph_store.get_structure_optimization_candidates(scope)
|
|
205
254
|
nodes = [GraphDBNode(**n) for n in raw_nodes]
|
|
206
255
|
|
|
207
256
|
if not nodes:
|
|
208
257
|
logger.info("[GraphStructureReorganize] No nodes to optimize. Skipping.")
|
|
209
258
|
return
|
|
210
|
-
|
|
211
259
|
if len(nodes) < min_group_size:
|
|
212
260
|
logger.info(
|
|
213
261
|
f"[GraphStructureReorganize] Only {len(nodes)} candidate nodes found. Not enough to reorganize. Skipping."
|
|
214
262
|
)
|
|
215
263
|
return
|
|
216
264
|
|
|
217
|
-
logger.info(f"[GraphStructureReorganize] Loaded {len(nodes)} nodes.")
|
|
218
|
-
|
|
219
265
|
# Step 2: Partition nodes
|
|
266
|
+
if _check_deadline("[GraphStructureReorganize] Before partition"):
|
|
267
|
+
return
|
|
220
268
|
partitioned_groups = self._partition(nodes)
|
|
221
|
-
|
|
222
269
|
logger.info(
|
|
223
270
|
f"[GraphStructureReorganize] Partitioned into {len(partitioned_groups)} clusters."
|
|
224
271
|
)
|
|
225
272
|
|
|
226
|
-
|
|
273
|
+
if _check_deadline("[GraphStructureReorganize] Before submit partition task"):
|
|
274
|
+
return
|
|
275
|
+
with ContextThreadPoolExecutor(max_workers=4) as executor:
|
|
227
276
|
futures = []
|
|
228
277
|
for cluster_nodes in partitioned_groups:
|
|
229
278
|
futures.append(
|
|
@@ -237,14 +286,17 @@ class GraphStructureReorganizer:
|
|
|
237
286
|
)
|
|
238
287
|
|
|
239
288
|
for f in as_completed(futures):
|
|
289
|
+
if _check_deadline("[GraphStructureReorganize] Waiting clusters..."):
|
|
290
|
+
for x in futures:
|
|
291
|
+
x.cancel()
|
|
292
|
+
return
|
|
240
293
|
try:
|
|
241
294
|
f.result()
|
|
242
295
|
except Exception as e:
|
|
243
296
|
logger.warning(
|
|
244
|
-
f"[
|
|
245
|
-
f"failed: {e}, cluster_nodes: {cluster_nodes}, trace: {traceback.format_exc()}"
|
|
297
|
+
f"[GraphStructureReorganize] Cluster processing failed: {e}, trace: {traceback.format_exc()}"
|
|
246
298
|
)
|
|
247
|
-
|
|
299
|
+
logger.info("[GraphStructure Reorganize] Structure optimization finished.")
|
|
248
300
|
|
|
249
301
|
finally:
|
|
250
302
|
self._is_optimizing[scope] = False
|
|
@@ -282,7 +334,7 @@ class GraphStructureReorganizer:
|
|
|
282
334
|
nodes_to_check = cluster_nodes
|
|
283
335
|
exclude_ids = [n.id for n in nodes_to_check]
|
|
284
336
|
|
|
285
|
-
with
|
|
337
|
+
with ContextThreadPoolExecutor(max_workers=4) as executor:
|
|
286
338
|
futures = []
|
|
287
339
|
for node in nodes_to_check:
|
|
288
340
|
futures.append(
|
|
@@ -294,7 +346,7 @@ class GraphStructureReorganizer:
|
|
|
294
346
|
)
|
|
295
347
|
)
|
|
296
348
|
|
|
297
|
-
for f in as_completed(futures):
|
|
349
|
+
for f in as_completed(futures, timeout=300):
|
|
298
350
|
results = f.result()
|
|
299
351
|
|
|
300
352
|
# 1) Add pairwise relations
|
|
@@ -331,11 +383,11 @@ class GraphStructureReorganizer:
|
|
|
331
383
|
for child_id in agg_node.metadata.sources:
|
|
332
384
|
self.graph_store.add_edge(agg_node.id, child_id, "AGGREGATE_TO")
|
|
333
385
|
|
|
334
|
-
|
|
386
|
+
logger.info("[Reorganizer] Cluster relation/reasoning done.")
|
|
335
387
|
|
|
336
388
|
def _local_subcluster(
|
|
337
|
-
self, cluster_nodes: list[GraphDBNode], max_length: int =
|
|
338
|
-
) ->
|
|
389
|
+
self, cluster_nodes: list[GraphDBNode], max_length: int = 15000
|
|
390
|
+
) -> list[list[GraphDBNode]]:
|
|
339
391
|
"""
|
|
340
392
|
Use LLM to split a large cluster into semantically coherent sub-clusters.
|
|
341
393
|
"""
|
|
@@ -350,7 +402,7 @@ class GraphStructureReorganizer:
|
|
|
350
402
|
|
|
351
403
|
joined_scene = "\n".join(scene_lines)
|
|
352
404
|
if len(joined_scene) > max_length:
|
|
353
|
-
logger.warning(
|
|
405
|
+
logger.warning("Sub-cluster too long")
|
|
354
406
|
prompt = LOCAL_SUBCLUSTER_PROMPT.replace("{joined_scene}", joined_scene[:max_length])
|
|
355
407
|
|
|
356
408
|
messages = [{"role": "user", "content": prompt}]
|
|
@@ -499,17 +551,17 @@ class GraphStructureReorganizer:
|
|
|
499
551
|
parent_node = GraphDBNode(
|
|
500
552
|
memory=parent_value,
|
|
501
553
|
metadata=TreeNodeTextualMemoryMetadata(
|
|
502
|
-
user_id=
|
|
503
|
-
session_id=
|
|
554
|
+
user_id=None,
|
|
555
|
+
session_id=None,
|
|
504
556
|
memory_type=scope,
|
|
505
557
|
status="activated",
|
|
506
558
|
key=parent_key,
|
|
507
559
|
tags=parent_tags,
|
|
508
560
|
embedding=embedding,
|
|
509
561
|
usage=[],
|
|
510
|
-
sources=
|
|
562
|
+
sources=build_summary_parent_node(cluster_nodes),
|
|
511
563
|
background=parent_background,
|
|
512
|
-
confidence=0.
|
|
564
|
+
confidence=0.66,
|
|
513
565
|
type="topic",
|
|
514
566
|
),
|
|
515
567
|
)
|
|
@@ -518,7 +570,7 @@ class GraphStructureReorganizer:
|
|
|
518
570
|
def _parse_json_result(self, response_text):
|
|
519
571
|
try:
|
|
520
572
|
response_text = response_text.replace("```", "").replace("json", "")
|
|
521
|
-
response_json =
|
|
573
|
+
response_json = extract_first_to_last_brace(response_text)[1]
|
|
522
574
|
return response_json
|
|
523
575
|
except json.JSONDecodeError as e:
|
|
524
576
|
logger.warning(
|
|
@@ -2,15 +2,17 @@
|
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
4
|
|
|
5
|
-
from concurrent.futures import
|
|
5
|
+
from concurrent.futures import as_completed
|
|
6
6
|
from datetime import datetime
|
|
7
|
+
from typing import Any
|
|
7
8
|
|
|
8
9
|
import requests
|
|
9
10
|
|
|
11
|
+
from memos.context.context import ContextThreadPoolExecutor
|
|
10
12
|
from memos.embedders.factory import OllamaEmbedder
|
|
11
13
|
from memos.log import get_logger
|
|
12
14
|
from memos.mem_reader.base import BaseMemReader
|
|
13
|
-
from memos.memories.textual.item import TextualMemoryItem
|
|
15
|
+
from memos.memories.textual.item import SourceMessage, TextualMemoryItem
|
|
14
16
|
|
|
15
17
|
|
|
16
18
|
logger = get_logger(__name__)
|
|
@@ -177,7 +179,7 @@ class BochaAISearchRetriever:
|
|
|
177
179
|
if not info:
|
|
178
180
|
info = {"user_id": "", "session_id": ""}
|
|
179
181
|
|
|
180
|
-
with
|
|
182
|
+
with ContextThreadPoolExecutor(max_workers=8) as executor:
|
|
181
183
|
futures = [
|
|
182
184
|
executor.submit(self._process_result, r, query, parsed_goal, info)
|
|
183
185
|
for r in search_results
|
|
@@ -193,7 +195,7 @@ class BochaAISearchRetriever:
|
|
|
193
195
|
return list(unique_memory_items.values())
|
|
194
196
|
|
|
195
197
|
def _process_result(
|
|
196
|
-
self, result: dict, query: str, parsed_goal: str, info:
|
|
198
|
+
self, result: dict, query: str, parsed_goal: str, info: dict[str, Any]
|
|
197
199
|
) -> list[TextualMemoryItem]:
|
|
198
200
|
"""Process one Bocha search result into TextualMemoryItem."""
|
|
199
201
|
title = result.get("name", "")
|
|
@@ -218,12 +220,14 @@ class BochaAISearchRetriever:
|
|
|
218
220
|
memory_items = []
|
|
219
221
|
for read_item_i in read_items[0]:
|
|
220
222
|
read_item_i.memory = (
|
|
221
|
-
f"Title: {title}\nNewsTime:
|
|
223
|
+
f"[Outer internet view] Title: {title}\nNewsTime:"
|
|
224
|
+
f" {publish_time}\nSummary:"
|
|
225
|
+
f" {summary}\n"
|
|
222
226
|
f"Content: {read_item_i.memory}"
|
|
223
227
|
)
|
|
224
228
|
read_item_i.metadata.source = "web"
|
|
225
229
|
read_item_i.metadata.memory_type = "OuterMemory"
|
|
226
|
-
read_item_i.metadata.sources = [url] if url else []
|
|
230
|
+
read_item_i.metadata.sources = [SourceMessage(type="web", url=url)] if url else []
|
|
227
231
|
read_item_i.metadata.visibility = "public"
|
|
228
232
|
memory_items.append(read_item_i)
|
|
229
233
|
return memory_items
|
|
@@ -7,7 +7,11 @@ from datetime import datetime
|
|
|
7
7
|
import requests
|
|
8
8
|
|
|
9
9
|
from memos.embedders.factory import OllamaEmbedder
|
|
10
|
-
from memos.memories.textual.item import
|
|
10
|
+
from memos.memories.textual.item import (
|
|
11
|
+
SourceMessage,
|
|
12
|
+
TextualMemoryItem,
|
|
13
|
+
TreeNodeTextualMemoryMetadata,
|
|
14
|
+
)
|
|
11
15
|
|
|
12
16
|
|
|
13
17
|
class GoogleCustomSearchAPI:
|
|
@@ -172,7 +176,7 @@ class InternetGoogleRetriever:
|
|
|
172
176
|
visibility="public",
|
|
173
177
|
memory_type="LongTermMemory", # Internet search results as working memory
|
|
174
178
|
key=title,
|
|
175
|
-
sources=[link] if link else [],
|
|
179
|
+
sources=[SourceMessage(type="web", url=link)] if link else [],
|
|
176
180
|
embedding=self.embedder.embed([memory_content])[0], # Can add embedding later
|
|
177
181
|
created_at=datetime.now().isoformat(),
|
|
178
182
|
usage=[],
|
|
@@ -10,6 +10,7 @@ from memos.memories.textual.tree_text_memory.retrieve.internet_retriever import
|
|
|
10
10
|
InternetGoogleRetriever,
|
|
11
11
|
)
|
|
12
12
|
from memos.memories.textual.tree_text_memory.retrieve.xinyusearch import XinyuSearchRetriever
|
|
13
|
+
from memos.memos_tools.singleton import singleton_factory
|
|
13
14
|
|
|
14
15
|
|
|
15
16
|
class InternetRetrieverFactory:
|
|
@@ -23,6 +24,7 @@ class InternetRetrieverFactory:
|
|
|
23
24
|
}
|
|
24
25
|
|
|
25
26
|
@classmethod
|
|
27
|
+
@singleton_factory()
|
|
26
28
|
def from_config(
|
|
27
29
|
cls, config_factory: InternetRetrieverConfigFactory, embedder: BaseEmbedder
|
|
28
30
|
) -> InternetGoogleRetriever | None:
|
|
@@ -1,11 +1,16 @@
|
|
|
1
1
|
import concurrent.futures
|
|
2
2
|
|
|
3
|
+
from memos.context.context import ContextThreadPoolExecutor
|
|
3
4
|
from memos.embedders.factory import OllamaEmbedder
|
|
4
5
|
from memos.graph_dbs.neo4j import Neo4jGraphDB
|
|
6
|
+
from memos.log import get_logger
|
|
5
7
|
from memos.memories.textual.item import TextualMemoryItem
|
|
6
8
|
from memos.memories.textual.tree_text_memory.retrieve.retrieval_mid_structs import ParsedTaskGoal
|
|
7
9
|
|
|
8
10
|
|
|
11
|
+
logger = get_logger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
9
14
|
class GraphMemoryRetriever:
|
|
10
15
|
"""
|
|
11
16
|
Unified memory retriever that combines both graph-based and vector-based retrieval logic.
|
|
@@ -14,6 +19,8 @@ class GraphMemoryRetriever:
|
|
|
14
19
|
def __init__(self, graph_store: Neo4jGraphDB, embedder: OllamaEmbedder):
|
|
15
20
|
self.graph_store = graph_store
|
|
16
21
|
self.embedder = embedder
|
|
22
|
+
self.max_workers = 10
|
|
23
|
+
self.filter_weight = 0.6
|
|
17
24
|
|
|
18
25
|
def retrieve(
|
|
19
26
|
self,
|
|
@@ -22,6 +29,7 @@ class GraphMemoryRetriever:
|
|
|
22
29
|
top_k: int,
|
|
23
30
|
memory_scope: str,
|
|
24
31
|
query_embedding: list[list[float]] | None = None,
|
|
32
|
+
search_filter: dict | None = None,
|
|
25
33
|
) -> list[TextualMemoryItem]:
|
|
26
34
|
"""
|
|
27
35
|
Perform hybrid memory retrieval:
|
|
@@ -35,7 +43,7 @@ class GraphMemoryRetriever:
|
|
|
35
43
|
top_k (int): Number of candidates to return.
|
|
36
44
|
memory_scope (str): One of ['working', 'long_term', 'user'].
|
|
37
45
|
query_embedding(list of embedding): list of embedding of query
|
|
38
|
-
|
|
46
|
+
search_filter (dict, optional): Optional metadata filters for search results.
|
|
39
47
|
Returns:
|
|
40
48
|
list: Combined memory items.
|
|
41
49
|
"""
|
|
@@ -45,16 +53,20 @@ class GraphMemoryRetriever:
|
|
|
45
53
|
if memory_scope == "WorkingMemory":
|
|
46
54
|
# For working memory, retrieve all entries (no filtering)
|
|
47
55
|
working_memories = self.graph_store.get_all_memory_items(
|
|
48
|
-
scope="WorkingMemory", include_embedding=
|
|
56
|
+
scope="WorkingMemory", include_embedding=False
|
|
49
57
|
)
|
|
50
58
|
return [TextualMemoryItem.from_dict(record) for record in working_memories]
|
|
51
59
|
|
|
52
|
-
with
|
|
60
|
+
with ContextThreadPoolExecutor(max_workers=2) as executor:
|
|
53
61
|
# Structured graph-based retrieval
|
|
54
62
|
future_graph = executor.submit(self._graph_recall, parsed_goal, memory_scope)
|
|
55
63
|
# Vector similarity search
|
|
56
64
|
future_vector = executor.submit(
|
|
57
|
-
self._vector_recall,
|
|
65
|
+
self._vector_recall,
|
|
66
|
+
query_embedding or [],
|
|
67
|
+
memory_scope,
|
|
68
|
+
top_k,
|
|
69
|
+
search_filter=search_filter,
|
|
58
70
|
)
|
|
59
71
|
|
|
60
72
|
graph_results = future_graph.result()
|
|
@@ -74,6 +86,51 @@ class GraphMemoryRetriever:
|
|
|
74
86
|
|
|
75
87
|
return list(combined.values())
|
|
76
88
|
|
|
89
|
+
def retrieve_from_cube(
|
|
90
|
+
self,
|
|
91
|
+
top_k: int,
|
|
92
|
+
memory_scope: str,
|
|
93
|
+
query_embedding: list[list[float]] | None = None,
|
|
94
|
+
cube_name: str = "memos_cube01",
|
|
95
|
+
) -> list[TextualMemoryItem]:
|
|
96
|
+
"""
|
|
97
|
+
Perform hybrid memory retrieval:
|
|
98
|
+
- Run graph-based lookup from dispatch plan.
|
|
99
|
+
- Run vector similarity search from embedded query.
|
|
100
|
+
- Merge and return combined result set.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
top_k (int): Number of candidates to return.
|
|
104
|
+
memory_scope (str): One of ['working', 'long_term', 'user'].
|
|
105
|
+
query_embedding(list of embedding): list of embedding of query
|
|
106
|
+
cube_name: specify cube_name
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
list: Combined memory items.
|
|
110
|
+
"""
|
|
111
|
+
if memory_scope not in ["WorkingMemory", "LongTermMemory", "UserMemory"]:
|
|
112
|
+
raise ValueError(f"Unsupported memory scope: {memory_scope}")
|
|
113
|
+
|
|
114
|
+
graph_results = self._vector_recall(
|
|
115
|
+
query_embedding, memory_scope, top_k, cube_name=cube_name
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
for result_i in graph_results:
|
|
119
|
+
result_i.metadata.memory_type = "OuterMemory"
|
|
120
|
+
# Merge and deduplicate by ID
|
|
121
|
+
combined = {item.id: item for item in graph_results}
|
|
122
|
+
|
|
123
|
+
graph_ids = {item.id for item in graph_results}
|
|
124
|
+
combined_ids = set(combined.keys())
|
|
125
|
+
lost_ids = graph_ids - combined_ids
|
|
126
|
+
|
|
127
|
+
if lost_ids:
|
|
128
|
+
print(
|
|
129
|
+
f"[DEBUG] The following nodes were in graph_results but missing in combined: {lost_ids}"
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
return list(combined.values())
|
|
133
|
+
|
|
77
134
|
def _graph_recall(
|
|
78
135
|
self, parsed_goal: ParsedTaskGoal, memory_scope: str
|
|
79
136
|
) -> list[TextualMemoryItem]:
|
|
@@ -108,7 +165,7 @@ class GraphMemoryRetriever:
|
|
|
108
165
|
return []
|
|
109
166
|
|
|
110
167
|
# Load nodes and post-filter
|
|
111
|
-
node_dicts = self.graph_store.get_nodes(list(candidate_ids), include_embedding=
|
|
168
|
+
node_dicts = self.graph_store.get_nodes(list(candidate_ids), include_embedding=False)
|
|
112
169
|
|
|
113
170
|
final_nodes = []
|
|
114
171
|
for node in node_dicts:
|
|
@@ -134,31 +191,72 @@ class GraphMemoryRetriever:
|
|
|
134
191
|
query_embedding: list[list[float]],
|
|
135
192
|
memory_scope: str,
|
|
136
193
|
top_k: int = 20,
|
|
137
|
-
max_num: int =
|
|
194
|
+
max_num: int = 3,
|
|
195
|
+
cube_name: str | None = None,
|
|
196
|
+
search_filter: dict | None = None,
|
|
138
197
|
) -> list[TextualMemoryItem]:
|
|
139
198
|
"""
|
|
140
|
-
# TODO: tackle with post-filter and pre-filter(5.18+) better.
|
|
141
199
|
Perform vector-based similarity retrieval using query embedding.
|
|
200
|
+
# TODO: tackle with post-filter and pre-filter(5.18+) better.
|
|
142
201
|
"""
|
|
143
|
-
|
|
202
|
+
if not query_embedding:
|
|
203
|
+
return []
|
|
144
204
|
|
|
145
|
-
def search_single(vec):
|
|
205
|
+
def search_single(vec, filt=None):
|
|
146
206
|
return (
|
|
147
|
-
self.graph_store.search_by_embedding(
|
|
207
|
+
self.graph_store.search_by_embedding(
|
|
208
|
+
vector=vec,
|
|
209
|
+
top_k=top_k,
|
|
210
|
+
scope=memory_scope,
|
|
211
|
+
cube_name=cube_name,
|
|
212
|
+
search_filter=filt,
|
|
213
|
+
)
|
|
148
214
|
or []
|
|
149
215
|
)
|
|
150
216
|
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
217
|
+
def search_path_a():
|
|
218
|
+
"""Path A: search without filter"""
|
|
219
|
+
path_a_hits = []
|
|
220
|
+
with ContextThreadPoolExecutor() as executor:
|
|
221
|
+
futures = [
|
|
222
|
+
executor.submit(search_single, vec, None) for vec in query_embedding[:max_num]
|
|
223
|
+
]
|
|
224
|
+
for f in concurrent.futures.as_completed(futures):
|
|
225
|
+
path_a_hits.extend(f.result() or [])
|
|
226
|
+
return path_a_hits
|
|
156
227
|
|
|
157
|
-
|
|
158
|
-
|
|
228
|
+
def search_path_b():
|
|
229
|
+
"""Path B: search with filter"""
|
|
230
|
+
if not search_filter:
|
|
231
|
+
return []
|
|
232
|
+
path_b_hits = []
|
|
233
|
+
with ContextThreadPoolExecutor() as executor:
|
|
234
|
+
futures = [
|
|
235
|
+
executor.submit(search_single, vec, search_filter)
|
|
236
|
+
for vec in query_embedding[:max_num]
|
|
237
|
+
]
|
|
238
|
+
for f in concurrent.futures.as_completed(futures):
|
|
239
|
+
path_b_hits.extend(f.result() or [])
|
|
240
|
+
return path_b_hits
|
|
241
|
+
|
|
242
|
+
# Execute both paths concurrently
|
|
243
|
+
all_hits = []
|
|
244
|
+
with ContextThreadPoolExecutor(max_workers=2) as executor:
|
|
245
|
+
path_a_future = executor.submit(search_path_a)
|
|
246
|
+
path_b_future = executor.submit(search_path_b)
|
|
159
247
|
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
node_dicts = self.graph_store.get_nodes(list(unique_ids), include_embedding=True)
|
|
248
|
+
all_hits.extend(path_a_future.result())
|
|
249
|
+
all_hits.extend(path_b_future.result())
|
|
163
250
|
|
|
164
|
-
|
|
251
|
+
if not all_hits:
|
|
252
|
+
return []
|
|
253
|
+
|
|
254
|
+
# merge and deduplicate
|
|
255
|
+
unique_ids = {r["id"] for r in all_hits if r.get("id")}
|
|
256
|
+
node_dicts = (
|
|
257
|
+
self.graph_store.get_nodes(
|
|
258
|
+
list(unique_ids), include_embedding=False, cube_name=cube_name
|
|
259
|
+
)
|
|
260
|
+
or []
|
|
261
|
+
)
|
|
262
|
+
return [TextualMemoryItem.from_dict(n) for n in node_dicts]
|